File size: 3,474 Bytes
604568a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import csv
import torch
import numpy as np
from skimage.metrics import structural_similarity as compare_ssim
import math

class ImageQualityMetrics:
    """Class to compute image quality metrics like SSIM, PSNR, and MSE."""

    @staticmethod
    @staticmethod
    def ssim_3d(img1, img2):
        """Calculate SSIM for each 2D slice in the 3D image and return the average."""
        ssim_vals = []
        for i in range(img1.shape[1]):  # Depth
            slice1 = img1[0, i, :, :]
            slice2 = img2[0, i, :, :]
            ssim_val = compare_ssim(slice1, slice2, data_range=slice1.max() - slice1.min())
            ssim_vals.append(ssim_val)
        return np.mean(ssim_vals)
    
    @staticmethod
    def psnr(img1, img2):
        """Calculate PSNR (Peak Signal-to-Noise Ratio) between two images."""
        mse = torch.mean((img1 - img2) ** 2)
        if mse == 0:
            return math.inf
        return 20 * math.log10(img1.max() - img1.min()) - 10 * math.log10(mse)

    @staticmethod
    def mse(img1, img2):
        """Calculate MSE (Mean Squared Error) between two images."""
        return torch.mean((img1 - img2) ** 2)


class ValidationRecorder:
    """Class to handle validation process and record the metrics."""

    def __init__(self, csv_file_path):
        """Initialize the recorder with the path to the CSV file."""
        self.csv_file_path = csv_file_path

    def initialize_csv(self):
        """Initialize the CSV file with headers."""
        with open(self.csv_file_path, mode='w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(['Epoch', 'Loss', 'SSIM', 'PSNR', 'MSE'])

    def validate_and_record(self, epoch, dataloader, device, generator, 
                            criterion_g):
        """Validate the model and record the metrics in the CSV file."""
        generator.eval()
        total_loss, total_ssim, total_psnr, total_mse = 0.0, 0.0, 0.0, 0.0

        with torch.no_grad():
            for _, (low_res, high_res) in enumerate(dataloader):
                low_res, high_res = low_res.to(device), high_res.to(device)
                fake_images = generator(low_res)

                loss = criterion_g(fake_images, high_res)
                total_loss += loss.item()

                for j in range(high_res.size(0)):
                    ssim_val = ImageQualityMetrics.ssim_3d(
                        high_res[j].cpu().numpy(), fake_images[j].cpu().numpy())
                    psnr_val = ImageQualityMetrics.psnr(
                        high_res[j], fake_images[j])
                    mse_val = ImageQualityMetrics.mse(
                        high_res[j], fake_images[j])
                    total_ssim += ssim_val
                    total_psnr += psnr_val
                    total_mse += mse_val.item()

        avg_loss = total_loss / len(dataloader)
        avg_ssim = total_ssim / (len(dataloader) * dataloader.batch_size)
        avg_psnr = total_psnr / (len(dataloader) * dataloader.batch_size)
        avg_mse = total_mse / (len(dataloader) * dataloader.batch_size)

        self._write_to_csv(epoch, avg_loss, avg_ssim, avg_psnr, avg_mse)

    def _write_to_csv(self, epoch, avg_loss, avg_ssim, avg_psnr, avg_mse):
        """Write the validation metrics to the CSV file."""
        with open(self.csv_file_path, mode='a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([epoch, avg_loss, avg_ssim, avg_psnr, avg_mse])