acmyu commited on
Commit
05511bd
·
1 Parent(s): 781071b

remove input resize

Browse files
Files changed (1) hide show
  1. main.py +6 -10
main.py CHANGED
@@ -283,7 +283,7 @@ def prepare_inputs_train(images, bg_remove, dwpose, rembg_session):
283
  return in_img, in_pose, train_imgs, train_poses
284
 
285
 
286
- def prepare_inputs_inference(in_img, in_vid, fps, dwpose, rembg_session, bg_remove, resize='target', is_app=False):
287
  progress=gr.Progress(track_tqdm=True)
288
 
289
  print("prepare_inputs_inference")
@@ -334,7 +334,7 @@ def prepare_inputs_inference(in_img, in_vid, fps, dwpose, rembg_session, bg_remo
334
 
335
  target_poses_cropped = []
336
  for tpose in target_poses:
337
- if resize=='target':
338
  tpose = tpose.crop((max_left, max_top, max_right, max_bottom))
339
  tpose = ImageOps.expand(tpose, border=int(tpose.width*0.2), fill=(0,0,0))
340
 
@@ -348,11 +348,11 @@ def prepare_inputs_inference(in_img, in_vid, fps, dwpose, rembg_session, bg_remo
348
  return in_img, target_poses_cropped, in_pose, target_poses_coords, frames
349
 
350
 
351
- def prepare_inputs(images, in_vid, fps, bg_remove, dwpose, rembg_session, resize='target', is_app=False):
352
 
353
  in_img, in_pose, train_imgs, train_poses = prepare_inputs_train(images, bg_remove, dwpose, rembg_session)
354
 
355
- in_img, target_poses_cropped, _, _, _ = prepare_inputs_inference(in_img, in_vid, fps, dwpose, rembg_session, bg_remove, resize, is_app)
356
 
357
 
358
  return in_img, in_pose, train_imgs, train_poses, target_poses_cropped
@@ -1087,11 +1087,7 @@ def run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remo
1087
  dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
1088
 
1089
  print("==== Pose Detection ====")
1090
- if resize_inputs:
1091
- resize = 'target'
1092
- else:
1093
- resize = 'none'
1094
- in_img, in_pose, train_imgs, train_poses, target_poses = prepare_inputs(images, video_path, fps, bg_remove, dwpose, rembg_session, resize=resize, is_app=is_app)
1095
 
1096
  if save_model:
1097
  train("fine_tuned_pcdms", in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
@@ -1138,7 +1134,7 @@ def run_inference(images, video_path, train_steps=100, inference_steps=10, fps=1
1138
  images = [img[0] for img in images]
1139
  in_img = images[0]
1140
 
1141
- in_img, target_poses, in_pose, target_poses_coords, orig_frames = prepare_inputs_inference(in_img, video_path, fps, dwpose, rembg_session, bg_remove, 'target', is_app)
1142
 
1143
  results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
1144
  #urls = save_temp_imgs(results)
 
283
  return in_img, in_pose, train_imgs, train_poses
284
 
285
 
286
+ def prepare_inputs_inference(in_img, in_vid, fps, dwpose, rembg_session, bg_remove, resize_inputs, is_app=False):
287
  progress=gr.Progress(track_tqdm=True)
288
 
289
  print("prepare_inputs_inference")
 
334
 
335
  target_poses_cropped = []
336
  for tpose in target_poses:
337
+ if resize_inputs:
338
  tpose = tpose.crop((max_left, max_top, max_right, max_bottom))
339
  tpose = ImageOps.expand(tpose, border=int(tpose.width*0.2), fill=(0,0,0))
340
 
 
348
  return in_img, target_poses_cropped, in_pose, target_poses_coords, frames
349
 
350
 
351
+ def prepare_inputs(images, in_vid, fps, bg_remove, dwpose, rembg_session, resize_inputs, is_app=False):
352
 
353
  in_img, in_pose, train_imgs, train_poses = prepare_inputs_train(images, bg_remove, dwpose, rembg_session)
354
 
355
+ in_img, target_poses_cropped, _, _, _ = prepare_inputs_inference(in_img, in_vid, fps, dwpose, rembg_session, bg_remove, resize_inputs, is_app)
356
 
357
 
358
  return in_img, in_pose, train_imgs, train_poses, target_poses_cropped
 
1087
  dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
1088
 
1089
  print("==== Pose Detection ====")
1090
+ in_img, in_pose, train_imgs, train_poses, target_poses = prepare_inputs(images, video_path, fps, bg_remove, dwpose, rembg_session, resize_inputs, is_app=is_app)
 
 
 
 
1091
 
1092
  if save_model:
1093
  train("fine_tuned_pcdms", in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
 
1134
  images = [img[0] for img in images]
1135
  in_img = images[0]
1136
 
1137
+ in_img, target_poses, in_pose, target_poses_coords, orig_frames = prepare_inputs_inference(in_img, video_path, fps, dwpose, rembg_session, bg_remove, resize_inputs, is_app)
1138
 
1139
  results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
1140
  #urls = save_temp_imgs(results)