acmyu commited on
Commit
fc6eedb
·
1 Parent(s): 3366cca
Files changed (2) hide show
  1. app.py +3 -5
  2. main.py +63 -12
app.py CHANGED
@@ -22,23 +22,21 @@ with gr.Blocks() as demo:
22
  submit_btn = gr.Button(value="Generate")
23
  with gr.Column():
24
  animation = gr.Video(label="Result")
25
- frames = gr.Gallery(type="pil", label="Frames")
26
 
27
  submit_btn.click(
28
  run_app, inputs=[char_imgs, mocap, tr_steps, inf_steps, fps, remove_bg, resize_inputs], outputs=[animation, frames]
29
  )
30
 
31
  train_btn.click(
32
- run_train, inputs=[char_imgs, tr_steps, remove_bg, resize_inputs, modelId], outputs=[]
33
  )
34
 
35
  inference_btn.click(
36
- run_inference, inputs=[char_imgs, mocap, inf_steps, fps, remove_bg, resize_inputs, modelId], outputs=[animation, frames]
37
  )
38
 
39
 
40
-
41
-
42
  demo.launch(share=True)
43
 
44
 
 
22
  submit_btn = gr.Button(value="Generate")
23
  with gr.Column():
24
  animation = gr.Video(label="Result")
25
+ frames = gr.Gallery(type="pil", label="Frames", format="png")
26
 
27
  submit_btn.click(
28
  run_app, inputs=[char_imgs, mocap, tr_steps, inf_steps, fps, remove_bg, resize_inputs], outputs=[animation, frames]
29
  )
30
 
31
  train_btn.click(
32
+ run_train, inputs=[char_imgs, tr_steps, modelId, remove_bg, resize_inputs], outputs=[]
33
  )
34
 
35
  inference_btn.click(
36
+ run_inference, inputs=[char_imgs, mocap, tr_steps, inf_steps, fps, modelId, remove_bg, resize_inputs], outputs=[animation, frames]
37
  )
38
 
39
 
 
 
40
  demo.launch(share=True)
41
 
42
 
main.py CHANGED
@@ -57,6 +57,8 @@ import rembg
57
  import uuid
58
  import gc
59
  from numba import cuda
 
 
60
 
61
  from huggingface_hub import hf_hub_download
62
 
@@ -76,8 +78,34 @@ fps = 12
76
 
77
  debug = False
78
  save_model = True
 
79
  max_batch_size = 8
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  # Pose detection ==============================================================================================
82
 
83
  def load_models():
@@ -712,7 +740,11 @@ def train(modelId, in_image, in_pose, train_images, train_poses, train_steps, pc
712
  accelerator.wait_for_everyone()
713
  accelerator.end_training()
714
 
715
-
 
 
 
 
716
 
717
  if save_model: #if global_steps % args.checkpointing_steps == 0 or global_steps == args.max_train_steps:
718
  print('saving', modelId)
@@ -724,13 +756,14 @@ def train(modelId, in_image, in_pose, train_images, train_poses, train_steps, pc
724
  print(list(sd_model.state_dict().keys())[:20])
725
  torch.save(checkpoint_state_dict, modelId+".pt")
726
 
 
727
  gc.collect()
728
  torch.cuda.empty_cache()
729
- #device = cuda.get_current_device()
730
- #device.reset()
731
  print('done train')
 
732
  return
733
 
 
734
  gc.collect()
735
  torch.cuda.empty_cache()
736
  return {k: v.cpu() for k, v in sd_model.state_dict().items()}
@@ -953,8 +986,18 @@ def inference(modelId, in_image, in_pose, target_poses, inference_steps, finetun
953
  results.append(result)
954
  progress_bar.update(1)
955
 
 
 
 
 
 
 
 
 
 
956
  gc.collect()
957
  torch.cuda.empty_cache()
 
958
 
959
  return results
960
 
@@ -1006,7 +1049,7 @@ def run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remo
1006
  return results
1007
 
1008
 
1009
- def run_train(images, train_steps=100, bg_remove=False, resize_inputs=True, modelId="fine_tuned_pcdms"):
1010
  finetune=True
1011
  is_app=True
1012
  images = [img[0] for img in images]
@@ -1023,21 +1066,29 @@ def run_train(images, train_steps=100, bg_remove=False, resize_inputs=True, mode
1023
  train(modelId, in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
1024
 
1025
 
1026
- def run_inference(images, video_path, inference_steps=10, fps=12, bg_remove=False, resize_inputs=True, modelId="fine_tuned_pcdms"):
 
1027
  is_app=True
1028
- images = [img[0] for img in images]
1029
- in_img = images[0]
1030
 
1031
  dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
1032
 
 
 
 
 
 
 
1033
  target_poses, in_pose = prepare_inputs_inference(in_img, video_path, fps, dwpose, 'target', is_app)
1034
 
1035
  results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
1036
-
1037
- if debug:
1038
- gen_vid(results, out_vid+'.mp4', fps, 'mp4')
1039
- else:
1040
- gen_vid(results, out_vid+'.webm', fps, 'webm')
 
 
1041
 
1042
  print("Done!")
1043
 
 
57
  import uuid
58
  import gc
59
  from numba import cuda
60
+ import requests
61
+ import uuid
62
 
63
  from huggingface_hub import hf_hub_download
64
 
 
78
 
79
  debug = False
80
  save_model = True
81
+ should_gen_vid = False
82
  max_batch_size = 8
83
 
84
+
85
+ def save_temp_imgs(imgs):
86
+ for img in imgs:
87
+
88
+ img_name = str(uuid.uuid4())+'.png'
89
+ img.save(img_name)
90
+ print(img_name)
91
+
92
+ url = 'https://tmpfiles.org/api/v1/upload'
93
+ data_payload = {'file': img_name}
94
+
95
+ try:
96
+ response = requests.post(url, data=data_payload)
97
+
98
+ # Check for successful response (status code 200)
99
+ response.raise_for_status()
100
+
101
+ # Print the server's response
102
+ print("Status Code:", response.status_code)
103
+ print("Response JSON:", response.json())
104
+
105
+ except requests.exceptions.RequestException as e:
106
+ print(f"An error occurred: {e}")
107
+
108
+
109
  # Pose detection ==============================================================================================
110
 
111
  def load_models():
 
740
  accelerator.wait_for_everyone()
741
  accelerator.end_training()
742
 
743
+ sd_model.unet.cpu()
744
+ sd_model.cpu()
745
+ del vae
746
+ del image_encoder_p
747
+ del image_encoder_g
748
 
749
  if save_model: #if global_steps % args.checkpointing_steps == 0 or global_steps == args.max_train_steps:
750
  print('saving', modelId)
 
756
  print(list(sd_model.state_dict().keys())[:20])
757
  torch.save(checkpoint_state_dict, modelId+".pt")
758
 
759
+ del sd_model
760
  gc.collect()
761
  torch.cuda.empty_cache()
 
 
762
  print('done train')
763
+ print(torch.cuda.memory_allocated()/1024**2)
764
  return
765
 
766
+ del sd_model
767
  gc.collect()
768
  torch.cuda.empty_cache()
769
  return {k: v.cpu() for k, v in sd_model.state_dict().items()}
 
986
  results.append(result)
987
  progress_bar.update(1)
988
 
989
+ del unet
990
+ del vae
991
+ del image_encoder
992
+ del image_proj_model
993
+ del pose_proj_model
994
+
995
+ if not save_model:
996
+ del finetuned_model
997
+
998
  gc.collect()
999
  torch.cuda.empty_cache()
1000
+ print(torch.cuda.memory_allocated()/1024**2)
1001
 
1002
  return results
1003
 
 
1049
  return results
1050
 
1051
 
1052
+ def run_train(images, train_steps=100, modelId="fine_tuned_pcdms", bg_remove=True, resize_inputs=True):
1053
  finetune=True
1054
  is_app=True
1055
  images = [img[0] for img in images]
 
1066
  train(modelId, in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
1067
 
1068
 
1069
+ def run_inference(images, video_path, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", bg_remove=True, resize_inputs=True):
1070
+ finetune=True
1071
  is_app=True
1072
+
 
1073
 
1074
  dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
1075
 
1076
+ if not os.path.exists(modelId+".pt"):
1077
+ run_train(images, train_steps, modelId, bg_remove, resize_inputs)
1078
+
1079
+ images = [img[0] for img in images]
1080
+ in_img = images[0]
1081
+
1082
  target_poses, in_pose = prepare_inputs_inference(in_img, video_path, fps, dwpose, 'target', is_app)
1083
 
1084
  results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
1085
+ save_temp_imgs(results)
1086
+
1087
+ if should_gen_vid:
1088
+ if debug:
1089
+ gen_vid(results, out_vid+'.mp4', fps, 'mp4')
1090
+ else:
1091
+ gen_vid(results, out_vid+'.webm', fps, 'webm')
1092
 
1093
  print("Done!")
1094