Spaces:
Paused
Paused
| from main import extract_frames, run_eval #run | |
| from PIL import Image | |
| import numpy as np | |
| from skimage.metrics import structural_similarity as ssim | |
| from skimage.metrics import peak_signal_noise_ratio as psnr | |
| import torch | |
| import torchvision.transforms as transforms | |
| import lpips | |
| from pytorch_fid.fid_score import calculate_fid_given_paths | |
| from cdfvd import fvd | |
| import os | |
| import json | |
| import cv2 | |
| from huggingface_hub import snapshot_download | |
| outdir = 'outputs/' #'/data/out/' | |
| # Convert PIL to numpy | |
| def pil_to_np(img): | |
| return np.array(img).astype(np.float32) / 255.0 | |
| def save_mp4(images, name): | |
| width, height = images[0].size | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Codec for MP4 | |
| video = cv2.VideoWriter(name, fourcc, 12, (width, height)) | |
| for image in images: | |
| img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| video.write(img) | |
| video.release() | |
| # SSIM | |
| def compute_ssim(img1, img2): | |
| img1_np = pil_to_np(img1) | |
| img2_np = pil_to_np(img2) | |
| h, w = img1_np.shape[:2] | |
| min_dim = min(h, w) | |
| win_size = min(7, min_dim if min_dim % 2 == 1 else min_dim - 1) # ensure odd | |
| return ssim(img1_np, img2_np, win_size=win_size, channel_axis=-1, data_range=1.0) | |
| # PSNR | |
| def compute_psnr(img1, img2): | |
| img1_np = pil_to_np(img1) | |
| img2_np = pil_to_np(img2) | |
| return psnr(img1_np, img2_np, data_range=1.0) | |
| # LPIPS | |
| lpips_model = lpips.LPIPS(net='alex') | |
| lpips_transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5]*3, [0.5]*3) | |
| ]) | |
| def compute_lpips(img1, img2): | |
| img1_tensor = lpips_transform(img1).unsqueeze(0) | |
| img2_tensor = lpips_transform(img2).unsqueeze(0) | |
| return lpips_model(img1_tensor, img2_tensor).item() | |
| def trans(x): | |
| # if greyscale images add channel | |
| if x.shape[-3] == 1: | |
| x = x.repeat(1, 1, 3, 1, 1) | |
| # permute BTCHW -> BCTHW | |
| x = x.permute(0, 2, 1, 3, 4) | |
| return x | |
| def compute_fvd(item, gt_imgs, results): | |
| os.makedirs('temp/gt', exist_ok=True) | |
| os.makedirs('temp/result', exist_ok=True) | |
| save_mp4(gt_imgs, "temp/gt/gt.mp4") | |
| save_mp4(results, "temp/result/result.mp4") | |
| evaluator = fvd.cdfvd('i3d', ckpt_path=None, device='cuda', n_real=1, n_fake=1) | |
| evaluator.compute_real_stats(evaluator.load_videos('temp/gt', data_type='video_folder')) | |
| evaluator.compute_fake_stats(evaluator.load_videos('temp/result', data_type='video_folder')) | |
| score = evaluator.compute_fvd_from_stats() | |
| evaluator.offload_model_to_cpu() | |
| print(score) | |
| return score | |
| def compute_fidx(item, gt_imgs, results): | |
| os.makedirs('temp/'+item+'_gt', exist_ok=True) | |
| os.makedirs('temp/'+item, exist_ok=True) | |
| c = 0 | |
| for img in gt_imgs: | |
| img.save('temp/'+item+'_gt/'+str(c)+'.png') | |
| c = c+1 | |
| c = 0 | |
| for img in gt_imgs: | |
| img.save('temp/'+item+'/'+str(c)+'.png') | |
| c = c+1 | |
| fid = calculate_fid_given_paths(['temp/'+item+'_gt', 'temp/'+item], batch_size=8, device='cuda', dims=2048) | |
| return fid | |
| # FID: Save images to temp folders for FID calculation | |
| def compute_fid(img1, img2): | |
| os.makedirs('temp/img1', exist_ok=True) | |
| os.makedirs('temp/img2', exist_ok=True) | |
| img1.save('temp/img1/0.png') | |
| img2.save('temp/img2/0.png') | |
| fid = calculate_fid_given_paths(['temp/img1', 'temp/img2'], batch_size=1, device='cuda', dims=2048) | |
| return fid | |
| def get_score(item, image_paths, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False): | |
| images = [] | |
| results = [] | |
| results_base = [] | |
| gt_frames = [] | |
| max_frame_count = 200 | |
| if os.path.isdir(outdir+item): | |
| for filename in os.listdir(outdir+item): | |
| img = Image.open(outdir+item+'/'+filename) | |
| if filename.startswith('result_'): | |
| results.append(img) | |
| elif filename.startswith('base_'): | |
| results_base.append(img) | |
| elif filename.startswith('frame_'): | |
| gt_frames.append(img) | |
| #results = results[:max_frame_count] | |
| #results_base = results_base[:max_frame_count] | |
| #gt_frames = gt_frames[:max_frame_count] | |
| else: | |
| if not isinstance(image_paths[0], str): | |
| images = image_paths | |
| else: | |
| for path in image_paths: | |
| print(path) | |
| img = Image.open(path) | |
| images.append([img]) | |
| gt_frames = extract_frames(video_path, fps) | |
| gt_frames = gt_frames[:max_frame_count] | |
| for f in gt_frames: | |
| f.thumbnail((512,512)) | |
| #results = run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, finetune=True) | |
| results, results_base = run_eval(images, video_path, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=False, resize_inputs=False) | |
| os.makedirs(outdir+item, exist_ok=True) | |
| for i, frame in enumerate(gt_frames): | |
| frame.save(outdir+item+"/frame_"+str(i)+".png") | |
| for i, result in enumerate(results): | |
| result.save(outdir+item+"/result_"+str(i)+".png") | |
| for i, result in enumerate(results_base): | |
| result.save(outdir+item+"/base_"+str(i)+".png") | |
| ssim = [] | |
| psnr = [] | |
| lpips = [] | |
| fid = [] | |
| ssim2 = [] | |
| psnr2 = [] | |
| lpips2 = [] | |
| fid2 = [] | |
| c = 0 | |
| #print(len(gt_frames), len(results), len(results_base)) | |
| for gt, result, base in zip(gt_frames, results, results_base): | |
| ssim.append(float(compute_ssim(gt, result))) | |
| psnr.append(float(compute_psnr(gt, result))) | |
| lpips.append(float(compute_lpips(gt, result))) | |
| ssim2.append(float(compute_ssim(gt, base))) | |
| psnr2.append(float(compute_psnr(gt, base))) | |
| lpips2.append(float(compute_lpips(gt, base))) | |
| if c<50: | |
| print(c) | |
| fid.append(float(compute_fid(gt, result))) | |
| fid2.append(float(compute_fid(gt, base))) | |
| c = c+1 | |
| #fvd = float(compute_fvd(item, gt_frames, results)) | |
| #fvd2 = float(compute_fvd(item, gt_frames, results_base)) | |
| print("SSIM:", sum(ssim)/len(ssim)) | |
| print("PSNR:", sum(psnr)/len(psnr)) | |
| print("LPIPS:", sum(lpips)/len(lpips)) | |
| print("FID:", sum(fid)/len(fid)) | |
| #print("FVD:", fvd) | |
| print('baseline:') | |
| print("SSIM:", sum(ssim2)/len(ssim2)) | |
| print("PSNR:", sum(psnr2)/len(psnr2)) | |
| print("LPIPS:", sum(lpips2)/len(lpips2)) | |
| print("FID:", sum(fid2)/len(fid2)) | |
| #print("FVD:", fvd2) | |
| metrics = {} | |
| metrics[item] = {'ft': {}, 'base': {}, 'n_frames': len(gt_frames), 'complexity': len(images)} | |
| metrics[item]['ft']['ssim'] = {'avg': sum(ssim)/len(ssim), 'vals': ssim} | |
| metrics[item]['ft']['psnr'] = {'avg': sum(psnr)/len(psnr), 'vals': psnr} | |
| metrics[item]['ft']['lpips'] = {'avg': sum(lpips)/len(lpips), 'vals': lpips} | |
| metrics[item]['ft']['fid'] = {'avg': sum(fid)/len(fid), 'vals': fid} | |
| #metrics[item]['ft']['fvd'] = fvd | |
| metrics[item]['base']['ssim'] = {'avg': sum(ssim2)/len(ssim2), 'vals': ssim2} | |
| metrics[item]['base']['psnr'] = {'avg': sum(psnr2)/len(psnr2), 'vals': psnr2} | |
| metrics[item]['base']['lpips'] = {'avg': sum(lpips2)/len(lpips2), 'vals': lpips2} | |
| metrics[item]['base']['fid'] = {'avg': sum(fid2)/len(fid2), 'vals': fid2} | |
| #metrics[item]['base']['fvd'] = fvd2 | |
| #print(metrics) | |
| return metrics[item] | |
| def get_files(directory_path): | |
| """ | |
| Returns a list of all files in the specified directory. | |
| """ | |
| files = [] | |
| for entry in os.listdir(directory_path): | |
| full_path = os.path.join(directory_path, entry) | |
| if os.path.isfile(full_path): | |
| files.append(entry) | |
| return files | |
| def run_evaluate(): | |
| print("run_evaluate") | |
| snapshot_download(repo_id="acmyu/KeyframesAI-eval", local_dir="test", repo_type="dataset") | |
| with open('/data/metrics.json', 'r') as file: | |
| metrics = json.load(file) | |
| items = os.listdir('test') | |
| items = [it for it in items if not it[0]=='.' and not os.path.isfile('test/'+it)] | |
| print(items) | |
| #items = ['sidewalk'] #['sidewalk', 'aaa', 'azri', 'dead', 'frankgirl', 'kobold', 'ramona', 'renee', 'walk', 'woody'] | |
| for item in items: | |
| if item in metrics: | |
| continue | |
| print(item) | |
| try: | |
| files = get_files('test/'+item) | |
| images = list(filter(lambda x: not x.endswith('.mp4'), files)) | |
| images = ['test/'+item+'/'+img for img in images] | |
| videos = [x for x in files if x.endswith('.mp4')] | |
| print(images, videos) | |
| if len(videos) == 1: | |
| metrics[item] = get_score(item, images, 'test/'+item+'/'+videos[0]) | |
| #get_score(item, ['test/'+item+'/1.jpg', 'test/'+item+'/2.jpg', 'test/'+item+'/3.jpg'], 'test/'+item+'/v.mp4') | |
| with open('/data/metrics.json', "w", encoding="utf-8") as json_file: | |
| json.dump(metrics, json_file, ensure_ascii=False, indent=4) | |
| else: | |
| print('Error: mp4 not found') | |
| except Exception as e: | |
| print("Error", item, e) | |
| ssim = [] | |
| psnr = [] | |
| lpips = [] | |
| fid = [] | |
| ssim2 = [] | |
| psnr2 = [] | |
| lpips2 = [] | |
| fid2 = [] | |
| for item in metrics.keys(): | |
| ssim.append(metrics[item]['ft']['ssim']['avg']) | |
| psnr.append(metrics[item]['ft']['psnr']['avg']) | |
| lpips.append(metrics[item]['ft']['lpips']['avg']) | |
| fid.append(metrics[item]['ft']['fid']['avg']) | |
| ssim2.append(metrics[item]['base']['ssim']['avg']) | |
| psnr2.append(metrics[item]['base']['psnr']['avg']) | |
| lpips2.append(metrics[item]['base']['lpips']['avg']) | |
| fid2.append(metrics[item]['base']['fid']['avg']) | |
| print(item) | |
| print("SSIM:", metrics[item]['ft']['ssim']['avg'], metrics[item]['base']['ssim']['avg']) | |
| print("PSNR:", metrics[item]['ft']['psnr']['avg'], metrics[item]['base']['psnr']['avg']) | |
| print("LPIPS:", metrics[item]['ft']['lpips']['avg'], metrics[item]['base']['lpips']['avg']) | |
| print("FID:", metrics[item]['ft']['fid']['avg'], metrics[item]['base']['fid']['avg']) | |
| print('Results:') | |
| print("SSIM:", sum(ssim)/len(ssim)) | |
| print("PSNR:", sum(psnr)/len(psnr)) | |
| print("LPIPS:", sum(lpips)/len(lpips)) | |
| print("FID:", sum(fid)/len(fid)) | |
| print('baseline:') | |
| print("SSIM:", sum(ssim2)/len(ssim2)) | |
| print("PSNR:", sum(psnr2)/len(psnr2)) | |
| print("LPIPS:", sum(lpips2)/len(lpips2)) | |
| print("FID:", sum(fid2)/len(fid2)) | |