acmyu commited on
Commit
d3ec508
·
1 Parent(s): a3692c7

api for generating frames from poses

Browse files
Files changed (2) hide show
  1. app.py +7 -1
  2. main.py +27 -0
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from main import run_app, run_train, run_inference
2
 
3
  import spaces
4
  from PIL import Image
@@ -11,6 +11,7 @@ with gr.Blocks() as demo:
11
  with gr.Column():
12
  char_imgs = gr.Gallery(type="pil", label="Images of the Character")
13
  mocap = gr.Video(label="Motion-Capture Video")
 
14
  tr_steps = gr.Number(label="Training steps", value=10)
15
  inf_steps = gr.Number(label="Inference steps", value=10)
16
  fps = gr.Number(label="Output frame rate", value=12)
@@ -21,6 +22,7 @@ with gr.Blocks() as demo:
21
  img_height = gr.Number(label="Output height", value=1080)
22
  train_btn = gr.Button(value="Train")
23
  inference_btn = gr.Button(value="Inference")
 
24
  submit_btn = gr.Button(value="Generate")
25
  with gr.Column():
26
  animation = gr.Video(label="Result")
@@ -41,6 +43,10 @@ with gr.Blocks() as demo:
41
  run_inference, inputs=[char_imgs, mocap, tr_steps, inf_steps, fps, modelId, img_width, img_height, remove_bg, resize_inputs], outputs=[animation, frames, frames_thumb, pose_coords, reference]
42
  )
43
 
 
 
 
 
44
 
45
  demo.launch(share=True)
46
 
 
1
+ from main import run_app, run_train, run_inference, run_generate_frame
2
 
3
  import spaces
4
  from PIL import Image
 
11
  with gr.Column():
12
  char_imgs = gr.Gallery(type="pil", label="Images of the Character")
13
  mocap = gr.Video(label="Motion-Capture Video")
14
+ poses = gr.JSON(label="Pose Coordinates")
15
  tr_steps = gr.Number(label="Training steps", value=10)
16
  inf_steps = gr.Number(label="Inference steps", value=10)
17
  fps = gr.Number(label="Output frame rate", value=12)
 
22
  img_height = gr.Number(label="Output height", value=1080)
23
  train_btn = gr.Button(value="Train")
24
  inference_btn = gr.Button(value="Inference")
25
+ generate_frame_btn = gr.Button(value="Generate Frame")
26
  submit_btn = gr.Button(value="Generate")
27
  with gr.Column():
28
  animation = gr.Video(label="Result")
 
43
  run_inference, inputs=[char_imgs, mocap, tr_steps, inf_steps, fps, modelId, img_width, img_height, remove_bg, resize_inputs], outputs=[animation, frames, frames_thumb, pose_coords, reference]
44
  )
45
 
46
+ generate_frame_btn.click(
47
+ run_generate_frame, inputs=[char_imgs, poses, tr_steps, inf_steps, modelId, img_width, img_height, remove_bg, resize_inputs], outputs=[frames, frames_thumb]
48
+ )
49
+
50
 
51
  demo.launch(share=True)
52
 
main.py CHANGED
@@ -1155,6 +1155,33 @@ def run_inference(images, video_path, train_steps=100, inference_steps=10, fps=1
1155
  return out_vid+'.webm', results, getThumbnails(results), target_poses_coords, getThumbnails(orig_frames)
1156
 
1157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1158
  def run_app(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, resize_inputs=True):
1159
 
1160
  images = [img[0] for img in images]
 
1155
  return out_vid+'.webm', results, getThumbnails(results), target_poses_coords, getThumbnails(orig_frames)
1156
 
1157
 
1158
+ def run_generate_frame(images, target_poses, train_steps=100, inference_steps=10, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=True, resize_inputs=True):
1159
+ finetune=True
1160
+ is_app=True
1161
+
1162
+ print(target_poses)
1163
+
1164
+ dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
1165
+
1166
+ if not os.path.exists(modelId+".pt"):
1167
+ run_train(images, train_steps, modelId, bg_remove, resize_inputs)
1168
+
1169
+ images = [img[0] for img in images]
1170
+ in_img = images[0]
1171
+ in_pose, _ = get_pose(in_img, dwpose, "in_pose.png")
1172
+
1173
+ results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
1174
+ #urls = save_temp_imgs(results)
1175
+
1176
+ # postprocessing
1177
+ results = [removebg(img, rembg_session, True) for img in results]
1178
+ #results = [img_pad(img, img_width, img_height, True) for img in results]
1179
+
1180
+ print("Done!")
1181
+
1182
+ return results, getThumbnails(results)
1183
+
1184
+
1185
  def run_app(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, resize_inputs=True):
1186
 
1187
  images = [img[0] for img in images]