acmyu commited on
Commit
28c5a09
·
1 Parent(s): ac1ffdc

transparent background

Browse files
Files changed (1) hide show
  1. main.py +19 -8
main.py CHANGED
@@ -129,7 +129,7 @@ def save_temp_imgs(imgs):
129
 
130
  def getThumbnails(imgs):
131
  thumbs = []
132
- thumb_size = (256, 256)
133
  for img in imgs:
134
  th = img.copy()
135
  th.thumbnail(thumb_size)
@@ -248,14 +248,19 @@ def extract_frames(video_path, fps):
248
  return frames
249
 
250
 
251
- def removebg(img, rembg_session):
252
- result = Image.new("RGB", img.size, "#ffffff")
 
 
 
 
253
  out = rembg.remove(img, session=rembg_session)
254
  result.paste(out, mask=out)
255
  return result
256
 
257
 
258
  def prepare_inputs_train(images, bg_remove, dwpose, rembg_session):
 
259
  if bg_remove:
260
  images = [removebg(img, rembg_session) for img in images]
261
 
@@ -270,7 +275,7 @@ def prepare_inputs_train(images, bg_remove, dwpose, rembg_session):
270
  return in_img, in_pose, train_imgs, train_poses
271
 
272
 
273
- def prepare_inputs_inference(in_img, in_vid, fps, dwpose, resize='target', is_app=False):
274
  progress=gr.Progress(track_tqdm=True)
275
 
276
  print("prepare_inputs_inference")
@@ -278,7 +283,10 @@ def prepare_inputs_inference(in_img, in_vid, fps, dwpose, resize='target', is_ap
278
  in_pose = get_pose(in_img, dwpose, "in_pose.png")
279
 
280
  frames = extract_frames(in_vid, fps)
281
- #frames = [removebg(img, rembg_session) for img in frames]
 
 
 
282
  if debug:
283
  for i, frame in enumerate(frames):
284
  frame.save("out/frame_"+str(i)+".png")
@@ -317,14 +325,14 @@ def prepare_inputs_inference(in_img, in_vid, fps, dwpose, resize='target', is_ap
317
  tpose.save("out/"+"tar_pose"+str(len(target_poses_cropped))+".png")
318
  target_poses_cropped.append(tpose)
319
 
320
- return target_poses_cropped, in_pose
321
 
322
 
323
  def prepare_inputs(images, in_vid, fps, bg_remove, dwpose, rembg_session, resize='target', is_app=False):
324
 
325
  in_img, in_pose, train_imgs, train_poses = prepare_inputs_train(images, bg_remove, dwpose, rembg_session)
326
 
327
- target_poses_cropped, _ = prepare_inputs_inference(in_img, in_vid, fps, dwpose, resize, is_app)
328
 
329
 
330
  return in_img, in_pose, train_imgs, train_poses, target_poses_cropped
@@ -1110,7 +1118,7 @@ def run_inference(images, video_path, train_steps=100, inference_steps=10, fps=1
1110
  images = [img[0] for img in images]
1111
  in_img = images[0]
1112
 
1113
- target_poses, in_pose = prepare_inputs_inference(in_img, video_path, fps, dwpose, 'target', is_app)
1114
 
1115
  results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
1116
  #urls = save_temp_imgs(results)
@@ -1121,6 +1129,9 @@ def run_inference(images, video_path, train_steps=100, inference_steps=10, fps=1
1121
  else:
1122
  gen_vid(results, out_vid+'.webm', fps, 'webm')
1123
 
 
 
 
1124
  print("Done!")
1125
 
1126
  return out_vid+'.webm', results, getThumbnails(results)
 
129
 
130
  def getThumbnails(imgs):
131
  thumbs = []
132
+ thumb_size = (512, 512)
133
  for img in imgs:
134
  th = img.copy()
135
  th.thumbnail(thumb_size)
 
248
  return frames
249
 
250
 
251
+ def removebg(img, rembg_session, transparent=False):
252
+
253
+ if transparent:
254
+ result = Image.new('RGBA', img.size, (0, 0, 0, 0))
255
+ else:
256
+ result = Image.new("RGB", img.size, "#ffffff")
257
  out = rembg.remove(img, session=rembg_session)
258
  result.paste(out, mask=out)
259
  return result
260
 
261
 
262
  def prepare_inputs_train(images, bg_remove, dwpose, rembg_session):
263
+ print("remove background", bg_remove)
264
  if bg_remove:
265
  images = [removebg(img, rembg_session) for img in images]
266
 
 
275
  return in_img, in_pose, train_imgs, train_poses
276
 
277
 
278
+ def prepare_inputs_inference(in_img, in_vid, fps, dwpose, rembg_session, bg_remove, resize='target', is_app=False):
279
  progress=gr.Progress(track_tqdm=True)
280
 
281
  print("prepare_inputs_inference")
 
283
  in_pose = get_pose(in_img, dwpose, "in_pose.png")
284
 
285
  frames = extract_frames(in_vid, fps)
286
+ print("remove background", bg_remove)
287
+ if bg_remove:
288
+ in_img = removebg(in_img, rembg_session)
289
+ #frames = [removebg(img, rembg_session) for img in frames]
290
  if debug:
291
  for i, frame in enumerate(frames):
292
  frame.save("out/frame_"+str(i)+".png")
 
325
  tpose.save("out/"+"tar_pose"+str(len(target_poses_cropped))+".png")
326
  target_poses_cropped.append(tpose)
327
 
328
+ return in_img, target_poses_cropped, in_pose
329
 
330
 
331
  def prepare_inputs(images, in_vid, fps, bg_remove, dwpose, rembg_session, resize='target', is_app=False):
332
 
333
  in_img, in_pose, train_imgs, train_poses = prepare_inputs_train(images, bg_remove, dwpose, rembg_session)
334
 
335
+ in_img, target_poses_cropped, _ = prepare_inputs_inference(in_img, in_vid, fps, dwpose, rembg_session, bg_remove, resize, is_app)
336
 
337
 
338
  return in_img, in_pose, train_imgs, train_poses, target_poses_cropped
 
1118
  images = [img[0] for img in images]
1119
  in_img = images[0]
1120
 
1121
+ in_img, target_poses, in_pose = prepare_inputs_inference(in_img, video_path, fps, dwpose, rembg_session, bg_remove, 'target', is_app)
1122
 
1123
  results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
1124
  #urls = save_temp_imgs(results)
 
1129
  else:
1130
  gen_vid(results, out_vid+'.webm', fps, 'webm')
1131
 
1132
+
1133
+ results = [removebg(img, rembg_session, True) for img in results]
1134
+
1135
  print("Done!")
1136
 
1137
  return out_vid+'.webm', results, getThumbnails(results)