UniBioTransfer / infer.py
scy639's picture
Upload folder using huggingface_hub
2b534de verified
# --------------------------------------------------------- Config -------------------------------------------------
num_workers :int = 1
DDIM_STEPS = 50
BATCH_SIZE = 1
FIXED_CODE = False
# for vis
SAVE_INTERMEDIATES = True
NUM_grid_in_a_column = 5
# ------------------------------------------------------------------------------------------------------------------------
import sys
import os
from pathlib import Path
cur_dir = os.path.dirname(os.path.abspath(__file__))
from confs import *
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm
from einops import rearrange
from torchvision.utils import make_grid
from my_py_lib.image_util import imgs_2_grid_A,img_paths_2_grid_A
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import nullcontext
import torchvision
from ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from Dataset_custom import Dataset_custom
from MoE import offload_unused_tasks__LD
from ldm.models.diffusion.ddpm import LandmarkExtractor
from my_py_lib.torch_util import cleanup_gpu_memory
from gen_lmk_and_mask import gen_lmk_and_mask
# ------------------------------------------------------------------------------------------------------------------------
DDIM_ETA = 0.0
SCALE = 3.0
PRECISION = "full" # "full" or "autocast"
H = 512
W = 512
C = 4
F = 8
# ------------------------------------------------------------------------------------------------------------------------
def load_first_stage_from_sd14(model: LatentDiffusion, sd14_path: Path) -> None:
print(f"Loading first_stage_model from {sd14_path}")
sd14 = torch.load(str(sd14_path), map_location="cpu")
if isinstance(sd14, dict) and "state_dict" in sd14:
sd14_sd = sd14["state_dict"]
else:
sd14_sd = sd14
prefixes = ["first_stage_model.", "model.first_stage_model."]
fs_sd = {}
for prefix in prefixes:
for k, v in sd14_sd.items():
if k.startswith(prefix):
fs_sd[k[len(prefix):]] = v
if fs_sd:
break
if not fs_sd:
raise RuntimeError("Could not find first_stage_model weights in SD v1-4 checkpoint.")
model.first_stage_model.load_state_dict(fs_sd, strict=True)
def save_sample_by_decode(x, model, base_path, segment_id, intermediate_num):
x = model.decode_first_stage(x)
x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0)
x = x.cpu().permute(0, 2, 3, 1).numpy()
for i in range(len(x)):
img = Image.fromarray((x[i] * 255).astype(np.uint8))
save_path = Path(base_path) / segment_id
save_path.mkdir(parents=True, exist_ok=True)
img.save(save_path / f"{intermediate_num}.png")
def get_tensor_clip(normalize=True, toTensor=True):
transform_list = []
if toTensor:
transform_list += [torchvision.transforms.ToTensor()]
if normalize:
transform_list += [
torchvision.transforms.Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
)
]
return torchvision.transforms.Compose(transform_list)
def load_model_from_config(ckpt, verbose=1):
if 1:
ckpt = Path(ckpt)
print(f"Loading model from {ckpt}")
pl_sd = torch.load(str(ckpt), map_location="cpu")
if isinstance(pl_sd, dict) and "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
sd = pl_sd
else:
print("DEBUG_skip_load_ckpt")
if 1:
from init_model import get_moe
model: LatentDiffusion = get_moe()
model.ptsM_Generator = LandmarkExtractor(include_visualizer=True, img_256_mode=False)
cleanup_gpu_memory()
if 1:
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
pretty_print_torch_module_keys(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
pretty_print_torch_module_keys(u)
load_first_stage_from_sd14(model, SD14_localpath)
offload_unused_tasks__LD(model, TASK, method="del") # for save cuda mem
model.cuda()
model.eval()
return model
def load_pairs(pair_list, tgt, ref):
if tgt and ref:
pairs = [(tgt, ref), ]
elif pair_list:
pairs = []
with open(pair_list, "r") as f:
for line_num, line in enumerate(f, start=1):
line = line.strip()
if not line or line.startswith("#"):
continue
parts = line.split(" ")
if len(parts) != 2:
raise ValueError(f"Invalid pair list line {line_num}: expected white-space-separated tgt/ref. got {parts=}")
pairs.append((parts[0], parts[1]))
else:
raise ValueError("No input pairs provided. Use --tgt/--ref or --pair-list.")
print(f"{pairs=}")
return pairs
def un_norm(x):
return (x + 1.0) / 2.0
def un_norm_clip(x1):
x = x1 * 1.0
reduce = False
if len(x.shape) == 3:
x = x.unsqueeze(0)
reduce = True
x[:, 0, :, :] = x[:, 0, :, :] * 0.26862954 + 0.48145466
x[:, 1, :, :] = x[:, 1, :, :] * 0.26130258 + 0.4578275
x[:, 2, :, :] = x[:, 2, :, :] * 0.27577711 + 0.40821073
if reduce:
x = x.squeeze(0)
return x
if __name__ == "__main__":
pairs = load_pairs(args.pair_list, args.tgt, args.ref)
out_dir = Path(args.out_dir)
result_path = out_dir / "results"
grid_path = out_dir / "grid"
inter_path = out_dir / "intermediates"
inter_pred_path = inter_path / "pred_x0"
inter_noised_path = inter_path / "noised"
out_dir.mkdir(parents=False, exist_ok=True)
result_path.mkdir(parents=False, exist_ok=True)
grid_path.mkdir(parents=False, exist_ok=True)
inter_path.mkdir(parents=False, exist_ok=True)
if SAVE_INTERMEDIATES:
inter_pred_path.mkdir(parents=False, exist_ok=True)
inter_noised_path.mkdir(parents=False, exist_ok=True)
paths_tgt = [p[0] for p in pairs]
paths_ref = [p[1] for p in pairs]
gen_lmk_and_mask(paths_tgt + paths_ref)
seed_everything(42)
model: LatentDiffusion = load_model_from_config(PRETRAIN_CKPT_PATH, )
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
sampler = DDIMSampler(model)
dataset = Dataset_custom(
"test",
task=TASK,
paths_tgt=paths_tgt,
paths_ref=paths_ref,
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=BATCH_SIZE,
num_workers=num_workers,
pin_memory=True,
shuffle=False,
drop_last=False,
)
start_code = None
if FIXED_CODE:
start_code = torch.randn([BATCH_SIZE, C, H // F, W // F], device=device)
precision_scope = autocast if PRECISION == "autocast" else nullcontext
grids = []
grid_stems = []
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
for test_batch, prior, test_model_kwargs, out_stem_batch in tqdm(dataloader):
model.set_task(test_model_kwargs)
bs = test_batch.shape[0]
batch_ = {
**test_model_kwargs,
"GT": torch.zeros_like(test_model_kwargs["inpaint_image"]),
}
batch_, c = model.get_input_and_conditioning(batch_, device=device)
z_inpaint = batch_["z4_inpaint"]
z_inpaint_mask = batch_["tgt_mask_64"]
z_ref = batch_["z_ref"]
z9 = batch_["z9"]
uc = None
if SCALE != 1.0:
uc = model.learnable_vector[TASK].repeat(bs, 1, 1)
shape = [C, H // F, W // F]
local_start_code = start_code
if FIXED_CODE and (local_start_code is None or local_start_code.shape[0] != bs):
local_start_code = torch.randn([bs, C, H // F, W // F], device=device)
samples_ddim, intermediates = sampler.sample(
S=DDIM_STEPS,
conditioning=c,
batch_size=bs,
shape=shape,
verbose=False,
unconditional_guidance_scale=SCALE,
unconditional_conditioning=uc,
eta=DDIM_ETA,
x_T=local_start_code,
log_every_t=100,
z_inpaint=z_inpaint,
z_inpaint_mask=z_inpaint_mask,
z_ref=z_ref,
z9=z9,
)
if SAVE_INTERMEDIATES:
intermediate_pred_x0 = intermediates["pred_x0"]
intermediate_noised = intermediates["x_inter"]
for i in range(len(intermediate_pred_x0)):
for j in range(bs):
stem = f"{out_stem_batch[j]}"
save_sample_by_decode(
intermediate_pred_x0[i][j : j + 1],
model,
inter_pred_path,
stem,
i,
)
save_sample_by_decode(
intermediate_noised[i][j : j + 1],
model,
inter_noised_path,
stem,
i,
)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
x_checked_image_torch = torch.from_numpy(x_samples_ddim).permute(0, 3, 1, 2)
for i, x_sample in enumerate(x_checked_image_torch):
stem = f"{out_stem_batch[i]}"
out_path = result_path / f"{stem}.png"
img = Image.fromarray((x_sample.permute(1, 2, 0).numpy() * 255).astype(np.uint8))
img.save(out_path)
print(f"{out_path=}")
for i, x_sample in enumerate(x_checked_image_torch):
all_img = []
all_img.append(un_norm(test_batch[i]).cpu())
if TASK != 2:
ref_img = test_model_kwargs["ref_imgs"].squeeze(1)
ref_img = torchvision.transforms.Resize([512, 512])(ref_img)
ref_img = un_norm_clip(ref_img[i]).cpu()
else:
ref_img = un_norm(test_model_kwargs["ref512"].squeeze(1)[i]).cpu()
all_img.append(ref_img)
all_img.append(x_sample)
grid = torch.stack(all_img, 0)
grid = make_grid(grid)
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
img = Image.fromarray(grid.astype(np.uint8))
stem = f"{out_stem_batch[i]}"
path_save_img = grid_path / f"grid-{stem}.jpg"
img.save(path_save_img)
print(f"{path_save_img=}")
grids.append(img)
grid_stems.append(stem)
if len(grids) >= NUM_grid_in_a_column:
stem_start = grid_stems[0]
stem_end = grid_stems[-1]
grid_column = imgs_2_grid_A(
grids,
grid_layout='column',
grid_path=os.path.join(grid_path, f"{stem_start}--{stem_end}.jpg"),
)
grids = []
grid_stems = []
model.unset_task()
print(f"Your samples are ready and waiting for you here: {out_dir}")