|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import torch |
|
|
from tqdm import tqdm |
|
|
import time |
|
|
|
|
|
from src.eval.metrics import get_lpips, compute_lpips, compute_ssim, compute_psnr, resize_and_crop_video, read_mp4_to_tensor, plot_average_metric_per_frame, compute_std |
|
|
from src.models.btimer.core.utils.data import write_dict_to_json, read_json_to_dict |
|
|
|
|
|
def compute_metrics(path_data_pred: str, path_data_gt: str, out_path: str = None, H_target: int = None, W_target: int = None, num_scenes: int = None, num_frames_eval: int = None): |
|
|
device = torch.device("cuda:0") |
|
|
lpips_module = get_lpips(device) |
|
|
file_names = sorted(os.listdir(path_data_pred)) |
|
|
if num_scenes is not None: |
|
|
file_names = file_names[:num_scenes] |
|
|
|
|
|
psnr_sum, ssim_sum, lpips_sum = None, None, None |
|
|
count = 0 |
|
|
out_path_metrics_main = os.path.join(out_path, 'metrics') |
|
|
os.makedirs(out_path_metrics_main, exist_ok=True) |
|
|
time.sleep(30) |
|
|
for file_name in tqdm(file_names): |
|
|
path_video_pred = os.path.join(path_data_pred, file_name) |
|
|
path_video_gt = os.path.join(path_data_gt, file_name) |
|
|
path_metrics_out = os.path.join(out_path_metrics_main, file_name.replace('.mp4', '.json')) |
|
|
|
|
|
if psnr_sum is None: |
|
|
video_pred = read_mp4_to_tensor(path_video_pred, device) |
|
|
T = video_pred.shape[0] |
|
|
psnr_sum = torch.zeros(T, device=device) |
|
|
ssim_sum = torch.zeros(T, device=device) |
|
|
lpips_sum = torch.zeros(T, device=device) |
|
|
if not os.path.isfile(path_metrics_out): |
|
|
|
|
|
video_pred = read_mp4_to_tensor(path_video_pred, device) |
|
|
video_gt = read_mp4_to_tensor(path_video_gt, device) |
|
|
T, C, H, W = video_pred.shape |
|
|
|
|
|
video_gt = video_gt[:T] |
|
|
|
|
|
video_gt = resize_and_crop_video(video_gt, H, W) |
|
|
|
|
|
if H_target is not None and W_target is not None: |
|
|
video_pred = resize_and_crop_video(video_pred, H_target, W_target) |
|
|
video_gt = resize_and_crop_video(video_gt, H_target, W_target) |
|
|
|
|
|
T_gt = video_gt.shape[0] |
|
|
if T_gt != T: |
|
|
pad_mask = torch.zeros(T, device=device, dtype=bool) |
|
|
pad_mask[T_gt:] = True |
|
|
video_pred = video_pred[:T_gt] |
|
|
|
|
|
psnr = compute_psnr(video_gt, video_pred) |
|
|
ssim = compute_ssim(video_gt, video_pred) |
|
|
lpips = compute_lpips(video_gt, video_pred) |
|
|
|
|
|
|
|
|
if T_gt != T: |
|
|
psnr_full = torch.zeros(T, device=device) |
|
|
ssim_full = torch.zeros(T, device=device) |
|
|
lpips_full = torch.zeros(T, device=device) |
|
|
psnr_full[:T_gt] = psnr |
|
|
ssim_full[:T_gt] = ssim |
|
|
lpips_full[:T_gt] = lpips |
|
|
psnr = psnr_full |
|
|
ssim = ssim_full |
|
|
lpips = lpips_full |
|
|
|
|
|
metrics_dict_out = {'psnr': psnr, 'ssim': ssim, 'lpips': lpips} |
|
|
metrics_dict_out = {k: v.tolist() for k, v in metrics_dict_out.items()} |
|
|
write_dict_to_json(metrics_dict_out, path_metrics_out) |
|
|
else: |
|
|
metrics_dict_out = read_json_to_dict(path_metrics_out) |
|
|
metrics_dict_out = {k: torch.tensor(v, device=device) for k, v in metrics_dict_out.items()} |
|
|
psnr = metrics_dict_out['psnr'] |
|
|
ssim = metrics_dict_out['ssim'] |
|
|
lpips = metrics_dict_out['lpips'] |
|
|
psnr_sum += psnr |
|
|
ssim_sum += ssim |
|
|
lpips_sum += lpips |
|
|
count += 1 |
|
|
|
|
|
psnr_avg = psnr_sum / count |
|
|
ssim_avg = ssim_sum / count |
|
|
lpips_avg = lpips_sum / count |
|
|
psnr_std = compute_std(psnr_sum, psnr_avg, count) |
|
|
ssim_std = compute_std(ssim_sum, ssim_avg, count) |
|
|
lpips_std = compute_std(lpips_sum, lpips_avg, count) |
|
|
|
|
|
|
|
|
print("psnr_avg", psnr_avg, "ssim_avg", ssim_avg, "lpips_avg", lpips_avg) |
|
|
if num_frames_eval is not None: |
|
|
print(f"psnr_avg for first {num_frames_eval} frames: {psnr_avg[:num_frames_eval].mean()}") |
|
|
print(f"ssim_avg for first {num_frames_eval} frames: {ssim_avg[:num_frames_eval].mean()}") |
|
|
print(f"lpips_avg for first {num_frames_eval} frames: {lpips_avg[:num_frames_eval].mean()}") |
|
|
plot_average_metric_per_frame(psnr_avg, os.path.join(out_path, "psnr.png"), psnr_std, metric_name="PSNR") |
|
|
plot_average_metric_per_frame(ssim_avg, os.path.join(out_path, "ssim.png"), ssim_std, metric_name="SSIM") |
|
|
plot_average_metric_per_frame(lpips_avg, os.path.join(out_path, "lpips.png"), lpips_std, metric_name="LPIPS") |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
path_data_pred = '/path/to/rgb' |
|
|
path_data_gt = '/path/to/gt_rgb/' |
|
|
out_path = 'outputs/eval' |
|
|
|
|
|
H_target, W_target = None, None |
|
|
num_scenes = 1000 |
|
|
num_frames_eval = 121 |
|
|
out_path = os.path.join(out_path, f"img_res_{H_target}_{W_target}", f"num_scenes_{num_scenes}") |
|
|
compute_metrics(path_data_pred, path_data_gt, out_path, H_target, W_target, num_scenes, num_frames_eval) |