JensLundsgaard commited on
Commit
59e08fd
·
verified ·
1 Parent(s): 0c00839

Upload raffael_losses.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. raffael_losses.py +180 -0
raffael_losses.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ High-Quality Loss Functions
3
+ - MS-SSIM Loss (Multi-Scale Structural Similarity)
4
+ - L1 Loss
5
+ - Combined Reconstruction Loss
6
+ - Classification Loss
7
+ """
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ def gaussian_window(window_size, sigma, channels, device):
14
+ coords = torch.arange(window_size, dtype=torch.float32, device=device)
15
+ coords -= window_size // 2
16
+ g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
17
+ g /= g.sum()
18
+ window_1d = g.unsqueeze(1)
19
+ window_2d = window_1d @ window_1d.t()
20
+ window = window_2d.expand(channels, 1, window_size, window_size)
21
+ return window
22
+
23
+
24
+ def _ssim_and_mcs(img1, img2, window_size=11, sigma=1.5, data_range=1.0, size_average=True):
25
+ """
26
+ Compute both SSIM and MCS (contrast-structure) maps in the standard decomposition.
27
+ Returns:
28
+ ssim_val: scalar (or per-sample if size_average=False)
29
+ mcs_val: scalar (or per-sample if size_average=False)
30
+ """
31
+ assert img1.shape == img2.shape
32
+ B, C, H, W = img1.shape
33
+ device = img1.device
34
+
35
+ window = gaussian_window(window_size, sigma, C, device)
36
+
37
+ mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=C)
38
+ mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=C)
39
+
40
+ mu1_sq = mu1.pow(2)
41
+ mu2_sq = mu2.pow(2)
42
+ mu1_mu2 = mu1 * mu2
43
+
44
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=C) - mu1_sq
45
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=C) - mu2_sq
46
+ sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=C) - mu1_mu2
47
+
48
+ # Standard constants scaled by data range
49
+ C1 = (0.01 * data_range) ** 2
50
+ C2 = (0.03 * data_range) ** 2
51
+
52
+ # Luminance term
53
+ l = (2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1 + 1e-12)
54
+ # Contrast-structure term (often called "cs" or "mcs")
55
+ cs = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2 + 1e-12)
56
+
57
+ ssim_map = l * cs
58
+ mcs_map = cs # standard MS-SSIM uses cs for scales 1..M-1
59
+
60
+ if size_average:
61
+ return ssim_map.mean(), mcs_map.mean()
62
+ else:
63
+ # per-sample
64
+ return ssim_map.mean(dim=[1, 2, 3]), mcs_map.mean(dim=[1, 2, 3])
65
+
66
+
67
+ def ms_ssim(
68
+ img1,
69
+ img2,
70
+ window_size=11,
71
+ sigma=1.5,
72
+ data_range=1.0,
73
+ weights=None,
74
+ levels=5,
75
+ size_average=True
76
+ ):
77
+ """
78
+ Standard MS-SSIM:
79
+ MS-SSIM = (SSIM_M)^{w_M} * Π_{j=1}^{M-1} (MCS_j)^{w_j}
80
+
81
+ Args:
82
+ img1, img2: (B, C, H, W) in [0, data_range]
83
+ weights: length==levels, default is the common 5-scale weights
84
+ levels: number of scales M
85
+ """
86
+ assert img1.shape == img2.shape
87
+ if weights is None:
88
+ weights = torch.tensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333], device=img1.device)
89
+ else:
90
+ weights = torch.as_tensor(weights, device=img1.device, dtype=torch.float32)
91
+
92
+ weights = weights[:levels]
93
+ weights = weights / weights.sum() # normalized weights (optional but fine)
94
+
95
+ mcs_vals = []
96
+ ssim_val = None
97
+
98
+ x1, x2 = img1, img2
99
+ for j in range(levels):
100
+ ssim_j, mcs_j = _ssim_and_mcs(
101
+ x1, x2, window_size=window_size, sigma=sigma, data_range=data_range, size_average=size_average
102
+ )
103
+
104
+ if j < levels - 1:
105
+ mcs_vals.append(mcs_j)
106
+ x1 = F.avg_pool2d(x1, kernel_size=2, stride=2)
107
+ x2 = F.avg_pool2d(x2, kernel_size=2, stride=2)
108
+ else:
109
+ ssim_val = ssim_j
110
+
111
+ # Combine exactly once (no iterative re-exponentiation)
112
+ # MS-SSIM = Π_{j=1}^{M-1} mcs_j^{w_j} * ssim_M^{w_M}
113
+ out = ssim_val.pow(weights[levels - 1])
114
+ for j, mcs_j in enumerate(mcs_vals):
115
+ out = out * mcs_j.pow(weights[j])
116
+
117
+ return out
118
+
119
+
120
+ def reconstruction_loss(x_rec, x_true, l1_weight=0.5, ms_ssim_weight=0.5,
121
+ window_size=11, sigma=1.5, data_range=1.0, levels=5, weights=None):
122
+ """
123
+ Combined reconstruction loss: L1 + MS-SSIM
124
+ Args:
125
+ x_rec, x_true: (B, T, C, H, W)
126
+ """
127
+ assert x_rec.shape == x_true.shape
128
+ B, T, C, H, W = x_rec.shape
129
+
130
+ # Flatten temporal dimension for MS-SSIM computation
131
+ x_rec_flat = x_rec.reshape(B * T, C, H, W)
132
+ x_true_flat = x_true.reshape(B * T, C, H, W)
133
+
134
+ l1 = F.l1_loss(x_rec, x_true)
135
+
136
+ ms_val = ms_ssim(
137
+ x_rec_flat, x_true_flat,
138
+ window_size=window_size, sigma=sigma, data_range=data_range,
139
+ weights=weights, levels=levels, size_average=True
140
+ )
141
+ ms_loss = 1.0 - ms_val
142
+
143
+ total = l1_weight * l1 + ms_ssim_weight * ms_loss
144
+
145
+ return total, {
146
+ "l1_loss": float(l1.detach().cpu()),
147
+ "ms_ssim_loss": float(ms_loss.detach().cpu()),
148
+ "ms_ssim_value": float(ms_val.detach().cpu()),
149
+ }
150
+
151
+ def temporal_smoothness_loss(z_seq, weight=0.1):
152
+ """
153
+ Temporal smoothness loss: encourages similar latents for adjacent timesteps
154
+ Args:
155
+ z_seq: (B, T, C, H, W) - latent sequence
156
+ weight: loss weight
157
+ """
158
+ if z_seq.size(1) < 2:
159
+ return torch.tensor(0.0, device=z_seq.device)
160
+
161
+ # Compute difference between adjacent timesteps
162
+ diff = z_seq[:, 1:] - z_seq[:, :-1] # (B, T-1, C, H, W)
163
+ smooth_loss = (diff ** 2).mean()
164
+
165
+ return weight * smooth_loss
166
+
167
+
168
+ def classification_loss(logits, labels, criterion=None):
169
+ """
170
+ Classification loss
171
+ Args:
172
+ logits: (B, num_classes) - classification logits
173
+ labels: (B,) - ground truth labels
174
+ criterion: loss function, default CrossEntropyLoss
175
+ """
176
+ if criterion is None:
177
+ criterion = nn.CrossEntropyLoss()
178
+
179
+ return criterion(logits, labels)
180
+