Spaces:
Paused
Paused
transparent background
Browse files
main.py
CHANGED
|
@@ -129,7 +129,7 @@ def save_temp_imgs(imgs):
|
|
| 129 |
|
| 130 |
def getThumbnails(imgs):
|
| 131 |
thumbs = []
|
| 132 |
-
thumb_size = (
|
| 133 |
for img in imgs:
|
| 134 |
th = img.copy()
|
| 135 |
th.thumbnail(thumb_size)
|
|
@@ -248,14 +248,19 @@ def extract_frames(video_path, fps):
|
|
| 248 |
return frames
|
| 249 |
|
| 250 |
|
| 251 |
-
def removebg(img, rembg_session):
|
| 252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
out = rembg.remove(img, session=rembg_session)
|
| 254 |
result.paste(out, mask=out)
|
| 255 |
return result
|
| 256 |
|
| 257 |
|
| 258 |
def prepare_inputs_train(images, bg_remove, dwpose, rembg_session):
|
|
|
|
| 259 |
if bg_remove:
|
| 260 |
images = [removebg(img, rembg_session) for img in images]
|
| 261 |
|
|
@@ -270,7 +275,7 @@ def prepare_inputs_train(images, bg_remove, dwpose, rembg_session):
|
|
| 270 |
return in_img, in_pose, train_imgs, train_poses
|
| 271 |
|
| 272 |
|
| 273 |
-
def prepare_inputs_inference(in_img, in_vid, fps, dwpose, resize='target', is_app=False):
|
| 274 |
progress=gr.Progress(track_tqdm=True)
|
| 275 |
|
| 276 |
print("prepare_inputs_inference")
|
|
@@ -278,7 +283,10 @@ def prepare_inputs_inference(in_img, in_vid, fps, dwpose, resize='target', is_ap
|
|
| 278 |
in_pose = get_pose(in_img, dwpose, "in_pose.png")
|
| 279 |
|
| 280 |
frames = extract_frames(in_vid, fps)
|
| 281 |
-
|
|
|
|
|
|
|
|
|
|
| 282 |
if debug:
|
| 283 |
for i, frame in enumerate(frames):
|
| 284 |
frame.save("out/frame_"+str(i)+".png")
|
|
@@ -317,14 +325,14 @@ def prepare_inputs_inference(in_img, in_vid, fps, dwpose, resize='target', is_ap
|
|
| 317 |
tpose.save("out/"+"tar_pose"+str(len(target_poses_cropped))+".png")
|
| 318 |
target_poses_cropped.append(tpose)
|
| 319 |
|
| 320 |
-
return target_poses_cropped, in_pose
|
| 321 |
|
| 322 |
|
| 323 |
def prepare_inputs(images, in_vid, fps, bg_remove, dwpose, rembg_session, resize='target', is_app=False):
|
| 324 |
|
| 325 |
in_img, in_pose, train_imgs, train_poses = prepare_inputs_train(images, bg_remove, dwpose, rembg_session)
|
| 326 |
|
| 327 |
-
target_poses_cropped, _ = prepare_inputs_inference(in_img, in_vid, fps, dwpose, resize, is_app)
|
| 328 |
|
| 329 |
|
| 330 |
return in_img, in_pose, train_imgs, train_poses, target_poses_cropped
|
|
@@ -1110,7 +1118,7 @@ def run_inference(images, video_path, train_steps=100, inference_steps=10, fps=1
|
|
| 1110 |
images = [img[0] for img in images]
|
| 1111 |
in_img = images[0]
|
| 1112 |
|
| 1113 |
-
target_poses, in_pose = prepare_inputs_inference(in_img, video_path, fps, dwpose, 'target', is_app)
|
| 1114 |
|
| 1115 |
results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
|
| 1116 |
#urls = save_temp_imgs(results)
|
|
@@ -1121,6 +1129,9 @@ def run_inference(images, video_path, train_steps=100, inference_steps=10, fps=1
|
|
| 1121 |
else:
|
| 1122 |
gen_vid(results, out_vid+'.webm', fps, 'webm')
|
| 1123 |
|
|
|
|
|
|
|
|
|
|
| 1124 |
print("Done!")
|
| 1125 |
|
| 1126 |
return out_vid+'.webm', results, getThumbnails(results)
|
|
|
|
| 129 |
|
| 130 |
def getThumbnails(imgs):
|
| 131 |
thumbs = []
|
| 132 |
+
thumb_size = (512, 512)
|
| 133 |
for img in imgs:
|
| 134 |
th = img.copy()
|
| 135 |
th.thumbnail(thumb_size)
|
|
|
|
| 248 |
return frames
|
| 249 |
|
| 250 |
|
| 251 |
+
def removebg(img, rembg_session, transparent=False):
|
| 252 |
+
|
| 253 |
+
if transparent:
|
| 254 |
+
result = Image.new('RGBA', img.size, (0, 0, 0, 0))
|
| 255 |
+
else:
|
| 256 |
+
result = Image.new("RGB", img.size, "#ffffff")
|
| 257 |
out = rembg.remove(img, session=rembg_session)
|
| 258 |
result.paste(out, mask=out)
|
| 259 |
return result
|
| 260 |
|
| 261 |
|
| 262 |
def prepare_inputs_train(images, bg_remove, dwpose, rembg_session):
|
| 263 |
+
print("remove background", bg_remove)
|
| 264 |
if bg_remove:
|
| 265 |
images = [removebg(img, rembg_session) for img in images]
|
| 266 |
|
|
|
|
| 275 |
return in_img, in_pose, train_imgs, train_poses
|
| 276 |
|
| 277 |
|
| 278 |
+
def prepare_inputs_inference(in_img, in_vid, fps, dwpose, rembg_session, bg_remove, resize='target', is_app=False):
|
| 279 |
progress=gr.Progress(track_tqdm=True)
|
| 280 |
|
| 281 |
print("prepare_inputs_inference")
|
|
|
|
| 283 |
in_pose = get_pose(in_img, dwpose, "in_pose.png")
|
| 284 |
|
| 285 |
frames = extract_frames(in_vid, fps)
|
| 286 |
+
print("remove background", bg_remove)
|
| 287 |
+
if bg_remove:
|
| 288 |
+
in_img = removebg(in_img, rembg_session)
|
| 289 |
+
#frames = [removebg(img, rembg_session) for img in frames]
|
| 290 |
if debug:
|
| 291 |
for i, frame in enumerate(frames):
|
| 292 |
frame.save("out/frame_"+str(i)+".png")
|
|
|
|
| 325 |
tpose.save("out/"+"tar_pose"+str(len(target_poses_cropped))+".png")
|
| 326 |
target_poses_cropped.append(tpose)
|
| 327 |
|
| 328 |
+
return in_img, target_poses_cropped, in_pose
|
| 329 |
|
| 330 |
|
| 331 |
def prepare_inputs(images, in_vid, fps, bg_remove, dwpose, rembg_session, resize='target', is_app=False):
|
| 332 |
|
| 333 |
in_img, in_pose, train_imgs, train_poses = prepare_inputs_train(images, bg_remove, dwpose, rembg_session)
|
| 334 |
|
| 335 |
+
in_img, target_poses_cropped, _ = prepare_inputs_inference(in_img, in_vid, fps, dwpose, rembg_session, bg_remove, resize, is_app)
|
| 336 |
|
| 337 |
|
| 338 |
return in_img, in_pose, train_imgs, train_poses, target_poses_cropped
|
|
|
|
| 1118 |
images = [img[0] for img in images]
|
| 1119 |
in_img = images[0]
|
| 1120 |
|
| 1121 |
+
in_img, target_poses, in_pose = prepare_inputs_inference(in_img, video_path, fps, dwpose, rembg_session, bg_remove, 'target', is_app)
|
| 1122 |
|
| 1123 |
results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
|
| 1124 |
#urls = save_temp_imgs(results)
|
|
|
|
| 1129 |
else:
|
| 1130 |
gen_vid(results, out_vid+'.webm', fps, 'webm')
|
| 1131 |
|
| 1132 |
+
|
| 1133 |
+
results = [removebg(img, rembg_session, True) for img in results]
|
| 1134 |
+
|
| 1135 |
print("Done!")
|
| 1136 |
|
| 1137 |
return out_vid+'.webm', results, getThumbnails(results)
|