Spaces:
Runtime error
Runtime error
| import os | |
| import cv2 | |
| import torch | |
| from basicsr.utils import tensor2img | |
| from pytorch_lightning import seed_everything | |
| from torch import autocast | |
| from ldm.inference_base import (diffusion_inference, get_adapters, get_base_argument_parser, get_sd_models) | |
| from ldm.modules.extra_condition import api | |
| from ldm.modules.extra_condition.api import (ExtraCondition, get_adapter_feature, get_cond_model) | |
| torch.set_grad_enabled(False) | |
| def main(): | |
| supported_cond = [e.name for e in ExtraCondition] | |
| parser = get_base_argument_parser() | |
| parser.add_argument( | |
| '--which_cond', | |
| type=str, | |
| required=True, | |
| choices=supported_cond, | |
| help='which condition modality you want to test', | |
| ) | |
| opt = parser.parse_args() | |
| which_cond = opt.which_cond | |
| if opt.outdir is None: | |
| opt.outdir = f'outputs/test-{which_cond}' | |
| os.makedirs(opt.outdir, exist_ok=True) | |
| if opt.resize_short_edge is None: | |
| print(f"you don't specify the resize_shot_edge, so the maximum resolution is set to {opt.max_resolution}") | |
| opt.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| # support two test mode: single image test, and batch test (through a txt file) | |
| if opt.prompt.endswith('.txt'): | |
| assert opt.prompt.endswith('.txt') | |
| image_paths = [] | |
| prompts = [] | |
| with open(opt.prompt, 'r') as f: | |
| lines = f.readlines() | |
| for line in lines: | |
| line = line.strip() | |
| image_paths.append(line.split('; ')[0]) | |
| prompts.append(line.split('; ')[1]) | |
| else: | |
| image_paths = [opt.cond_path] | |
| prompts = [opt.prompt] | |
| print(image_paths) | |
| # prepare models | |
| sd_model, sampler = get_sd_models(opt) | |
| adapter = get_adapters(opt, getattr(ExtraCondition, which_cond)) | |
| cond_model = None | |
| if opt.cond_inp_type == 'image': | |
| cond_model = get_cond_model(opt, getattr(ExtraCondition, which_cond)) | |
| process_cond_module = getattr(api, f'get_cond_{which_cond}') | |
| # inference | |
| with torch.inference_mode(), \ | |
| sd_model.ema_scope(), \ | |
| autocast('cuda'): | |
| for test_idx, (cond_path, prompt) in enumerate(zip(image_paths, prompts)): | |
| seed_everything(opt.seed) | |
| for v_idx in range(opt.n_samples): | |
| # seed_everything(opt.seed+v_idx+test_idx) | |
| cond = process_cond_module(opt, cond_path, opt.cond_inp_type, cond_model) | |
| base_count = len(os.listdir(opt.outdir)) // 2 | |
| cv2.imwrite(os.path.join(opt.outdir, f'{base_count:05}_{which_cond}.png'), tensor2img(cond)) | |
| adapter_features, append_to_context = get_adapter_feature(cond, adapter) | |
| opt.prompt = prompt | |
| result = diffusion_inference(opt, sd_model, sampler, adapter_features, append_to_context) | |
| cv2.imwrite(os.path.join(opt.outdir, f'{base_count:05}_result.png'), tensor2img(result)) | |
| if __name__ == '__main__': | |
| main() | |