Spaces:
Paused
Paused
api to eval single
Browse files- app.py +8 -2
- evaluate.py +8 -4
app.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
from main import run_app, run_train, run_inference, run_generate_frame, run_interpolate_frames
|
| 2 |
-
from evaluate import run_evaluate
|
| 3 |
|
| 4 |
import spaces
|
| 5 |
from PIL import Image
|
|
@@ -25,12 +25,14 @@ with gr.Blocks() as demo:
|
|
| 25 |
interp_frame1 = gr.Image(type="pil", label="Interpolation Start Frame")
|
| 26 |
interp_frame2 = gr.Image(type="pil", label="Interpolation End Frame")
|
| 27 |
times_to_interp = gr.Number(label="Times to Interpolate", value=1)
|
|
|
|
| 28 |
train_btn = gr.Button(value="Train")
|
| 29 |
inference_btn = gr.Button(value="Inference")
|
| 30 |
generate_frame_btn = gr.Button(value="Generate Frame")
|
| 31 |
submit_btn = gr.Button(value="Generate")
|
| 32 |
interp_btn = gr.Button(value="Interpolate Frames")
|
| 33 |
-
eval_btn = gr.Button(value="Evaluate")
|
|
|
|
| 34 |
with gr.Column():
|
| 35 |
animation = gr.Video(label="Result")
|
| 36 |
frames = gr.Gallery(type="pil", label="Frames", format="png")
|
|
@@ -62,6 +64,10 @@ with gr.Blocks() as demo:
|
|
| 62 |
eval_btn.click(
|
| 63 |
run_evaluate, inputs=[], outputs=[eval_scores]
|
| 64 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
|
| 67 |
demo.launch(share=True)
|
|
|
|
| 1 |
from main import run_app, run_train, run_inference, run_generate_frame, run_interpolate_frames
|
| 2 |
+
from evaluate import run_evaluate, get_score
|
| 3 |
|
| 4 |
import spaces
|
| 5 |
from PIL import Image
|
|
|
|
| 25 |
interp_frame1 = gr.Image(type="pil", label="Interpolation Start Frame")
|
| 26 |
interp_frame2 = gr.Image(type="pil", label="Interpolation End Frame")
|
| 27 |
times_to_interp = gr.Number(label="Times to Interpolate", value=1)
|
| 28 |
+
name = gr.Text(label="Name", value="")
|
| 29 |
train_btn = gr.Button(value="Train")
|
| 30 |
inference_btn = gr.Button(value="Inference")
|
| 31 |
generate_frame_btn = gr.Button(value="Generate Frame")
|
| 32 |
submit_btn = gr.Button(value="Generate")
|
| 33 |
interp_btn = gr.Button(value="Interpolate Frames")
|
| 34 |
+
eval_btn = gr.Button(value="Evaluate All")
|
| 35 |
+
eval_btn2 = gr.Button(value="Evaluate")
|
| 36 |
with gr.Column():
|
| 37 |
animation = gr.Video(label="Result")
|
| 38 |
frames = gr.Gallery(type="pil", label="Frames", format="png")
|
|
|
|
| 64 |
eval_btn.click(
|
| 65 |
run_evaluate, inputs=[], outputs=[eval_scores]
|
| 66 |
)
|
| 67 |
+
|
| 68 |
+
eval_btn2.click(
|
| 69 |
+
get_score, inputs=[name, char_imgs, mocap, tr_steps, inf_steps, fps, remove_bg], outputs=[eval_scores]
|
| 70 |
+
)
|
| 71 |
|
| 72 |
|
| 73 |
demo.launch(share=True)
|
evaluate.py
CHANGED
|
@@ -111,7 +111,7 @@ def compute_fid(img1, img2):
|
|
| 111 |
return fid
|
| 112 |
|
| 113 |
|
| 114 |
-
def get_score(item, image_paths, video_path,
|
| 115 |
images = []
|
| 116 |
for path in image_paths:
|
| 117 |
img = Image.open(path)
|
|
@@ -196,6 +196,7 @@ def get_score(item, image_paths, video_path, metrics, train_steps=100, inference
|
|
| 196 |
print("FID:", sum(fid2)/len(fid2))
|
| 197 |
#print("FVD:", fvd2)
|
| 198 |
|
|
|
|
| 199 |
metrics[item] = {'ft': {}, 'base': {}, 'n_frames': len(gt_frames), 'complexity': len(images)}
|
| 200 |
metrics[item]['ft']['ssim'] = {'avg': sum(ssim)/len(ssim), 'vals': ssim}
|
| 201 |
metrics[item]['ft']['psnr'] = {'avg': sum(psnr)/len(psnr), 'vals': psnr}
|
|
@@ -209,9 +210,9 @@ def get_score(item, image_paths, video_path, metrics, train_steps=100, inference
|
|
| 209 |
#metrics[item]['base']['fvd'] = fvd2
|
| 210 |
|
| 211 |
#print(metrics)
|
|
|
|
| 212 |
|
| 213 |
-
|
| 214 |
-
json.dump(metrics, json_file, ensure_ascii=False, indent=4)
|
| 215 |
|
| 216 |
|
| 217 |
|
|
@@ -252,8 +253,11 @@ def run_evaluate():
|
|
| 252 |
print(images, videos)
|
| 253 |
|
| 254 |
if len(videos) == 1:
|
| 255 |
-
get_score(item, images, 'test/'+item+'/'+videos[0]
|
| 256 |
#get_score(item, ['test/'+item+'/1.jpg', 'test/'+item+'/2.jpg', 'test/'+item+'/3.jpg'], 'test/'+item+'/v.mp4')
|
|
|
|
|
|
|
|
|
|
| 257 |
else:
|
| 258 |
print('Error: mp4 not found')
|
| 259 |
except Exception as e:
|
|
|
|
| 111 |
return fid
|
| 112 |
|
| 113 |
|
| 114 |
+
def get_score(item, image_paths, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False):
|
| 115 |
images = []
|
| 116 |
for path in image_paths:
|
| 117 |
img = Image.open(path)
|
|
|
|
| 196 |
print("FID:", sum(fid2)/len(fid2))
|
| 197 |
#print("FVD:", fvd2)
|
| 198 |
|
| 199 |
+
metrics = {}
|
| 200 |
metrics[item] = {'ft': {}, 'base': {}, 'n_frames': len(gt_frames), 'complexity': len(images)}
|
| 201 |
metrics[item]['ft']['ssim'] = {'avg': sum(ssim)/len(ssim), 'vals': ssim}
|
| 202 |
metrics[item]['ft']['psnr'] = {'avg': sum(psnr)/len(psnr), 'vals': psnr}
|
|
|
|
| 210 |
#metrics[item]['base']['fvd'] = fvd2
|
| 211 |
|
| 212 |
#print(metrics)
|
| 213 |
+
return metrics[item]
|
| 214 |
|
| 215 |
+
|
|
|
|
| 216 |
|
| 217 |
|
| 218 |
|
|
|
|
| 253 |
print(images, videos)
|
| 254 |
|
| 255 |
if len(videos) == 1:
|
| 256 |
+
metrics[item] = get_score(item, images, 'test/'+item+'/'+videos[0])
|
| 257 |
#get_score(item, ['test/'+item+'/1.jpg', 'test/'+item+'/2.jpg', 'test/'+item+'/3.jpg'], 'test/'+item+'/v.mp4')
|
| 258 |
+
|
| 259 |
+
with open('/data/metrics.json', "w", encoding="utf-8") as json_file:
|
| 260 |
+
json.dump(metrics, json_file, ensure_ascii=False, indent=4)
|
| 261 |
else:
|
| 262 |
print('Error: mp4 not found')
|
| 263 |
except Exception as e:
|