acmyu commited on
Commit
12a7e75
·
1 Parent(s): 4e76f1b

fix eval script

Browse files
Files changed (2) hide show
  1. evaluate.py +25 -20
  2. main.py +7 -4
evaluate.py CHANGED
@@ -56,10 +56,9 @@ def compute_fid(img1, img2):
56
  return fid
57
 
58
 
59
- with open('metrics.json', 'r') as file:
60
- metrics = json.load(file)
61
 
62
- def get_score(item, image_paths, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False):
 
63
 
64
  images = []
65
  for path in image_paths:
@@ -67,6 +66,9 @@ def get_score(item, image_paths, video_path, train_steps=100, inference_steps=10
67
  images.append([img])
68
 
69
  gt_frames = extract_frames(video_path, fps)
 
 
 
70
 
71
  os.makedirs('out/'+item, exist_ok=True)
72
 
@@ -76,8 +78,6 @@ def get_score(item, image_paths, video_path, train_steps=100, inference_steps=10
76
 
77
  #results = run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, finetune=True)
78
  results, results_base = run_eval(images, video_path, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=False, resize_inputs=False)
79
-
80
- print(results)
81
 
82
  for i, result in enumerate(results):
83
  result.save("out/"+item+"/result_"+str(i)+".png")
@@ -134,6 +134,8 @@ def get_score(item, image_paths, video_path, train_steps=100, inference_steps=10
134
  metrics[item]['base']['lpips'] = {'avg': sum(lpips2)/len(lpips2), 'vals': lpips2}
135
  metrics[item]['base']['fid'] = {'avg': sum(fid2)/len(fid2), 'vals': fid2}
136
 
 
 
137
  with open('metrics.json', "w", encoding="utf-8") as json_file:
138
  json.dump(metrics, json_file, ensure_ascii=False, indent=4)
139
 
@@ -154,30 +156,33 @@ def get_files(directory_path):
154
  def run_evaluate():
155
  snapshot_download(repo_id="acmyu/KeyframesAI-eval", local_dir="test", repo_type="dataset")
156
 
 
 
 
157
  items = os.listdir('test')
158
  items = [it for it in items if not it[0]=='.' and not os.path.isfile('test/'+it)]
159
  print(items)
160
- items = ['sidewalk'] #['sidewalk', 'aaa', 'azri', 'dead', 'frankgirl', 'kobold', 'ramona', 'renee', 'walk', 'woody']
161
 
162
  for item in items:
163
  if item in metrics:
164
  continue
165
  print(item)
166
 
167
- try:
168
- files = get_files('test/'+item)
169
- images = list(filter(lambda x: not x.endswith('.mp4'), files))
170
- images = ['test/'+item+'/'+img for img in images]
171
- videos = [x for x in files if x.endswith('.mp4')]
172
- print(images, videos)
173
-
174
- if len(videos) == 1:
175
- get_score(item, images, 'test/'+item+'/'+videos[0])
176
- #get_score(item, ['test/'+item+'/1.jpg', 'test/'+item+'/2.jpg', 'test/'+item+'/3.jpg'], 'test/'+item+'/v.mp4')
177
- else:
178
- print('Error: mp4 not found')
179
- except:
180
- print("Error", item)
181
 
182
 
183
  ssim = []
 
56
  return fid
57
 
58
 
 
 
59
 
60
+
61
+ def get_score(item, image_paths, video_path, metrics, train_steps=100, inference_steps=10, fps=12, bg_remove=False):
62
 
63
  images = []
64
  for path in image_paths:
 
66
  images.append([img])
67
 
68
  gt_frames = extract_frames(video_path, fps)
69
+ #gt_frames = gt_frames[:2]
70
+ for f in gt_frames:
71
+ f.thumbnail((512,512))
72
 
73
  os.makedirs('out/'+item, exist_ok=True)
74
 
 
78
 
79
  #results = run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, finetune=True)
80
  results, results_base = run_eval(images, video_path, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=False, resize_inputs=False)
 
 
81
 
82
  for i, result in enumerate(results):
83
  result.save("out/"+item+"/result_"+str(i)+".png")
 
134
  metrics[item]['base']['lpips'] = {'avg': sum(lpips2)/len(lpips2), 'vals': lpips2}
135
  metrics[item]['base']['fid'] = {'avg': sum(fid2)/len(fid2), 'vals': fid2}
136
 
137
+ #print(metrics)
138
+
139
  with open('metrics.json', "w", encoding="utf-8") as json_file:
140
  json.dump(metrics, json_file, ensure_ascii=False, indent=4)
141
 
 
156
  def run_evaluate():
157
  snapshot_download(repo_id="acmyu/KeyframesAI-eval", local_dir="test", repo_type="dataset")
158
 
159
+ with open('metrics.json', 'r') as file:
160
+ metrics = json.load(file)
161
+
162
  items = os.listdir('test')
163
  items = [it for it in items if not it[0]=='.' and not os.path.isfile('test/'+it)]
164
  print(items)
165
+ #items = ['sidewalk'] #['sidewalk', 'aaa', 'azri', 'dead', 'frankgirl', 'kobold', 'ramona', 'renee', 'walk', 'woody']
166
 
167
  for item in items:
168
  if item in metrics:
169
  continue
170
  print(item)
171
 
172
+ #try:
173
+ files = get_files('test/'+item)
174
+ images = list(filter(lambda x: not x.endswith('.mp4'), files))
175
+ images = ['test/'+item+'/'+img for img in images]
176
+ videos = [x for x in files if x.endswith('.mp4')]
177
+ print(images, videos)
178
+
179
+ if len(videos) == 1:
180
+ get_score(item, images, 'test/'+item+'/'+videos[0], metrics)
181
+ #get_score(item, ['test/'+item+'/1.jpg', 'test/'+item+'/2.jpg', 'test/'+item+'/3.jpg'], 'test/'+item+'/v.mp4')
182
+ else:
183
+ print('Error: mp4 not found')
184
+ #except:
185
+ # print("Error", item)
186
 
187
 
188
  ssim = []
main.py CHANGED
@@ -1271,14 +1271,17 @@ def run_eval(images_orig, video_path, train_steps=100, inference_steps=10, fps=1
1271
  in_img, in_pose, train_imgs, train_poses = prepare_inputs_train(images, bg_remove, dwpose, rembg_session)
1272
  in_img, target_poses, in_pose, _, _ = prepare_inputs_inference(in_img, video_path, [], fps, dwpose, rembg_session, bg_remove, resize_inputs, is_app)
1273
 
1274
- finetune = True
1275
- train(modelId, in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
1276
- results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
1277
-
1278
  finetune = False
1279
  train(modelId, in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
1280
  results_base = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
1281
 
 
 
 
 
1282
  gc.collect()
1283
  torch.cuda.empty_cache()
1284
 
 
1271
  in_img, in_pose, train_imgs, train_poses = prepare_inputs_train(images, bg_remove, dwpose, rembg_session)
1272
  in_img, target_poses, in_pose, _, _ = prepare_inputs_inference(in_img, video_path, [], fps, dwpose, rembg_session, bg_remove, resize_inputs, is_app)
1273
 
1274
+ #target_poses = target_poses[:2]
1275
+ #train_steps = 3
1276
+
 
1277
  finetune = False
1278
  train(modelId, in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
1279
  results_base = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
1280
 
1281
+ finetune = True
1282
+ train(modelId, in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
1283
+ results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
1284
+
1285
  gc.collect()
1286
  torch.cuda.empty_cache()
1287