acmyu commited on
Commit
ce2fb73
·
1 Parent(s): 6696eda

api to eval single

Browse files
Files changed (2) hide show
  1. app.py +8 -2
  2. evaluate.py +8 -4
app.py CHANGED
@@ -1,5 +1,5 @@
1
  from main import run_app, run_train, run_inference, run_generate_frame, run_interpolate_frames
2
- from evaluate import run_evaluate
3
 
4
  import spaces
5
  from PIL import Image
@@ -25,12 +25,14 @@ with gr.Blocks() as demo:
25
  interp_frame1 = gr.Image(type="pil", label="Interpolation Start Frame")
26
  interp_frame2 = gr.Image(type="pil", label="Interpolation End Frame")
27
  times_to_interp = gr.Number(label="Times to Interpolate", value=1)
 
28
  train_btn = gr.Button(value="Train")
29
  inference_btn = gr.Button(value="Inference")
30
  generate_frame_btn = gr.Button(value="Generate Frame")
31
  submit_btn = gr.Button(value="Generate")
32
  interp_btn = gr.Button(value="Interpolate Frames")
33
- eval_btn = gr.Button(value="Evaluate")
 
34
  with gr.Column():
35
  animation = gr.Video(label="Result")
36
  frames = gr.Gallery(type="pil", label="Frames", format="png")
@@ -62,6 +64,10 @@ with gr.Blocks() as demo:
62
  eval_btn.click(
63
  run_evaluate, inputs=[], outputs=[eval_scores]
64
  )
 
 
 
 
65
 
66
 
67
  demo.launch(share=True)
 
1
  from main import run_app, run_train, run_inference, run_generate_frame, run_interpolate_frames
2
+ from evaluate import run_evaluate, get_score
3
 
4
  import spaces
5
  from PIL import Image
 
25
  interp_frame1 = gr.Image(type="pil", label="Interpolation Start Frame")
26
  interp_frame2 = gr.Image(type="pil", label="Interpolation End Frame")
27
  times_to_interp = gr.Number(label="Times to Interpolate", value=1)
28
+ name = gr.Text(label="Name", value="")
29
  train_btn = gr.Button(value="Train")
30
  inference_btn = gr.Button(value="Inference")
31
  generate_frame_btn = gr.Button(value="Generate Frame")
32
  submit_btn = gr.Button(value="Generate")
33
  interp_btn = gr.Button(value="Interpolate Frames")
34
+ eval_btn = gr.Button(value="Evaluate All")
35
+ eval_btn2 = gr.Button(value="Evaluate")
36
  with gr.Column():
37
  animation = gr.Video(label="Result")
38
  frames = gr.Gallery(type="pil", label="Frames", format="png")
 
64
  eval_btn.click(
65
  run_evaluate, inputs=[], outputs=[eval_scores]
66
  )
67
+
68
+ eval_btn2.click(
69
+ get_score, inputs=[name, char_imgs, mocap, tr_steps, inf_steps, fps, remove_bg], outputs=[eval_scores]
70
+ )
71
 
72
 
73
  demo.launch(share=True)
evaluate.py CHANGED
@@ -111,7 +111,7 @@ def compute_fid(img1, img2):
111
  return fid
112
 
113
 
114
- def get_score(item, image_paths, video_path, metrics, train_steps=100, inference_steps=10, fps=12, bg_remove=False):
115
  images = []
116
  for path in image_paths:
117
  img = Image.open(path)
@@ -196,6 +196,7 @@ def get_score(item, image_paths, video_path, metrics, train_steps=100, inference
196
  print("FID:", sum(fid2)/len(fid2))
197
  #print("FVD:", fvd2)
198
 
 
199
  metrics[item] = {'ft': {}, 'base': {}, 'n_frames': len(gt_frames), 'complexity': len(images)}
200
  metrics[item]['ft']['ssim'] = {'avg': sum(ssim)/len(ssim), 'vals': ssim}
201
  metrics[item]['ft']['psnr'] = {'avg': sum(psnr)/len(psnr), 'vals': psnr}
@@ -209,9 +210,9 @@ def get_score(item, image_paths, video_path, metrics, train_steps=100, inference
209
  #metrics[item]['base']['fvd'] = fvd2
210
 
211
  #print(metrics)
 
212
 
213
- with open('/data/metrics.json', "w", encoding="utf-8") as json_file:
214
- json.dump(metrics, json_file, ensure_ascii=False, indent=4)
215
 
216
 
217
 
@@ -252,8 +253,11 @@ def run_evaluate():
252
  print(images, videos)
253
 
254
  if len(videos) == 1:
255
- get_score(item, images, 'test/'+item+'/'+videos[0], metrics)
256
  #get_score(item, ['test/'+item+'/1.jpg', 'test/'+item+'/2.jpg', 'test/'+item+'/3.jpg'], 'test/'+item+'/v.mp4')
 
 
 
257
  else:
258
  print('Error: mp4 not found')
259
  except Exception as e:
 
111
  return fid
112
 
113
 
114
+ def get_score(item, image_paths, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False):
115
  images = []
116
  for path in image_paths:
117
  img = Image.open(path)
 
196
  print("FID:", sum(fid2)/len(fid2))
197
  #print("FVD:", fvd2)
198
 
199
+ metrics = {}
200
  metrics[item] = {'ft': {}, 'base': {}, 'n_frames': len(gt_frames), 'complexity': len(images)}
201
  metrics[item]['ft']['ssim'] = {'avg': sum(ssim)/len(ssim), 'vals': ssim}
202
  metrics[item]['ft']['psnr'] = {'avg': sum(psnr)/len(psnr), 'vals': psnr}
 
210
  #metrics[item]['base']['fvd'] = fvd2
211
 
212
  #print(metrics)
213
+ return metrics[item]
214
 
215
+
 
216
 
217
 
218
 
 
253
  print(images, videos)
254
 
255
  if len(videos) == 1:
256
+ metrics[item] = get_score(item, images, 'test/'+item+'/'+videos[0])
257
  #get_score(item, ['test/'+item+'/1.jpg', 'test/'+item+'/2.jpg', 'test/'+item+'/3.jpg'], 'test/'+item+'/v.mp4')
258
+
259
+ with open('/data/metrics.json', "w", encoding="utf-8") as json_file:
260
+ json.dump(metrics, json_file, ensure_ascii=False, indent=4)
261
  else:
262
  print('Error: mp4 not found')
263
  except Exception as e: