| |
| num_workers :int = 1 |
| DDIM_STEPS = 50 |
| BATCH_SIZE = 1 |
| FIXED_CODE = False |
| |
| SAVE_INTERMEDIATES = True |
| NUM_grid_in_a_column = 5 |
| |
| import sys |
| import os |
| from pathlib import Path |
|
|
| cur_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
| from confs import * |
| import torch |
| import numpy as np |
| from omegaconf import OmegaConf |
| from PIL import Image |
| from tqdm import tqdm |
| from einops import rearrange |
| from torchvision.utils import make_grid |
| from my_py_lib.image_util import imgs_2_grid_A,img_paths_2_grid_A |
| from pytorch_lightning import seed_everything |
| from torch import autocast |
| from contextlib import nullcontext |
| import torchvision |
|
|
| from ldm.models.diffusion.ddpm import LatentDiffusion |
| from ldm.util import instantiate_from_config |
| from ldm.models.diffusion.ddim import DDIMSampler |
| from Dataset_custom import Dataset_custom |
| from MoE import offload_unused_tasks__LD |
| from ldm.models.diffusion.ddpm import LandmarkExtractor |
| from my_py_lib.torch_util import cleanup_gpu_memory |
| from gen_lmk_and_mask import gen_lmk_and_mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| |
| DDIM_ETA = 0.0 |
| SCALE = 3.0 |
| PRECISION = "full" |
| H = 512 |
| W = 512 |
| C = 4 |
| F = 8 |
| |
|
|
|
|
| def load_first_stage_from_sd14(model: LatentDiffusion, sd14_path: Path) -> None: |
| print(f"Loading first_stage_model from {sd14_path}") |
| sd14 = torch.load(str(sd14_path), map_location="cpu") |
| if isinstance(sd14, dict) and "state_dict" in sd14: |
| sd14_sd = sd14["state_dict"] |
| else: |
| sd14_sd = sd14 |
|
|
| prefixes = ["first_stage_model.", "model.first_stage_model."] |
| fs_sd = {} |
| for prefix in prefixes: |
| for k, v in sd14_sd.items(): |
| if k.startswith(prefix): |
| fs_sd[k[len(prefix):]] = v |
| if fs_sd: |
| break |
|
|
| if not fs_sd: |
| raise RuntimeError("Could not find first_stage_model weights in SD v1-4 checkpoint.") |
|
|
| model.first_stage_model.load_state_dict(fs_sd, strict=True) |
|
|
|
|
| def save_sample_by_decode(x, model, base_path, segment_id, intermediate_num): |
| x = model.decode_first_stage(x) |
| x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0) |
| x = x.cpu().permute(0, 2, 3, 1).numpy() |
| for i in range(len(x)): |
| img = Image.fromarray((x[i] * 255).astype(np.uint8)) |
| save_path = Path(base_path) / segment_id |
| save_path.mkdir(parents=True, exist_ok=True) |
| img.save(save_path / f"{intermediate_num}.png") |
|
|
|
|
| def get_tensor_clip(normalize=True, toTensor=True): |
| transform_list = [] |
| if toTensor: |
| transform_list += [torchvision.transforms.ToTensor()] |
| if normalize: |
| transform_list += [ |
| torchvision.transforms.Normalize( |
| (0.48145466, 0.4578275, 0.40821073), |
| (0.26862954, 0.26130258, 0.27577711), |
| ) |
| ] |
| return torchvision.transforms.Compose(transform_list) |
|
|
|
|
| def load_model_from_config(ckpt, verbose=1): |
| if 1: |
| ckpt = Path(ckpt) |
| print(f"Loading model from {ckpt}") |
| pl_sd = torch.load(str(ckpt), map_location="cpu") |
| if isinstance(pl_sd, dict) and "state_dict" in pl_sd: |
| sd = pl_sd["state_dict"] |
| else: |
| sd = pl_sd |
| else: |
| print("DEBUG_skip_load_ckpt") |
| if 1: |
| from init_model import get_moe |
| model: LatentDiffusion = get_moe() |
| model.ptsM_Generator = LandmarkExtractor(include_visualizer=True, img_256_mode=False) |
| cleanup_gpu_memory() |
| if 1: |
| m, u = model.load_state_dict(sd, strict=False) |
| if len(m) > 0 and verbose: |
| print("missing keys:") |
| pretty_print_torch_module_keys(m) |
| if len(u) > 0 and verbose: |
| print("unexpected keys:") |
| pretty_print_torch_module_keys(u) |
| load_first_stage_from_sd14(model, SD14_localpath) |
|
|
| offload_unused_tasks__LD(model, TASK, method="del") |
| model.cuda() |
| model.eval() |
| return model |
|
|
|
|
|
|
|
|
| def load_pairs(pair_list, tgt, ref): |
| if tgt and ref: |
| pairs = [(tgt, ref), ] |
| elif pair_list: |
| pairs = [] |
| with open(pair_list, "r") as f: |
| for line_num, line in enumerate(f, start=1): |
| line = line.strip() |
| if not line or line.startswith("#"): |
| continue |
| parts = line.split(" ") |
| if len(parts) != 2: |
| raise ValueError(f"Invalid pair list line {line_num}: expected white-space-separated tgt/ref. got {parts=}") |
| pairs.append((parts[0], parts[1])) |
| else: |
| raise ValueError("No input pairs provided. Use --tgt/--ref or --pair-list.") |
| print(f"{pairs=}") |
| return pairs |
|
|
|
|
| def un_norm(x): |
| return (x + 1.0) / 2.0 |
|
|
|
|
| def un_norm_clip(x1): |
| x = x1 * 1.0 |
| reduce = False |
| if len(x.shape) == 3: |
| x = x.unsqueeze(0) |
| reduce = True |
| x[:, 0, :, :] = x[:, 0, :, :] * 0.26862954 + 0.48145466 |
| x[:, 1, :, :] = x[:, 1, :, :] * 0.26130258 + 0.4578275 |
| x[:, 2, :, :] = x[:, 2, :, :] * 0.27577711 + 0.40821073 |
| if reduce: |
| x = x.squeeze(0) |
| return x |
|
|
|
|
| if __name__ == "__main__": |
| pairs = load_pairs(args.pair_list, args.tgt, args.ref) |
|
|
| out_dir = Path(args.out_dir) |
| result_path = out_dir / "results" |
| grid_path = out_dir / "grid" |
| inter_path = out_dir / "intermediates" |
| inter_pred_path = inter_path / "pred_x0" |
| inter_noised_path = inter_path / "noised" |
| out_dir.mkdir(parents=False, exist_ok=True) |
| result_path.mkdir(parents=False, exist_ok=True) |
| grid_path.mkdir(parents=False, exist_ok=True) |
| inter_path.mkdir(parents=False, exist_ok=True) |
| if SAVE_INTERMEDIATES: |
| inter_pred_path.mkdir(parents=False, exist_ok=True) |
| inter_noised_path.mkdir(parents=False, exist_ok=True) |
| paths_tgt = [p[0] for p in pairs] |
| paths_ref = [p[1] for p in pairs] |
| gen_lmk_and_mask(paths_tgt + paths_ref) |
|
|
| seed_everything(42) |
|
|
| model: LatentDiffusion = load_model_from_config(PRETRAIN_CKPT_PATH, ) |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| model = model.to(device) |
| sampler = DDIMSampler(model) |
|
|
| dataset = Dataset_custom( |
| "test", |
| task=TASK, |
| paths_tgt=paths_tgt, |
| paths_ref=paths_ref, |
| ) |
| dataloader = torch.utils.data.DataLoader( |
| dataset, |
| batch_size=BATCH_SIZE, |
| num_workers=num_workers, |
| pin_memory=True, |
| shuffle=False, |
| drop_last=False, |
| ) |
|
|
| start_code = None |
| if FIXED_CODE: |
| start_code = torch.randn([BATCH_SIZE, C, H // F, W // F], device=device) |
|
|
| precision_scope = autocast if PRECISION == "autocast" else nullcontext |
| grids = [] |
| grid_stems = [] |
|
|
| with torch.no_grad(): |
| with precision_scope("cuda"): |
| with model.ema_scope(): |
| for test_batch, prior, test_model_kwargs, out_stem_batch in tqdm(dataloader): |
| model.set_task(test_model_kwargs) |
| bs = test_batch.shape[0] |
|
|
| batch_ = { |
| **test_model_kwargs, |
| "GT": torch.zeros_like(test_model_kwargs["inpaint_image"]), |
| } |
| batch_, c = model.get_input_and_conditioning(batch_, device=device) |
| z_inpaint = batch_["z4_inpaint"] |
| z_inpaint_mask = batch_["tgt_mask_64"] |
| z_ref = batch_["z_ref"] |
| z9 = batch_["z9"] |
|
|
| uc = None |
| if SCALE != 1.0: |
| uc = model.learnable_vector[TASK].repeat(bs, 1, 1) |
|
|
| shape = [C, H // F, W // F] |
| local_start_code = start_code |
| if FIXED_CODE and (local_start_code is None or local_start_code.shape[0] != bs): |
| local_start_code = torch.randn([bs, C, H // F, W // F], device=device) |
| samples_ddim, intermediates = sampler.sample( |
| S=DDIM_STEPS, |
| conditioning=c, |
| batch_size=bs, |
| shape=shape, |
| verbose=False, |
| unconditional_guidance_scale=SCALE, |
| unconditional_conditioning=uc, |
| eta=DDIM_ETA, |
| x_T=local_start_code, |
| log_every_t=100, |
| z_inpaint=z_inpaint, |
| z_inpaint_mask=z_inpaint_mask, |
| z_ref=z_ref, |
| z9=z9, |
| ) |
|
|
| if SAVE_INTERMEDIATES: |
| intermediate_pred_x0 = intermediates["pred_x0"] |
| intermediate_noised = intermediates["x_inter"] |
| for i in range(len(intermediate_pred_x0)): |
| for j in range(bs): |
| stem = f"{out_stem_batch[j]}" |
| save_sample_by_decode( |
| intermediate_pred_x0[i][j : j + 1], |
| model, |
| inter_pred_path, |
| stem, |
| i, |
| ) |
| save_sample_by_decode( |
| intermediate_noised[i][j : j + 1], |
| model, |
| inter_noised_path, |
| stem, |
| i, |
| ) |
|
|
| x_samples_ddim = model.decode_first_stage(samples_ddim) |
| x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) |
| x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() |
|
|
| x_checked_image_torch = torch.from_numpy(x_samples_ddim).permute(0, 3, 1, 2) |
| for i, x_sample in enumerate(x_checked_image_torch): |
| stem = f"{out_stem_batch[i]}" |
| out_path = result_path / f"{stem}.png" |
| img = Image.fromarray((x_sample.permute(1, 2, 0).numpy() * 255).astype(np.uint8)) |
| img.save(out_path) |
| print(f"{out_path=}") |
|
|
| for i, x_sample in enumerate(x_checked_image_torch): |
| all_img = [] |
| all_img.append(un_norm(test_batch[i]).cpu()) |
| if TASK != 2: |
| ref_img = test_model_kwargs["ref_imgs"].squeeze(1) |
| ref_img = torchvision.transforms.Resize([512, 512])(ref_img) |
| ref_img = un_norm_clip(ref_img[i]).cpu() |
| else: |
| ref_img = un_norm(test_model_kwargs["ref512"].squeeze(1)[i]).cpu() |
| all_img.append(ref_img) |
| all_img.append(x_sample) |
|
|
| grid = torch.stack(all_img, 0) |
| grid = make_grid(grid) |
| grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy() |
| img = Image.fromarray(grid.astype(np.uint8)) |
| stem = f"{out_stem_batch[i]}" |
| path_save_img = grid_path / f"grid-{stem}.jpg" |
| img.save(path_save_img) |
| print(f"{path_save_img=}") |
| grids.append(img) |
| grid_stems.append(stem) |
| if len(grids) >= NUM_grid_in_a_column: |
| stem_start = grid_stems[0] |
| stem_end = grid_stems[-1] |
| grid_column = imgs_2_grid_A( |
| grids, |
| grid_layout='column', |
| grid_path=os.path.join(grid_path, f"{stem_start}--{stem_end}.jpg"), |
| ) |
| grids = [] |
| grid_stems = [] |
|
|
| model.unset_task() |
|
|
| print(f"Your samples are ready and waiting for you here: {out_dir}") |
|
|
|
|
|
|