Spaces:
Paused
Paused
fix eval script
Browse files- evaluate.py +23 -18
- main.py +25 -11
evaluate.py
CHANGED
|
@@ -60,12 +60,11 @@ with open('metrics.json', 'r') as file:
|
|
| 60 |
metrics = json.load(file)
|
| 61 |
|
| 62 |
def get_score(item, image_paths, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False):
|
| 63 |
-
|
| 64 |
-
|
| 65 |
images = []
|
| 66 |
for path in image_paths:
|
| 67 |
img = Image.open(path)
|
| 68 |
-
images.append(img)
|
| 69 |
|
| 70 |
gt_frames = extract_frames(video_path, fps)
|
| 71 |
|
|
@@ -76,13 +75,13 @@ def get_score(item, image_paths, video_path, train_steps=100, inference_steps=10
|
|
| 76 |
frame.save("out/"+item+"/frame_"+str(i)+".png")
|
| 77 |
|
| 78 |
#results = run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, finetune=True)
|
| 79 |
-
results = run_eval(images, video_path, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=False, resize_inputs=False)
|
| 80 |
|
|
|
|
|
|
|
| 81 |
for i, result in enumerate(results):
|
| 82 |
result.save("out/"+item+"/result_"+str(i)+".png")
|
| 83 |
|
| 84 |
-
results_base = run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, finetune=False)
|
| 85 |
-
|
| 86 |
for i, result in enumerate(results_base):
|
| 87 |
result.save("out/"+item+"/base_"+str(i)+".png")
|
| 88 |
|
|
@@ -156,23 +155,29 @@ def run_evaluate():
|
|
| 156 |
snapshot_download(repo_id="acmyu/KeyframesAI-eval", local_dir="test", repo_type="dataset")
|
| 157 |
|
| 158 |
items = os.listdir('test')
|
| 159 |
-
items = [
|
|
|
|
|
|
|
| 160 |
|
| 161 |
for item in items:
|
| 162 |
if item in metrics:
|
| 163 |
continue
|
| 164 |
-
|
| 165 |
-
name = os.path.basename(os.path.dirname(item))
|
| 166 |
-
print(name)
|
| 167 |
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
|
| 178 |
ssim = []
|
|
|
|
| 60 |
metrics = json.load(file)
|
| 61 |
|
| 62 |
def get_score(item, image_paths, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False):
|
| 63 |
+
|
|
|
|
| 64 |
images = []
|
| 65 |
for path in image_paths:
|
| 66 |
img = Image.open(path)
|
| 67 |
+
images.append([img])
|
| 68 |
|
| 69 |
gt_frames = extract_frames(video_path, fps)
|
| 70 |
|
|
|
|
| 75 |
frame.save("out/"+item+"/frame_"+str(i)+".png")
|
| 76 |
|
| 77 |
#results = run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, finetune=True)
|
| 78 |
+
results, results_base = run_eval(images, video_path, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=False, resize_inputs=False)
|
| 79 |
|
| 80 |
+
print(results)
|
| 81 |
+
|
| 82 |
for i, result in enumerate(results):
|
| 83 |
result.save("out/"+item+"/result_"+str(i)+".png")
|
| 84 |
|
|
|
|
|
|
|
| 85 |
for i, result in enumerate(results_base):
|
| 86 |
result.save("out/"+item+"/base_"+str(i)+".png")
|
| 87 |
|
|
|
|
| 155 |
snapshot_download(repo_id="acmyu/KeyframesAI-eval", local_dir="test", repo_type="dataset")
|
| 156 |
|
| 157 |
items = os.listdir('test')
|
| 158 |
+
items = [it for it in items if not it[0]=='.' and not os.path.isfile('test/'+it)]
|
| 159 |
+
print(items)
|
| 160 |
+
items = ['sidewalk'] #['sidewalk', 'aaa', 'azri', 'dead', 'frankgirl', 'kobold', 'ramona', 'renee', 'walk', 'woody']
|
| 161 |
|
| 162 |
for item in items:
|
| 163 |
if item in metrics:
|
| 164 |
continue
|
| 165 |
+
print(item)
|
|
|
|
|
|
|
| 166 |
|
| 167 |
+
try:
|
| 168 |
+
files = get_files('test/'+item)
|
| 169 |
+
images = list(filter(lambda x: not x.endswith('.mp4'), files))
|
| 170 |
+
images = ['test/'+item+'/'+img for img in images]
|
| 171 |
+
videos = [x for x in files if x.endswith('.mp4')]
|
| 172 |
+
print(images, videos)
|
| 173 |
+
|
| 174 |
+
if len(videos) == 1:
|
| 175 |
+
get_score(item, images, 'test/'+item+'/'+videos[0])
|
| 176 |
+
#get_score(item, ['test/'+item+'/1.jpg', 'test/'+item+'/2.jpg', 'test/'+item+'/3.jpg'], 'test/'+item+'/v.mp4')
|
| 177 |
+
else:
|
| 178 |
+
print('Error: mp4 not found')
|
| 179 |
+
except:
|
| 180 |
+
print("Error", item)
|
| 181 |
|
| 182 |
|
| 183 |
ssim = []
|
main.py
CHANGED
|
@@ -708,7 +708,18 @@ def train(modelId, in_image, in_pose, train_images, train_poses, train_steps, pc
|
|
| 708 |
if not finetune or train_steps == 0:
|
| 709 |
accelerator.wait_for_everyone()
|
| 710 |
accelerator.end_training()
|
| 711 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 712 |
|
| 713 |
|
| 714 |
it = range(starting_epoch, args.num_train_epochs)
|
|
@@ -1117,7 +1128,7 @@ def run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remo
|
|
| 1117 |
return results
|
| 1118 |
|
| 1119 |
|
| 1120 |
-
def run_train_impl(images, train_steps=100, modelId="fine_tuned_pcdms", bg_remove=True, resize_inputs=True):
|
| 1121 |
finetune=True
|
| 1122 |
is_app=True
|
| 1123 |
images = [img[0] for img in images]
|
|
@@ -1250,25 +1261,28 @@ def run_app(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_
|
|
| 1250 |
return out_vid+'.webm', results
|
| 1251 |
|
| 1252 |
|
| 1253 |
-
def run_eval(
|
| 1254 |
-
finetune=True
|
| 1255 |
is_app=False
|
| 1256 |
|
| 1257 |
dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
|
| 1258 |
|
| 1259 |
-
|
| 1260 |
-
|
| 1261 |
-
images = [img[0] for img in images]
|
| 1262 |
-
in_img = images[0]
|
| 1263 |
|
| 1264 |
-
in_img,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1265 |
|
| 1266 |
-
|
|
|
|
|
|
|
| 1267 |
|
| 1268 |
gc.collect()
|
| 1269 |
torch.cuda.empty_cache()
|
| 1270 |
|
| 1271 |
-
return results
|
| 1272 |
|
| 1273 |
|
| 1274 |
|
|
|
|
| 708 |
if not finetune or train_steps == 0:
|
| 709 |
accelerator.wait_for_everyone()
|
| 710 |
accelerator.end_training()
|
| 711 |
+
|
| 712 |
+
checkpoint_state_dict = {
|
| 713 |
+
"epoch": 0,
|
| 714 |
+
"module": {k: v.cpu() for k, v in sd_model.state_dict().items()}, #sd_model.state_dict(),
|
| 715 |
+
}
|
| 716 |
+
torch.save(checkpoint_state_dict, modelId+".pt")
|
| 717 |
+
|
| 718 |
+
del sd_model
|
| 719 |
+
gc.collect()
|
| 720 |
+
torch.cuda.empty_cache()
|
| 721 |
+
return
|
| 722 |
+
#return {k: v.cpu() for k, v in sd_model.state_dict().items()}
|
| 723 |
|
| 724 |
|
| 725 |
it = range(starting_epoch, args.num_train_epochs)
|
|
|
|
| 1128 |
return results
|
| 1129 |
|
| 1130 |
|
| 1131 |
+
def run_train_impl(images, train_steps=100, modelId="fine_tuned_pcdms", bg_remove=True, resize_inputs=True, finetune=True):
|
| 1132 |
finetune=True
|
| 1133 |
is_app=True
|
| 1134 |
images = [img[0] for img in images]
|
|
|
|
| 1261 |
return out_vid+'.webm', results
|
| 1262 |
|
| 1263 |
|
| 1264 |
+
def run_eval(images_orig, video_path, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=False, resize_inputs=False):
|
|
|
|
| 1265 |
is_app=False
|
| 1266 |
|
| 1267 |
dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
|
| 1268 |
|
| 1269 |
+
images = [img[0] for img in images_orig]
|
|
|
|
|
|
|
|
|
|
| 1270 |
|
| 1271 |
+
in_img, in_pose, train_imgs, train_poses = prepare_inputs_train(images, bg_remove, dwpose, rembg_session)
|
| 1272 |
+
in_img, target_poses, in_pose, _, _ = prepare_inputs_inference(in_img, video_path, [], fps, dwpose, rembg_session, bg_remove, resize_inputs, is_app)
|
| 1273 |
+
|
| 1274 |
+
finetune = True
|
| 1275 |
+
train(modelId, in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
|
| 1276 |
+
results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
|
| 1277 |
|
| 1278 |
+
finetune = False
|
| 1279 |
+
train(modelId, in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
|
| 1280 |
+
results_base = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
|
| 1281 |
|
| 1282 |
gc.collect()
|
| 1283 |
torch.cuda.empty_cache()
|
| 1284 |
|
| 1285 |
+
return results, results_base
|
| 1286 |
|
| 1287 |
|
| 1288 |
|