JensLundsgaard commited on
Commit
939f16b
·
verified ·
1 Parent(s): 4ee72df

Upload raffael_losses.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. raffael_losses.py +161 -0
raffael_losses.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ High-Quality Loss Functions
3
+ - MS-SSIM Loss (Multi-Scale Structural Similarity)
4
+ - L1 Loss
5
+ - Combined Reconstruction Loss
6
+ - Classification Loss
7
+ """
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ def gaussian_kernel(size=11, sigma=1.5):
14
+ """Generate Gaussian kernel for SSIM"""
15
+ coords = torch.arange(size, dtype=torch.float32)
16
+ coords -= size // 2
17
+ g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
18
+ g /= g.sum()
19
+ return g.unsqueeze(0) * g.unsqueeze(1)
20
+
21
+
22
+ def ssim(img1, img2, kernel_size=11, sigma=1.5, C1=0.01**2, C2=0.03**2):
23
+ """
24
+ Single-scale SSIM
25
+ Args:
26
+ img1, img2: (B, C, H, W)
27
+ """
28
+ kernel = gaussian_kernel(kernel_size, sigma).to(img1.device)
29
+ kernel = kernel.unsqueeze(0).unsqueeze(0) # (1, 1, k, k)
30
+
31
+ mu1 = F.conv2d(img1, kernel, padding=kernel_size//2)
32
+ mu2 = F.conv2d(img2, kernel, padding=kernel_size//2)
33
+
34
+ mu1_sq = mu1 ** 2
35
+ mu2_sq = mu2 ** 2
36
+ mu1_mu2 = mu1 * mu2
37
+
38
+ sigma1_sq = F.conv2d(img1 * img1, kernel, padding=kernel_size//2) - mu1_sq
39
+ sigma2_sq = F.conv2d(img2 * img2, kernel, padding=kernel_size//2) - mu2_sq
40
+ sigma12 = F.conv2d(img1 * img2, kernel, padding=kernel_size//2) - mu1_mu2
41
+
42
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \
43
+ ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
44
+
45
+ return ssim_map.mean()
46
+
47
+
48
+ def ms_ssim(img1, img2, kernel_size=11, sigma=1.5, weights=None, levels=5):
49
+ """
50
+ Multi-Scale SSIM (MS-SSIM)
51
+ Args:
52
+ img1, img2: (B, C, H, W)
53
+ weights: weights for each scale, default [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
54
+ levels: number of scales
55
+ """
56
+ if weights is None:
57
+ weights = torch.tensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333],
58
+ device=img1.device)
59
+
60
+ # Ensure weight count matches
61
+ weights = weights[:levels]
62
+ weights = weights / weights.sum()
63
+
64
+ mcs_list = []
65
+ ssim_val = None
66
+
67
+ for i in range(levels):
68
+ if i == levels - 1:
69
+ # Last layer computes SSIM
70
+ ssim_val = ssim(img1, img2, kernel_size, sigma)
71
+ else:
72
+ # Other layers compute contrast
73
+ kernel = gaussian_kernel(kernel_size, sigma).to(img1.device)
74
+ kernel = kernel.unsqueeze(0).unsqueeze(0)
75
+
76
+ mu1 = F.conv2d(img1, kernel, padding=kernel_size//2)
77
+ mu2 = F.conv2d(img2, kernel, padding=kernel_size//2)
78
+
79
+ sigma1_sq = F.conv2d(img1 * img1, kernel, padding=kernel_size//2) - mu1 ** 2
80
+ sigma2_sq = F.conv2d(img2 * img2, kernel, padding=kernel_size//2) - mu2 ** 2
81
+ sigma12 = F.conv2d(img1 * img2, kernel, padding=kernel_size//2) - mu1 * mu2
82
+
83
+ C2 = 0.03 ** 2
84
+ mcs = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
85
+ mcs_list.append(mcs.mean())
86
+
87
+ # Downsample to next level
88
+ if i < levels - 1:
89
+ img1 = F.avg_pool2d(img1, 2)
90
+ img2 = F.avg_pool2d(img2, 2)
91
+
92
+ # Combine all scales
93
+ ms_ssim_val = ssim_val
94
+ for i, mcs in enumerate(mcs_list):
95
+ ms_ssim_val = ms_ssim_val ** weights[i] * mcs ** weights[i]
96
+
97
+ return ms_ssim_val
98
+
99
+
100
+ def reconstruction_loss(x_rec, x_true, l1_weight=0.5, ms_ssim_weight=0.5):
101
+ """
102
+ Combined reconstruction loss: L1 + MS-SSIM
103
+ Args:
104
+ x_rec: (B, T, 1, H, W) - reconstructed video
105
+ x_true: (B, T, 1, H, W) - original video
106
+ l1_weight: L1 loss weight
107
+ ms_ssim_weight: MS-SSIM loss weight
108
+ """
109
+ B, T, C, H, W = x_rec.shape
110
+
111
+ # Flatten temporal dimension for MS-SSIM computation
112
+ x_rec_flat = x_rec.view(B * T, C, H, W) # (B*T, 1, 128, 128)
113
+ x_true_flat = x_true.view(B * T, C, H, W) # (B*T, 1, 128, 128)
114
+
115
+ # L1 Loss
116
+ l1_loss = F.l1_loss(x_rec, x_true)
117
+
118
+ # MS-SSIM Loss
119
+ ms_ssim_val = ms_ssim(x_rec_flat, x_true_flat)
120
+ ms_ssim_loss = 1 - ms_ssim_val
121
+
122
+ # Combined loss
123
+ total_loss = l1_weight * l1_loss + ms_ssim_weight * ms_ssim_loss
124
+
125
+ return total_loss, {
126
+ "l1_loss": l1_loss.item(),
127
+ "ms_ssim_loss": ms_ssim_loss.item(),
128
+ "ms_ssim_value": ms_ssim_val.item()
129
+ }
130
+
131
+
132
+ def temporal_smoothness_loss(z_seq, weight=0.1):
133
+ """
134
+ Temporal smoothness loss: encourages similar latents for adjacent timesteps
135
+ Args:
136
+ z_seq: (B, T, C, H, W) - latent sequence
137
+ weight: loss weight
138
+ """
139
+ if z_seq.size(1) < 2:
140
+ return torch.tensor(0.0, device=z_seq.device)
141
+
142
+ # Compute difference between adjacent timesteps
143
+ diff = z_seq[:, 1:] - z_seq[:, :-1] # (B, T-1, C, H, W)
144
+ smooth_loss = (diff ** 2).mean()
145
+
146
+ return weight * smooth_loss
147
+
148
+
149
+ def classification_loss(logits, labels, criterion=None):
150
+ """
151
+ Classification loss
152
+ Args:
153
+ logits: (B, num_classes) - classification logits
154
+ labels: (B,) - ground truth labels
155
+ criterion: loss function, default CrossEntropyLoss
156
+ """
157
+ if criterion is None:
158
+ criterion = nn.CrossEntropyLoss()
159
+
160
+ return criterion(logits, labels)
161
+