from main import extract_frames, run from PIL import Image import numpy as np from skimage.metrics import structural_similarity as ssim from skimage.metrics import peak_signal_noise_ratio as psnr import torch import torchvision.transforms as transforms import lpips from pytorch_fid.fid_score import calculate_fid_given_paths import os import json # Convert PIL to numpy def pil_to_np(img): return np.array(img).astype(np.float32) / 255.0 # SSIM def compute_ssim(img1, img2): img1_np = pil_to_np(img1) img2_np = pil_to_np(img2) h, w = img1_np.shape[:2] min_dim = min(h, w) win_size = min(7, min_dim if min_dim % 2 == 1 else min_dim - 1) # ensure odd return ssim(img1_np, img2_np, win_size=win_size, channel_axis=-1, data_range=1.0) # PSNR def compute_psnr(img1, img2): img1_np = pil_to_np(img1) img2_np = pil_to_np(img2) return psnr(img1_np, img2_np, data_range=1.0) # LPIPS lpips_model = lpips.LPIPS(net='alex') lpips_transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.5]*3, [0.5]*3) ]) def compute_lpips(img1, img2): img1_tensor = lpips_transform(img1).unsqueeze(0) img2_tensor = lpips_transform(img2).unsqueeze(0) return lpips_model(img1_tensor, img2_tensor).item() # FID: Save images to temp folders for FID calculation def compute_fid(img1, img2): os.makedirs('temp/img1', exist_ok=True) os.makedirs('temp/img2', exist_ok=True) img1.save('temp/img1/0.png') img2.save('temp/img2/0.png') fid = calculate_fid_given_paths(['temp/img1', 'temp/img2'], batch_size=1, device='cpu', dims=2048) return fid with open('metrics.json', 'r') as file: metrics = json.load(file) def get_score(item, image_paths, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False): print(item) images = [] for path in image_paths: img = Image.open(path) images.append(img) gt_frames = extract_frames(video_path, fps) os.makedirs('out/'+item, exist_ok=True) for i, frame in enumerate(gt_frames): frame.save("out/"+item+"/frame_"+str(i)+".png") results = run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, finetune=True) for i, result in enumerate(results): result.save("out/"+item+"/result_"+str(i)+".png") results_base = run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, finetune=False) for i, result in enumerate(results_base): result.save("out/"+item+"/base_"+str(i)+".png") """ img1=gt_frames[0] img2=Image.open("out/base_0.png") print("SSIM:", compute_ssim(img1, img2)) print("PSNR:", compute_psnr(img1, img2)) print("LPIPS:", compute_lpips(img1, img2)) print("FID:", compute_fid(img1, img2)) """ ssim = [] psnr = [] lpips = [] fid = [] ssim2 = [] psnr2 = [] lpips2 = [] fid2 = [] for gt, result, base in zip(gt_frames, results, results_base): ssim.append(float(compute_ssim(gt, result))) psnr.append(float(compute_psnr(gt, result))) lpips.append(float(compute_lpips(gt, result))) fid.append(float(compute_fid(gt, result))) ssim2.append(float(compute_ssim(gt, base))) psnr2.append(float(compute_psnr(gt, base))) lpips2.append(float(compute_lpips(gt, base))) fid2.append(float(compute_fid(gt, base))) print("SSIM:", sum(ssim)/len(ssim)) print("PSNR:", sum(psnr)/len(psnr)) print("LPIPS:", sum(lpips)/len(lpips)) print("FID:", sum(fid)/len(fid)) print('baseline:') print("SSIM:", sum(ssim2)/len(ssim2)) print("PSNR:", sum(psnr2)/len(psnr2)) print("LPIPS:", sum(lpips2)/len(lpips2)) print("FID:", sum(fid2)/len(fid2)) metrics[item] = {'ft': {}, 'base': {}} metrics[item]['ft']['ssim'] = {'avg': sum(ssim)/len(ssim), 'vals': ssim} metrics[item]['ft']['psnr'] = {'avg': sum(psnr)/len(psnr), 'vals': psnr} metrics[item]['ft']['lpips'] = {'avg': sum(lpips)/len(lpips), 'vals': lpips} metrics[item]['ft']['fid'] = {'avg': sum(fid)/len(fid), 'vals': fid} metrics[item]['base']['ssim'] = {'avg': sum(ssim2)/len(ssim2), 'vals': ssim2} metrics[item]['base']['psnr'] = {'avg': sum(psnr2)/len(psnr2), 'vals': psnr2} metrics[item]['base']['lpips'] = {'avg': sum(lpips2)/len(lpips2), 'vals': lpips2} metrics[item]['base']['fid'] = {'avg': sum(fid2)/len(fid2), 'vals': fid2} with open('metrics.json', "w", encoding="utf-8") as json_file: json.dump(metrics, json_file, ensure_ascii=False, indent=4) items = ['sidewalk', 'aaa', 'azri', 'dead', 'frankgirl', 'kobold', 'ramona', 'renee', 'walk', 'woody'] for item in items: if item in metrics: continue get_score(item, ['test/'+item+'/1.jpg', 'test/'+item+'/2.jpg', 'test/'+item+'/3.jpg'], 'test/'+item+'/v.mp4') ssim = [] psnr = [] lpips = [] fid = [] ssim2 = [] psnr2 = [] lpips2 = [] fid2 = [] for item in metrics.keys(): ssim.append(metrics[item]['ft']['ssim']['avg']) psnr.append(metrics[item]['ft']['psnr']['avg']) lpips.append(metrics[item]['ft']['lpips']['avg']) fid.append(metrics[item]['ft']['fid']['avg']) ssim2.append(metrics[item]['base']['ssim']['avg']) psnr2.append(metrics[item]['base']['psnr']['avg']) lpips2.append(metrics[item]['base']['lpips']['avg']) fid2.append(metrics[item]['base']['fid']['avg']) print(item) print("SSIM:", metrics[item]['ft']['ssim']['avg'], metrics[item]['base']['ssim']['avg']) print("PSNR:", metrics[item]['ft']['psnr']['avg'], metrics[item]['base']['psnr']['avg']) print("LPIPS:", metrics[item]['ft']['lpips']['avg'], metrics[item]['base']['lpips']['avg']) print("FID:", metrics[item]['ft']['fid']['avg'], metrics[item]['base']['fid']['avg']) print('Results:') print("SSIM:", sum(ssim)/len(ssim)) print("PSNR:", sum(psnr)/len(psnr)) print("LPIPS:", sum(lpips)/len(lpips)) print("FID:", sum(fid)/len(fid)) print('baseline:') print("SSIM:", sum(ssim2)/len(ssim2)) print("PSNR:", sum(psnr2)/len(psnr2)) print("LPIPS:", sum(lpips2)/len(lpips2)) print("FID:", sum(fid2)/len(fid2))