File size: 5,690 Bytes
199f9c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torchmetrics
import os
import torch
from PIL import Image
import numpy as np
import csv
import sys

num_positions = 9

output_dir_path = "/datasets/sai/focal-burst-learning/metrics_output"



gt = "gt"
model = sys.argv[1]

gt_path = os.path.join(output_dir_path, gt)
model_path = os.path.join(output_dir_path, model)

device = sys.argv[2]

metrics_grid = []

for i in range(num_positions):
    row = []
    for j in range(num_positions):
        metrics = {
            "psnr": torchmetrics.image.PeakSignalNoiseRatio(data_range=1.0).to(device),
            "ssim": torchmetrics.image.StructuralSimilarityIndexMeasure().to(device),
            "lpips": torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity(net_type='vgg', normalize=True).to(device),
            "fid": torchmetrics.image.fid.FrechetInceptionDistance(normalize=True).to(device),
            "vif": torchmetrics.image.VisualInformationFidelity().to(device),
        }
        row.append(metrics)
    metrics_grid.append(row)
    print("Created metrics for position", i)

#lopp through each directory in gt_path
#get all directories in gt_path
position_dirs = os.listdir(gt_path)
position_dirs = sorted([dir for dir in position_dirs if os.path.isdir(os.path.join(gt_path, dir))]) [0:num_positions]

for gt_dir in position_dirs:
    position_number = int(gt_dir.split("_")[1])
    #get pngs inside that directory
    gt_pngs = sorted(os.listdir(os.path.join(gt_path, gt_dir, "images")))
    #Confirm that number of pngs == 164*9
    assert len(gt_pngs) == 164*9
    #loop through the 164 imgs
    for i in range(164):
        #get the 9 frames
        gt_frames_names = gt_pngs[i*9:(i+1)*9]
        #load the 9 frames
        gt_frames = [Image.open(os.path.join(gt_path, gt_dir, "images", frame)) for frame in gt_frames_names]
        #make into numpy arraymo
        gt_frames = [torch.tensor(np.array(frame)/255).to(torch.float32).to(device).permute(2,0,1).unsqueeze(0) for frame in gt_frames]
        
        #load model_frames which is almost smae path but in model_path
        model_frames = [Image.open(os.path.join(model_path, gt_dir, "images", frame)) for frame in gt_frames_names]
        #make into numpy array
        model_frames = [torch.tensor(np.array(frame)/255).to(torch.float32).to(device).permute(2,0,1).unsqueeze(0) for frame in model_frames]

        #loop through the 9 frames
        for j in range(num_positions):
            #compute metrics
            for key, metric in metrics_grid[position_number][j].items():
                #if frames have a 4th channel discard it
                if gt_frames[j].shape[1] == 4:
                    gt_frames[j] = gt_frames[j][:,:3,:,:]
                if model_frames[j].shape[1] == 4:
                    model_frames[j] = model_frames[j][:,:3,:,:]
                if key == "fid":
                    metric.update(model_frames[j], real=False)
                    metric.update(gt_frames[j], real=True)
                else:
                    metric(gt_frames[j], model_frames[j])

        print("Computed metrics for position", position_number, "frame", i)

#write the metrics to a csv (each metric as a csv)

def write_metrics_to_csv(metrics_grid, metric_names, formatting_options=None, output_dir="metrics_output"):
    """
    Writes each metric in the metrics_grid to a separate CSV file.

    Args:
        metrics_grid (list): A 9x9 list of dictionaries containing metrics.
        metric_names (list): List of metric names (e.g., ["psnr", "lpips", "fid"]).
        output_dir (str): Directory where the CSV files will be saved.
    """
    import os
    os.makedirs(output_dir, exist_ok=True)  # Create output directory if it doesn't exist

    positions = list(range(1, num_positions+1))

    for metric_name in metric_names:
        output_file = os.path.join(output_dir, f"{metric_name}.csv")

        # Get the formatting function for the current metric, or use default
        format_fn = formatting_options.get(metric_name, lambda x: f"{x}") if formatting_options else lambda x: f"{x}"
        
        
        # Write the metric to the CSV
        with open(output_file, mode='w', newline='') as csv_file:
            writer = csv.writer(csv_file)

            header = ["Starting Position/End Position"] + [f"Position {i}" for i in positions]
            writer.writerow(header)
            
            # Iterate over the grid and extract the metric values
            for i, row in enumerate(metrics_grid):
                csv_row = [f"Position {positions[i]}"]  # Add the column label as the first column
                for cell in row:
                    metric = cell[metric_name]
                    # Assuming metrics are PyTorch objects with a `compute` method
                    # Replace `0.0` with metric.compute() if metric values are computed
                    value = 0.0 if not hasattr(metric, "compute") else metric.compute().item()
                    csv_row.append(format_fn(value))  # Format the value
                writer.writerow(csv_row)
                print(f"Wrote row for position {positions[i]} with metric {metric_name}")
        
        print(f"Saved {metric_name} metrics to {output_file}")

formatting_options = {
    "psnr": lambda x: f"{x:.2f}",  # Two decimal places
    "lpips": lambda x: f"{x:.4f}",  # Four decimal places
    "fid": lambda x: f"{x:.2f}",  # Two  decimal places
    "ssim": lambda x: f"{x:.4f}",  # Four decimal places
    "vif": lambda x: f"{x:.4f}"  # Four decimal places
}



write_metrics_to_csv(metrics_grid, ["psnr", "ssim", "lpips", "fid", "vif"], formatting_options=formatting_options, output_dir=f"{output_dir_path}/metrics_output/{model}")