Spaces:
Paused
Paused
fix eval script
Browse files- evaluate.py +25 -20
- main.py +7 -4
evaluate.py
CHANGED
|
@@ -56,10 +56,9 @@ def compute_fid(img1, img2):
|
|
| 56 |
return fid
|
| 57 |
|
| 58 |
|
| 59 |
-
with open('metrics.json', 'r') as file:
|
| 60 |
-
metrics = json.load(file)
|
| 61 |
|
| 62 |
-
|
|
|
|
| 63 |
|
| 64 |
images = []
|
| 65 |
for path in image_paths:
|
|
@@ -67,6 +66,9 @@ def get_score(item, image_paths, video_path, train_steps=100, inference_steps=10
|
|
| 67 |
images.append([img])
|
| 68 |
|
| 69 |
gt_frames = extract_frames(video_path, fps)
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
os.makedirs('out/'+item, exist_ok=True)
|
| 72 |
|
|
@@ -76,8 +78,6 @@ def get_score(item, image_paths, video_path, train_steps=100, inference_steps=10
|
|
| 76 |
|
| 77 |
#results = run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, finetune=True)
|
| 78 |
results, results_base = run_eval(images, video_path, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=False, resize_inputs=False)
|
| 79 |
-
|
| 80 |
-
print(results)
|
| 81 |
|
| 82 |
for i, result in enumerate(results):
|
| 83 |
result.save("out/"+item+"/result_"+str(i)+".png")
|
|
@@ -134,6 +134,8 @@ def get_score(item, image_paths, video_path, train_steps=100, inference_steps=10
|
|
| 134 |
metrics[item]['base']['lpips'] = {'avg': sum(lpips2)/len(lpips2), 'vals': lpips2}
|
| 135 |
metrics[item]['base']['fid'] = {'avg': sum(fid2)/len(fid2), 'vals': fid2}
|
| 136 |
|
|
|
|
|
|
|
| 137 |
with open('metrics.json', "w", encoding="utf-8") as json_file:
|
| 138 |
json.dump(metrics, json_file, ensure_ascii=False, indent=4)
|
| 139 |
|
|
@@ -154,30 +156,33 @@ def get_files(directory_path):
|
|
| 154 |
def run_evaluate():
|
| 155 |
snapshot_download(repo_id="acmyu/KeyframesAI-eval", local_dir="test", repo_type="dataset")
|
| 156 |
|
|
|
|
|
|
|
|
|
|
| 157 |
items = os.listdir('test')
|
| 158 |
items = [it for it in items if not it[0]=='.' and not os.path.isfile('test/'+it)]
|
| 159 |
print(items)
|
| 160 |
-
items = ['sidewalk'] #['sidewalk', 'aaa', 'azri', 'dead', 'frankgirl', 'kobold', 'ramona', 'renee', 'walk', 'woody']
|
| 161 |
|
| 162 |
for item in items:
|
| 163 |
if item in metrics:
|
| 164 |
continue
|
| 165 |
print(item)
|
| 166 |
|
| 167 |
-
try:
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
except:
|
| 180 |
-
|
| 181 |
|
| 182 |
|
| 183 |
ssim = []
|
|
|
|
| 56 |
return fid
|
| 57 |
|
| 58 |
|
|
|
|
|
|
|
| 59 |
|
| 60 |
+
|
| 61 |
+
def get_score(item, image_paths, video_path, metrics, train_steps=100, inference_steps=10, fps=12, bg_remove=False):
|
| 62 |
|
| 63 |
images = []
|
| 64 |
for path in image_paths:
|
|
|
|
| 66 |
images.append([img])
|
| 67 |
|
| 68 |
gt_frames = extract_frames(video_path, fps)
|
| 69 |
+
#gt_frames = gt_frames[:2]
|
| 70 |
+
for f in gt_frames:
|
| 71 |
+
f.thumbnail((512,512))
|
| 72 |
|
| 73 |
os.makedirs('out/'+item, exist_ok=True)
|
| 74 |
|
|
|
|
| 78 |
|
| 79 |
#results = run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, finetune=True)
|
| 80 |
results, results_base = run_eval(images, video_path, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=False, resize_inputs=False)
|
|
|
|
|
|
|
| 81 |
|
| 82 |
for i, result in enumerate(results):
|
| 83 |
result.save("out/"+item+"/result_"+str(i)+".png")
|
|
|
|
| 134 |
metrics[item]['base']['lpips'] = {'avg': sum(lpips2)/len(lpips2), 'vals': lpips2}
|
| 135 |
metrics[item]['base']['fid'] = {'avg': sum(fid2)/len(fid2), 'vals': fid2}
|
| 136 |
|
| 137 |
+
#print(metrics)
|
| 138 |
+
|
| 139 |
with open('metrics.json', "w", encoding="utf-8") as json_file:
|
| 140 |
json.dump(metrics, json_file, ensure_ascii=False, indent=4)
|
| 141 |
|
|
|
|
| 156 |
def run_evaluate():
|
| 157 |
snapshot_download(repo_id="acmyu/KeyframesAI-eval", local_dir="test", repo_type="dataset")
|
| 158 |
|
| 159 |
+
with open('metrics.json', 'r') as file:
|
| 160 |
+
metrics = json.load(file)
|
| 161 |
+
|
| 162 |
items = os.listdir('test')
|
| 163 |
items = [it for it in items if not it[0]=='.' and not os.path.isfile('test/'+it)]
|
| 164 |
print(items)
|
| 165 |
+
#items = ['sidewalk'] #['sidewalk', 'aaa', 'azri', 'dead', 'frankgirl', 'kobold', 'ramona', 'renee', 'walk', 'woody']
|
| 166 |
|
| 167 |
for item in items:
|
| 168 |
if item in metrics:
|
| 169 |
continue
|
| 170 |
print(item)
|
| 171 |
|
| 172 |
+
#try:
|
| 173 |
+
files = get_files('test/'+item)
|
| 174 |
+
images = list(filter(lambda x: not x.endswith('.mp4'), files))
|
| 175 |
+
images = ['test/'+item+'/'+img for img in images]
|
| 176 |
+
videos = [x for x in files if x.endswith('.mp4')]
|
| 177 |
+
print(images, videos)
|
| 178 |
+
|
| 179 |
+
if len(videos) == 1:
|
| 180 |
+
get_score(item, images, 'test/'+item+'/'+videos[0], metrics)
|
| 181 |
+
#get_score(item, ['test/'+item+'/1.jpg', 'test/'+item+'/2.jpg', 'test/'+item+'/3.jpg'], 'test/'+item+'/v.mp4')
|
| 182 |
+
else:
|
| 183 |
+
print('Error: mp4 not found')
|
| 184 |
+
#except:
|
| 185 |
+
# print("Error", item)
|
| 186 |
|
| 187 |
|
| 188 |
ssim = []
|
main.py
CHANGED
|
@@ -1271,14 +1271,17 @@ def run_eval(images_orig, video_path, train_steps=100, inference_steps=10, fps=1
|
|
| 1271 |
in_img, in_pose, train_imgs, train_poses = prepare_inputs_train(images, bg_remove, dwpose, rembg_session)
|
| 1272 |
in_img, target_poses, in_pose, _, _ = prepare_inputs_inference(in_img, video_path, [], fps, dwpose, rembg_session, bg_remove, resize_inputs, is_app)
|
| 1273 |
|
| 1274 |
-
|
| 1275 |
-
|
| 1276 |
-
|
| 1277 |
-
|
| 1278 |
finetune = False
|
| 1279 |
train(modelId, in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
|
| 1280 |
results_base = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
|
| 1281 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1282 |
gc.collect()
|
| 1283 |
torch.cuda.empty_cache()
|
| 1284 |
|
|
|
|
| 1271 |
in_img, in_pose, train_imgs, train_poses = prepare_inputs_train(images, bg_remove, dwpose, rembg_session)
|
| 1272 |
in_img, target_poses, in_pose, _, _ = prepare_inputs_inference(in_img, video_path, [], fps, dwpose, rembg_session, bg_remove, resize_inputs, is_app)
|
| 1273 |
|
| 1274 |
+
#target_poses = target_poses[:2]
|
| 1275 |
+
#train_steps = 3
|
| 1276 |
+
|
|
|
|
| 1277 |
finetune = False
|
| 1278 |
train(modelId, in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
|
| 1279 |
results_base = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
|
| 1280 |
|
| 1281 |
+
finetune = True
|
| 1282 |
+
train(modelId, in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
|
| 1283 |
+
results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
|
| 1284 |
+
|
| 1285 |
gc.collect()
|
| 1286 |
torch.cuda.empty_cache()
|
| 1287 |
|