Spaces:
Paused
Paused
updates
Browse files
app.py
CHANGED
|
@@ -22,23 +22,21 @@ with gr.Blocks() as demo:
|
|
| 22 |
submit_btn = gr.Button(value="Generate")
|
| 23 |
with gr.Column():
|
| 24 |
animation = gr.Video(label="Result")
|
| 25 |
-
frames = gr.Gallery(type="pil", label="Frames")
|
| 26 |
|
| 27 |
submit_btn.click(
|
| 28 |
run_app, inputs=[char_imgs, mocap, tr_steps, inf_steps, fps, remove_bg, resize_inputs], outputs=[animation, frames]
|
| 29 |
)
|
| 30 |
|
| 31 |
train_btn.click(
|
| 32 |
-
run_train, inputs=[char_imgs, tr_steps, remove_bg, resize_inputs
|
| 33 |
)
|
| 34 |
|
| 35 |
inference_btn.click(
|
| 36 |
-
run_inference, inputs=[char_imgs, mocap, inf_steps, fps, remove_bg, resize_inputs
|
| 37 |
)
|
| 38 |
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
demo.launch(share=True)
|
| 43 |
|
| 44 |
|
|
|
|
| 22 |
submit_btn = gr.Button(value="Generate")
|
| 23 |
with gr.Column():
|
| 24 |
animation = gr.Video(label="Result")
|
| 25 |
+
frames = gr.Gallery(type="pil", label="Frames", format="png")
|
| 26 |
|
| 27 |
submit_btn.click(
|
| 28 |
run_app, inputs=[char_imgs, mocap, tr_steps, inf_steps, fps, remove_bg, resize_inputs], outputs=[animation, frames]
|
| 29 |
)
|
| 30 |
|
| 31 |
train_btn.click(
|
| 32 |
+
run_train, inputs=[char_imgs, tr_steps, modelId, remove_bg, resize_inputs], outputs=[]
|
| 33 |
)
|
| 34 |
|
| 35 |
inference_btn.click(
|
| 36 |
+
run_inference, inputs=[char_imgs, mocap, tr_steps, inf_steps, fps, modelId, remove_bg, resize_inputs], outputs=[animation, frames]
|
| 37 |
)
|
| 38 |
|
| 39 |
|
|
|
|
|
|
|
| 40 |
demo.launch(share=True)
|
| 41 |
|
| 42 |
|
main.py
CHANGED
|
@@ -57,6 +57,8 @@ import rembg
|
|
| 57 |
import uuid
|
| 58 |
import gc
|
| 59 |
from numba import cuda
|
|
|
|
|
|
|
| 60 |
|
| 61 |
from huggingface_hub import hf_hub_download
|
| 62 |
|
|
@@ -76,8 +78,34 @@ fps = 12
|
|
| 76 |
|
| 77 |
debug = False
|
| 78 |
save_model = True
|
|
|
|
| 79 |
max_batch_size = 8
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
# Pose detection ==============================================================================================
|
| 82 |
|
| 83 |
def load_models():
|
|
@@ -712,7 +740,11 @@ def train(modelId, in_image, in_pose, train_images, train_poses, train_steps, pc
|
|
| 712 |
accelerator.wait_for_everyone()
|
| 713 |
accelerator.end_training()
|
| 714 |
|
| 715 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 716 |
|
| 717 |
if save_model: #if global_steps % args.checkpointing_steps == 0 or global_steps == args.max_train_steps:
|
| 718 |
print('saving', modelId)
|
|
@@ -724,13 +756,14 @@ def train(modelId, in_image, in_pose, train_images, train_poses, train_steps, pc
|
|
| 724 |
print(list(sd_model.state_dict().keys())[:20])
|
| 725 |
torch.save(checkpoint_state_dict, modelId+".pt")
|
| 726 |
|
|
|
|
| 727 |
gc.collect()
|
| 728 |
torch.cuda.empty_cache()
|
| 729 |
-
#device = cuda.get_current_device()
|
| 730 |
-
#device.reset()
|
| 731 |
print('done train')
|
|
|
|
| 732 |
return
|
| 733 |
|
|
|
|
| 734 |
gc.collect()
|
| 735 |
torch.cuda.empty_cache()
|
| 736 |
return {k: v.cpu() for k, v in sd_model.state_dict().items()}
|
|
@@ -953,8 +986,18 @@ def inference(modelId, in_image, in_pose, target_poses, inference_steps, finetun
|
|
| 953 |
results.append(result)
|
| 954 |
progress_bar.update(1)
|
| 955 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 956 |
gc.collect()
|
| 957 |
torch.cuda.empty_cache()
|
|
|
|
| 958 |
|
| 959 |
return results
|
| 960 |
|
|
@@ -1006,7 +1049,7 @@ def run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remo
|
|
| 1006 |
return results
|
| 1007 |
|
| 1008 |
|
| 1009 |
-
def run_train(images, train_steps=100,
|
| 1010 |
finetune=True
|
| 1011 |
is_app=True
|
| 1012 |
images = [img[0] for img in images]
|
|
@@ -1023,21 +1066,29 @@ def run_train(images, train_steps=100, bg_remove=False, resize_inputs=True, mode
|
|
| 1023 |
train(modelId, in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
|
| 1024 |
|
| 1025 |
|
| 1026 |
-
def run_inference(images, video_path, inference_steps=10, fps=12,
|
|
|
|
| 1027 |
is_app=True
|
| 1028 |
-
|
| 1029 |
-
in_img = images[0]
|
| 1030 |
|
| 1031 |
dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
|
| 1032 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1033 |
target_poses, in_pose = prepare_inputs_inference(in_img, video_path, fps, dwpose, 'target', is_app)
|
| 1034 |
|
| 1035 |
results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
|
| 1036 |
-
|
| 1037 |
-
|
| 1038 |
-
|
| 1039 |
-
|
| 1040 |
-
|
|
|
|
|
|
|
| 1041 |
|
| 1042 |
print("Done!")
|
| 1043 |
|
|
|
|
| 57 |
import uuid
|
| 58 |
import gc
|
| 59 |
from numba import cuda
|
| 60 |
+
import requests
|
| 61 |
+
import uuid
|
| 62 |
|
| 63 |
from huggingface_hub import hf_hub_download
|
| 64 |
|
|
|
|
| 78 |
|
| 79 |
debug = False
|
| 80 |
save_model = True
|
| 81 |
+
should_gen_vid = False
|
| 82 |
max_batch_size = 8
|
| 83 |
|
| 84 |
+
|
| 85 |
+
def save_temp_imgs(imgs):
|
| 86 |
+
for img in imgs:
|
| 87 |
+
|
| 88 |
+
img_name = str(uuid.uuid4())+'.png'
|
| 89 |
+
img.save(img_name)
|
| 90 |
+
print(img_name)
|
| 91 |
+
|
| 92 |
+
url = 'https://tmpfiles.org/api/v1/upload'
|
| 93 |
+
data_payload = {'file': img_name}
|
| 94 |
+
|
| 95 |
+
try:
|
| 96 |
+
response = requests.post(url, data=data_payload)
|
| 97 |
+
|
| 98 |
+
# Check for successful response (status code 200)
|
| 99 |
+
response.raise_for_status()
|
| 100 |
+
|
| 101 |
+
# Print the server's response
|
| 102 |
+
print("Status Code:", response.status_code)
|
| 103 |
+
print("Response JSON:", response.json())
|
| 104 |
+
|
| 105 |
+
except requests.exceptions.RequestException as e:
|
| 106 |
+
print(f"An error occurred: {e}")
|
| 107 |
+
|
| 108 |
+
|
| 109 |
# Pose detection ==============================================================================================
|
| 110 |
|
| 111 |
def load_models():
|
|
|
|
| 740 |
accelerator.wait_for_everyone()
|
| 741 |
accelerator.end_training()
|
| 742 |
|
| 743 |
+
sd_model.unet.cpu()
|
| 744 |
+
sd_model.cpu()
|
| 745 |
+
del vae
|
| 746 |
+
del image_encoder_p
|
| 747 |
+
del image_encoder_g
|
| 748 |
|
| 749 |
if save_model: #if global_steps % args.checkpointing_steps == 0 or global_steps == args.max_train_steps:
|
| 750 |
print('saving', modelId)
|
|
|
|
| 756 |
print(list(sd_model.state_dict().keys())[:20])
|
| 757 |
torch.save(checkpoint_state_dict, modelId+".pt")
|
| 758 |
|
| 759 |
+
del sd_model
|
| 760 |
gc.collect()
|
| 761 |
torch.cuda.empty_cache()
|
|
|
|
|
|
|
| 762 |
print('done train')
|
| 763 |
+
print(torch.cuda.memory_allocated()/1024**2)
|
| 764 |
return
|
| 765 |
|
| 766 |
+
del sd_model
|
| 767 |
gc.collect()
|
| 768 |
torch.cuda.empty_cache()
|
| 769 |
return {k: v.cpu() for k, v in sd_model.state_dict().items()}
|
|
|
|
| 986 |
results.append(result)
|
| 987 |
progress_bar.update(1)
|
| 988 |
|
| 989 |
+
del unet
|
| 990 |
+
del vae
|
| 991 |
+
del image_encoder
|
| 992 |
+
del image_proj_model
|
| 993 |
+
del pose_proj_model
|
| 994 |
+
|
| 995 |
+
if not save_model:
|
| 996 |
+
del finetuned_model
|
| 997 |
+
|
| 998 |
gc.collect()
|
| 999 |
torch.cuda.empty_cache()
|
| 1000 |
+
print(torch.cuda.memory_allocated()/1024**2)
|
| 1001 |
|
| 1002 |
return results
|
| 1003 |
|
|
|
|
| 1049 |
return results
|
| 1050 |
|
| 1051 |
|
| 1052 |
+
def run_train(images, train_steps=100, modelId="fine_tuned_pcdms", bg_remove=True, resize_inputs=True):
|
| 1053 |
finetune=True
|
| 1054 |
is_app=True
|
| 1055 |
images = [img[0] for img in images]
|
|
|
|
| 1066 |
train(modelId, in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
|
| 1067 |
|
| 1068 |
|
| 1069 |
+
def run_inference(images, video_path, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", bg_remove=True, resize_inputs=True):
|
| 1070 |
+
finetune=True
|
| 1071 |
is_app=True
|
| 1072 |
+
|
|
|
|
| 1073 |
|
| 1074 |
dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
|
| 1075 |
|
| 1076 |
+
if not os.path.exists(modelId+".pt"):
|
| 1077 |
+
run_train(images, train_steps, modelId, bg_remove, resize_inputs)
|
| 1078 |
+
|
| 1079 |
+
images = [img[0] for img in images]
|
| 1080 |
+
in_img = images[0]
|
| 1081 |
+
|
| 1082 |
target_poses, in_pose = prepare_inputs_inference(in_img, video_path, fps, dwpose, 'target', is_app)
|
| 1083 |
|
| 1084 |
results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
|
| 1085 |
+
save_temp_imgs(results)
|
| 1086 |
+
|
| 1087 |
+
if should_gen_vid:
|
| 1088 |
+
if debug:
|
| 1089 |
+
gen_vid(results, out_vid+'.mp4', fps, 'mp4')
|
| 1090 |
+
else:
|
| 1091 |
+
gen_vid(results, out_vid+'.webm', fps, 'webm')
|
| 1092 |
|
| 1093 |
print("Done!")
|
| 1094 |
|