Spaces:
Runtime error
Runtime error
Commit Β·
514015e
1
Parent(s): b6b5d48
Create utils.py
Browse files
utils.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from PIL import Image
|
| 4 |
+
|
| 5 |
+
from lvdm.models.modules.lora import net_load_lora
|
| 6 |
+
from lvdm.utils.common_utils import instantiate_from_config
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# ------------------------------------------------------------------------------------------
|
| 10 |
+
def load_model(config, ckpt_path, gpu_id=None, inject_lora=False, lora_scale=1.0, lora_path=''):
|
| 11 |
+
print(f"Loading model from {ckpt_path}")
|
| 12 |
+
|
| 13 |
+
# load sd
|
| 14 |
+
pl_sd = torch.load(ckpt_path, map_location="cpu")
|
| 15 |
+
try:
|
| 16 |
+
global_step = pl_sd["global_step"]
|
| 17 |
+
epoch = pl_sd["epoch"]
|
| 18 |
+
except:
|
| 19 |
+
global_step = -1
|
| 20 |
+
epoch = -1
|
| 21 |
+
|
| 22 |
+
# load sd to model
|
| 23 |
+
try:
|
| 24 |
+
sd = pl_sd["state_dict"]
|
| 25 |
+
except:
|
| 26 |
+
sd = pl_sd
|
| 27 |
+
model = instantiate_from_config(config.model)
|
| 28 |
+
model.load_state_dict(sd, strict=True)
|
| 29 |
+
|
| 30 |
+
if inject_lora:
|
| 31 |
+
net_load_lora(model, lora_path, alpha=lora_scale)
|
| 32 |
+
|
| 33 |
+
# move to device & eval
|
| 34 |
+
if gpu_id is not None:
|
| 35 |
+
model.to(f"cuda:{gpu_id}")
|
| 36 |
+
else:
|
| 37 |
+
model.cuda()
|
| 38 |
+
model.eval()
|
| 39 |
+
|
| 40 |
+
return model, global_step, epoch
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# ------------------------------------------------------------------------------------------
|
| 44 |
+
@torch.no_grad()
|
| 45 |
+
def get_conditions(prompts, model, batch_size, cond_fps=None,):
|
| 46 |
+
|
| 47 |
+
if isinstance(prompts, str) or isinstance(prompts, int):
|
| 48 |
+
prompts = [prompts]
|
| 49 |
+
if isinstance(prompts, list):
|
| 50 |
+
if len(prompts) == 1:
|
| 51 |
+
prompts = prompts * batch_size
|
| 52 |
+
elif len(prompts) == batch_size:
|
| 53 |
+
pass
|
| 54 |
+
else:
|
| 55 |
+
raise ValueError(f"invalid prompts length: {len(prompts)}")
|
| 56 |
+
else:
|
| 57 |
+
raise ValueError(f"invalid prompts: {prompts}")
|
| 58 |
+
assert(len(prompts) == batch_size)
|
| 59 |
+
|
| 60 |
+
# content condition: text / class label
|
| 61 |
+
c = model.get_learned_conditioning(prompts)
|
| 62 |
+
key = 'c_concat' if model.conditioning_key == 'concat' else 'c_crossattn'
|
| 63 |
+
c = {key: [c]}
|
| 64 |
+
|
| 65 |
+
# temporal condition: fps
|
| 66 |
+
if getattr(model, 'cond_stage2_config', None) is not None:
|
| 67 |
+
if model.cond_stage2_key == "temporal_context":
|
| 68 |
+
assert(cond_fps is not None)
|
| 69 |
+
batch = {'fps': torch.tensor([cond_fps] * batch_size).long().to(model.device)}
|
| 70 |
+
fps_embd = model.cond_stage2_model(batch)
|
| 71 |
+
c[model.cond_stage2_key] = fps_embd
|
| 72 |
+
|
| 73 |
+
return c
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# ------------------------------------------------------------------------------------------
|
| 77 |
+
def make_model_input_shape(model, batch_size, T=None):
|
| 78 |
+
image_size = [model.image_size, model.image_size] if isinstance(model.image_size, int) else model.image_size
|
| 79 |
+
C = model.model.diffusion_model.in_channels
|
| 80 |
+
if T is None:
|
| 81 |
+
T = model.model.diffusion_model.temporal_length
|
| 82 |
+
shape = [batch_size, C, T, *image_size]
|
| 83 |
+
return shape
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# ------------------------------------------------------------------------------------------
|
| 87 |
+
def custom_to_pil(x):
|
| 88 |
+
x = x.detach().cpu()
|
| 89 |
+
x = torch.clamp(x, -1., 1.)
|
| 90 |
+
x = (x + 1.) / 2.
|
| 91 |
+
x = x.permute(1, 2, 0).numpy()
|
| 92 |
+
x = (255 * x).astype(np.uint8)
|
| 93 |
+
x = Image.fromarray(x)
|
| 94 |
+
if not x.mode == "RGB":
|
| 95 |
+
x = x.convert("RGB")
|
| 96 |
+
return x
|
| 97 |
+
|
| 98 |
+
def torch_to_np(x):
|
| 99 |
+
# saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py
|
| 100 |
+
sample = x.detach().cpu()
|
| 101 |
+
sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
|
| 102 |
+
if sample.dim() == 5:
|
| 103 |
+
sample = sample.permute(0, 2, 3, 4, 1)
|
| 104 |
+
else:
|
| 105 |
+
sample = sample.permute(0, 2, 3, 1)
|
| 106 |
+
sample = sample.contiguous()
|
| 107 |
+
return sample
|
| 108 |
+
|
| 109 |
+
def make_sample_dir(opt, global_step=None, epoch=None):
|
| 110 |
+
if not getattr(opt, 'not_automatic_logdir', False):
|
| 111 |
+
gs_str = f"globalstep{global_step:09}" if global_step is not None else "None"
|
| 112 |
+
e_str = f"epoch{epoch:06}" if epoch is not None else "None"
|
| 113 |
+
ckpt_dir = os.path.join(opt.logdir, f"{gs_str}_{e_str}")
|
| 114 |
+
|
| 115 |
+
# subdir name
|
| 116 |
+
if opt.prompt_file is not None:
|
| 117 |
+
subdir = f"prompts_{os.path.splitext(os.path.basename(opt.prompt_file))[0]}"
|
| 118 |
+
else:
|
| 119 |
+
subdir = f"prompt_{opt.prompt[:10]}"
|
| 120 |
+
subdir += "_DDPM" if opt.vanilla_sample else f"_DDIM{opt.custom_steps}steps"
|
| 121 |
+
subdir += f"_CfgScale{opt.scale}"
|
| 122 |
+
if opt.cond_fps is not None:
|
| 123 |
+
subdir += f"_fps{opt.cond_fps}"
|
| 124 |
+
if opt.seed is not None:
|
| 125 |
+
subdir += f"_seed{opt.seed}"
|
| 126 |
+
|
| 127 |
+
return os.path.join(ckpt_dir, subdir)
|
| 128 |
+
else:
|
| 129 |
+
return opt.logdir
|