acmyu commited on
Commit
fb5b0b1
·
1 Parent(s): 28c5a09

pad results to specified resolution

Browse files
Files changed (2) hide show
  1. app.py +1178 -26
  2. main.py +15 -6
app.py CHANGED
@@ -1,44 +1,1196 @@
1
- from main import run_app, run_train, run_inference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import spaces
 
4
  from PIL import Image
5
  import cv2
6
  import os
7
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- with gr.Blocks() as demo:
10
- with gr.Row():
11
- with gr.Column():
12
- char_imgs = gr.Gallery(type="pil", label="Images of the Character")
13
- mocap = gr.Video(label="Motion-Capture Video")
14
- tr_steps = gr.Number(label="Training steps", value=10)
15
- inf_steps = gr.Number(label="Inference steps", value=10)
16
- fps = gr.Number(label="Output frame rate", value=12)
17
- modelId = gr.Text(label="Model Id", value="fine_tuned_pcdms")
18
- remove_bg = gr.Checkbox(label="Remove background", value=False)
19
- resize_inputs = gr.Checkbox(label="Resize images to match video", value=True)
20
- train_btn = gr.Button(value="Train")
21
- inference_btn = gr.Button(value="Inference")
22
- submit_btn = gr.Button(value="Generate")
23
- with gr.Column():
24
- animation = gr.Video(label="Result")
25
- frames = gr.Gallery(type="pil", label="Frames", format="png")
26
- frames_thumb = gr.Gallery(type="pil", label="Thumbnails", format="png")
27
 
28
- submit_btn.click(
29
- run_app, inputs=[char_imgs, mocap, tr_steps, inf_steps, fps, remove_bg, resize_inputs], outputs=[animation, frames]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- train_btn.click(
33
- run_train, inputs=[char_imgs, tr_steps, modelId, remove_bg, resize_inputs], outputs=[]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  )
35
 
36
- inference_btn.click(
37
- run_inference, inputs=[char_imgs, mocap, tr_steps, inf_steps, fps, modelId, remove_bg, resize_inputs], outputs=[animation, frames, frames_thumb]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
 
 
 
 
 
40
 
41
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
 
44
 
 
1
+ import logging
2
+ import math
3
+ import os
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+ from diffusers.models.controlnet import ControlNetConditioningEmbedding
6
+ import torch
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+ import torch.utils.checkpoint
10
+ import transformers
11
+ from accelerate import Accelerator
12
+ from accelerate.logging import get_logger
13
+ from accelerate.utils import ProjectConfiguration, set_seed
14
+
15
+ from tqdm.auto import tqdm
16
+ from src.configs.stage2_config import args
17
+
18
+ import diffusers
19
+ from diffusers import (
20
+ AutoencoderKL,
21
+ DDPMScheduler,
22
+ )
23
+ from diffusers.optimization import get_scheduler
24
+ from diffusers.utils import check_min_version, is_wandb_available
25
+ from src.dataset.stage2_dataset import InpaintDataset, InpaintCollate_fn
26
+ from transformers import CLIPVisionModelWithProjection
27
+ from transformers import Dinov2Model
28
+ from src.models.stage2_inpaint_unet_2d_condition import Stage2_InapintUNet2DConditionModel
29
+
30
+
31
+
32
+ import glob
33
+ import os
34
+ import torch
35
+ from torch import nn
36
+ from PIL import Image, ImageOps
37
+ import numpy as np
38
+ from diffusers import UniPCMultistepScheduler
39
+ from src.models.stage2_inpaint_unet_2d_condition import Stage2_InapintUNet2DConditionModel
40
+
41
+ from torchvision import transforms
42
+ from diffusers.models.controlnet import ControlNetConditioningEmbedding
43
+ from transformers import CLIPImageProcessor
44
+ from transformers import Dinov2Model
45
+ from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel,ControlNetModel,DDIMScheduler
46
+ from src.pipelines.PCDMs_pipeline import PCDMsPipeline
47
+ #from single_extract_pose import inference_pose
48
+
49
 
50
  import spaces
51
+ from easy_dwpose import DWposeDetector
52
  from PIL import Image
53
  import cv2
54
  import os
55
  import gradio as gr
56
+ import rembg
57
+ import uuid
58
+ import gc
59
+ from numba import cuda
60
+ import requests
61
+ import uuid
62
+
63
+ from huggingface_hub import hf_hub_download, HfApi
64
+
65
+
66
+ # Inputs ===================================================================================================
67
+
68
+ input_img = "sm.png"
69
+ train_imgs = ["target.png"]
70
+ in_vid = "walk.mp4"
71
+ out_vid = 'out.mp4'
72
+
73
+ """
74
+ train_steps = 100
75
+ inference_steps = 10
76
+ fps = 12
77
+ """
78
+
79
+ debug = False
80
+ save_model = True
81
+ should_gen_vid = False
82
+ max_batch_size = 8
83
+
84
+
85
+ def save_temp_imgs(imgs):
86
+ os.makedirs('temp', exist_ok=True)
87
+ results = []
88
+
89
+ api = HfApi()
90
+
91
+
92
+ for i, img in enumerate(imgs):
93
+
94
+ #img_name = 'temp/'+str(uuid.uuid4())+'.png'
95
+ img_name = 'temp/'+str(i)+'.png'
96
+ img.save(img_name)
97
+
98
+ """
99
+ url = 'https://tmpfiles.org/api/v1/upload'
100
+
101
+ try:
102
+ response = requests.post(url, files={'file': open(img_name, 'rb')})
103
+
104
+ # Check for successful response (status code 200)
105
+ response.raise_for_status()
106
+
107
+ # Print the server's response
108
+ print("Status Code:", response.status_code)
109
+
110
+ data = response.json()
111
+ print("Response JSON:", data)
112
+ results.append(data['data']['url'])
113
+
114
+ except requests.exceptions.RequestException as e:
115
+ print(f"An error occurred: {e}")
116
+ """
117
+
118
+ results.append('https://huggingface.co/datasets/acmyu/KeyframesAIFiles/resolve/main/'+img_name)
119
+
120
+ api.upload_file(
121
+ path_or_fileobj='temp',
122
+ path_in_repo='temp',
123
+ repo_id="acmyu/KeyframesAIFiles",
124
+ repo_type="dataset",
125
+ )
126
+
127
+ return results
128
+
129
+
130
+ def getThumbnails(imgs):
131
+ thumbs = []
132
+ thumb_size = (512, 512)
133
+ for img in imgs:
134
+ th = img.copy()
135
+ th.thumbnail(thumb_size)
136
+ thumbs.append(th)
137
+ return thumbs
138
+
139
+
140
+ # Pose detection ==============================================================================================
141
+
142
+ def load_models():
143
+ dwpose = DWposeDetector(device="cpu")
144
+ rembg_session = rembg.new_session("u2netp")
145
+
146
+ pcdms_model = hf_hub_download(repo_id="acmyu/PCDMs", filename="pcdms_ckpt.pt")
147
+
148
+ # Load scheduler
149
+ noise_scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="scheduler")
150
+
151
+ # Load model
152
+ image_encoder_p = Dinov2Model.from_pretrained('facebook/dinov2-giant')
153
+ image_encoder_g = CLIPVisionModelWithProjection.from_pretrained('laion/CLIP-ViT-H-14-laion2B-s32B-b79K')#("openai/clip-vit-base-patch32")
154
+
155
+ vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="vae")
156
+ unet = Stage2_InapintUNet2DConditionModel.from_pretrained(
157
+ "stabilityai/stable-diffusion-2-1-base",
158
+ torch_dtype=torch.float16,
159
+ subfolder="unet",
160
+ in_channels=9,
161
+ low_cpu_mem_usage=False,
162
+ ignore_mismatched_sizes=True)
163
+
164
+
165
+ return dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet
166
+
167
+
168
+ #load_models()
169
+
170
+ def img_pad(img, tw, th, transparent=False):
171
+ img.thumbnail((tw, th))
172
+ if transparent:
173
+ new_img = Image.new('RGBA', (tw, th), (0, 0, 0, 0))
174
+ else:
175
+ new_img = Image.new("RGB", (tw, th), (0, 0, 0))
176
+ left = (tw - img.width) // 2
177
+ top = (th - img.height) // 2
178
+ new_img.paste(img, (left, top))
179
+ return new_img
180
+
181
+
182
+ def resize_and_pad(img, target_img):
183
+ tw, th = target_img.size
184
+ w, h = img.size
185
+
186
+ if tw/th > w/h:
187
+ tw = int(th * w/h)
188
+ elif tw/th < w/h:
189
+ th = int(tw * h/w)
190
+
191
+ img = img.resize((tw, th), Image.BICUBIC)
192
+
193
+ tw, th = target_img.size
194
+
195
+ return img_pad(img, tw, th)
196
+
197
+
198
+ def remove_zero_pad(image):
199
+ image = np.array(image)
200
+ dummy = np.argwhere(image != 0) # assume blackground is zero
201
+ max_y = dummy[:, 0].max()
202
+ min_y = dummy[:, 0].min()
203
+ min_x = dummy[:, 1].min()
204
+ max_x = dummy[:, 1].max()
205
+ crop_image = image[min_y:max_y, min_x:max_x]
206
+
207
+ return Image.fromarray(crop_image)
208
+
209
+
210
+ def get_pose(img, dwpose, outfile, crop=False):
211
+ #pil_image = Image.open("imgs/"+img).convert("RGB")
212
+ #skeleton = dwpose(pil_image, output_type="np", include_hands=True, include_face=False)
213
+
214
+ #img.thumbnail((512,512))
215
+ out_img = dwpose(img, include_hands=True, include_face=False)
216
+
217
+ #print(pose['bodies'])
218
+
219
+ if crop:
220
+ bbox = out_img.getbbox()
221
+ out_img = out_img.crop(bbox)
222
+ out_img = ImageOps.expand(out_img, border=int(out_img.width*0.2), fill=(0,0,0))
223
+
224
+ return out_img
225
+
226
+
227
+ def extract_frames(video_path, fps):
228
+ video_capture = cv2.VideoCapture(video_path)
229
+ frame_count = 0
230
+ frames = []
231
+
232
+ fps_in = video_capture.get(cv2.CAP_PROP_FPS)
233
+ fps_out = fps
234
+
235
+ index_in = -1
236
+ index_out = -1
237
+
238
+ while True:
239
+ success = video_capture.grab()
240
+ if not success: break
241
+ index_in += 1
242
+
243
+ out_due = int(index_in / fps_in * fps_out)
244
+ if out_due > index_out:
245
+ success, frame = video_capture.retrieve()
246
+ if not success:
247
+ break
248
+ index_out += 1
249
+
250
+ frame_count += 1
251
+ frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
252
+
253
+ video_capture.release()
254
+ print(f"Extracted {frame_count} frames")
255
+ return frames
256
+
257
+
258
+ def removebg(img, rembg_session, transparent=False):
259
+
260
+ if transparent:
261
+ result = Image.new('RGBA', img.size, (0, 0, 0, 0))
262
+ else:
263
+ result = Image.new("RGB", img.size, "#ffffff")
264
+ out = rembg.remove(img, session=rembg_session)
265
+ result.paste(out, mask=out)
266
+ return result
267
+
268
 
269
+ def prepare_inputs_train(images, bg_remove, dwpose, rembg_session):
270
+ print("remove background", bg_remove)
271
+ if bg_remove:
272
+ images = [removebg(img, rembg_session) for img in images]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
+ in_img = images[0]
275
+ in_pose = get_pose(in_img, dwpose, "in_pose.png")
276
+ train_poses = []
277
+ train_imgs = [resize_and_pad(img, in_img) for img in images[1:]]
278
+
279
+ for i, img in enumerate(train_imgs):
280
+ train_poses.append(get_pose(img, dwpose, "tr_pose"+str(i)+".png"))
281
+
282
+ return in_img, in_pose, train_imgs, train_poses
283
+
284
+
285
+ def prepare_inputs_inference(in_img, in_vid, fps, dwpose, rembg_session, bg_remove, resize='target', is_app=False):
286
+ progress=gr.Progress(track_tqdm=True)
287
+
288
+ print("prepare_inputs_inference")
289
+
290
+ in_pose = get_pose(in_img, dwpose, "in_pose.png")
291
+
292
+ frames = extract_frames(in_vid, fps)
293
+ print("remove background", bg_remove)
294
+ if bg_remove:
295
+ in_img = removebg(in_img, rembg_session)
296
+ #frames = [removebg(img, rembg_session) for img in frames]
297
+ if debug:
298
+ for i, frame in enumerate(frames):
299
+ frame.save("out/frame_"+str(i)+".png")
300
+
301
+ print("vid: ", in_vid, fps)
302
+
303
+ progress_bar = tqdm(range(len(frames)), initial=0, desc="Frames")
304
+ target_poses = []
305
+ max_left = max_top = 999999
306
+ max_right = max_bottom = 0
307
+ it = frames
308
+ if is_app:
309
+ it = progress.tqdm(frames, desc="Pose Detection")
310
+ for f in it:
311
+ tpose = get_pose(f, dwpose, "tar_pose"+str(len(target_poses))+".png")
312
+ target_poses.append(tpose)
313
+ progress_bar.update(1)
314
+
315
+ bbox = tpose.getbbox()
316
+ left, top, right, bottom = bbox
317
+ max_left = min(max_left, left)
318
+ max_top = min(max_top, top)
319
+ max_right = max(max_right, right)
320
+ max_bottom = max(max_bottom, bottom)
321
+
322
+ target_poses_cropped = []
323
+ for tpose in target_poses:
324
+ if resize=='target':
325
+ tpose = tpose.crop((max_left, max_top, max_right, max_bottom))
326
+ tpose = ImageOps.expand(tpose, border=int(tpose.width*0.2), fill=(0,0,0))
327
+
328
+ tpose = resize_and_pad(tpose, in_img)
329
+
330
+
331
+ if debug:
332
+ tpose.save("out/"+"tar_pose"+str(len(target_poses_cropped))+".png")
333
+ target_poses_cropped.append(tpose)
334
+
335
+ return in_img, target_poses_cropped, in_pose
336
+
337
+
338
+ def prepare_inputs(images, in_vid, fps, bg_remove, dwpose, rembg_session, resize='target', is_app=False):
339
+
340
+ in_img, in_pose, train_imgs, train_poses = prepare_inputs_train(images, bg_remove, dwpose, rembg_session)
341
+
342
+ in_img, target_poses_cropped, _ = prepare_inputs_inference(in_img, in_vid, fps, dwpose, rembg_session, bg_remove, resize, is_app)
343
+
344
+
345
+ return in_img, in_pose, train_imgs, train_poses, target_poses_cropped
346
+
347
+
348
+ # Training ===================================================================================================
349
+
350
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
351
+ check_min_version("0.18.0.dev0")
352
+
353
+ logger = get_logger(__name__)
354
+
355
+
356
+ class ImageProjModel_p(torch.nn.Module):
357
+ """SD model with image prompt"""
358
+
359
+ def __init__(self, in_dim, hidden_dim, out_dim, dropout = 0.):
360
+ super().__init__()
361
+
362
+ self.net = nn.Sequential(
363
+ nn.Linear(in_dim, hidden_dim),
364
+ nn.GELU(),
365
+ nn.Dropout(dropout),
366
+ nn.LayerNorm(hidden_dim),
367
+ nn.Linear(hidden_dim, out_dim),
368
+ nn.Dropout(dropout)
369
+ )
370
+
371
+ def forward(self, x):
372
+ return self.net(x)
373
+
374
+ class ImageProjModel_g(torch.nn.Module):
375
+ """SD model with image prompt"""
376
+
377
+ def __init__(self, in_dim, hidden_dim, out_dim, dropout = 0.):
378
+ super().__init__()
379
+
380
+ self.net = nn.Sequential(
381
+ nn.Linear(in_dim, hidden_dim),
382
+ nn.GELU(),
383
+ nn.Dropout(dropout),
384
+ nn.LayerNorm(hidden_dim),
385
+ nn.Linear(hidden_dim, out_dim),
386
+ nn.Dropout(dropout)
387
+ )
388
+
389
+ def forward(self, x): # b, 257,1280
390
+ return self.net(x)
391
+
392
+
393
+ class SDModel(torch.nn.Module):
394
+ """SD model with image prompt"""
395
+ def __init__(self, unet) -> None:
396
+ super().__init__()
397
+ self.image_proj_model_p = ImageProjModel_p(in_dim=1536, hidden_dim=768, out_dim=1024)
398
+
399
+ self.unet = unet
400
+ self.pose_proj = ControlNetConditioningEmbedding(
401
+ conditioning_embedding_channels=320,
402
+ block_out_channels=(16, 32, 96, 256),
403
+ conditioning_channels=3)
404
+
405
+
406
+ def forward(self, noisy_latents, timesteps, simg_f_p, timg_f_g, pose_f):
407
+
408
+ extra_image_embeddings_p = self.image_proj_model_p(simg_f_p)
409
+ extra_image_embeddings_g = timg_f_g
410
+
411
+ print(extra_image_embeddings_p.size())
412
+ print(extra_image_embeddings_g.size())
413
+
414
+ encoder_image_hidden_states = torch.cat([extra_image_embeddings_p ,extra_image_embeddings_g], dim=1)
415
+ pose_cond = self.pose_proj(pose_f)
416
+
417
+ pred_noise = self.unet(noisy_latents, timesteps, class_labels=timg_f_g, encoder_hidden_states=encoder_image_hidden_states,my_pose_cond=pose_cond).sample
418
+ return pred_noise
419
+
420
+ def load_training_checkpoint(model, pcdms_model, tag=None, **kwargs):
421
+ #model_sd = torch.load(load_dir, map_location="cpu")["module"]
422
+ model_sd = torch.load(
423
+ pcdms_model,
424
+ map_location="cpu"
425
+ )["module"]
426
+
427
+
428
+ image_proj_model_dict = {}
429
+ pose_proj_dict = {}
430
+ unet_dict = {}
431
+ for k in model_sd.keys():
432
+ if k.startswith("pose_proj"):
433
+ pose_proj_dict[k.replace("pose_proj.", "")] = model_sd[k]
434
+
435
+ elif k.startswith("image_proj_model_p"):
436
+ image_proj_model_dict[k.replace("image_proj_model_p.", "")] = model_sd[k]
437
+
438
+ elif k.startswith("image_proj_model."):
439
+ image_proj_model_dict[k.replace("image_proj_model.", "")] = model_sd[k]
440
+
441
+
442
+ elif k.startswith("unet"):
443
+ unet_dict[k.replace("unet.", "")] = model_sd[k]
444
+ else:
445
+ print(k)
446
+
447
+ model.pose_proj.load_state_dict(pose_proj_dict)
448
+ model.image_proj_model_p.load_state_dict(image_proj_model_dict)
449
+ model.unet.load_state_dict(unet_dict)
450
+
451
+ return model, 0, 0
452
+
453
+
454
+ def checkpoint_model(checkpoint_folder, ckpt_id, model, epoch, last_global_step, **kwargs):
455
+ """Utility function for checkpointing model + optimizer dictionaries
456
+ The main purpose for this is to be able to resume training from that instant again
457
+ """
458
+ checkpoint_state_dict = {
459
+ "epoch": epoch,
460
+ "last_global_step": last_global_step,
461
+ }
462
+ # Add extra kwargs too
463
+ checkpoint_state_dict.update(kwargs)
464
+
465
+ success = model.save_checkpoint(checkpoint_folder, ckpt_id, checkpoint_state_dict)
466
+ status_msg = f"checkpointing: checkpoint_folder={checkpoint_folder}, ckpt_id={ckpt_id}"
467
+ if success:
468
+ logging.info(f"Success {status_msg}")
469
+ else:
470
+ logging.warning(f"Failure {status_msg}")
471
+ return
472
+
473
+
474
+ @spaces.GPU(duration=600)
475
+ def train(modelId, in_image, in_pose, train_images, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune=True, is_app=False):
476
+ logging_dir = 'outputs/logging'
477
+ print('start train')
478
+
479
+
480
+ progress=gr.Progress(track_tqdm=True)
481
+
482
+ accelerator = Accelerator(
483
+ log_with=args.report_to,
484
+ project_dir=logging_dir,
485
+ mixed_precision=args.mixed_precision,
486
+ gradient_accumulation_steps=args.gradient_accumulation_steps
487
  )
488
+
489
+ # Make one log on every process with the configuration for debugging.
490
+ #logging.basicConfig(
491
+ # format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
492
+ # datefmt="%m/%d/%Y %H:%M:%S",
493
+ # level=logging.INFO, )
494
+
495
+ print(accelerator.state)
496
+ if accelerator.is_local_main_process:
497
+ transformers.utils.logging.set_verbosity_warning()
498
+ diffusers.utils.logging.set_verbosity_info()
499
+ else:
500
+ transformers.utils.logging.set_verbosity_error()
501
+ diffusers.utils.logging.set_verbosity_error()
502
+
503
+ # If passed along, set the training seed now.
504
+ set_seed(42)
505
+
506
+ # Handle the repository creation
507
+ if accelerator.is_main_process:
508
+ os.makedirs('outputs', exist_ok=True)
509
 
510
+
511
+ """
512
+ unet = Stage2_InapintUNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="unet",
513
+ in_channels=9, class_embed_type="projection" ,projection_class_embeddings_input_dim=1024,
514
+ low_cpu_mem_usage=False, ignore_mismatched_sizes=True)
515
+ """
516
+ image_encoder_p.requires_grad_(False)
517
+ image_encoder_g.requires_grad_(False)
518
+ vae.requires_grad_(False)
519
+
520
+ sd_model = SDModel(unet=unet)
521
+ sd_model.train()
522
+
523
+
524
+ if args.gradient_checkpointing:
525
+ sd_model.enable_gradient_checkpointing()
526
+
527
+
528
+ # Enable TF32 for faster training on Ampere GPUs,
529
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
530
+ if args.allow_tf32:
531
+ torch.backends.cuda.matmul.allow_tf32 = True
532
+
533
+ learning_rate = 1e-4
534
+ train_batch_size = min(len(train_images), max_batch_size) #len(train_images) % 16
535
+
536
+
537
+ # Optimizer creation
538
+ params_to_optimize = sd_model.parameters()
539
+ optimizer = torch.optim.AdamW(
540
+ params_to_optimize,
541
+ lr=learning_rate,
542
+ betas=(args.adam_beta1, args.adam_beta2),
543
+ weight_decay=args.adam_weight_decay,
544
+ eps=args.adam_epsilon,
545
  )
546
 
547
+ inputs = [{
548
+ "source_image": in_image,
549
+ "source_pose": in_pose,
550
+ "target_image": timg,
551
+ "target_pose": tpose,
552
+ } for timg, tpose in zip(train_images, train_poses)]
553
+
554
+ """
555
+ inputs = {[
556
+ "source_image": Image.open('imgs/sm.png'),
557
+ "source_pose": Image.open('imgs/sm_pose.jpg'),
558
+ "target_image": Image.open('imgs/target.png'),
559
+ "target_pose": Image.open('imgs/target_pose.jpg'),
560
+ ]}
561
+ """
562
+
563
+ #print(inputs)
564
+
565
+ dataset = InpaintDataset(
566
+ inputs,
567
+ 'imgs/',
568
+ size=(args.img_width, args.img_height), # w h
569
+ imgp_drop_rate=0.1,
570
+ imgg_drop_rate=0.1,
571
+ )
572
+
573
+ """
574
+ dataset = InpaintDataset(
575
+ args.json_path,
576
+ args.image_root_path,
577
+ size=(args.img_width, args.img_height), # w h
578
+ imgp_drop_rate=0.1,
579
+ imgg_drop_rate=0.1,
580
  )
581
+ """
582
+
583
+ train_sampler = torch.utils.data.distributed.DistributedSampler(
584
+ dataset, num_replicas=accelerator.num_processes, rank=accelerator.process_index, shuffle=True)
585
+
586
+ train_dataloader = torch.utils.data.DataLoader(
587
+ dataset,
588
+ sampler=train_sampler,
589
+ collate_fn=InpaintCollate_fn,
590
+ batch_size=train_batch_size,
591
+ num_workers=0,)
592
+
593
+
594
+ # Scheduler and math around the number of training steps.
595
+ overrode_max_train_steps = False
596
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
597
+ if args.max_train_steps is None:
598
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
599
+ overrode_max_train_steps = True
600
+ args.max_train_steps = train_steps
601
+
602
+ lr_scheduler = get_scheduler(
603
+ args.lr_scheduler,
604
+ optimizer=optimizer,
605
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
606
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
607
+ num_cycles=args.lr_num_cycles,
608
+ power=args.lr_power,
609
+ )
610
+
611
+ # Prepare everything with our `accelerator`.
612
+ sd_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(sd_model, optimizer, train_dataloader, lr_scheduler)
613
+
614
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
615
+ # as these models are only used for inference, keeping weights in full precision is not required.
616
+ weight_dtype = torch.float32
617
+ """
618
+ if accelerator.mixed_precision == "fp16":
619
+ weight_dtype = torch.float16
620
+ elif accelerator.mixed_precision == "bf16":
621
+ weight_dtype = torch.bfloat16
622
+ """
623
+
624
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
625
+ vae.to(accelerator.device, dtype=weight_dtype)
626
+ sd_model.unet.to(accelerator.device, dtype=weight_dtype)
627
+ image_encoder_p.to(accelerator.device, dtype=weight_dtype)
628
+ image_encoder_g.to(accelerator.device, dtype=weight_dtype)
629
+
630
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
631
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
632
+ if overrode_max_train_steps:
633
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
634
+ # Afterwards we recalculate our number of training epochs
635
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
636
+
637
+
638
+ args.num_train_epochs = train_steps
639
+
640
+
641
+ # Train!
642
+ total_batch_size = (
643
+ train_batch_size
644
+ * accelerator.num_processes
645
+ * args.gradient_accumulation_steps
646
+ )
647
+
648
+ print("***** Running training *****")
649
+ print(f" Num batches each epoch = {len(train_dataloader)}")
650
+ print(f" Num Epochs = {args.num_train_epochs}")
651
+ print(f" Instantaneous batch size per device = {train_batch_size}")
652
+ print(
653
+ f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
654
+ )
655
+ print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
656
+ print(f" Total optimization steps = {args.max_train_steps}")
657
+
658
+
659
+ if args.resume_from_checkpoint:
660
+ # New Code #
661
+ # Loads the DeepSpeed checkpoint from the specified path
662
+ prior_model, last_epoch, last_global_step = load_training_checkpoint(
663
+ sd_model,
664
+ pcdms_model,
665
+ **{"load_optimizer_states": True, "load_lr_scheduler_states": True},
666
+ )
667
+ print(f"Resumed from checkpoint: {args.resume_from_checkpoint}, global step: {last_global_step}")
668
+ starting_epoch = last_epoch
669
+ global_steps = last_global_step
670
+ sd_model = sd_model
671
+ else:
672
+ global_steps = 0
673
+ starting_epoch = 0
674
+ sd_model = sd_model
675
+
676
+ progress_bar = tqdm(range(global_steps, args.max_train_steps), initial=global_steps, desc="Steps",
677
+ # Only show the progress bar once on each machine.
678
+ disable=not accelerator.is_local_main_process, )
679
+
680
+ bsz = train_batch_size
681
+
682
+ if not finetune or train_steps == 0:
683
+ accelerator.wait_for_everyone()
684
+ accelerator.end_training()
685
+ return {k: v.cpu() for k, v in sd_model.state_dict().items()}
686
+
687
+
688
+ it = range(starting_epoch, args.num_train_epochs)
689
+ if is_app:
690
+ it = progress.tqdm(it, desc="Fine-tuning")
691
+ for epoch in it:
692
+ for step, batch in enumerate(train_dataloader):
693
+ with accelerator.accumulate(sd_model):
694
+ with torch.no_grad():
695
+ # Convert images to latent space
696
+ latents = vae.encode(batch["source_target_image"].to(dtype=weight_dtype)).latent_dist.sample()
697
+ latents = latents * vae.config.scaling_factor
698
+
699
+ # Get the masked image latents
700
+ masked_latents = vae.encode(batch["vae_source_mask_image"].to(dtype=weight_dtype)).latent_dist.sample()
701
+ masked_latents = masked_latents * vae.config.scaling_factor
702
+
703
+ bsz = batch["target_image"].size(dim=0)
704
+
705
+ # mask
706
+ mask1 = torch.ones((bsz, 1, int(args.img_height / 8), int(args.img_width / 8))).to(accelerator.device, dtype=weight_dtype)
707
+ mask0 = torch.zeros((bsz, 1, int(args.img_height / 8), int(args.img_width / 8))).to(accelerator.device, dtype=weight_dtype)
708
+ mask = torch.cat([mask1, mask0], dim=3)
709
+ # Get the image embedding for conditioning
710
+ cond_image_feature_p = image_encoder_p(batch["source_image"].to(accelerator.device, dtype=weight_dtype))
711
+ cond_image_feature_p = (cond_image_feature_p.last_hidden_state)
712
+
713
+
714
+ cond_image_feature_g = image_encoder_g(batch["target_image"].to(accelerator.device, dtype=weight_dtype), ).image_embeds
715
+ cond_image_feature_g =cond_image_feature_g.unsqueeze(1)
716
+
717
+ # Sample noise that we'll add to the latents
718
+ noise = torch.randn_like(latents)
719
+ if args.noise_offset:
720
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
721
+ noise += args.noise_offset * torch.randn(
722
+ (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
723
+ )
724
+
725
+ # Sample a random timestep for each image
726
+ #timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (train_batch_size,),device=latents.device, )
727
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,),device=latents.device, )
728
+ timesteps = timesteps.long()
729
+
730
+
731
+
732
+ # Add noise to the latents according to the noise magnitude at each timestep (this is the forward diffusion process)
733
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
734
+
735
+ #print(noisy_latents.size(), mask.size(), masked_latents.size())
736
+
737
+ noisy_latents = torch.cat([noisy_latents, mask, masked_latents], dim=1)
738
+ # Get the text embedding for conditioning
739
+
740
+
741
+ cond_pose = batch["source_target_pose"].to(dtype=weight_dtype)
742
+
743
+ #print(noisy_latents.size())
744
+ #print(cond_image_feature_p.size())
745
+ #print(cond_image_feature_g.size())
746
+ #print(cond_pose.size())
747
+
748
+ # Predict the noise residual
749
+ model_pred = sd_model(noisy_latents, timesteps, cond_image_feature_p,cond_image_feature_g, cond_pose, )
750
+
751
+ # Get the target for loss depending on the prediction type
752
+ if noise_scheduler.config.prediction_type == "epsilon":
753
+ target = noise
754
+ elif noise_scheduler.config.prediction_type == "v_prediction":
755
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
756
+ else:
757
+ raise ValueError(
758
+ f"Unknown prediction type {noise_scheduler.config.prediction_type}"
759
+ )
760
+
761
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
762
+
763
+ accelerator.backward(loss)
764
+ if accelerator.sync_gradients:
765
+ params_to_clip = sd_model.parameters()
766
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
767
+ optimizer.step()
768
+ lr_scheduler.step()
769
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
770
+
771
+ # Checks if the accelerator has performed an optimization step behind the scenes
772
+ if accelerator.sync_gradients:
773
+ global_steps += 1
774
+
775
+ if global_steps >= args.max_train_steps:
776
+ break
777
+
778
+
779
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
780
+ print(logs)
781
+ progress_bar.set_postfix(**logs)
782
+
783
+ progress_bar.update(1)
784
+
785
+ # Create the pipeline using the trained modules and save it.
786
+ accelerator.wait_for_everyone()
787
+ accelerator.end_training()
788
+
789
+ sd_model.unet.cpu()
790
+ sd_model.cpu()
791
+ del vae
792
+ del image_encoder_p
793
+ del image_encoder_g
794
+
795
+ if save_model: #if global_steps % args.checkpointing_steps == 0 or global_steps == args.max_train_steps:
796
+ print('saving', modelId)
797
+
798
+ checkpoint_state_dict = {
799
+ "epoch": 0,
800
+ "module": {k: v.cpu() for k, v in sd_model.state_dict().items()}, #sd_model.state_dict(),
801
+ }
802
+ print(list(sd_model.state_dict().keys())[:20])
803
+ torch.save(checkpoint_state_dict, modelId+".pt")
804
+
805
+ del sd_model
806
+ gc.collect()
807
+ torch.cuda.empty_cache()
808
+ print('done train')
809
+ print(torch.cuda.memory_allocated()/1024**2)
810
+ return
811
+
812
+ del sd_model
813
+ gc.collect()
814
+ torch.cuda.empty_cache()
815
+ return {k: v.cpu() for k, v in sd_model.state_dict().items()}
816
+
817
+
818
+
819
+
820
+ # Pose-transfer ===================================================================================================
821
+
822
+
823
+ device = "cuda"
824
+
825
+ class ImageProjModel(torch.nn.Module):
826
+ """SD model with image prompt"""
827
+ def __init__(self, in_dim, hidden_dim, out_dim, dropout = 0.):
828
+ super().__init__()
829
+
830
+ self.net = nn.Sequential(
831
+ nn.Linear(in_dim, hidden_dim),
832
+ nn.GELU(),
833
+ nn.Dropout(dropout),
834
+ nn.LayerNorm(hidden_dim),
835
+ nn.Linear(hidden_dim, out_dim),
836
+ nn.Dropout(dropout)
837
+ )
838
+
839
+ def forward(self, x):
840
+ return self.net(x)
841
+
842
+ def image_grid(imgs, rows, cols):
843
+ assert len(imgs) == rows * cols
844
+ w, h = imgs[0].size
845
+ print(w, h)
846
+ grid = Image.new("RGB", size=(cols * w, rows * h))
847
+ grid_w, grid_h = grid.size
848
+
849
+ for i, img in enumerate(imgs):
850
+ grid.paste(img, box=(i % cols * w, i // cols * h))
851
+ return grid
852
+
853
+ def load_mydict(modelId, finetuned_model):
854
+ if save_model:
855
+ model_ckpt_path = modelId+'.pt'
856
+ model_sd = torch.load(model_ckpt_path, map_location="cpu")["module"]
857
+ else:
858
+ model_sd = finetuned_model #torch.load(model_ckpt_path, map_location="cpu")["module"]
859
+
860
+ image_proj_model_dict = {}
861
+ pose_proj_dict = {}
862
+ unet_dict = {}
863
+ for k in model_sd.keys():
864
+ if k.startswith("pose_proj"):
865
+ pose_proj_dict[k.replace("pose_proj.", "")] = model_sd[k]
866
+
867
+ elif k.startswith("image_proj_model_p"):
868
+ image_proj_model_dict[k.replace("image_proj_model_p.", "")] = model_sd[k]
869
+ elif k.startswith("image_proj_model"):
870
+ image_proj_model_dict[k.replace("image_proj_model.", "")] = model_sd[k]
871
+
872
+
873
+ elif k.startswith("unet"):
874
+ unet_dict[k.replace("unet.", "")] = model_sd[k]
875
+ else:
876
+ print(k)
877
+ return image_proj_model_dict, pose_proj_dict, unet_dict
878
+
879
+
880
+
881
+ @spaces.GPU(duration=600)
882
+ def inference(modelId, in_image, in_pose, target_poses, inference_steps, finetuned_model, vae, unet, image_encoder, is_app=False):
883
+ print('start inference')
884
+ progress=gr.Progress(track_tqdm=True)
885
+
886
+ if not save_model:
887
+ finetuned_model = {k: v.cuda() for k, v in finetuned_model.items()}
888
+
889
+ device = "cuda"
890
+ pretrained_model_name_or_path ="stabilityai/stable-diffusion-2-1-base"
891
+ image_encoder_path = "facebook/dinov2-giant"
892
+ #model_ckpt_path = "./pcdms_ckpt.pt" # ckpt path
893
+ model_ckpt_path = modelId+'.pt'
894
+
895
+
896
+ clip_image_processor = CLIPImageProcessor()
897
+ img_transform = transforms.Compose([
898
+ transforms.ToTensor(),
899
+ transforms.Normalize([0.5], [0.5]),
900
+ ])
901
+
902
+ generator = torch.Generator(device=device).manual_seed(42)
903
+
904
+ """
905
+ unet = Stage2_InapintUNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16,subfolder="unet",in_channels=9, low_cpu_mem_usage=False, ignore_mismatched_sizes=True).to(device)
906
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path,subfolder="vae").to(device, dtype=torch.float16)
907
+ image_encoder = Dinov2Model.from_pretrained(image_encoder_path).to(device, dtype=torch.float16)
908
+ """
909
+ noise_scheduler = DDIMScheduler(
910
+ num_train_timesteps=1000,
911
+ beta_start=0.00085,
912
+ beta_end=0.012,
913
+ beta_schedule="scaled_linear",
914
+ clip_sample=False,
915
+ set_alpha_to_one=False,
916
+ steps_offset=1,
917
+ )
918
+
919
+ unet = unet.to(device, dtype=torch.float16)
920
+ vae = vae.to(device, dtype=torch.float16)
921
+ image_encoder = image_encoder.to(device, dtype=torch.float16)
922
+
923
+
924
+ image_proj_model = ImageProjModel(in_dim=1536, hidden_dim=768, out_dim=1024).to(device).to(dtype=torch.float16)
925
+ pose_proj_model = ControlNetConditioningEmbedding(
926
+ conditioning_embedding_channels=320,
927
+ block_out_channels=(16, 32, 96, 256),
928
+ conditioning_channels=3).to(device).to(dtype=torch.float16)
929
+
930
+
931
+ # load weight
932
+ print('loading', modelId)
933
+ image_proj_model_dict, pose_proj_dict, unet_dict = load_mydict(modelId, finetuned_model)
934
+ print('loaded', modelId)
935
+ image_proj_model.load_state_dict(image_proj_model_dict)
936
+ pose_proj_model.load_state_dict(pose_proj_dict)
937
+ unet.load_state_dict(unet_dict)
938
+
939
+
940
+ pipe = PCDMsPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base", unet=unet, torch_dtype=torch.float16, scheduler=noise_scheduler,feature_extractor=None,safety_checker=None).to(device)
941
+
942
+ print('====================== model load finish ===================')
943
+
944
+ results = []
945
+ progress_bar = tqdm(range(len(target_poses)), initial=0, desc="Frames")
946
+
947
+
948
+ it = target_poses
949
+ if is_app:
950
+ it = progress.tqdm(it, desc="Pose Transfer")
951
+ for pose in it:
952
+
953
+ num_samples = 1
954
+ image_size = (512, 512)
955
+ s_img_path = 'imgs/'+input_img # input image 1
956
+ #target_pose_img = 'imgs/pose_'+str(n)+'.png' # input image 2
957
+
958
+ #t_pose = inference_pose(target_pose_img, image_size=(image_size[1], image_size[0])).resize(image_size, Image.BICUBIC)
959
+ #t_pose = Image.open(target_pose_img).convert("RGB").resize((image_size), Image.BICUBIC)
960
+ t_pose = pose.convert("RGB").resize((image_size), Image.BICUBIC)
961
+ #t_pose = resize_and_pad(pose.convert("RGB"))
962
+
963
+
964
+ #s_img = Image.open(s_img_path)
965
+ width_orig, height_orig = in_image.size
966
+ s_img = in_image.convert("RGB").resize(image_size, Image.BICUBIC)
967
+ #s_img = resize_and_pad(in_image.convert("RGB"))
968
+ black_image = Image.new("RGB", s_img.size, (0, 0, 0)).resize(image_size, Image.BICUBIC)
969
+
970
+ s_img_t_mask = Image.new("RGB", (s_img.width * 2, s_img.height))
971
+ s_img_t_mask.paste(s_img, (0, 0))
972
+ s_img_t_mask.paste(black_image, (s_img.width, 0))
973
+
974
+ #s_pose = inference_pose(s_img_path, image_size=(image_size[1], image_size[0])).resize(image_size, Image.BICUBIC)
975
+ #s_pose = Image.open('imgs/sm_pose.jpg').convert("RGB").resize(image_size, Image.BICUBIC)
976
+ s_pose = in_pose.convert("RGB").resize(image_size, Image.BICUBIC)
977
+ #s_pose = resize_and_pad(in_pose.convert("RGB"))
978
+ print('source image width: {}, height: {}'.format(s_pose.width, s_pose.height))
979
+ #t_pose = Image.open(target_pose_img).convert("RGB").resize((image_size), Image.BICUBIC)
980
+
981
+ st_pose = Image.new("RGB", (s_pose.width * 2, s_pose.height))
982
+ st_pose.paste(s_pose, (0, 0))
983
+ st_pose.paste(t_pose, (s_pose.width, 0))
984
+
985
+
986
+ clip_s_img = clip_image_processor(images=s_img, return_tensors="pt").pixel_values
987
+ vae_image = torch.unsqueeze(img_transform(s_img_t_mask), 0)
988
+ cond_st_pose = torch.unsqueeze(img_transform(st_pose), 0)
989
+
990
+ mask1 = torch.ones((1, 1, int(image_size[0] / 8), int(image_size[1] / 8))).to(device, dtype=torch.float16)
991
+ mask0 = torch.zeros((1, 1, int(image_size[0] / 8), int(image_size[1] / 8))).to(device, dtype=torch.float16)
992
+ mask = torch.cat([mask1, mask0], dim=3)
993
+
994
+
995
+ with torch.inference_mode():
996
+ cond_pose = pose_proj_model(cond_st_pose.to(dtype=torch.float16, device=device))
997
+ simg_mask_latents = pipe.vae.encode(vae_image.to(device, dtype=torch.float16)).latent_dist.sample()
998
+ simg_mask_latents = simg_mask_latents * 0.18215
999
+
1000
+ images_embeds = image_encoder(clip_s_img.to(device, dtype=torch.float16)).last_hidden_state
1001
+ image_prompt_embeds = image_proj_model(images_embeds)
1002
+ uncond_image_prompt_embeds = image_proj_model(torch.zeros_like(images_embeds))
1003
+
1004
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
1005
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
1006
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
1007
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
1008
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
1009
+
1010
+ output, _ = pipe(
1011
+ simg_mask_latents= simg_mask_latents,
1012
+ mask = mask,
1013
+ cond_pose = cond_pose,
1014
+ prompt_embeds=image_prompt_embeds,
1015
+ negative_prompt_embeds=uncond_image_prompt_embeds,
1016
+ height=image_size[1],
1017
+ width=image_size[0]*2,
1018
+ num_images_per_prompt=num_samples,
1019
+ guidance_scale=2.0,
1020
+ generator=generator,
1021
+ num_inference_steps=inference_steps,
1022
+ )
1023
+
1024
+ output = output.images[-1]
1025
+
1026
+ result = output.crop((image_size[0], 0, image_size[0] * 2, image_size[1]))
1027
+ result = result.resize((width_orig, height_orig), Image.BICUBIC)
1028
+ #result = remove_zero_pad(result)
1029
+
1030
+ if debug:
1031
+ result.save('out/'+str(len(results))+'.png')
1032
+ results.append(result)
1033
+ progress_bar.update(1)
1034
+
1035
+ del unet
1036
+ del vae
1037
+ del image_encoder
1038
+ del image_proj_model
1039
+ del pose_proj_model
1040
+
1041
+ if not save_model:
1042
+ del finetuned_model
1043
+
1044
+ gc.collect()
1045
+ torch.cuda.empty_cache()
1046
+ print(torch.cuda.memory_allocated()/1024**2)
1047
+
1048
+ return results
1049
+
1050
+
1051
+ def gen_vid(frames, video_name, fps, codec):
1052
+ progress=gr.Progress(track_tqdm=True)
1053
+
1054
+ frame = cv2.cvtColor(np.array(frames[0]), cv2.COLOR_RGB2BGR)
1055
+ height, width, layers = frame.shape
1056
+
1057
+ #video = cv2.VideoWriter(video_name, 0, 1, (width,height))
1058
+ if codec == 'mp4':
1059
+ video = cv2.VideoWriter(video_name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
1060
+ else:
1061
+ video = cv2.VideoWriter(video_name, cv2.VideoWriter_fourcc(*'VP90'), fps, (width, height))
1062
+
1063
+ for r in progress.tqdm(frames, desc="Creating video"):
1064
+ image = cv2.cvtColor(np.array(r), cv2.COLOR_RGB2BGR)
1065
+ video.write(image)
1066
+
1067
+ #cv2.destroyAllWindows()
1068
+ #video.release()
1069
+
1070
+
1071
+
1072
+ def run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, resize_inputs=True, finetune=True, is_app=False):
1073
+ print("==== Load Models ====")
1074
+ dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
1075
+
1076
+ print("==== Pose Detection ====")
1077
+ if resize_inputs:
1078
+ resize = 'target'
1079
+ else:
1080
+ resize = 'none'
1081
+ in_img, in_pose, train_imgs, train_poses, target_poses = prepare_inputs(images, video_path, fps, bg_remove, dwpose, rembg_session, resize=resize, is_app=is_app)
1082
+
1083
+ if save_model:
1084
+ train("fine_tuned_pcdms", in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
1085
+ print('next')
1086
+ results = inference("fine_tuned_pcdms", in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
1087
+
1088
+ else:
1089
+ print("==== Finetuning ====")
1090
+ finetuned_model = train("fine_tuned_pcdms", in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
1091
+
1092
+ print("==== Pose Transfer ====")
1093
+ results = inference("fine_tuned_pcdms", in_img, in_pose, target_poses, inference_steps, finetuned_model, vae, unet, image_encoder_p, is_app)
1094
+
1095
+ return results
1096
+
1097
+
1098
+ def run_train(images, train_steps=100, modelId="fine_tuned_pcdms", bg_remove=True, resize_inputs=True):
1099
+ finetune=True
1100
+ is_app=True
1101
+ images = [img[0] for img in images]
1102
+
1103
+ dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
1104
+
1105
+ if resize_inputs:
1106
+ resize = 'target'
1107
+ else:
1108
+ resize = 'none'
1109
+
1110
+ in_img, in_pose, train_imgs, train_poses = prepare_inputs_train(images, bg_remove, dwpose, rembg_session)
1111
+
1112
+ train(modelId, in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
1113
+
1114
+
1115
+ def run_inference(images, video_path, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=True, resize_inputs=True):
1116
+ finetune=True
1117
+ is_app=True
1118
+
1119
+
1120
+ dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
1121
+
1122
+ if not os.path.exists(modelId+".pt"):
1123
+ run_train(images, train_steps, modelId, bg_remove, resize_inputs)
1124
+
1125
+ images = [img[0] for img in images]
1126
+ in_img = images[0]
1127
+
1128
+ in_img, target_poses, in_pose = prepare_inputs_inference(in_img, video_path, fps, dwpose, rembg_session, bg_remove, 'target', is_app)
1129
+
1130
+ results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
1131
+ #urls = save_temp_imgs(results)
1132
+
1133
+ if should_gen_vid:
1134
+ if debug:
1135
+ gen_vid(results, out_vid+'.mp4', fps, 'mp4')
1136
+ else:
1137
+ gen_vid(results, out_vid+'.webm', fps, 'webm')
1138
+
1139
+
1140
+ # postprocessing
1141
+ results = [removebg(img, rembg_session, True) for img in results]
1142
+ results = [img_pad(img, img_width, img_height, True) for img in results]
1143
+
1144
+ print("Done!")
1145
+
1146
+ return out_vid+'.webm', results, getThumbnails(results)
1147
+
1148
+
1149
+ def run_app(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, resize_inputs=True):
1150
+
1151
+ images = [img[0] for img in images]
1152
+
1153
+ results = run(images, video_path, train_steps, inference_steps, fps, bg_remove, resize_inputs, finetune=True, is_app=True)
1154
+
1155
+
1156
+ print("==== Video generation ====")
1157
+ out_vid = f"out_{uuid.uuid4()}"
1158
+
1159
+ if debug:
1160
+ gen_vid(results, out_vid+'.mp4', fps, 'mp4')
1161
+ else:
1162
+ gen_vid(results, out_vid+'.webm', fps, 'webm')
1163
+
1164
+
1165
+
1166
+ print("Done!")
1167
+
1168
+ return out_vid+'.webm', results
1169
+
1170
+
1171
 
1172
+ """
1173
+ train_steps = 100
1174
+ inference_steps = 10
1175
+ fps = 12
1176
+ """
1177
 
1178
+ """
1179
+ iface = gr.Interface(
1180
+ fn=run,
1181
+ inputs=[
1182
+ gr.Gallery(type="pil", label="Images of the Character"),
1183
+ gr.Video(label="Motion-Capture Video"),
1184
+ gr.Number(label="Training steps", value=100),
1185
+ gr.Number(label="Inference steps", value=10),
1186
+ gr.Number(label="Output frame rate", value=12),
1187
+ gr.Checkbox(label="Remove background", value=False),
1188
+ ],
1189
+ outputs=[gr.Video(label="Result"), gr.Gallery(type="pil", label="Frames")],
1190
+ title="Keyframes AI",
1191
+ description="Upload images of your character and a motion-capture video to generate an animation of the character.",
1192
+ )
1193
+ """
1194
 
1195
 
1196
 
main.py CHANGED
@@ -167,6 +167,17 @@ def load_models():
167
 
168
  #load_models()
169
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  def resize_and_pad(img, target_img):
172
  tw, th = target_img.size
@@ -180,12 +191,8 @@ def resize_and_pad(img, target_img):
180
  img = img.resize((tw, th), Image.BICUBIC)
181
 
182
  tw, th = target_img.size
183
- new_img = Image.new("RGB", (tw, th), (0, 0, 0))
184
- left = (tw - img.width) // 2
185
- top = (th - img.height) // 2
186
- new_img.paste(img, (left, top))
187
 
188
- return new_img
189
 
190
 
191
  def remove_zero_pad(image):
@@ -1105,7 +1112,7 @@ def run_train(images, train_steps=100, modelId="fine_tuned_pcdms", bg_remove=Tru
1105
  train(modelId, in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
1106
 
1107
 
1108
- def run_inference(images, video_path, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", bg_remove=True, resize_inputs=True):
1109
  finetune=True
1110
  is_app=True
1111
 
@@ -1130,7 +1137,9 @@ def run_inference(images, video_path, train_steps=100, inference_steps=10, fps=1
1130
  gen_vid(results, out_vid+'.webm', fps, 'webm')
1131
 
1132
 
 
1133
  results = [removebg(img, rembg_session, True) for img in results]
 
1134
 
1135
  print("Done!")
1136
 
 
167
 
168
  #load_models()
169
 
170
+ def img_pad(img, tw, th, transparent=False):
171
+ img.thumbnail((tw, th))
172
+ if transparent:
173
+ new_img = Image.new('RGBA', (tw, th), (0, 0, 0, 0))
174
+ else:
175
+ new_img = Image.new("RGB", (tw, th), (0, 0, 0))
176
+ left = (tw - img.width) // 2
177
+ top = (th - img.height) // 2
178
+ new_img.paste(img, (left, top))
179
+ return new_img
180
+
181
 
182
  def resize_and_pad(img, target_img):
183
  tw, th = target_img.size
 
191
  img = img.resize((tw, th), Image.BICUBIC)
192
 
193
  tw, th = target_img.size
 
 
 
 
194
 
195
+ return img_pad(img, tw, th)
196
 
197
 
198
  def remove_zero_pad(image):
 
1112
  train(modelId, in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
1113
 
1114
 
1115
+ def run_inference(images, video_path, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=True, resize_inputs=True):
1116
  finetune=True
1117
  is_app=True
1118
 
 
1137
  gen_vid(results, out_vid+'.webm', fps, 'webm')
1138
 
1139
 
1140
+ # postprocessing
1141
  results = [removebg(img, rembg_session, True) for img in results]
1142
+ results = [img_pad(img, img_width, img_height, True) for img in results]
1143
 
1144
  print("Done!")
1145