acmyu commited on
Commit
4e76f1b
·
1 Parent(s): af05866

fix eval script

Browse files
Files changed (2) hide show
  1. evaluate.py +23 -18
  2. main.py +25 -11
evaluate.py CHANGED
@@ -60,12 +60,11 @@ with open('metrics.json', 'r') as file:
60
  metrics = json.load(file)
61
 
62
  def get_score(item, image_paths, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False):
63
- print(item)
64
-
65
  images = []
66
  for path in image_paths:
67
  img = Image.open(path)
68
- images.append(img)
69
 
70
  gt_frames = extract_frames(video_path, fps)
71
 
@@ -76,13 +75,13 @@ def get_score(item, image_paths, video_path, train_steps=100, inference_steps=10
76
  frame.save("out/"+item+"/frame_"+str(i)+".png")
77
 
78
  #results = run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, finetune=True)
79
- results = run_eval(images, video_path, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=False, resize_inputs=False)
80
 
 
 
81
  for i, result in enumerate(results):
82
  result.save("out/"+item+"/result_"+str(i)+".png")
83
 
84
- results_base = run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, finetune=False)
85
-
86
  for i, result in enumerate(results_base):
87
  result.save("out/"+item+"/base_"+str(i)+".png")
88
 
@@ -156,23 +155,29 @@ def run_evaluate():
156
  snapshot_download(repo_id="acmyu/KeyframesAI-eval", local_dir="test", repo_type="dataset")
157
 
158
  items = os.listdir('test')
159
- items = ['test/woody'] #['sidewalk', 'aaa', 'azri', 'dead', 'frankgirl', 'kobold', 'ramona', 'renee', 'walk', 'woody']
 
 
160
 
161
  for item in items:
162
  if item in metrics:
163
  continue
164
-
165
- name = os.path.basename(os.path.dirname(item))
166
- print(name)
167
 
168
- files = get_files(item)
169
- videos = (x for x in files if x.endswith('.mp4'))
170
- if len(videos) == 1:
171
- get_score(name, list(filter(lambda x: not x.endswith('.mp4'), files)), videos[0])
172
- #get_score(item, ['test/'+item+'/1.jpg', 'test/'+item+'/2.jpg', 'test/'+item+'/3.jpg'], 'test/'+item+'/v.mp4')
173
- else:
174
- print('Error: mp4 not found')
175
-
 
 
 
 
 
 
176
 
177
 
178
  ssim = []
 
60
  metrics = json.load(file)
61
 
62
  def get_score(item, image_paths, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False):
63
+
 
64
  images = []
65
  for path in image_paths:
66
  img = Image.open(path)
67
+ images.append([img])
68
 
69
  gt_frames = extract_frames(video_path, fps)
70
 
 
75
  frame.save("out/"+item+"/frame_"+str(i)+".png")
76
 
77
  #results = run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, finetune=True)
78
+ results, results_base = run_eval(images, video_path, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=False, resize_inputs=False)
79
 
80
+ print(results)
81
+
82
  for i, result in enumerate(results):
83
  result.save("out/"+item+"/result_"+str(i)+".png")
84
 
 
 
85
  for i, result in enumerate(results_base):
86
  result.save("out/"+item+"/base_"+str(i)+".png")
87
 
 
155
  snapshot_download(repo_id="acmyu/KeyframesAI-eval", local_dir="test", repo_type="dataset")
156
 
157
  items = os.listdir('test')
158
+ items = [it for it in items if not it[0]=='.' and not os.path.isfile('test/'+it)]
159
+ print(items)
160
+ items = ['sidewalk'] #['sidewalk', 'aaa', 'azri', 'dead', 'frankgirl', 'kobold', 'ramona', 'renee', 'walk', 'woody']
161
 
162
  for item in items:
163
  if item in metrics:
164
  continue
165
+ print(item)
 
 
166
 
167
+ try:
168
+ files = get_files('test/'+item)
169
+ images = list(filter(lambda x: not x.endswith('.mp4'), files))
170
+ images = ['test/'+item+'/'+img for img in images]
171
+ videos = [x for x in files if x.endswith('.mp4')]
172
+ print(images, videos)
173
+
174
+ if len(videos) == 1:
175
+ get_score(item, images, 'test/'+item+'/'+videos[0])
176
+ #get_score(item, ['test/'+item+'/1.jpg', 'test/'+item+'/2.jpg', 'test/'+item+'/3.jpg'], 'test/'+item+'/v.mp4')
177
+ else:
178
+ print('Error: mp4 not found')
179
+ except:
180
+ print("Error", item)
181
 
182
 
183
  ssim = []
main.py CHANGED
@@ -708,7 +708,18 @@ def train(modelId, in_image, in_pose, train_images, train_poses, train_steps, pc
708
  if not finetune or train_steps == 0:
709
  accelerator.wait_for_everyone()
710
  accelerator.end_training()
711
- return {k: v.cpu() for k, v in sd_model.state_dict().items()}
 
 
 
 
 
 
 
 
 
 
 
712
 
713
 
714
  it = range(starting_epoch, args.num_train_epochs)
@@ -1117,7 +1128,7 @@ def run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remo
1117
  return results
1118
 
1119
 
1120
- def run_train_impl(images, train_steps=100, modelId="fine_tuned_pcdms", bg_remove=True, resize_inputs=True):
1121
  finetune=True
1122
  is_app=True
1123
  images = [img[0] for img in images]
@@ -1250,25 +1261,28 @@ def run_app(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_
1250
  return out_vid+'.webm', results
1251
 
1252
 
1253
- def run_eval(images, video_path, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=False, resize_inputs=False):
1254
- finetune=True
1255
  is_app=False
1256
 
1257
  dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
1258
 
1259
- run_train_impl(images, train_steps, modelId, bg_remove, resize_inputs)
1260
-
1261
- images = [img[0] for img in images]
1262
- in_img = images[0]
1263
 
1264
- in_img, target_poses, in_pose, target_poses_coords, orig_frames = prepare_inputs_inference(in_img, video_path, [], fps, dwpose, rembg_session, bg_remove, resize_inputs, is_app)
 
 
 
 
 
1265
 
1266
- _, results, _, _, _ = run_inference_impl(images, video_path, frames, train_steps, inference_steps, fps, modelId, img_width, img_height, bg_remove, resize_inputs)
 
 
1267
 
1268
  gc.collect()
1269
  torch.cuda.empty_cache()
1270
 
1271
- return results
1272
 
1273
 
1274
 
 
708
  if not finetune or train_steps == 0:
709
  accelerator.wait_for_everyone()
710
  accelerator.end_training()
711
+
712
+ checkpoint_state_dict = {
713
+ "epoch": 0,
714
+ "module": {k: v.cpu() for k, v in sd_model.state_dict().items()}, #sd_model.state_dict(),
715
+ }
716
+ torch.save(checkpoint_state_dict, modelId+".pt")
717
+
718
+ del sd_model
719
+ gc.collect()
720
+ torch.cuda.empty_cache()
721
+ return
722
+ #return {k: v.cpu() for k, v in sd_model.state_dict().items()}
723
 
724
 
725
  it = range(starting_epoch, args.num_train_epochs)
 
1128
  return results
1129
 
1130
 
1131
+ def run_train_impl(images, train_steps=100, modelId="fine_tuned_pcdms", bg_remove=True, resize_inputs=True, finetune=True):
1132
  finetune=True
1133
  is_app=True
1134
  images = [img[0] for img in images]
 
1261
  return out_vid+'.webm', results
1262
 
1263
 
1264
+ def run_eval(images_orig, video_path, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=False, resize_inputs=False):
 
1265
  is_app=False
1266
 
1267
  dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
1268
 
1269
+ images = [img[0] for img in images_orig]
 
 
 
1270
 
1271
+ in_img, in_pose, train_imgs, train_poses = prepare_inputs_train(images, bg_remove, dwpose, rembg_session)
1272
+ in_img, target_poses, in_pose, _, _ = prepare_inputs_inference(in_img, video_path, [], fps, dwpose, rembg_session, bg_remove, resize_inputs, is_app)
1273
+
1274
+ finetune = True
1275
+ train(modelId, in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
1276
+ results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
1277
 
1278
+ finetune = False
1279
+ train(modelId, in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
1280
+ results_base = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
1281
 
1282
  gc.collect()
1283
  torch.cuda.empty_cache()
1284
 
1285
+ return results, results_base
1286
 
1287
 
1288