File size: 5,927 Bytes
af758d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import torch
from tqdm import tqdm
import time

from src.eval.metrics import get_lpips, compute_lpips, compute_ssim, compute_psnr, resize_and_crop_video, read_mp4_to_tensor, plot_average_metric_per_frame, compute_std
from src.models.btimer.core.utils.data import write_dict_to_json, read_json_to_dict

def compute_metrics(path_data_pred: str, path_data_gt: str, out_path: str = None, H_target: int = None, W_target: int = None, num_scenes: int = None, num_frames_eval: int = None):
    device = torch.device("cuda:0")
    lpips_module = get_lpips(device)
    file_names = sorted(os.listdir(path_data_pred))
    if num_scenes is not None:
        file_names = file_names[:num_scenes]

    psnr_sum, ssim_sum, lpips_sum = None, None, None
    count = 0
    out_path_metrics_main = os.path.join(out_path, 'metrics')
    os.makedirs(out_path_metrics_main, exist_ok=True)
    time.sleep(30)
    for file_name in tqdm(file_names):
        path_video_pred = os.path.join(path_data_pred, file_name)
        path_video_gt = os.path.join(path_data_gt, file_name)
        path_metrics_out = os.path.join(out_path_metrics_main, file_name.replace('.mp4', '.json'))
        # Initialize accumulation tensors
        if psnr_sum is None:
            video_pred = read_mp4_to_tensor(path_video_pred, device)
            T = video_pred.shape[0]
            psnr_sum = torch.zeros(T, device=device)
            ssim_sum = torch.zeros(T, device=device)
            lpips_sum = torch.zeros(T, device=device)
        if not os.path.isfile(path_metrics_out):
            # Read videos
            video_pred = read_mp4_to_tensor(path_video_pred, device)
            video_gt = read_mp4_to_tensor(path_video_gt, device)
            T, C, H, W = video_pred.shape
            # Cut gt to the same frames
            video_gt = video_gt[:T]
            # Resize and crop videos to target res
            video_gt = resize_and_crop_video(video_gt, H, W)
            # Additional resize and crop
            if H_target is not None and W_target is not None:
                video_pred = resize_and_crop_video(video_pred, H_target, W_target)
                video_gt = resize_and_crop_video(video_gt, H_target, W_target)
            # Optionally shorten
            T_gt = video_gt.shape[0]
            if T_gt != T:
                pad_mask = torch.zeros(T, device=device, dtype=bool)
                pad_mask[T_gt:] = True
                video_pred = video_pred[:T_gt]
            # Compute metrics
            psnr = compute_psnr(video_gt, video_pred)
            ssim = compute_ssim(video_gt, video_pred)
            lpips = compute_lpips(video_gt, video_pred)

            # Optionally pad
            if T_gt != T:
                psnr_full = torch.zeros(T, device=device)
                ssim_full = torch.zeros(T, device=device)
                lpips_full = torch.zeros(T, device=device)
                psnr_full[:T_gt] = psnr
                ssim_full[:T_gt] = ssim
                lpips_full[:T_gt] = lpips
                psnr = psnr_full
                ssim = ssim_full
                lpips = lpips_full

            metrics_dict_out = {'psnr': psnr, 'ssim': ssim, 'lpips': lpips}
            metrics_dict_out = {k: v.tolist() for k, v in metrics_dict_out.items()}
            write_dict_to_json(metrics_dict_out, path_metrics_out)
        else:
            metrics_dict_out = read_json_to_dict(path_metrics_out)
            metrics_dict_out = {k: torch.tensor(v, device=device) for k, v in metrics_dict_out.items()}
            psnr = metrics_dict_out['psnr']
            ssim = metrics_dict_out['ssim']
            lpips = metrics_dict_out['lpips']
        psnr_sum += psnr
        ssim_sum += ssim
        lpips_sum += lpips
        count += 1
    # Compute average PSNR per frame
    psnr_avg = psnr_sum / count
    ssim_avg = ssim_sum / count
    lpips_avg = lpips_sum / count
    psnr_std = compute_std(psnr_sum, psnr_avg, count)
    ssim_std = compute_std(ssim_sum, ssim_avg, count)
    lpips_std = compute_std(lpips_sum, lpips_avg, count)

    # Save histogram
    print("psnr_avg", psnr_avg, "ssim_avg", ssim_avg, "lpips_avg", lpips_avg)
    if num_frames_eval is not None:
        print(f"psnr_avg for first {num_frames_eval} frames: {psnr_avg[:num_frames_eval].mean()}")
        print(f"ssim_avg for first {num_frames_eval} frames: {ssim_avg[:num_frames_eval].mean()}")
        print(f"lpips_avg for first {num_frames_eval} frames: {lpips_avg[:num_frames_eval].mean()}")
    plot_average_metric_per_frame(psnr_avg, os.path.join(out_path, "psnr.png"), psnr_std, metric_name="PSNR")
    plot_average_metric_per_frame(ssim_avg, os.path.join(out_path, "ssim.png"), ssim_std, metric_name="SSIM")
    plot_average_metric_per_frame(lpips_avg, os.path.join(out_path, "lpips.png"), lpips_std, metric_name="LPIPS")
    

if __name__ == '__main__':
    path_data_pred = '/path/to/rgb'
    path_data_gt = '/path/to/gt_rgb/'
    out_path = 'outputs/eval'
    
    H_target, W_target = None, None
    num_scenes = 1000
    num_frames_eval = 121
    out_path = os.path.join(out_path, f"img_res_{H_target}_{W_target}", f"num_scenes_{num_scenes}")
    compute_metrics(path_data_pred, path_data_gt, out_path, H_target, W_target, num_scenes, num_frames_eval)