Spaces:
Paused
Paused
evaluation script
Browse files- app.py +7 -0
- evaluate.py +69 -41
- libs/film/eval/interpolator.py +1 -3
- main.py +25 -2
app.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
from main import run_app, run_train, run_inference, run_generate_frame, run_interpolate_frames
|
|
|
|
| 2 |
|
| 3 |
import spaces
|
| 4 |
from PIL import Image
|
|
@@ -29,12 +30,14 @@ with gr.Blocks() as demo:
|
|
| 29 |
generate_frame_btn = gr.Button(value="Generate Frame")
|
| 30 |
submit_btn = gr.Button(value="Generate")
|
| 31 |
interp_btn = gr.Button(value="Interpolate Frames")
|
|
|
|
| 32 |
with gr.Column():
|
| 33 |
animation = gr.Video(label="Result")
|
| 34 |
frames = gr.Gallery(type="pil", label="Frames", format="png")
|
| 35 |
frames_thumb = gr.Gallery(type="pil", label="Thumbnails", format="png")
|
| 36 |
pose_coords = gr.JSON(label="Pose Coordinates")
|
| 37 |
reference = gr.Gallery(type="pil", label="Reference Images", format="png")
|
|
|
|
| 38 |
|
| 39 |
submit_btn.click(
|
| 40 |
run_app, inputs=[char_imgs, mocap, tr_steps, inf_steps, fps, remove_bg, resize_inputs], outputs=[animation, frames]
|
|
@@ -56,6 +59,10 @@ with gr.Blocks() as demo:
|
|
| 56 |
run_interpolate_frames, inputs=[interp_frame1, interp_frame2, times_to_interp], outputs=[frames, frames_thumb]
|
| 57 |
)
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
demo.launch(share=True)
|
| 61 |
|
|
|
|
| 1 |
from main import run_app, run_train, run_inference, run_generate_frame, run_interpolate_frames
|
| 2 |
+
from evaluate import run_evaluate
|
| 3 |
|
| 4 |
import spaces
|
| 5 |
from PIL import Image
|
|
|
|
| 30 |
generate_frame_btn = gr.Button(value="Generate Frame")
|
| 31 |
submit_btn = gr.Button(value="Generate")
|
| 32 |
interp_btn = gr.Button(value="Interpolate Frames")
|
| 33 |
+
eval_btn = gr.Button(value="Evaluate")
|
| 34 |
with gr.Column():
|
| 35 |
animation = gr.Video(label="Result")
|
| 36 |
frames = gr.Gallery(type="pil", label="Frames", format="png")
|
| 37 |
frames_thumb = gr.Gallery(type="pil", label="Thumbnails", format="png")
|
| 38 |
pose_coords = gr.JSON(label="Pose Coordinates")
|
| 39 |
reference = gr.Gallery(type="pil", label="Reference Images", format="png")
|
| 40 |
+
eval_scores = gr.JSON(label="Evaluation Scores")
|
| 41 |
|
| 42 |
submit_btn.click(
|
| 43 |
run_app, inputs=[char_imgs, mocap, tr_steps, inf_steps, fps, remove_bg, resize_inputs], outputs=[animation, frames]
|
|
|
|
| 59 |
run_interpolate_frames, inputs=[interp_frame1, interp_frame2, times_to_interp], outputs=[frames, frames_thumb]
|
| 60 |
)
|
| 61 |
|
| 62 |
+
eval_btn.click(
|
| 63 |
+
run_interpolate_frames, inputs=[], outputs=[eval_scores]
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
|
| 67 |
demo.launch(share=True)
|
| 68 |
|
evaluate.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from main import extract_frames, run
|
| 2 |
|
| 3 |
from PIL import Image
|
| 4 |
import numpy as np
|
|
@@ -10,6 +10,7 @@ import lpips
|
|
| 10 |
from pytorch_fid.fid_score import calculate_fid_given_paths
|
| 11 |
import os
|
| 12 |
import json
|
|
|
|
| 13 |
|
| 14 |
# Convert PIL to numpy
|
| 15 |
def pil_to_np(img):
|
|
@@ -74,7 +75,8 @@ def get_score(item, image_paths, video_path, train_steps=100, inference_steps=10
|
|
| 74 |
for i, frame in enumerate(gt_frames):
|
| 75 |
frame.save("out/"+item+"/frame_"+str(i)+".png")
|
| 76 |
|
| 77 |
-
results = run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, finetune=True)
|
|
|
|
| 78 |
|
| 79 |
for i, result in enumerate(results):
|
| 80 |
result.save("out/"+item+"/result_"+str(i)+".png")
|
|
@@ -138,49 +140,75 @@ def get_score(item, image_paths, video_path, train_steps=100, inference_steps=10
|
|
| 138 |
|
| 139 |
|
| 140 |
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
|
|
|
|
|
|
|
| 149 |
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
lpips = []
|
| 153 |
-
fid = []
|
| 154 |
-
ssim2 = []
|
| 155 |
-
psnr2 = []
|
| 156 |
-
lpips2 = []
|
| 157 |
-
fid2 = []
|
| 158 |
-
for item in metrics.keys():
|
| 159 |
-
ssim.append(metrics[item]['ft']['ssim']['avg'])
|
| 160 |
-
psnr.append(metrics[item]['ft']['psnr']['avg'])
|
| 161 |
-
lpips.append(metrics[item]['ft']['lpips']['avg'])
|
| 162 |
-
fid.append(metrics[item]['ft']['fid']['avg'])
|
| 163 |
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
fid2.append(metrics[item]['base']['fid']['avg'])
|
| 168 |
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
print(
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
|
|
|
|
| 1 |
+
from main import extract_frames, run_eval #run
|
| 2 |
|
| 3 |
from PIL import Image
|
| 4 |
import numpy as np
|
|
|
|
| 10 |
from pytorch_fid.fid_score import calculate_fid_given_paths
|
| 11 |
import os
|
| 12 |
import json
|
| 13 |
+
from huggingface_hub import snapshot_download
|
| 14 |
|
| 15 |
# Convert PIL to numpy
|
| 16 |
def pil_to_np(img):
|
|
|
|
| 75 |
for i, frame in enumerate(gt_frames):
|
| 76 |
frame.save("out/"+item+"/frame_"+str(i)+".png")
|
| 77 |
|
| 78 |
+
#results = run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, finetune=True)
|
| 79 |
+
results = run_eval(images, video_path, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=False, resize_inputs=False)
|
| 80 |
|
| 81 |
for i, result in enumerate(results):
|
| 82 |
result.save("out/"+item+"/result_"+str(i)+".png")
|
|
|
|
| 140 |
|
| 141 |
|
| 142 |
|
| 143 |
+
def get_files(directory_path):
|
| 144 |
+
"""
|
| 145 |
+
Returns a list of all files in the specified directory.
|
| 146 |
+
"""
|
| 147 |
+
files = []
|
| 148 |
+
for entry in os.listdir(directory_path):
|
| 149 |
+
full_path = os.path.join(directory_path, entry)
|
| 150 |
+
if os.path.isfile(full_path):
|
| 151 |
+
files.append(entry)
|
| 152 |
+
return files
|
| 153 |
|
| 154 |
|
| 155 |
+
def run_evaluate():
|
| 156 |
+
snapshot_download(repo_id="acmyu/KeyframesAI-eval", local_dir="test")
|
| 157 |
|
| 158 |
+
items = os.listdir('test')
|
| 159 |
+
items = ['test/woody'] #['sidewalk', 'aaa', 'azri', 'dead', 'frankgirl', 'kobold', 'ramona', 'renee', 'walk', 'woody']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
+
for item in items:
|
| 162 |
+
if item in metrics:
|
| 163 |
+
continue
|
|
|
|
| 164 |
|
| 165 |
+
name = os.path.basename(os.path.dirname(item))
|
| 166 |
+
print(name)
|
| 167 |
+
|
| 168 |
+
files = get_files(item)
|
| 169 |
+
videos = (x for x in files if x.endswith('.mp4'))
|
| 170 |
+
if len(videos) == 1:
|
| 171 |
+
get_score(name, list(filter(lambda x: not x.endswith('.mp4'), files)), videos[0])
|
| 172 |
+
#get_score(item, ['test/'+item+'/1.jpg', 'test/'+item+'/2.jpg', 'test/'+item+'/3.jpg'], 'test/'+item+'/v.mp4')
|
| 173 |
+
else:
|
| 174 |
+
print('Error: mp4 not found')
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
ssim = []
|
| 179 |
+
psnr = []
|
| 180 |
+
lpips = []
|
| 181 |
+
fid = []
|
| 182 |
+
ssim2 = []
|
| 183 |
+
psnr2 = []
|
| 184 |
+
lpips2 = []
|
| 185 |
+
fid2 = []
|
| 186 |
+
for item in metrics.keys():
|
| 187 |
+
ssim.append(metrics[item]['ft']['ssim']['avg'])
|
| 188 |
+
psnr.append(metrics[item]['ft']['psnr']['avg'])
|
| 189 |
+
lpips.append(metrics[item]['ft']['lpips']['avg'])
|
| 190 |
+
fid.append(metrics[item]['ft']['fid']['avg'])
|
| 191 |
+
|
| 192 |
+
ssim2.append(metrics[item]['base']['ssim']['avg'])
|
| 193 |
+
psnr2.append(metrics[item]['base']['psnr']['avg'])
|
| 194 |
+
lpips2.append(metrics[item]['base']['lpips']['avg'])
|
| 195 |
+
fid2.append(metrics[item]['base']['fid']['avg'])
|
| 196 |
+
|
| 197 |
+
print(item)
|
| 198 |
+
print("SSIM:", metrics[item]['ft']['ssim']['avg'], metrics[item]['base']['ssim']['avg'])
|
| 199 |
+
print("PSNR:", metrics[item]['ft']['psnr']['avg'], metrics[item]['base']['psnr']['avg'])
|
| 200 |
+
print("LPIPS:", metrics[item]['ft']['lpips']['avg'], metrics[item]['base']['lpips']['avg'])
|
| 201 |
+
print("FID:", metrics[item]['ft']['fid']['avg'], metrics[item]['base']['fid']['avg'])
|
| 202 |
+
|
| 203 |
+
print('Results:')
|
| 204 |
+
print("SSIM:", sum(ssim)/len(ssim))
|
| 205 |
+
print("PSNR:", sum(psnr)/len(psnr))
|
| 206 |
+
print("LPIPS:", sum(lpips)/len(lpips))
|
| 207 |
+
print("FID:", sum(fid)/len(fid))
|
| 208 |
+
print('baseline:')
|
| 209 |
+
print("SSIM:", sum(ssim2)/len(ssim2))
|
| 210 |
+
print("PSNR:", sum(psnr2)/len(psnr2))
|
| 211 |
+
print("LPIPS:", sum(lpips2)/len(lpips2))
|
| 212 |
+
print("FID:", sum(fid2)/len(fid2))
|
| 213 |
|
| 214 |
|
libs/film/eval/interpolator.py
CHANGED
|
@@ -149,9 +149,7 @@ class Interpolator:
|
|
| 149 |
self._align = align or None
|
| 150 |
self._block_shape = block_shape or None
|
| 151 |
|
| 152 |
-
|
| 153 |
-
tf.keras.backend.clear_session()
|
| 154 |
-
|
| 155 |
def interpolate(self, x0: np.ndarray, x1: np.ndarray,
|
| 156 |
dt: np.ndarray) -> np.ndarray:
|
| 157 |
"""Generates an interpolated frame between given two batches of frames.
|
|
|
|
| 149 |
self._align = align or None
|
| 150 |
self._block_shape = block_shape or None
|
| 151 |
|
| 152 |
+
|
|
|
|
|
|
|
| 153 |
def interpolate(self, x0: np.ndarray, x1: np.ndarray,
|
| 154 |
dt: np.ndarray) -> np.ndarray:
|
| 155 |
"""Generates an interpolated frame between given two batches of frames.
|
main.py
CHANGED
|
@@ -1163,7 +1163,7 @@ def run_inference_impl(images, video_path, frames, train_steps=100, inference_st
|
|
| 1163 |
frames = [img[0] for img in frames]
|
| 1164 |
|
| 1165 |
in_img, target_poses, in_pose, target_poses_coords, orig_frames = prepare_inputs_inference(in_img, video_path, frames, fps, dwpose, rembg_session, bg_remove, resize_inputs, is_app)
|
| 1166 |
-
target_poses[0].save('inf_pose.png')
|
| 1167 |
|
| 1168 |
results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
|
| 1169 |
#urls = save_temp_imgs(results)
|
|
@@ -1207,7 +1207,7 @@ def generate_frame(images, target_poses, train_steps=100, inference_steps=10, mo
|
|
| 1207 |
target_poses = [Image.fromarray(draw_openpose(pose, height=img_height, width=img_width, include_hands=True, include_face=False)) for pose in target_poses]
|
| 1208 |
|
| 1209 |
in_img, target_poses, in_pose, target_poses_coords, orig_frames = prepare_inputs_inference(in_img, None, [], 12, dwpose, rembg_session, bg_remove, resize_inputs, is_app, target_poses)
|
| 1210 |
-
target_poses[0].save('gen_pose.png')
|
| 1211 |
|
| 1212 |
results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
|
| 1213 |
#urls = save_temp_imgs(results)
|
|
@@ -1249,6 +1249,29 @@ def run_app(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_
|
|
| 1249 |
|
| 1250 |
return out_vid+'.webm', results
|
| 1251 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1252 |
@spaces.GPU(duration=30)
|
| 1253 |
def interpolate_frames(frame1, frame2, times_to_interp):
|
| 1254 |
film = Predictor()
|
|
|
|
| 1163 |
frames = [img[0] for img in frames]
|
| 1164 |
|
| 1165 |
in_img, target_poses, in_pose, target_poses_coords, orig_frames = prepare_inputs_inference(in_img, video_path, frames, fps, dwpose, rembg_session, bg_remove, resize_inputs, is_app)
|
| 1166 |
+
#target_poses[0].save('inf_pose.png')
|
| 1167 |
|
| 1168 |
results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
|
| 1169 |
#urls = save_temp_imgs(results)
|
|
|
|
| 1207 |
target_poses = [Image.fromarray(draw_openpose(pose, height=img_height, width=img_width, include_hands=True, include_face=False)) for pose in target_poses]
|
| 1208 |
|
| 1209 |
in_img, target_poses, in_pose, target_poses_coords, orig_frames = prepare_inputs_inference(in_img, None, [], 12, dwpose, rembg_session, bg_remove, resize_inputs, is_app, target_poses)
|
| 1210 |
+
#target_poses[0].save('gen_pose.png')
|
| 1211 |
|
| 1212 |
results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
|
| 1213 |
#urls = save_temp_imgs(results)
|
|
|
|
| 1249 |
|
| 1250 |
return out_vid+'.webm', results
|
| 1251 |
|
| 1252 |
+
|
| 1253 |
+
def run_eval(images, video_path, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=False, resize_inputs=False):
|
| 1254 |
+
finetune=True
|
| 1255 |
+
is_app=False
|
| 1256 |
+
|
| 1257 |
+
dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
|
| 1258 |
+
|
| 1259 |
+
run_train(images, train_steps, modelId, bg_remove, resize_inputs)
|
| 1260 |
+
|
| 1261 |
+
images = [img[0] for img in images]
|
| 1262 |
+
in_img = images[0]
|
| 1263 |
+
|
| 1264 |
+
in_img, target_poses, in_pose, target_poses_coords, orig_frames = prepare_inputs_inference(in_img, video_path, [], fps, dwpose, rembg_session, bg_remove, resize_inputs, is_app)
|
| 1265 |
+
|
| 1266 |
+
_, results, _, _, _ = run_inference_impl(images, video_path, frames, train_steps, inference_steps, fps, modelId, img_width, img_height, bg_remove, resize_inputs)
|
| 1267 |
+
|
| 1268 |
+
gc.collect()
|
| 1269 |
+
torch.cuda.empty_cache()
|
| 1270 |
+
|
| 1271 |
+
return results
|
| 1272 |
+
|
| 1273 |
+
|
| 1274 |
+
|
| 1275 |
@spaces.GPU(duration=30)
|
| 1276 |
def interpolate_frames(frame1, frame2, times_to_interp):
|
| 1277 |
film = Predictor()
|