Lyra / src /eval /compute_metrics_datasets.py
Muhammad Taqi Raza
adding lyra files
af758d1
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from tqdm import tqdm
import time
from src.eval.metrics import get_lpips, compute_lpips, compute_ssim, compute_psnr, resize_and_crop_video, read_mp4_to_tensor, plot_average_metric_per_frame, compute_std
from src.models.btimer.core.utils.data import write_dict_to_json, read_json_to_dict
def compute_metrics(path_data_pred: str, path_data_gt: str, out_path: str = None, H_target: int = None, W_target: int = None, num_scenes: int = None, num_frames_eval: int = None):
device = torch.device("cuda:0")
lpips_module = get_lpips(device)
file_names = sorted(os.listdir(path_data_pred))
if num_scenes is not None:
file_names = file_names[:num_scenes]
psnr_sum, ssim_sum, lpips_sum = None, None, None
count = 0
out_path_metrics_main = os.path.join(out_path, 'metrics')
os.makedirs(out_path_metrics_main, exist_ok=True)
time.sleep(30)
for file_name in tqdm(file_names):
path_video_pred = os.path.join(path_data_pred, file_name)
path_video_gt = os.path.join(path_data_gt, file_name)
path_metrics_out = os.path.join(out_path_metrics_main, file_name.replace('.mp4', '.json'))
# Initialize accumulation tensors
if psnr_sum is None:
video_pred = read_mp4_to_tensor(path_video_pred, device)
T = video_pred.shape[0]
psnr_sum = torch.zeros(T, device=device)
ssim_sum = torch.zeros(T, device=device)
lpips_sum = torch.zeros(T, device=device)
if not os.path.isfile(path_metrics_out):
# Read videos
video_pred = read_mp4_to_tensor(path_video_pred, device)
video_gt = read_mp4_to_tensor(path_video_gt, device)
T, C, H, W = video_pred.shape
# Cut gt to the same frames
video_gt = video_gt[:T]
# Resize and crop videos to target res
video_gt = resize_and_crop_video(video_gt, H, W)
# Additional resize and crop
if H_target is not None and W_target is not None:
video_pred = resize_and_crop_video(video_pred, H_target, W_target)
video_gt = resize_and_crop_video(video_gt, H_target, W_target)
# Optionally shorten
T_gt = video_gt.shape[0]
if T_gt != T:
pad_mask = torch.zeros(T, device=device, dtype=bool)
pad_mask[T_gt:] = True
video_pred = video_pred[:T_gt]
# Compute metrics
psnr = compute_psnr(video_gt, video_pred)
ssim = compute_ssim(video_gt, video_pred)
lpips = compute_lpips(video_gt, video_pred)
# Optionally pad
if T_gt != T:
psnr_full = torch.zeros(T, device=device)
ssim_full = torch.zeros(T, device=device)
lpips_full = torch.zeros(T, device=device)
psnr_full[:T_gt] = psnr
ssim_full[:T_gt] = ssim
lpips_full[:T_gt] = lpips
psnr = psnr_full
ssim = ssim_full
lpips = lpips_full
metrics_dict_out = {'psnr': psnr, 'ssim': ssim, 'lpips': lpips}
metrics_dict_out = {k: v.tolist() for k, v in metrics_dict_out.items()}
write_dict_to_json(metrics_dict_out, path_metrics_out)
else:
metrics_dict_out = read_json_to_dict(path_metrics_out)
metrics_dict_out = {k: torch.tensor(v, device=device) for k, v in metrics_dict_out.items()}
psnr = metrics_dict_out['psnr']
ssim = metrics_dict_out['ssim']
lpips = metrics_dict_out['lpips']
psnr_sum += psnr
ssim_sum += ssim
lpips_sum += lpips
count += 1
# Compute average PSNR per frame
psnr_avg = psnr_sum / count
ssim_avg = ssim_sum / count
lpips_avg = lpips_sum / count
psnr_std = compute_std(psnr_sum, psnr_avg, count)
ssim_std = compute_std(ssim_sum, ssim_avg, count)
lpips_std = compute_std(lpips_sum, lpips_avg, count)
# Save histogram
print("psnr_avg", psnr_avg, "ssim_avg", ssim_avg, "lpips_avg", lpips_avg)
if num_frames_eval is not None:
print(f"psnr_avg for first {num_frames_eval} frames: {psnr_avg[:num_frames_eval].mean()}")
print(f"ssim_avg for first {num_frames_eval} frames: {ssim_avg[:num_frames_eval].mean()}")
print(f"lpips_avg for first {num_frames_eval} frames: {lpips_avg[:num_frames_eval].mean()}")
plot_average_metric_per_frame(psnr_avg, os.path.join(out_path, "psnr.png"), psnr_std, metric_name="PSNR")
plot_average_metric_per_frame(ssim_avg, os.path.join(out_path, "ssim.png"), ssim_std, metric_name="SSIM")
plot_average_metric_per_frame(lpips_avg, os.path.join(out_path, "lpips.png"), lpips_std, metric_name="LPIPS")
if __name__ == '__main__':
path_data_pred = '/path/to/rgb'
path_data_gt = '/path/to/gt_rgb/'
out_path = 'outputs/eval'
H_target, W_target = None, None
num_scenes = 1000
num_frames_eval = 121
out_path = os.path.join(out_path, f"img_res_{H_target}_{W_target}", f"num_scenes_{num_scenes}")
compute_metrics(path_data_pred, path_data_gt, out_path, H_target, W_target, num_scenes, num_frames_eval)