ArianatorQualquer commited on
Commit
0c8b8fb
·
1 Parent(s): a78dfde

Upload streamlit_helpers.py

Browse files
Files changed (1) hide show
  1. streamlit_helpers.py +887 -0
streamlit_helpers.py ADDED
@@ -0,0 +1,887 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import os
4
+ from glob import glob
5
+ from typing import Dict, List, Optional, Tuple, Union
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import streamlit as st
10
+ import torch
11
+ import torch.nn as nn
12
+ import torchvision.transforms as TT
13
+ from einops import rearrange, repeat
14
+ from imwatermark import WatermarkEncoder
15
+ from omegaconf import ListConfig, OmegaConf
16
+ from PIL import Image
17
+ from safetensors.torch import load_file as load_safetensors
18
+ from torch import autocast
19
+ from torchvision import transforms
20
+ from torchvision.utils import make_grid, save_image
21
+
22
+ from scripts.demo.discretization import (Img2ImgDiscretizationWrapper,
23
+ Txt2NoisyDiscretizationWrapper)
24
+ from scripts.util.detection.nsfw_and_watermark_dectection import \
25
+ DeepFloydDataFiltering
26
+ from sgm.inference.helpers import embed_watermark
27
+ from sgm.modules.diffusionmodules.guiders import (LinearPredictionGuider,
28
+ VanillaCFG)
29
+ from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
30
+ DPMPP2SAncestralSampler,
31
+ EulerAncestralSampler,
32
+ EulerEDMSampler,
33
+ HeunEDMSampler,
34
+ LinearMultistepSampler)
35
+ from sgm.util import append_dims, default, instantiate_from_config
36
+
37
+
38
+ @st.cache_resource()
39
+ def init_st(version_dict, load_ckpt=True, load_filter=True):
40
+ state = dict()
41
+ if not "model" in state:
42
+ config = version_dict["config"]
43
+ ckpt = version_dict["ckpt"]
44
+
45
+ config = OmegaConf.load(config)
46
+ model, msg = load_model_from_config(config, ckpt if load_ckpt else None)
47
+
48
+ state["msg"] = msg
49
+ state["model"] = model
50
+ state["ckpt"] = ckpt if load_ckpt else None
51
+ state["config"] = config
52
+ if load_filter:
53
+ state["filter"] = DeepFloydDataFiltering(verbose=False)
54
+ return state
55
+
56
+
57
+ def load_model(model):
58
+ model.cuda()
59
+
60
+
61
+ lowvram_mode = False
62
+
63
+
64
+ def set_lowvram_mode(mode):
65
+ global lowvram_mode
66
+ lowvram_mode = mode
67
+
68
+
69
+ def initial_model_load(model):
70
+ global lowvram_mode
71
+ if lowvram_mode:
72
+ model.model.half()
73
+ else:
74
+ model.cuda()
75
+ return model
76
+
77
+
78
+ def unload_model(model):
79
+ global lowvram_mode
80
+ if lowvram_mode:
81
+ model.cpu()
82
+ torch.cuda.empty_cache()
83
+
84
+
85
+ def load_model_from_config(config, ckpt=None, verbose=True):
86
+ model = instantiate_from_config(config.model)
87
+
88
+ if ckpt is not None:
89
+ print(f"Loading model from {ckpt}")
90
+ if ckpt.endswith("ckpt"):
91
+ pl_sd = torch.load(ckpt, map_location="cpu")
92
+ if "global_step" in pl_sd:
93
+ global_step = pl_sd["global_step"]
94
+ st.info(f"loaded ckpt from global step {global_step}")
95
+ print(f"Global Step: {pl_sd['global_step']}")
96
+ sd = pl_sd["state_dict"]
97
+ elif ckpt.endswith("safetensors"):
98
+ sd = load_safetensors(ckpt)
99
+ else:
100
+ raise NotImplementedError
101
+
102
+ msg = None
103
+
104
+ m, u = model.load_state_dict(sd, strict=False)
105
+
106
+ if len(m) > 0 and verbose:
107
+ print("missing keys:")
108
+ print(m)
109
+ if len(u) > 0 and verbose:
110
+ print("unexpected keys:")
111
+ print(u)
112
+ else:
113
+ msg = None
114
+
115
+ model = initial_model_load(model)
116
+ model.eval()
117
+ return model, msg
118
+
119
+
120
+ def get_unique_embedder_keys_from_conditioner(conditioner):
121
+ return list(set([x.input_key for x in conditioner.embedders]))
122
+
123
+
124
+ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
125
+ # Hardcoded demo settings; might undergo some changes in the future
126
+
127
+ value_dict = {}
128
+ for key in keys:
129
+ if key == "txt":
130
+ if prompt is None:
131
+ prompt = "A professional photograph of an astronaut riding a pig"
132
+ if negative_prompt is None:
133
+ negative_prompt = ""
134
+
135
+ prompt = st.text_input("Prompt", prompt)
136
+ negative_prompt = st.text_input("Negative prompt", negative_prompt)
137
+
138
+ value_dict["prompt"] = prompt
139
+ value_dict["negative_prompt"] = negative_prompt
140
+
141
+ if key == "original_size_as_tuple":
142
+ orig_width = st.number_input(
143
+ "orig_width",
144
+ value=init_dict["orig_width"],
145
+ min_value=16,
146
+ )
147
+ orig_height = st.number_input(
148
+ "orig_height",
149
+ value=init_dict["orig_height"],
150
+ min_value=16,
151
+ )
152
+
153
+ value_dict["orig_width"] = orig_width
154
+ value_dict["orig_height"] = orig_height
155
+
156
+ if key == "crop_coords_top_left":
157
+ crop_coord_top = st.number_input("crop_coords_top", value=0, min_value=0)
158
+ crop_coord_left = st.number_input("crop_coords_left", value=0, min_value=0)
159
+
160
+ value_dict["crop_coords_top"] = crop_coord_top
161
+ value_dict["crop_coords_left"] = crop_coord_left
162
+
163
+ if key == "aesthetic_score":
164
+ value_dict["aesthetic_score"] = 6.0
165
+ value_dict["negative_aesthetic_score"] = 2.5
166
+
167
+ if key == "target_size_as_tuple":
168
+ value_dict["target_width"] = init_dict["target_width"]
169
+ value_dict["target_height"] = init_dict["target_height"]
170
+
171
+ if key in ["fps_id", "fps"]:
172
+ fps = st.number_input("fps", value=6, min_value=1)
173
+
174
+ value_dict["fps"] = fps
175
+ value_dict["fps_id"] = fps - 1
176
+
177
+ if key == "motion_bucket_id":
178
+ mb_id = st.number_input("motion bucket id", 0, 511, value=127)
179
+ value_dict["motion_bucket_id"] = mb_id
180
+
181
+ if key == "pool_image":
182
+ st.text("Image for pool conditioning")
183
+ image = load_img(
184
+ key="pool_image_input",
185
+ size=224,
186
+ center_crop=True,
187
+ )
188
+ if image is None:
189
+ st.info("Need an image here")
190
+ image = torch.zeros(1, 3, 224, 224)
191
+ value_dict["pool_image"] = image
192
+
193
+ return value_dict
194
+
195
+
196
+ def perform_save_locally(save_path, samples):
197
+ os.makedirs(os.path.join(save_path), exist_ok=True)
198
+ base_count = len(os.listdir(os.path.join(save_path)))
199
+ samples = embed_watermark(samples)
200
+ for sample in samples:
201
+ sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
202
+ Image.fromarray(sample.astype(np.uint8)).save(
203
+ os.path.join(save_path, f"{base_count:09}.png")
204
+ )
205
+ base_count += 1
206
+
207
+
208
+ def init_save_locally(_dir, init_value: bool = False):
209
+ save_locally = st.sidebar.checkbox("Save images locally", value=init_value)
210
+ if save_locally:
211
+ save_path = st.text_input("Save path", value=os.path.join(_dir, "samples"))
212
+ else:
213
+ save_path = None
214
+
215
+ return save_locally, save_path
216
+
217
+
218
+ def get_guider(options, key):
219
+ guider = st.sidebar.selectbox(
220
+ f"Discretization #{key}",
221
+ [
222
+ "VanillaCFG",
223
+ "IdentityGuider",
224
+ "LinearPredictionGuider",
225
+ ],
226
+ options.get("guider", 0),
227
+ )
228
+
229
+ additional_guider_kwargs = options.pop("additional_guider_kwargs", {})
230
+
231
+ if guider == "IdentityGuider":
232
+ guider_config = {
233
+ "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
234
+ }
235
+ elif guider == "VanillaCFG":
236
+ scale = st.number_input(
237
+ f"cfg-scale #{key}",
238
+ value=options.get("cfg", 5.0),
239
+ min_value=0.0,
240
+ )
241
+
242
+ guider_config = {
243
+ "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
244
+ "params": {
245
+ "scale": scale,
246
+ **additional_guider_kwargs,
247
+ },
248
+ }
249
+ elif guider == "LinearPredictionGuider":
250
+ max_scale = st.number_input(
251
+ f"max-cfg-scale #{key}",
252
+ value=options.get("cfg", 1.5),
253
+ min_value=1.0,
254
+ )
255
+ min_scale = st.number_input(
256
+ f"min guidance scale",
257
+ value=options.get("min_cfg", 1.0),
258
+ min_value=1.0,
259
+ max_value=10.0,
260
+ )
261
+
262
+ guider_config = {
263
+ "target": "sgm.modules.diffusionmodules.guiders.LinearPredictionGuider",
264
+ "params": {
265
+ "max_scale": max_scale,
266
+ "min_scale": min_scale,
267
+ "num_frames": options["num_frames"],
268
+ **additional_guider_kwargs,
269
+ },
270
+ }
271
+ else:
272
+ raise NotImplementedError
273
+ return guider_config
274
+
275
+
276
+ def init_sampling(
277
+ key=1,
278
+ img2img_strength: Optional[float] = None,
279
+ specify_num_samples: bool = True,
280
+ stage2strength: Optional[float] = None,
281
+ options: Optional[Dict[str, int]] = None,
282
+ ):
283
+ options = {} if options is None else options
284
+
285
+ num_rows, num_cols = 1, 1
286
+ if specify_num_samples:
287
+ num_cols = st.number_input(
288
+ f"num cols #{key}", value=num_cols, min_value=1, max_value=10
289
+ )
290
+
291
+ steps = st.sidebar.number_input(
292
+ f"steps #{key}", value=options.get("num_steps", 40), min_value=1, max_value=1000
293
+ )
294
+ sampler = st.sidebar.selectbox(
295
+ f"Sampler #{key}",
296
+ [
297
+ "EulerEDMSampler",
298
+ "HeunEDMSampler",
299
+ "EulerAncestralSampler",
300
+ "DPMPP2SAncestralSampler",
301
+ "DPMPP2MSampler",
302
+ "LinearMultistepSampler",
303
+ ],
304
+ options.get("sampler", 0),
305
+ )
306
+ discretization = st.sidebar.selectbox(
307
+ f"Discretization #{key}",
308
+ [
309
+ "LegacyDDPMDiscretization",
310
+ "EDMDiscretization",
311
+ ],
312
+ options.get("discretization", 0),
313
+ )
314
+
315
+ discretization_config = get_discretization(discretization, options=options, key=key)
316
+
317
+ guider_config = get_guider(options=options, key=key)
318
+
319
+ sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key)
320
+ if img2img_strength is not None:
321
+ st.warning(
322
+ f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper"
323
+ )
324
+ sampler.discretization = Img2ImgDiscretizationWrapper(
325
+ sampler.discretization, strength=img2img_strength
326
+ )
327
+ if stage2strength is not None:
328
+ sampler.discretization = Txt2NoisyDiscretizationWrapper(
329
+ sampler.discretization, strength=stage2strength, original_steps=steps
330
+ )
331
+ return sampler, num_rows, num_cols
332
+
333
+
334
+ def get_discretization(discretization, options, key=1):
335
+ if discretization == "LegacyDDPMDiscretization":
336
+ discretization_config = {
337
+ "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
338
+ }
339
+ elif discretization == "EDMDiscretization":
340
+ sigma_min = st.number_input(
341
+ f"sigma_min #{key}", value=options.get("sigma_min", 0.03)
342
+ ) # 0.0292
343
+ sigma_max = st.number_input(
344
+ f"sigma_max #{key}", value=options.get("sigma_max", 14.61)
345
+ ) # 14.6146
346
+ rho = st.number_input(f"rho #{key}", value=options.get("rho", 3.0))
347
+ discretization_config = {
348
+ "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
349
+ "params": {
350
+ "sigma_min": sigma_min,
351
+ "sigma_max": sigma_max,
352
+ "rho": rho,
353
+ },
354
+ }
355
+
356
+ return discretization_config
357
+
358
+
359
+ def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1):
360
+ if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler":
361
+ s_churn = st.sidebar.number_input(f"s_churn #{key}", value=0.0, min_value=0.0)
362
+ s_tmin = st.sidebar.number_input(f"s_tmin #{key}", value=0.0, min_value=0.0)
363
+ s_tmax = st.sidebar.number_input(f"s_tmax #{key}", value=999.0, min_value=0.0)
364
+ s_noise = st.sidebar.number_input(f"s_noise #{key}", value=1.0, min_value=0.0)
365
+
366
+ if sampler_name == "EulerEDMSampler":
367
+ sampler = EulerEDMSampler(
368
+ num_steps=steps,
369
+ discretization_config=discretization_config,
370
+ guider_config=guider_config,
371
+ s_churn=s_churn,
372
+ s_tmin=s_tmin,
373
+ s_tmax=s_tmax,
374
+ s_noise=s_noise,
375
+ verbose=True,
376
+ )
377
+ elif sampler_name == "HeunEDMSampler":
378
+ sampler = HeunEDMSampler(
379
+ num_steps=steps,
380
+ discretization_config=discretization_config,
381
+ guider_config=guider_config,
382
+ s_churn=s_churn,
383
+ s_tmin=s_tmin,
384
+ s_tmax=s_tmax,
385
+ s_noise=s_noise,
386
+ verbose=True,
387
+ )
388
+ elif (
389
+ sampler_name == "EulerAncestralSampler"
390
+ or sampler_name == "DPMPP2SAncestralSampler"
391
+ ):
392
+ s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0)
393
+ eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0)
394
+
395
+ if sampler_name == "EulerAncestralSampler":
396
+ sampler = EulerAncestralSampler(
397
+ num_steps=steps,
398
+ discretization_config=discretization_config,
399
+ guider_config=guider_config,
400
+ eta=eta,
401
+ s_noise=s_noise,
402
+ verbose=True,
403
+ )
404
+ elif sampler_name == "DPMPP2SAncestralSampler":
405
+ sampler = DPMPP2SAncestralSampler(
406
+ num_steps=steps,
407
+ discretization_config=discretization_config,
408
+ guider_config=guider_config,
409
+ eta=eta,
410
+ s_noise=s_noise,
411
+ verbose=True,
412
+ )
413
+ elif sampler_name == "DPMPP2MSampler":
414
+ sampler = DPMPP2MSampler(
415
+ num_steps=steps,
416
+ discretization_config=discretization_config,
417
+ guider_config=guider_config,
418
+ verbose=True,
419
+ )
420
+ elif sampler_name == "LinearMultistepSampler":
421
+ order = st.sidebar.number_input("order", value=4, min_value=1)
422
+ sampler = LinearMultistepSampler(
423
+ num_steps=steps,
424
+ discretization_config=discretization_config,
425
+ guider_config=guider_config,
426
+ order=order,
427
+ verbose=True,
428
+ )
429
+ else:
430
+ raise ValueError(f"unknown sampler {sampler_name}!")
431
+
432
+ return sampler
433
+
434
+
435
+ def get_interactive_image() -> Image.Image:
436
+ image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
437
+ if image is not None:
438
+ image = Image.open(image)
439
+ if not image.mode == "RGB":
440
+ image = image.convert("RGB")
441
+ return image
442
+
443
+
444
+ def load_img(
445
+ display: bool = True,
446
+ size: Union[None, int, Tuple[int, int]] = None,
447
+ center_crop: bool = False,
448
+ ):
449
+ image = get_interactive_image()
450
+ if image is None:
451
+ return None
452
+ if display:
453
+ st.image(image)
454
+ w, h = image.size
455
+ print(f"loaded input image of size ({w}, {h})")
456
+
457
+ transform = []
458
+ if size is not None:
459
+ transform.append(transforms.Resize(size))
460
+ if center_crop:
461
+ transform.append(transforms.CenterCrop(size))
462
+ transform.append(transforms.ToTensor())
463
+ transform.append(transforms.Lambda(lambda x: 2.0 * x - 1.0))
464
+
465
+ transform = transforms.Compose(transform)
466
+ img = transform(image)[None, ...]
467
+ st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}")
468
+ return img
469
+
470
+
471
+ def get_init_img(batch_size=1, key=None):
472
+ init_image = load_img(key=key).cuda()
473
+ init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
474
+ return init_image
475
+
476
+
477
+ def do_sample(
478
+ model,
479
+ sampler,
480
+ value_dict,
481
+ num_samples,
482
+ H,
483
+ W,
484
+ C,
485
+ F,
486
+ force_uc_zero_embeddings: Optional[List] = None,
487
+ force_cond_zero_embeddings: Optional[List] = None,
488
+ batch2model_input: List = None,
489
+ return_latents=False,
490
+ filter=None,
491
+ T=None,
492
+ additional_batch_uc_fields=None,
493
+ decoding_t=None,
494
+ ):
495
+ force_uc_zero_embeddings = default(force_uc_zero_embeddings, [])
496
+ batch2model_input = default(batch2model_input, [])
497
+ additional_batch_uc_fields = default(additional_batch_uc_fields, [])
498
+
499
+ st.text("Sampling")
500
+
501
+ outputs = st.empty()
502
+ precision_scope = autocast
503
+ with torch.no_grad():
504
+ with precision_scope("cuda"):
505
+ with model.ema_scope():
506
+ if T is not None:
507
+ num_samples = [num_samples, T]
508
+ else:
509
+ num_samples = [num_samples]
510
+
511
+ load_model(model.conditioner)
512
+ batch, batch_uc = get_batch(
513
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
514
+ value_dict,
515
+ num_samples,
516
+ T=T,
517
+ additional_batch_uc_fields=additional_batch_uc_fields,
518
+ )
519
+
520
+ c, uc = model.conditioner.get_unconditional_conditioning(
521
+ batch,
522
+ batch_uc=batch_uc,
523
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
524
+ force_cond_zero_embeddings=force_cond_zero_embeddings,
525
+ )
526
+ unload_model(model.conditioner)
527
+
528
+ for k in c:
529
+ if not k == "crossattn":
530
+ c[k], uc[k] = map(
531
+ lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
532
+ )
533
+ if k in ["crossattn", "concat"] and T is not None:
534
+ uc[k] = repeat(uc[k], "b ... -> b t ...", t=T)
535
+ uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=T)
536
+ c[k] = repeat(c[k], "b ... -> b t ...", t=T)
537
+ c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=T)
538
+
539
+ additional_model_inputs = {}
540
+ for k in batch2model_input:
541
+ if k == "image_only_indicator":
542
+ assert T is not None
543
+
544
+ if isinstance(
545
+ sampler.guider, (VanillaCFG, LinearPredictionGuider)
546
+ ):
547
+ additional_model_inputs[k] = torch.zeros(
548
+ num_samples[0] * 2, num_samples[1]
549
+ ).to("cuda")
550
+ else:
551
+ additional_model_inputs[k] = torch.zeros(num_samples).to(
552
+ "cuda"
553
+ )
554
+ else:
555
+ additional_model_inputs[k] = batch[k]
556
+
557
+ shape = (math.prod(num_samples), C, H // F, W // F)
558
+ randn = torch.randn(shape).to("cuda")
559
+
560
+ def denoiser(input, sigma, c):
561
+ return model.denoiser(
562
+ model.model, input, sigma, c, **additional_model_inputs
563
+ )
564
+
565
+ load_model(model.denoiser)
566
+ load_model(model.model)
567
+ samples_z = sampler(denoiser, randn, cond=c, uc=uc)
568
+ unload_model(model.model)
569
+ unload_model(model.denoiser)
570
+
571
+ load_model(model.first_stage_model)
572
+ model.en_and_decode_n_samples_a_time = (
573
+ decoding_t # Decode n frames at a time
574
+ )
575
+ samples_x = model.decode_first_stage(samples_z)
576
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
577
+ unload_model(model.first_stage_model)
578
+
579
+ if filter is not None:
580
+ samples = filter(samples)
581
+
582
+ if T is None:
583
+ grid = torch.stack([samples])
584
+ grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
585
+ outputs.image(grid.cpu().numpy())
586
+ else:
587
+ as_vids = rearrange(samples, "(b t) c h w -> b t c h w", t=T)
588
+ for i, vid in enumerate(as_vids):
589
+ grid = rearrange(make_grid(vid, nrow=4), "c h w -> h w c")
590
+ st.image(
591
+ grid.cpu().numpy(),
592
+ f"Sample #{i} as image",
593
+ )
594
+
595
+ if return_latents:
596
+ return samples, samples_z
597
+ return samples
598
+
599
+
600
+ def get_batch(
601
+ keys,
602
+ value_dict: dict,
603
+ N: Union[List, ListConfig],
604
+ device: str = "cuda",
605
+ T: int = None,
606
+ additional_batch_uc_fields: List[str] = [],
607
+ ):
608
+ # Hardcoded demo setups; might undergo some changes in the future
609
+
610
+ batch = {}
611
+ batch_uc = {}
612
+
613
+ for key in keys:
614
+ if key == "txt":
615
+ batch["txt"] = [value_dict["prompt"]] * math.prod(N)
616
+
617
+ batch_uc["txt"] = [value_dict["negative_prompt"]] * math.prod(N)
618
+
619
+ elif key == "original_size_as_tuple":
620
+ batch["original_size_as_tuple"] = (
621
+ torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
622
+ .to(device)
623
+ .repeat(math.prod(N), 1)
624
+ )
625
+ elif key == "crop_coords_top_left":
626
+ batch["crop_coords_top_left"] = (
627
+ torch.tensor(
628
+ [value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
629
+ )
630
+ .to(device)
631
+ .repeat(math.prod(N), 1)
632
+ )
633
+ elif key == "aesthetic_score":
634
+ batch["aesthetic_score"] = (
635
+ torch.tensor([value_dict["aesthetic_score"]])
636
+ .to(device)
637
+ .repeat(math.prod(N), 1)
638
+ )
639
+ batch_uc["aesthetic_score"] = (
640
+ torch.tensor([value_dict["negative_aesthetic_score"]])
641
+ .to(device)
642
+ .repeat(math.prod(N), 1)
643
+ )
644
+
645
+ elif key == "target_size_as_tuple":
646
+ batch["target_size_as_tuple"] = (
647
+ torch.tensor([value_dict["target_height"], value_dict["target_width"]])
648
+ .to(device)
649
+ .repeat(math.prod(N), 1)
650
+ )
651
+ elif key == "fps":
652
+ batch[key] = (
653
+ torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N))
654
+ )
655
+ elif key == "fps_id":
656
+ batch[key] = (
657
+ torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N))
658
+ )
659
+ elif key == "motion_bucket_id":
660
+ batch[key] = (
661
+ torch.tensor([value_dict["motion_bucket_id"]])
662
+ .to(device)
663
+ .repeat(math.prod(N))
664
+ )
665
+ elif key == "pool_image":
666
+ batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(
667
+ device, dtype=torch.half
668
+ )
669
+ elif key == "cond_aug":
670
+ batch[key] = repeat(
671
+ torch.tensor([value_dict["cond_aug"]]).to("cuda"),
672
+ "1 -> b",
673
+ b=math.prod(N),
674
+ )
675
+ elif key == "cond_frames":
676
+ batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
677
+ elif key == "cond_frames_without_noise":
678
+ batch[key] = repeat(
679
+ value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
680
+ )
681
+ else:
682
+ batch[key] = value_dict[key]
683
+
684
+ if T is not None:
685
+ batch["num_video_frames"] = T
686
+
687
+ for key in batch.keys():
688
+ if key not in batch_uc and isinstance(batch[key], torch.Tensor):
689
+ batch_uc[key] = torch.clone(batch[key])
690
+ elif key in additional_batch_uc_fields and key not in batch_uc:
691
+ batch_uc[key] = copy.copy(batch[key])
692
+ return batch, batch_uc
693
+
694
+
695
+ @torch.no_grad()
696
+ def do_img2img(
697
+ img,
698
+ model,
699
+ sampler,
700
+ value_dict,
701
+ num_samples,
702
+ force_uc_zero_embeddings: Optional[List] = None,
703
+ force_cond_zero_embeddings: Optional[List] = None,
704
+ additional_kwargs={},
705
+ offset_noise_level: int = 0.0,
706
+ return_latents=False,
707
+ skip_encode=False,
708
+ filter=None,
709
+ add_noise=True,
710
+ ):
711
+ st.text("Sampling")
712
+
713
+ outputs = st.empty()
714
+ precision_scope = autocast
715
+ with torch.no_grad():
716
+ with precision_scope("cuda"):
717
+ with model.ema_scope():
718
+ load_model(model.conditioner)
719
+ batch, batch_uc = get_batch(
720
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
721
+ value_dict,
722
+ [num_samples],
723
+ )
724
+ c, uc = model.conditioner.get_unconditional_conditioning(
725
+ batch,
726
+ batch_uc=batch_uc,
727
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
728
+ force_cond_zero_embeddings=force_cond_zero_embeddings,
729
+ )
730
+ unload_model(model.conditioner)
731
+ for k in c:
732
+ c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc))
733
+
734
+ for k in additional_kwargs:
735
+ c[k] = uc[k] = additional_kwargs[k]
736
+ if skip_encode:
737
+ z = img
738
+ else:
739
+ load_model(model.first_stage_model)
740
+ z = model.encode_first_stage(img)
741
+ unload_model(model.first_stage_model)
742
+
743
+ noise = torch.randn_like(z)
744
+
745
+ sigmas = sampler.discretization(sampler.num_steps).cuda()
746
+ sigma = sigmas[0]
747
+
748
+ st.info(f"all sigmas: {sigmas}")
749
+ st.info(f"noising sigma: {sigma}")
750
+ if offset_noise_level > 0.0:
751
+ noise = noise + offset_noise_level * append_dims(
752
+ torch.randn(z.shape[0], device=z.device), z.ndim
753
+ )
754
+ if add_noise:
755
+ noised_z = z + noise * append_dims(sigma, z.ndim).cuda()
756
+ noised_z = noised_z / torch.sqrt(
757
+ 1.0 + sigmas[0] ** 2.0
758
+ ) # Note: hardcoded to DDPM-like scaling. need to generalize later.
759
+ else:
760
+ noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0)
761
+
762
+ def denoiser(x, sigma, c):
763
+ return model.denoiser(model.model, x, sigma, c)
764
+
765
+ load_model(model.denoiser)
766
+ load_model(model.model)
767
+ samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
768
+ unload_model(model.model)
769
+ unload_model(model.denoiser)
770
+
771
+ load_model(model.first_stage_model)
772
+ samples_x = model.decode_first_stage(samples_z)
773
+ unload_model(model.first_stage_model)
774
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
775
+
776
+ if filter is not None:
777
+ samples = filter(samples)
778
+
779
+ grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
780
+ outputs.image(grid.cpu().numpy())
781
+ if return_latents:
782
+ return samples, samples_z
783
+ return samples
784
+
785
+
786
+ def get_resizing_factor(
787
+ desired_shape: Tuple[int, int], current_shape: Tuple[int, int]
788
+ ) -> float:
789
+ r_bound = desired_shape[1] / desired_shape[0]
790
+ aspect_r = current_shape[1] / current_shape[0]
791
+ if r_bound >= 1.0:
792
+ if aspect_r >= r_bound:
793
+ factor = min(desired_shape) / min(current_shape)
794
+ else:
795
+ if aspect_r < 1.0:
796
+ factor = max(desired_shape) / min(current_shape)
797
+ else:
798
+ factor = max(desired_shape) / max(current_shape)
799
+ else:
800
+ if aspect_r <= r_bound:
801
+ factor = min(desired_shape) / min(current_shape)
802
+ else:
803
+ if aspect_r > 1:
804
+ factor = max(desired_shape) / min(current_shape)
805
+ else:
806
+ factor = max(desired_shape) / max(current_shape)
807
+
808
+ return factor
809
+
810
+
811
+ def get_interactive_image(key=None) -> Image.Image:
812
+ image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
813
+ if image is not None:
814
+ image = Image.open(image)
815
+ if not image.mode == "RGB":
816
+ image = image.convert("RGB")
817
+ return image
818
+
819
+
820
+ def load_img_for_prediction(
821
+ W: int, H: int, display=True, key=None, device="cuda"
822
+ ) -> torch.Tensor:
823
+ image = get_interactive_image(key=key)
824
+ if image is None:
825
+ return None
826
+ if display:
827
+ st.image(image)
828
+ w, h = image.size
829
+
830
+ image = np.array(image).transpose(2, 0, 1)
831
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 255.0
832
+ image = image.unsqueeze(0)
833
+
834
+ rfs = get_resizing_factor((H, W), (h, w))
835
+ resize_size = [int(np.ceil(rfs * s)) for s in (h, w)]
836
+ top = (resize_size[0] - H) // 2
837
+ left = (resize_size[1] - W) // 2
838
+
839
+ image = torch.nn.functional.interpolate(
840
+ image, resize_size, mode="area", antialias=False
841
+ )
842
+ image = TT.functional.crop(image, top=top, left=left, height=H, width=W)
843
+
844
+ if display:
845
+ numpy_img = np.transpose(image[0].numpy(), (1, 2, 0))
846
+ pil_image = Image.fromarray((numpy_img * 255).astype(np.uint8))
847
+ st.image(pil_image)
848
+ return image.to(device) * 2.0 - 1.0
849
+
850
+
851
+ def save_video_as_grid_and_mp4(
852
+ video_batch: torch.Tensor, save_path: str, T: int, fps: int = 5
853
+ ):
854
+ os.makedirs(save_path, exist_ok=True)
855
+ base_count = len(glob(os.path.join(save_path, "*.mp4")))
856
+
857
+ video_batch = rearrange(video_batch, "(b t) c h w -> b t c h w", t=T)
858
+ video_batch = embed_watermark(video_batch)
859
+ for vid in video_batch:
860
+ save_image(vid, fp=os.path.join(save_path, f"{base_count:06d}.png"), nrow=4)
861
+
862
+ video_path = os.path.join(save_path, f"{base_count:06d}.mp4")
863
+
864
+ writer = cv2.VideoWriter(
865
+ video_path,
866
+ cv2.VideoWriter_fourcc(*"MP4V"),
867
+ fps,
868
+ (vid.shape[-1], vid.shape[-2]),
869
+ )
870
+
871
+ vid = (
872
+ (rearrange(vid, "t c h w -> t h w c") * 255).cpu().numpy().astype(np.uint8)
873
+ )
874
+ for frame in vid:
875
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
876
+ writer.write(frame)
877
+
878
+ writer.release()
879
+
880
+ video_path_h264 = video_path[:-4] + "_h264.mp4"
881
+ os.system(f"ffmpeg -i {video_path} -c:v libx264 {video_path_h264}")
882
+
883
+ with open(video_path_h264, "rb") as f:
884
+ video_bytes = f.read()
885
+ st.video(video_bytes)
886
+
887
+ base_count += 1