acmyu commited on
Commit
3dcc757
·
1 Parent(s): 604a6a7

try multiprocessing

Browse files
Files changed (1) hide show
  1. main.py +23 -32
main.py CHANGED
@@ -65,7 +65,8 @@ import json
65
 
66
  from huggingface_hub import hf_hub_download, HfApi
67
  from numba import cuda
68
- from multiprocessing import Pool
 
69
 
70
  # Inputs ===================================================================================================
71
 
@@ -1111,7 +1112,7 @@ def run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remo
1111
  return results
1112
 
1113
 
1114
- def run_train(images, train_steps=100, modelId="fine_tuned_pcdms", bg_remove=True, resize_inputs=True):
1115
  finetune=True
1116
  is_app=True
1117
  images = [img[0] for img in images]
@@ -1127,8 +1128,18 @@ def run_train(images, train_steps=100, modelId="fine_tuned_pcdms", bg_remove=Tru
1127
 
1128
  train(modelId, in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
1129
 
 
 
1130
 
1131
- def run_inference(images, video_path, frames, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=True, resize_inputs=True):
 
 
 
 
 
 
 
 
1132
  finetune=True
1133
  is_app=True
1134
 
@@ -1160,16 +1171,17 @@ def run_inference(images, video_path, frames, train_steps=100, inference_steps=1
1160
  #results = [img_pad(img, img_width, img_height, True) for img in results]
1161
 
1162
  print("Done!")
1163
-
1164
  return out_vid+'.webm', results, getThumbnails(results), target_poses_coords, getThumbnails(orig_frames)
1165
-
1166
 
1167
- def run_generate_frame(images, target_poses, train_steps=100, inference_steps=10, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=True, resize_inputs=True):
 
 
 
 
1168
  finetune=True
1169
  is_app=True
1170
 
1171
-
1172
-
1173
  dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
1174
 
1175
  if not os.path.exists(modelId+".pt"):
@@ -1198,6 +1210,9 @@ def run_generate_frame(images, target_poses, train_steps=100, inference_steps=10
1198
 
1199
  return results, getThumbnails(results)
1200
 
 
 
 
1201
 
1202
  def run_app(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, resize_inputs=True):
1203
 
@@ -1246,33 +1261,9 @@ def interpolate_frames(frame1, frame2, times_to_interp):
1246
  def run_interpolate_frames(frame1, frame2, times_to_interp):
1247
  with Pool() as pool:
1248
  results = pool.starmap(interpolate_frames, [(frame1, frame2, times_to_interp)])
1249
- print(results[0])
1250
  return results[0]
1251
 
1252
 
1253
- """
1254
- train_steps = 100
1255
- inference_steps = 10
1256
- fps = 12
1257
- """
1258
-
1259
- """
1260
- iface = gr.Interface(
1261
- fn=run,
1262
- inputs=[
1263
- gr.Gallery(type="pil", label="Images of the Character"),
1264
- gr.Video(label="Motion-Capture Video"),
1265
- gr.Number(label="Training steps", value=100),
1266
- gr.Number(label="Inference steps", value=10),
1267
- gr.Number(label="Output frame rate", value=12),
1268
- gr.Checkbox(label="Remove background", value=False),
1269
- ],
1270
- outputs=[gr.Video(label="Result"), gr.Gallery(type="pil", label="Frames")],
1271
- title="Keyframes AI",
1272
- description="Upload images of your character and a motion-capture video to generate an animation of the character.",
1273
- )
1274
- """
1275
-
1276
 
1277
 
1278
 
 
65
 
66
  from huggingface_hub import hf_hub_download, HfApi
67
  from numba import cuda
68
+ from multiprocessing import Pool, Process, Queue
69
+ import torch.multiprocessing as mp
70
 
71
  # Inputs ===================================================================================================
72
 
 
1112
  return results
1113
 
1114
 
1115
+ def run_train_impl(images, train_steps=100, modelId="fine_tuned_pcdms", bg_remove=True, resize_inputs=True):
1116
  finetune=True
1117
  is_app=True
1118
  images = [img[0] for img in images]
 
1128
 
1129
  train(modelId, in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
1130
 
1131
+ def run_train(images, train_steps=100, modelId="fine_tuned_pcdms", bg_remove=True, resize_inputs=True):
1132
+ run_train_impl(images, train_steps, modelId, bg_remove, resize_inputs)
1133
 
1134
+ """
1135
+ mp.set_start_method('spawn', force=True)
1136
+ p = mp.Process(target=run_train_impl, args=(images, train_steps, modelId, bg_remove, resize_inputs))
1137
+ p.start()
1138
+ p.join()
1139
+ """
1140
+
1141
+
1142
+ def run_inference_impl(images, video_path, frames, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=True, resize_inputs=True):
1143
  finetune=True
1144
  is_app=True
1145
 
 
1171
  #results = [img_pad(img, img_width, img_height, True) for img in results]
1172
 
1173
  print("Done!")
1174
+
1175
  return out_vid+'.webm', results, getThumbnails(results), target_poses_coords, getThumbnails(orig_frames)
 
1176
 
1177
+ def run_inference(images, video_path, frames, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=True, resize_inputs=True):
1178
+ return run_inference_impl(images, video_path, frames, train_steps, inference_steps, fps, modelId, img_width, img_height, bg_remove, resize_inputs)
1179
+
1180
+
1181
+ def generate_frame(images, target_poses, train_steps=100, inference_steps=10, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=True, resize_inputs=True):
1182
  finetune=True
1183
  is_app=True
1184
 
 
 
1185
  dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
1186
 
1187
  if not os.path.exists(modelId+".pt"):
 
1210
 
1211
  return results, getThumbnails(results)
1212
 
1213
+ def run_generate_frame(images, target_poses, train_steps=100, inference_steps=10, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=True, resize_inputs=True):
1214
+ return generate_frame(images, target_poses, train_steps, inference_steps, modelId, img_width, img_height, bg_remove, resize_inputs)
1215
+
1216
 
1217
  def run_app(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, resize_inputs=True):
1218
 
 
1261
  def run_interpolate_frames(frame1, frame2, times_to_interp):
1262
  with Pool() as pool:
1263
  results = pool.starmap(interpolate_frames, [(frame1, frame2, times_to_interp)])
 
1264
  return results[0]
1265
 
1266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1267
 
1268
 
1269