acmyu commited on
Commit
317a82a
·
1 Parent(s): fb5b0b1
Files changed (1) hide show
  1. app.py +28 -1178
app.py CHANGED
@@ -1,1196 +1,46 @@
1
- import logging
2
- import math
3
- import os
4
- from typing import Any, Dict, List, Optional, Tuple, Union
5
- from diffusers.models.controlnet import ControlNetConditioningEmbedding
6
- import torch
7
- from torch import nn
8
- import torch.nn.functional as F
9
- import torch.utils.checkpoint
10
- import transformers
11
- from accelerate import Accelerator
12
- from accelerate.logging import get_logger
13
- from accelerate.utils import ProjectConfiguration, set_seed
14
-
15
- from tqdm.auto import tqdm
16
- from src.configs.stage2_config import args
17
-
18
- import diffusers
19
- from diffusers import (
20
- AutoencoderKL,
21
- DDPMScheduler,
22
- )
23
- from diffusers.optimization import get_scheduler
24
- from diffusers.utils import check_min_version, is_wandb_available
25
- from src.dataset.stage2_dataset import InpaintDataset, InpaintCollate_fn
26
- from transformers import CLIPVisionModelWithProjection
27
- from transformers import Dinov2Model
28
- from src.models.stage2_inpaint_unet_2d_condition import Stage2_InapintUNet2DConditionModel
29
-
30
-
31
-
32
- import glob
33
- import os
34
- import torch
35
- from torch import nn
36
- from PIL import Image, ImageOps
37
- import numpy as np
38
- from diffusers import UniPCMultistepScheduler
39
- from src.models.stage2_inpaint_unet_2d_condition import Stage2_InapintUNet2DConditionModel
40
-
41
- from torchvision import transforms
42
- from diffusers.models.controlnet import ControlNetConditioningEmbedding
43
- from transformers import CLIPImageProcessor
44
- from transformers import Dinov2Model
45
- from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel,ControlNetModel,DDIMScheduler
46
- from src.pipelines.PCDMs_pipeline import PCDMsPipeline
47
- #from single_extract_pose import inference_pose
48
-
49
 
50
  import spaces
51
- from easy_dwpose import DWposeDetector
52
  from PIL import Image
53
  import cv2
54
  import os
55
  import gradio as gr
56
- import rembg
57
- import uuid
58
- import gc
59
- from numba import cuda
60
- import requests
61
- import uuid
62
-
63
- from huggingface_hub import hf_hub_download, HfApi
64
-
65
-
66
- # Inputs ===================================================================================================
67
-
68
- input_img = "sm.png"
69
- train_imgs = ["target.png"]
70
- in_vid = "walk.mp4"
71
- out_vid = 'out.mp4'
72
-
73
- """
74
- train_steps = 100
75
- inference_steps = 10
76
- fps = 12
77
- """
78
-
79
- debug = False
80
- save_model = True
81
- should_gen_vid = False
82
- max_batch_size = 8
83
-
84
-
85
- def save_temp_imgs(imgs):
86
- os.makedirs('temp', exist_ok=True)
87
- results = []
88
-
89
- api = HfApi()
90
-
91
-
92
- for i, img in enumerate(imgs):
93
-
94
- #img_name = 'temp/'+str(uuid.uuid4())+'.png'
95
- img_name = 'temp/'+str(i)+'.png'
96
- img.save(img_name)
97
-
98
- """
99
- url = 'https://tmpfiles.org/api/v1/upload'
100
-
101
- try:
102
- response = requests.post(url, files={'file': open(img_name, 'rb')})
103
-
104
- # Check for successful response (status code 200)
105
- response.raise_for_status()
106
-
107
- # Print the server's response
108
- print("Status Code:", response.status_code)
109
-
110
- data = response.json()
111
- print("Response JSON:", data)
112
- results.append(data['data']['url'])
113
-
114
- except requests.exceptions.RequestException as e:
115
- print(f"An error occurred: {e}")
116
- """
117
-
118
- results.append('https://huggingface.co/datasets/acmyu/KeyframesAIFiles/resolve/main/'+img_name)
119
-
120
- api.upload_file(
121
- path_or_fileobj='temp',
122
- path_in_repo='temp',
123
- repo_id="acmyu/KeyframesAIFiles",
124
- repo_type="dataset",
125
- )
126
-
127
- return results
128
-
129
-
130
- def getThumbnails(imgs):
131
- thumbs = []
132
- thumb_size = (512, 512)
133
- for img in imgs:
134
- th = img.copy()
135
- th.thumbnail(thumb_size)
136
- thumbs.append(th)
137
- return thumbs
138
-
139
-
140
- # Pose detection ==============================================================================================
141
-
142
- def load_models():
143
- dwpose = DWposeDetector(device="cpu")
144
- rembg_session = rembg.new_session("u2netp")
145
-
146
- pcdms_model = hf_hub_download(repo_id="acmyu/PCDMs", filename="pcdms_ckpt.pt")
147
-
148
- # Load scheduler
149
- noise_scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="scheduler")
150
-
151
- # Load model
152
- image_encoder_p = Dinov2Model.from_pretrained('facebook/dinov2-giant')
153
- image_encoder_g = CLIPVisionModelWithProjection.from_pretrained('laion/CLIP-ViT-H-14-laion2B-s32B-b79K')#("openai/clip-vit-base-patch32")
154
-
155
- vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="vae")
156
- unet = Stage2_InapintUNet2DConditionModel.from_pretrained(
157
- "stabilityai/stable-diffusion-2-1-base",
158
- torch_dtype=torch.float16,
159
- subfolder="unet",
160
- in_channels=9,
161
- low_cpu_mem_usage=False,
162
- ignore_mismatched_sizes=True)
163
-
164
-
165
- return dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet
166
-
167
-
168
- #load_models()
169
-
170
- def img_pad(img, tw, th, transparent=False):
171
- img.thumbnail((tw, th))
172
- if transparent:
173
- new_img = Image.new('RGBA', (tw, th), (0, 0, 0, 0))
174
- else:
175
- new_img = Image.new("RGB", (tw, th), (0, 0, 0))
176
- left = (tw - img.width) // 2
177
- top = (th - img.height) // 2
178
- new_img.paste(img, (left, top))
179
- return new_img
180
-
181
-
182
- def resize_and_pad(img, target_img):
183
- tw, th = target_img.size
184
- w, h = img.size
185
-
186
- if tw/th > w/h:
187
- tw = int(th * w/h)
188
- elif tw/th < w/h:
189
- th = int(tw * h/w)
190
-
191
- img = img.resize((tw, th), Image.BICUBIC)
192
-
193
- tw, th = target_img.size
194
-
195
- return img_pad(img, tw, th)
196
-
197
-
198
- def remove_zero_pad(image):
199
- image = np.array(image)
200
- dummy = np.argwhere(image != 0) # assume blackground is zero
201
- max_y = dummy[:, 0].max()
202
- min_y = dummy[:, 0].min()
203
- min_x = dummy[:, 1].min()
204
- max_x = dummy[:, 1].max()
205
- crop_image = image[min_y:max_y, min_x:max_x]
206
-
207
- return Image.fromarray(crop_image)
208
-
209
-
210
- def get_pose(img, dwpose, outfile, crop=False):
211
- #pil_image = Image.open("imgs/"+img).convert("RGB")
212
- #skeleton = dwpose(pil_image, output_type="np", include_hands=True, include_face=False)
213
-
214
- #img.thumbnail((512,512))
215
- out_img = dwpose(img, include_hands=True, include_face=False)
216
-
217
- #print(pose['bodies'])
218
-
219
- if crop:
220
- bbox = out_img.getbbox()
221
- out_img = out_img.crop(bbox)
222
- out_img = ImageOps.expand(out_img, border=int(out_img.width*0.2), fill=(0,0,0))
223
-
224
- return out_img
225
-
226
-
227
- def extract_frames(video_path, fps):
228
- video_capture = cv2.VideoCapture(video_path)
229
- frame_count = 0
230
- frames = []
231
-
232
- fps_in = video_capture.get(cv2.CAP_PROP_FPS)
233
- fps_out = fps
234
-
235
- index_in = -1
236
- index_out = -1
237
-
238
- while True:
239
- success = video_capture.grab()
240
- if not success: break
241
- index_in += 1
242
-
243
- out_due = int(index_in / fps_in * fps_out)
244
- if out_due > index_out:
245
- success, frame = video_capture.retrieve()
246
- if not success:
247
- break
248
- index_out += 1
249
-
250
- frame_count += 1
251
- frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
252
-
253
- video_capture.release()
254
- print(f"Extracted {frame_count} frames")
255
- return frames
256
-
257
-
258
- def removebg(img, rembg_session, transparent=False):
259
-
260
- if transparent:
261
- result = Image.new('RGBA', img.size, (0, 0, 0, 0))
262
- else:
263
- result = Image.new("RGB", img.size, "#ffffff")
264
- out = rembg.remove(img, session=rembg_session)
265
- result.paste(out, mask=out)
266
- return result
267
-
268
 
269
- def prepare_inputs_train(images, bg_remove, dwpose, rembg_session):
270
- print("remove background", bg_remove)
271
- if bg_remove:
272
- images = [removebg(img, rembg_session) for img in images]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
- in_img = images[0]
275
- in_pose = get_pose(in_img, dwpose, "in_pose.png")
276
- train_poses = []
277
- train_imgs = [resize_and_pad(img, in_img) for img in images[1:]]
278
-
279
- for i, img in enumerate(train_imgs):
280
- train_poses.append(get_pose(img, dwpose, "tr_pose"+str(i)+".png"))
281
-
282
- return in_img, in_pose, train_imgs, train_poses
283
-
284
-
285
- def prepare_inputs_inference(in_img, in_vid, fps, dwpose, rembg_session, bg_remove, resize='target', is_app=False):
286
- progress=gr.Progress(track_tqdm=True)
287
-
288
- print("prepare_inputs_inference")
289
-
290
- in_pose = get_pose(in_img, dwpose, "in_pose.png")
291
-
292
- frames = extract_frames(in_vid, fps)
293
- print("remove background", bg_remove)
294
- if bg_remove:
295
- in_img = removebg(in_img, rembg_session)
296
- #frames = [removebg(img, rembg_session) for img in frames]
297
- if debug:
298
- for i, frame in enumerate(frames):
299
- frame.save("out/frame_"+str(i)+".png")
300
-
301
- print("vid: ", in_vid, fps)
302
-
303
- progress_bar = tqdm(range(len(frames)), initial=0, desc="Frames")
304
- target_poses = []
305
- max_left = max_top = 999999
306
- max_right = max_bottom = 0
307
- it = frames
308
- if is_app:
309
- it = progress.tqdm(frames, desc="Pose Detection")
310
- for f in it:
311
- tpose = get_pose(f, dwpose, "tar_pose"+str(len(target_poses))+".png")
312
- target_poses.append(tpose)
313
- progress_bar.update(1)
314
-
315
- bbox = tpose.getbbox()
316
- left, top, right, bottom = bbox
317
- max_left = min(max_left, left)
318
- max_top = min(max_top, top)
319
- max_right = max(max_right, right)
320
- max_bottom = max(max_bottom, bottom)
321
-
322
- target_poses_cropped = []
323
- for tpose in target_poses:
324
- if resize=='target':
325
- tpose = tpose.crop((max_left, max_top, max_right, max_bottom))
326
- tpose = ImageOps.expand(tpose, border=int(tpose.width*0.2), fill=(0,0,0))
327
-
328
- tpose = resize_and_pad(tpose, in_img)
329
-
330
-
331
- if debug:
332
- tpose.save("out/"+"tar_pose"+str(len(target_poses_cropped))+".png")
333
- target_poses_cropped.append(tpose)
334
-
335
- return in_img, target_poses_cropped, in_pose
336
-
337
-
338
- def prepare_inputs(images, in_vid, fps, bg_remove, dwpose, rembg_session, resize='target', is_app=False):
339
-
340
- in_img, in_pose, train_imgs, train_poses = prepare_inputs_train(images, bg_remove, dwpose, rembg_session)
341
-
342
- in_img, target_poses_cropped, _ = prepare_inputs_inference(in_img, in_vid, fps, dwpose, rembg_session, bg_remove, resize, is_app)
343
-
344
-
345
- return in_img, in_pose, train_imgs, train_poses, target_poses_cropped
346
-
347
-
348
- # Training ===================================================================================================
349
-
350
- # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
351
- check_min_version("0.18.0.dev0")
352
-
353
- logger = get_logger(__name__)
354
-
355
-
356
- class ImageProjModel_p(torch.nn.Module):
357
- """SD model with image prompt"""
358
-
359
- def __init__(self, in_dim, hidden_dim, out_dim, dropout = 0.):
360
- super().__init__()
361
-
362
- self.net = nn.Sequential(
363
- nn.Linear(in_dim, hidden_dim),
364
- nn.GELU(),
365
- nn.Dropout(dropout),
366
- nn.LayerNorm(hidden_dim),
367
- nn.Linear(hidden_dim, out_dim),
368
- nn.Dropout(dropout)
369
- )
370
-
371
- def forward(self, x):
372
- return self.net(x)
373
-
374
- class ImageProjModel_g(torch.nn.Module):
375
- """SD model with image prompt"""
376
-
377
- def __init__(self, in_dim, hidden_dim, out_dim, dropout = 0.):
378
- super().__init__()
379
-
380
- self.net = nn.Sequential(
381
- nn.Linear(in_dim, hidden_dim),
382
- nn.GELU(),
383
- nn.Dropout(dropout),
384
- nn.LayerNorm(hidden_dim),
385
- nn.Linear(hidden_dim, out_dim),
386
- nn.Dropout(dropout)
387
- )
388
-
389
- def forward(self, x): # b, 257,1280
390
- return self.net(x)
391
-
392
-
393
- class SDModel(torch.nn.Module):
394
- """SD model with image prompt"""
395
- def __init__(self, unet) -> None:
396
- super().__init__()
397
- self.image_proj_model_p = ImageProjModel_p(in_dim=1536, hidden_dim=768, out_dim=1024)
398
-
399
- self.unet = unet
400
- self.pose_proj = ControlNetConditioningEmbedding(
401
- conditioning_embedding_channels=320,
402
- block_out_channels=(16, 32, 96, 256),
403
- conditioning_channels=3)
404
-
405
-
406
- def forward(self, noisy_latents, timesteps, simg_f_p, timg_f_g, pose_f):
407
-
408
- extra_image_embeddings_p = self.image_proj_model_p(simg_f_p)
409
- extra_image_embeddings_g = timg_f_g
410
-
411
- print(extra_image_embeddings_p.size())
412
- print(extra_image_embeddings_g.size())
413
-
414
- encoder_image_hidden_states = torch.cat([extra_image_embeddings_p ,extra_image_embeddings_g], dim=1)
415
- pose_cond = self.pose_proj(pose_f)
416
-
417
- pred_noise = self.unet(noisy_latents, timesteps, class_labels=timg_f_g, encoder_hidden_states=encoder_image_hidden_states,my_pose_cond=pose_cond).sample
418
- return pred_noise
419
-
420
- def load_training_checkpoint(model, pcdms_model, tag=None, **kwargs):
421
- #model_sd = torch.load(load_dir, map_location="cpu")["module"]
422
- model_sd = torch.load(
423
- pcdms_model,
424
- map_location="cpu"
425
- )["module"]
426
-
427
-
428
- image_proj_model_dict = {}
429
- pose_proj_dict = {}
430
- unet_dict = {}
431
- for k in model_sd.keys():
432
- if k.startswith("pose_proj"):
433
- pose_proj_dict[k.replace("pose_proj.", "")] = model_sd[k]
434
-
435
- elif k.startswith("image_proj_model_p"):
436
- image_proj_model_dict[k.replace("image_proj_model_p.", "")] = model_sd[k]
437
-
438
- elif k.startswith("image_proj_model."):
439
- image_proj_model_dict[k.replace("image_proj_model.", "")] = model_sd[k]
440
-
441
-
442
- elif k.startswith("unet"):
443
- unet_dict[k.replace("unet.", "")] = model_sd[k]
444
- else:
445
- print(k)
446
-
447
- model.pose_proj.load_state_dict(pose_proj_dict)
448
- model.image_proj_model_p.load_state_dict(image_proj_model_dict)
449
- model.unet.load_state_dict(unet_dict)
450
-
451
- return model, 0, 0
452
-
453
-
454
- def checkpoint_model(checkpoint_folder, ckpt_id, model, epoch, last_global_step, **kwargs):
455
- """Utility function for checkpointing model + optimizer dictionaries
456
- The main purpose for this is to be able to resume training from that instant again
457
- """
458
- checkpoint_state_dict = {
459
- "epoch": epoch,
460
- "last_global_step": last_global_step,
461
- }
462
- # Add extra kwargs too
463
- checkpoint_state_dict.update(kwargs)
464
-
465
- success = model.save_checkpoint(checkpoint_folder, ckpt_id, checkpoint_state_dict)
466
- status_msg = f"checkpointing: checkpoint_folder={checkpoint_folder}, ckpt_id={ckpt_id}"
467
- if success:
468
- logging.info(f"Success {status_msg}")
469
- else:
470
- logging.warning(f"Failure {status_msg}")
471
- return
472
-
473
-
474
- @spaces.GPU(duration=600)
475
- def train(modelId, in_image, in_pose, train_images, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune=True, is_app=False):
476
- logging_dir = 'outputs/logging'
477
- print('start train')
478
-
479
-
480
- progress=gr.Progress(track_tqdm=True)
481
-
482
- accelerator = Accelerator(
483
- log_with=args.report_to,
484
- project_dir=logging_dir,
485
- mixed_precision=args.mixed_precision,
486
- gradient_accumulation_steps=args.gradient_accumulation_steps
487
  )
488
-
489
- # Make one log on every process with the configuration for debugging.
490
- #logging.basicConfig(
491
- # format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
492
- # datefmt="%m/%d/%Y %H:%M:%S",
493
- # level=logging.INFO, )
494
-
495
- print(accelerator.state)
496
- if accelerator.is_local_main_process:
497
- transformers.utils.logging.set_verbosity_warning()
498
- diffusers.utils.logging.set_verbosity_info()
499
- else:
500
- transformers.utils.logging.set_verbosity_error()
501
- diffusers.utils.logging.set_verbosity_error()
502
-
503
- # If passed along, set the training seed now.
504
- set_seed(42)
505
-
506
- # Handle the repository creation
507
- if accelerator.is_main_process:
508
- os.makedirs('outputs', exist_ok=True)
509
 
510
-
511
- """
512
- unet = Stage2_InapintUNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="unet",
513
- in_channels=9, class_embed_type="projection" ,projection_class_embeddings_input_dim=1024,
514
- low_cpu_mem_usage=False, ignore_mismatched_sizes=True)
515
- """
516
- image_encoder_p.requires_grad_(False)
517
- image_encoder_g.requires_grad_(False)
518
- vae.requires_grad_(False)
519
-
520
- sd_model = SDModel(unet=unet)
521
- sd_model.train()
522
-
523
-
524
- if args.gradient_checkpointing:
525
- sd_model.enable_gradient_checkpointing()
526
-
527
-
528
- # Enable TF32 for faster training on Ampere GPUs,
529
- # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
530
- if args.allow_tf32:
531
- torch.backends.cuda.matmul.allow_tf32 = True
532
-
533
- learning_rate = 1e-4
534
- train_batch_size = min(len(train_images), max_batch_size) #len(train_images) % 16
535
-
536
-
537
- # Optimizer creation
538
- params_to_optimize = sd_model.parameters()
539
- optimizer = torch.optim.AdamW(
540
- params_to_optimize,
541
- lr=learning_rate,
542
- betas=(args.adam_beta1, args.adam_beta2),
543
- weight_decay=args.adam_weight_decay,
544
- eps=args.adam_epsilon,
545
  )
546
 
547
- inputs = [{
548
- "source_image": in_image,
549
- "source_pose": in_pose,
550
- "target_image": timg,
551
- "target_pose": tpose,
552
- } for timg, tpose in zip(train_images, train_poses)]
553
-
554
- """
555
- inputs = {[
556
- "source_image": Image.open('imgs/sm.png'),
557
- "source_pose": Image.open('imgs/sm_pose.jpg'),
558
- "target_image": Image.open('imgs/target.png'),
559
- "target_pose": Image.open('imgs/target_pose.jpg'),
560
- ]}
561
- """
562
-
563
- #print(inputs)
564
-
565
- dataset = InpaintDataset(
566
- inputs,
567
- 'imgs/',
568
- size=(args.img_width, args.img_height), # w h
569
- imgp_drop_rate=0.1,
570
- imgg_drop_rate=0.1,
571
- )
572
-
573
- """
574
- dataset = InpaintDataset(
575
- args.json_path,
576
- args.image_root_path,
577
- size=(args.img_width, args.img_height), # w h
578
- imgp_drop_rate=0.1,
579
- imgg_drop_rate=0.1,
580
  )
581
- """
582
-
583
- train_sampler = torch.utils.data.distributed.DistributedSampler(
584
- dataset, num_replicas=accelerator.num_processes, rank=accelerator.process_index, shuffle=True)
585
-
586
- train_dataloader = torch.utils.data.DataLoader(
587
- dataset,
588
- sampler=train_sampler,
589
- collate_fn=InpaintCollate_fn,
590
- batch_size=train_batch_size,
591
- num_workers=0,)
592
-
593
-
594
- # Scheduler and math around the number of training steps.
595
- overrode_max_train_steps = False
596
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
597
- if args.max_train_steps is None:
598
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
599
- overrode_max_train_steps = True
600
- args.max_train_steps = train_steps
601
-
602
- lr_scheduler = get_scheduler(
603
- args.lr_scheduler,
604
- optimizer=optimizer,
605
- num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
606
- num_training_steps=args.max_train_steps * accelerator.num_processes,
607
- num_cycles=args.lr_num_cycles,
608
- power=args.lr_power,
609
- )
610
-
611
- # Prepare everything with our `accelerator`.
612
- sd_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(sd_model, optimizer, train_dataloader, lr_scheduler)
613
-
614
- # For mixed precision training we cast the text_encoder and vae weights to half-precision
615
- # as these models are only used for inference, keeping weights in full precision is not required.
616
- weight_dtype = torch.float32
617
- """
618
- if accelerator.mixed_precision == "fp16":
619
- weight_dtype = torch.float16
620
- elif accelerator.mixed_precision == "bf16":
621
- weight_dtype = torch.bfloat16
622
- """
623
-
624
- # Move vae, unet and text_encoder to device and cast to weight_dtype
625
- vae.to(accelerator.device, dtype=weight_dtype)
626
- sd_model.unet.to(accelerator.device, dtype=weight_dtype)
627
- image_encoder_p.to(accelerator.device, dtype=weight_dtype)
628
- image_encoder_g.to(accelerator.device, dtype=weight_dtype)
629
-
630
- # We need to recalculate our total training steps as the size of the training dataloader may have changed.
631
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
632
- if overrode_max_train_steps:
633
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
634
- # Afterwards we recalculate our number of training epochs
635
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
636
-
637
-
638
- args.num_train_epochs = train_steps
639
-
640
-
641
- # Train!
642
- total_batch_size = (
643
- train_batch_size
644
- * accelerator.num_processes
645
- * args.gradient_accumulation_steps
646
- )
647
-
648
- print("***** Running training *****")
649
- print(f" Num batches each epoch = {len(train_dataloader)}")
650
- print(f" Num Epochs = {args.num_train_epochs}")
651
- print(f" Instantaneous batch size per device = {train_batch_size}")
652
- print(
653
- f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
654
- )
655
- print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
656
- print(f" Total optimization steps = {args.max_train_steps}")
657
-
658
-
659
- if args.resume_from_checkpoint:
660
- # New Code #
661
- # Loads the DeepSpeed checkpoint from the specified path
662
- prior_model, last_epoch, last_global_step = load_training_checkpoint(
663
- sd_model,
664
- pcdms_model,
665
- **{"load_optimizer_states": True, "load_lr_scheduler_states": True},
666
- )
667
- print(f"Resumed from checkpoint: {args.resume_from_checkpoint}, global step: {last_global_step}")
668
- starting_epoch = last_epoch
669
- global_steps = last_global_step
670
- sd_model = sd_model
671
- else:
672
- global_steps = 0
673
- starting_epoch = 0
674
- sd_model = sd_model
675
-
676
- progress_bar = tqdm(range(global_steps, args.max_train_steps), initial=global_steps, desc="Steps",
677
- # Only show the progress bar once on each machine.
678
- disable=not accelerator.is_local_main_process, )
679
-
680
- bsz = train_batch_size
681
-
682
- if not finetune or train_steps == 0:
683
- accelerator.wait_for_everyone()
684
- accelerator.end_training()
685
- return {k: v.cpu() for k, v in sd_model.state_dict().items()}
686
-
687
-
688
- it = range(starting_epoch, args.num_train_epochs)
689
- if is_app:
690
- it = progress.tqdm(it, desc="Fine-tuning")
691
- for epoch in it:
692
- for step, batch in enumerate(train_dataloader):
693
- with accelerator.accumulate(sd_model):
694
- with torch.no_grad():
695
- # Convert images to latent space
696
- latents = vae.encode(batch["source_target_image"].to(dtype=weight_dtype)).latent_dist.sample()
697
- latents = latents * vae.config.scaling_factor
698
-
699
- # Get the masked image latents
700
- masked_latents = vae.encode(batch["vae_source_mask_image"].to(dtype=weight_dtype)).latent_dist.sample()
701
- masked_latents = masked_latents * vae.config.scaling_factor
702
-
703
- bsz = batch["target_image"].size(dim=0)
704
-
705
- # mask
706
- mask1 = torch.ones((bsz, 1, int(args.img_height / 8), int(args.img_width / 8))).to(accelerator.device, dtype=weight_dtype)
707
- mask0 = torch.zeros((bsz, 1, int(args.img_height / 8), int(args.img_width / 8))).to(accelerator.device, dtype=weight_dtype)
708
- mask = torch.cat([mask1, mask0], dim=3)
709
- # Get the image embedding for conditioning
710
- cond_image_feature_p = image_encoder_p(batch["source_image"].to(accelerator.device, dtype=weight_dtype))
711
- cond_image_feature_p = (cond_image_feature_p.last_hidden_state)
712
-
713
-
714
- cond_image_feature_g = image_encoder_g(batch["target_image"].to(accelerator.device, dtype=weight_dtype), ).image_embeds
715
- cond_image_feature_g =cond_image_feature_g.unsqueeze(1)
716
-
717
- # Sample noise that we'll add to the latents
718
- noise = torch.randn_like(latents)
719
- if args.noise_offset:
720
- # https://www.crosslabs.org//blog/diffusion-with-offset-noise
721
- noise += args.noise_offset * torch.randn(
722
- (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
723
- )
724
-
725
- # Sample a random timestep for each image
726
- #timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (train_batch_size,),device=latents.device, )
727
- timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,),device=latents.device, )
728
- timesteps = timesteps.long()
729
-
730
-
731
-
732
- # Add noise to the latents according to the noise magnitude at each timestep (this is the forward diffusion process)
733
- noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
734
-
735
- #print(noisy_latents.size(), mask.size(), masked_latents.size())
736
-
737
- noisy_latents = torch.cat([noisy_latents, mask, masked_latents], dim=1)
738
- # Get the text embedding for conditioning
739
-
740
-
741
- cond_pose = batch["source_target_pose"].to(dtype=weight_dtype)
742
-
743
- #print(noisy_latents.size())
744
- #print(cond_image_feature_p.size())
745
- #print(cond_image_feature_g.size())
746
- #print(cond_pose.size())
747
-
748
- # Predict the noise residual
749
- model_pred = sd_model(noisy_latents, timesteps, cond_image_feature_p,cond_image_feature_g, cond_pose, )
750
-
751
- # Get the target for loss depending on the prediction type
752
- if noise_scheduler.config.prediction_type == "epsilon":
753
- target = noise
754
- elif noise_scheduler.config.prediction_type == "v_prediction":
755
- target = noise_scheduler.get_velocity(latents, noise, timesteps)
756
- else:
757
- raise ValueError(
758
- f"Unknown prediction type {noise_scheduler.config.prediction_type}"
759
- )
760
-
761
- loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
762
-
763
- accelerator.backward(loss)
764
- if accelerator.sync_gradients:
765
- params_to_clip = sd_model.parameters()
766
- accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
767
- optimizer.step()
768
- lr_scheduler.step()
769
- optimizer.zero_grad(set_to_none=args.set_grads_to_none)
770
-
771
- # Checks if the accelerator has performed an optimization step behind the scenes
772
- if accelerator.sync_gradients:
773
- global_steps += 1
774
-
775
- if global_steps >= args.max_train_steps:
776
- break
777
-
778
-
779
- logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
780
- print(logs)
781
- progress_bar.set_postfix(**logs)
782
-
783
- progress_bar.update(1)
784
-
785
- # Create the pipeline using the trained modules and save it.
786
- accelerator.wait_for_everyone()
787
- accelerator.end_training()
788
-
789
- sd_model.unet.cpu()
790
- sd_model.cpu()
791
- del vae
792
- del image_encoder_p
793
- del image_encoder_g
794
-
795
- if save_model: #if global_steps % args.checkpointing_steps == 0 or global_steps == args.max_train_steps:
796
- print('saving', modelId)
797
-
798
- checkpoint_state_dict = {
799
- "epoch": 0,
800
- "module": {k: v.cpu() for k, v in sd_model.state_dict().items()}, #sd_model.state_dict(),
801
- }
802
- print(list(sd_model.state_dict().keys())[:20])
803
- torch.save(checkpoint_state_dict, modelId+".pt")
804
-
805
- del sd_model
806
- gc.collect()
807
- torch.cuda.empty_cache()
808
- print('done train')
809
- print(torch.cuda.memory_allocated()/1024**2)
810
- return
811
-
812
- del sd_model
813
- gc.collect()
814
- torch.cuda.empty_cache()
815
- return {k: v.cpu() for k, v in sd_model.state_dict().items()}
816
-
817
-
818
-
819
-
820
- # Pose-transfer ===================================================================================================
821
-
822
-
823
- device = "cuda"
824
-
825
- class ImageProjModel(torch.nn.Module):
826
- """SD model with image prompt"""
827
- def __init__(self, in_dim, hidden_dim, out_dim, dropout = 0.):
828
- super().__init__()
829
-
830
- self.net = nn.Sequential(
831
- nn.Linear(in_dim, hidden_dim),
832
- nn.GELU(),
833
- nn.Dropout(dropout),
834
- nn.LayerNorm(hidden_dim),
835
- nn.Linear(hidden_dim, out_dim),
836
- nn.Dropout(dropout)
837
- )
838
-
839
- def forward(self, x):
840
- return self.net(x)
841
-
842
- def image_grid(imgs, rows, cols):
843
- assert len(imgs) == rows * cols
844
- w, h = imgs[0].size
845
- print(w, h)
846
- grid = Image.new("RGB", size=(cols * w, rows * h))
847
- grid_w, grid_h = grid.size
848
-
849
- for i, img in enumerate(imgs):
850
- grid.paste(img, box=(i % cols * w, i // cols * h))
851
- return grid
852
-
853
- def load_mydict(modelId, finetuned_model):
854
- if save_model:
855
- model_ckpt_path = modelId+'.pt'
856
- model_sd = torch.load(model_ckpt_path, map_location="cpu")["module"]
857
- else:
858
- model_sd = finetuned_model #torch.load(model_ckpt_path, map_location="cpu")["module"]
859
-
860
- image_proj_model_dict = {}
861
- pose_proj_dict = {}
862
- unet_dict = {}
863
- for k in model_sd.keys():
864
- if k.startswith("pose_proj"):
865
- pose_proj_dict[k.replace("pose_proj.", "")] = model_sd[k]
866
-
867
- elif k.startswith("image_proj_model_p"):
868
- image_proj_model_dict[k.replace("image_proj_model_p.", "")] = model_sd[k]
869
- elif k.startswith("image_proj_model"):
870
- image_proj_model_dict[k.replace("image_proj_model.", "")] = model_sd[k]
871
-
872
-
873
- elif k.startswith("unet"):
874
- unet_dict[k.replace("unet.", "")] = model_sd[k]
875
- else:
876
- print(k)
877
- return image_proj_model_dict, pose_proj_dict, unet_dict
878
-
879
-
880
-
881
- @spaces.GPU(duration=600)
882
- def inference(modelId, in_image, in_pose, target_poses, inference_steps, finetuned_model, vae, unet, image_encoder, is_app=False):
883
- print('start inference')
884
- progress=gr.Progress(track_tqdm=True)
885
-
886
- if not save_model:
887
- finetuned_model = {k: v.cuda() for k, v in finetuned_model.items()}
888
-
889
- device = "cuda"
890
- pretrained_model_name_or_path ="stabilityai/stable-diffusion-2-1-base"
891
- image_encoder_path = "facebook/dinov2-giant"
892
- #model_ckpt_path = "./pcdms_ckpt.pt" # ckpt path
893
- model_ckpt_path = modelId+'.pt'
894
-
895
-
896
- clip_image_processor = CLIPImageProcessor()
897
- img_transform = transforms.Compose([
898
- transforms.ToTensor(),
899
- transforms.Normalize([0.5], [0.5]),
900
- ])
901
-
902
- generator = torch.Generator(device=device).manual_seed(42)
903
-
904
- """
905
- unet = Stage2_InapintUNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16,subfolder="unet",in_channels=9, low_cpu_mem_usage=False, ignore_mismatched_sizes=True).to(device)
906
- vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path,subfolder="vae").to(device, dtype=torch.float16)
907
- image_encoder = Dinov2Model.from_pretrained(image_encoder_path).to(device, dtype=torch.float16)
908
- """
909
- noise_scheduler = DDIMScheduler(
910
- num_train_timesteps=1000,
911
- beta_start=0.00085,
912
- beta_end=0.012,
913
- beta_schedule="scaled_linear",
914
- clip_sample=False,
915
- set_alpha_to_one=False,
916
- steps_offset=1,
917
- )
918
-
919
- unet = unet.to(device, dtype=torch.float16)
920
- vae = vae.to(device, dtype=torch.float16)
921
- image_encoder = image_encoder.to(device, dtype=torch.float16)
922
-
923
-
924
- image_proj_model = ImageProjModel(in_dim=1536, hidden_dim=768, out_dim=1024).to(device).to(dtype=torch.float16)
925
- pose_proj_model = ControlNetConditioningEmbedding(
926
- conditioning_embedding_channels=320,
927
- block_out_channels=(16, 32, 96, 256),
928
- conditioning_channels=3).to(device).to(dtype=torch.float16)
929
-
930
-
931
- # load weight
932
- print('loading', modelId)
933
- image_proj_model_dict, pose_proj_dict, unet_dict = load_mydict(modelId, finetuned_model)
934
- print('loaded', modelId)
935
- image_proj_model.load_state_dict(image_proj_model_dict)
936
- pose_proj_model.load_state_dict(pose_proj_dict)
937
- unet.load_state_dict(unet_dict)
938
-
939
-
940
- pipe = PCDMsPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base", unet=unet, torch_dtype=torch.float16, scheduler=noise_scheduler,feature_extractor=None,safety_checker=None).to(device)
941
-
942
- print('====================== model load finish ===================')
943
-
944
- results = []
945
- progress_bar = tqdm(range(len(target_poses)), initial=0, desc="Frames")
946
-
947
-
948
- it = target_poses
949
- if is_app:
950
- it = progress.tqdm(it, desc="Pose Transfer")
951
- for pose in it:
952
-
953
- num_samples = 1
954
- image_size = (512, 512)
955
- s_img_path = 'imgs/'+input_img # input image 1
956
- #target_pose_img = 'imgs/pose_'+str(n)+'.png' # input image 2
957
-
958
- #t_pose = inference_pose(target_pose_img, image_size=(image_size[1], image_size[0])).resize(image_size, Image.BICUBIC)
959
- #t_pose = Image.open(target_pose_img).convert("RGB").resize((image_size), Image.BICUBIC)
960
- t_pose = pose.convert("RGB").resize((image_size), Image.BICUBIC)
961
- #t_pose = resize_and_pad(pose.convert("RGB"))
962
-
963
-
964
- #s_img = Image.open(s_img_path)
965
- width_orig, height_orig = in_image.size
966
- s_img = in_image.convert("RGB").resize(image_size, Image.BICUBIC)
967
- #s_img = resize_and_pad(in_image.convert("RGB"))
968
- black_image = Image.new("RGB", s_img.size, (0, 0, 0)).resize(image_size, Image.BICUBIC)
969
-
970
- s_img_t_mask = Image.new("RGB", (s_img.width * 2, s_img.height))
971
- s_img_t_mask.paste(s_img, (0, 0))
972
- s_img_t_mask.paste(black_image, (s_img.width, 0))
973
-
974
- #s_pose = inference_pose(s_img_path, image_size=(image_size[1], image_size[0])).resize(image_size, Image.BICUBIC)
975
- #s_pose = Image.open('imgs/sm_pose.jpg').convert("RGB").resize(image_size, Image.BICUBIC)
976
- s_pose = in_pose.convert("RGB").resize(image_size, Image.BICUBIC)
977
- #s_pose = resize_and_pad(in_pose.convert("RGB"))
978
- print('source image width: {}, height: {}'.format(s_pose.width, s_pose.height))
979
- #t_pose = Image.open(target_pose_img).convert("RGB").resize((image_size), Image.BICUBIC)
980
-
981
- st_pose = Image.new("RGB", (s_pose.width * 2, s_pose.height))
982
- st_pose.paste(s_pose, (0, 0))
983
- st_pose.paste(t_pose, (s_pose.width, 0))
984
-
985
-
986
- clip_s_img = clip_image_processor(images=s_img, return_tensors="pt").pixel_values
987
- vae_image = torch.unsqueeze(img_transform(s_img_t_mask), 0)
988
- cond_st_pose = torch.unsqueeze(img_transform(st_pose), 0)
989
-
990
- mask1 = torch.ones((1, 1, int(image_size[0] / 8), int(image_size[1] / 8))).to(device, dtype=torch.float16)
991
- mask0 = torch.zeros((1, 1, int(image_size[0] / 8), int(image_size[1] / 8))).to(device, dtype=torch.float16)
992
- mask = torch.cat([mask1, mask0], dim=3)
993
-
994
-
995
- with torch.inference_mode():
996
- cond_pose = pose_proj_model(cond_st_pose.to(dtype=torch.float16, device=device))
997
- simg_mask_latents = pipe.vae.encode(vae_image.to(device, dtype=torch.float16)).latent_dist.sample()
998
- simg_mask_latents = simg_mask_latents * 0.18215
999
-
1000
- images_embeds = image_encoder(clip_s_img.to(device, dtype=torch.float16)).last_hidden_state
1001
- image_prompt_embeds = image_proj_model(images_embeds)
1002
- uncond_image_prompt_embeds = image_proj_model(torch.zeros_like(images_embeds))
1003
-
1004
- bs_embed, seq_len, _ = image_prompt_embeds.shape
1005
- image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
1006
- image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
1007
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
1008
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
1009
-
1010
- output, _ = pipe(
1011
- simg_mask_latents= simg_mask_latents,
1012
- mask = mask,
1013
- cond_pose = cond_pose,
1014
- prompt_embeds=image_prompt_embeds,
1015
- negative_prompt_embeds=uncond_image_prompt_embeds,
1016
- height=image_size[1],
1017
- width=image_size[0]*2,
1018
- num_images_per_prompt=num_samples,
1019
- guidance_scale=2.0,
1020
- generator=generator,
1021
- num_inference_steps=inference_steps,
1022
- )
1023
-
1024
- output = output.images[-1]
1025
-
1026
- result = output.crop((image_size[0], 0, image_size[0] * 2, image_size[1]))
1027
- result = result.resize((width_orig, height_orig), Image.BICUBIC)
1028
- #result = remove_zero_pad(result)
1029
-
1030
- if debug:
1031
- result.save('out/'+str(len(results))+'.png')
1032
- results.append(result)
1033
- progress_bar.update(1)
1034
-
1035
- del unet
1036
- del vae
1037
- del image_encoder
1038
- del image_proj_model
1039
- del pose_proj_model
1040
-
1041
- if not save_model:
1042
- del finetuned_model
1043
-
1044
- gc.collect()
1045
- torch.cuda.empty_cache()
1046
- print(torch.cuda.memory_allocated()/1024**2)
1047
-
1048
- return results
1049
-
1050
-
1051
- def gen_vid(frames, video_name, fps, codec):
1052
- progress=gr.Progress(track_tqdm=True)
1053
-
1054
- frame = cv2.cvtColor(np.array(frames[0]), cv2.COLOR_RGB2BGR)
1055
- height, width, layers = frame.shape
1056
-
1057
- #video = cv2.VideoWriter(video_name, 0, 1, (width,height))
1058
- if codec == 'mp4':
1059
- video = cv2.VideoWriter(video_name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
1060
- else:
1061
- video = cv2.VideoWriter(video_name, cv2.VideoWriter_fourcc(*'VP90'), fps, (width, height))
1062
-
1063
- for r in progress.tqdm(frames, desc="Creating video"):
1064
- image = cv2.cvtColor(np.array(r), cv2.COLOR_RGB2BGR)
1065
- video.write(image)
1066
-
1067
- #cv2.destroyAllWindows()
1068
- #video.release()
1069
-
1070
-
1071
-
1072
- def run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, resize_inputs=True, finetune=True, is_app=False):
1073
- print("==== Load Models ====")
1074
- dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
1075
-
1076
- print("==== Pose Detection ====")
1077
- if resize_inputs:
1078
- resize = 'target'
1079
- else:
1080
- resize = 'none'
1081
- in_img, in_pose, train_imgs, train_poses, target_poses = prepare_inputs(images, video_path, fps, bg_remove, dwpose, rembg_session, resize=resize, is_app=is_app)
1082
-
1083
- if save_model:
1084
- train("fine_tuned_pcdms", in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
1085
- print('next')
1086
- results = inference("fine_tuned_pcdms", in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
1087
-
1088
- else:
1089
- print("==== Finetuning ====")
1090
- finetuned_model = train("fine_tuned_pcdms", in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
1091
-
1092
- print("==== Pose Transfer ====")
1093
- results = inference("fine_tuned_pcdms", in_img, in_pose, target_poses, inference_steps, finetuned_model, vae, unet, image_encoder_p, is_app)
1094
-
1095
- return results
1096
-
1097
-
1098
- def run_train(images, train_steps=100, modelId="fine_tuned_pcdms", bg_remove=True, resize_inputs=True):
1099
- finetune=True
1100
- is_app=True
1101
- images = [img[0] for img in images]
1102
-
1103
- dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
1104
-
1105
- if resize_inputs:
1106
- resize = 'target'
1107
- else:
1108
- resize = 'none'
1109
-
1110
- in_img, in_pose, train_imgs, train_poses = prepare_inputs_train(images, bg_remove, dwpose, rembg_session)
1111
-
1112
- train(modelId, in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
1113
-
1114
-
1115
- def run_inference(images, video_path, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=True, resize_inputs=True):
1116
- finetune=True
1117
- is_app=True
1118
-
1119
-
1120
- dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
1121
-
1122
- if not os.path.exists(modelId+".pt"):
1123
- run_train(images, train_steps, modelId, bg_remove, resize_inputs)
1124
-
1125
- images = [img[0] for img in images]
1126
- in_img = images[0]
1127
-
1128
- in_img, target_poses, in_pose = prepare_inputs_inference(in_img, video_path, fps, dwpose, rembg_session, bg_remove, 'target', is_app)
1129
-
1130
- results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
1131
- #urls = save_temp_imgs(results)
1132
-
1133
- if should_gen_vid:
1134
- if debug:
1135
- gen_vid(results, out_vid+'.mp4', fps, 'mp4')
1136
- else:
1137
- gen_vid(results, out_vid+'.webm', fps, 'webm')
1138
-
1139
-
1140
- # postprocessing
1141
- results = [removebg(img, rembg_session, True) for img in results]
1142
- results = [img_pad(img, img_width, img_height, True) for img in results]
1143
-
1144
- print("Done!")
1145
-
1146
- return out_vid+'.webm', results, getThumbnails(results)
1147
-
1148
-
1149
- def run_app(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, resize_inputs=True):
1150
-
1151
- images = [img[0] for img in images]
1152
-
1153
- results = run(images, video_path, train_steps, inference_steps, fps, bg_remove, resize_inputs, finetune=True, is_app=True)
1154
-
1155
-
1156
- print("==== Video generation ====")
1157
- out_vid = f"out_{uuid.uuid4()}"
1158
-
1159
- if debug:
1160
- gen_vid(results, out_vid+'.mp4', fps, 'mp4')
1161
- else:
1162
- gen_vid(results, out_vid+'.webm', fps, 'webm')
1163
-
1164
-
1165
-
1166
- print("Done!")
1167
-
1168
- return out_vid+'.webm', results
1169
-
1170
-
1171
 
1172
- """
1173
- train_steps = 100
1174
- inference_steps = 10
1175
- fps = 12
1176
- """
1177
 
1178
- """
1179
- iface = gr.Interface(
1180
- fn=run,
1181
- inputs=[
1182
- gr.Gallery(type="pil", label="Images of the Character"),
1183
- gr.Video(label="Motion-Capture Video"),
1184
- gr.Number(label="Training steps", value=100),
1185
- gr.Number(label="Inference steps", value=10),
1186
- gr.Number(label="Output frame rate", value=12),
1187
- gr.Checkbox(label="Remove background", value=False),
1188
- ],
1189
- outputs=[gr.Video(label="Result"), gr.Gallery(type="pil", label="Frames")],
1190
- title="Keyframes AI",
1191
- description="Upload images of your character and a motion-capture video to generate an animation of the character.",
1192
- )
1193
- """
1194
 
1195
 
1196
 
 
1
+ from main import run_app, run_train, run_inference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import spaces
 
4
  from PIL import Image
5
  import cv2
6
  import os
7
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ with gr.Blocks() as demo:
10
+ with gr.Row():
11
+ with gr.Column():
12
+ char_imgs = gr.Gallery(type="pil", label="Images of the Character")
13
+ mocap = gr.Video(label="Motion-Capture Video")
14
+ tr_steps = gr.Number(label="Training steps", value=10)
15
+ inf_steps = gr.Number(label="Inference steps", value=10)
16
+ fps = gr.Number(label="Output frame rate", value=12)
17
+ modelId = gr.Text(label="Model Id", value="fine_tuned_pcdms")
18
+ remove_bg = gr.Checkbox(label="Remove background", value=False)
19
+ resize_inputs = gr.Checkbox(label="Resize images to match video", value=True)
20
+ img_width = gr.Number(label="Output width", value=1920)
21
+ img_height = gr.Number(label="Output height", value=1080)
22
+ train_btn = gr.Button(value="Train")
23
+ inference_btn = gr.Button(value="Inference")
24
+ submit_btn = gr.Button(value="Generate")
25
+ with gr.Column():
26
+ animation = gr.Video(label="Result")
27
+ frames = gr.Gallery(type="pil", label="Frames", format="png")
28
+ frames_thumb = gr.Gallery(type="pil", label="Thumbnails", format="png")
29
 
30
+ submit_btn.click(
31
+ run_app, inputs=[char_imgs, mocap, tr_steps, inf_steps, fps, remove_bg, resize_inputs], outputs=[animation, frames]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ train_btn.click(
35
+ run_train, inputs=[char_imgs, tr_steps, modelId, remove_bg, resize_inputs], outputs=[]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  )
37
 
38
+ inference_btn.click(
39
+ run_inference, inputs=[char_imgs, mocap, tr_steps, inf_steps, fps, modelId, img_width, img_height, remove_bg, resize_inputs], outputs=[animation, frames, frames_thumb]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
 
 
 
 
 
42
 
43
+ demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
 
46