fxxkingusername commited on
Commit
3fc91dd
·
verified ·
1 Parent(s): d2aee5b

Upload src/training\losses.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/training//losses.py +386 -0
src/training//losses.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Loss functions for architectural style classification.
3
+ Includes hierarchical loss, contrastive loss, and style relationship loss.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from typing import Dict, List, Optional, Tuple
10
+ import numpy as np
11
+
12
+
13
+ class HierarchicalLoss(nn.Module):
14
+ """Hierarchical loss for ensuring consistency between broad and fine-grained classifications."""
15
+
16
+ def __init__(self, alpha: float = 0.5, beta: float = 0.3):
17
+ super().__init__()
18
+ self.alpha = alpha # Weight for broad classification loss
19
+ self.beta = beta # Weight for consistency loss
20
+
21
+ # Style hierarchy mapping
22
+ self.style_hierarchy = {
23
+ 0: [0, 1, 2, 3, 4], # Ancient
24
+ 1: [5, 6, 7, 8, 9], # Medieval
25
+ 2: [10, 11, 12, 13, 14], # Renaissance
26
+ 3: [15, 16, 17, 18, 19], # Modern
27
+ 4: [20, 21, 22, 23, 24] # Contemporary
28
+ }
29
+
30
+ self.broad_to_fine = self._create_broad_to_fine_mapping()
31
+
32
+ def _create_broad_to_fine_mapping(self) -> Dict[int, int]:
33
+ """Create mapping from fine-grained classes to broad classes."""
34
+ mapping = {}
35
+ for broad_class, fine_classes in self.style_hierarchy.items():
36
+ for fine_class in fine_classes:
37
+ mapping[fine_class] = broad_class
38
+ return mapping
39
+
40
+ def forward(self, broad_logits: torch.Tensor, fine_logits: torch.Tensor,
41
+ targets: torch.Tensor) -> torch.Tensor:
42
+ """Compute hierarchical loss."""
43
+ batch_size = targets.size(0)
44
+
45
+ # Convert fine-grained targets to broad targets
46
+ broad_targets = torch.tensor([
47
+ self.broad_to_fine[target.item()] for target in targets
48
+ ], device=targets.device)
49
+
50
+ # Broad classification loss
51
+ broad_loss = F.cross_entropy(broad_logits, broad_targets)
52
+
53
+ # Fine-grained classification loss
54
+ fine_loss = F.cross_entropy(fine_logits, targets)
55
+
56
+ # Consistency loss: ensure fine-grained predictions are consistent with broad predictions
57
+ broad_probs = F.softmax(broad_logits, dim=1)
58
+ fine_probs = F.softmax(fine_logits, dim=1)
59
+
60
+ consistency_loss = self._compute_consistency_loss(broad_probs, fine_probs, targets)
61
+
62
+ # Total hierarchical loss
63
+ total_loss = fine_loss + self.alpha * broad_loss + self.beta * consistency_loss
64
+
65
+ return total_loss
66
+
67
+ def _compute_consistency_loss(self, broad_probs: torch.Tensor,
68
+ fine_probs: torch.Tensor,
69
+ targets: torch.Tensor) -> torch.Tensor:
70
+ """Compute consistency loss between broad and fine predictions."""
71
+ batch_size = targets.size(0)
72
+ consistency_loss = 0.0
73
+
74
+ for i in range(batch_size):
75
+ target = targets[i].item()
76
+ broad_class = self.broad_to_fine[target]
77
+
78
+ # Get fine-grained probabilities for the correct broad category
79
+ fine_in_broad = self.style_hierarchy[broad_class]
80
+ fine_probs_in_broad = fine_probs[i, fine_in_broad]
81
+
82
+ # Get broad probability for the correct category
83
+ broad_prob = broad_probs[i, broad_class]
84
+
85
+ # Consistency: fine-grained probabilities should sum to broad probability
86
+ consistency_loss += F.mse_loss(
87
+ fine_probs_in_broad.sum(),
88
+ broad_prob
89
+ )
90
+
91
+ return consistency_loss / batch_size
92
+
93
+
94
+ class ContrastiveLoss(nn.Module):
95
+ """Contrastive loss for learning better feature representations."""
96
+
97
+ def __init__(self, temperature: float = 0.07, margin: float = 1.0):
98
+ super().__init__()
99
+ self.temperature = temperature
100
+ self.margin = margin
101
+
102
+ def forward(self, projections: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
103
+ """Compute contrastive loss."""
104
+ # Normalize projections
105
+ projections = F.normalize(projections, dim=1)
106
+
107
+ # Compute similarity matrix
108
+ similarity_matrix = torch.matmul(projections, projections.t()) / self.temperature
109
+
110
+ # Create positive and negative masks
111
+ batch_size = targets.size(0)
112
+ targets_expanded = targets.unsqueeze(1).expand(-1, batch_size)
113
+ positive_mask = (targets_expanded == targets_expanded.t()).float()
114
+ negative_mask = 1 - positive_mask
115
+
116
+ # Remove self-similarity
117
+ positive_mask.fill_diagonal_(0)
118
+
119
+ # Compute positive and negative similarities
120
+ positive_similarities = similarity_matrix * positive_mask
121
+ negative_similarities = similarity_matrix * negative_mask
122
+
123
+ # Find hardest negative for each positive
124
+ hardest_negative_similarities = negative_similarities.max(dim=1)[0]
125
+
126
+ # Compute contrastive loss
127
+ positive_similarities = positive_similarities.sum(dim=1)
128
+ num_positives = positive_mask.sum(dim=1)
129
+
130
+ # Avoid division by zero
131
+ num_positives = torch.clamp(num_positives, min=1)
132
+ positive_similarities = positive_similarities / num_positives
133
+
134
+ # Contrastive loss
135
+ loss = F.relu(self.margin - positive_similarities + hardest_negative_similarities)
136
+
137
+ return loss.mean()
138
+
139
+
140
+ class StyleRelationshipLoss(nn.Module):
141
+ """Loss for modeling relationships between architectural styles."""
142
+
143
+ def __init__(self, relationship_weight: float = 0.1):
144
+ super().__init__()
145
+ self.relationship_weight = relationship_weight
146
+
147
+ # Define style relationships (simplified)
148
+ self.style_relationships = self._initialize_style_relationships()
149
+
150
+ def _initialize_style_relationships(self) -> torch.Tensor:
151
+ """Initialize style relationship matrix."""
152
+ num_styles = 25
153
+ relationships = torch.zeros(num_styles, num_styles)
154
+
155
+ # Same period relationships
156
+ periods = [
157
+ list(range(0, 5)), # Ancient
158
+ list(range(5, 10)), # Medieval
159
+ list(range(10, 15)), # Renaissance
160
+ list(range(15, 20)), # Modern
161
+ list(range(20, 25)) # Contemporary
162
+ ]
163
+
164
+ for period in periods:
165
+ for i in period:
166
+ for j in period:
167
+ if i != j:
168
+ relationships[i, j] = 0.8 # High similarity within period
169
+
170
+ # Cross-period relationships (evolutionary)
171
+ cross_periods = [
172
+ (list(range(0, 5)), list(range(5, 10))), # Ancient -> Medieval
173
+ (list(range(5, 10)), list(range(10, 15))), # Medieval -> Renaissance
174
+ (list(range(10, 15)), list(range(15, 20))), # Renaissance -> Modern
175
+ (list(range(15, 20)), list(range(20, 25))) # Modern -> Contemporary
176
+ ]
177
+
178
+ for prev_period, next_period in cross_periods:
179
+ for i in prev_period:
180
+ for j in next_period:
181
+ relationships[i, j] = 0.3 # Medium similarity across periods
182
+
183
+ return relationships
184
+
185
+ def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
186
+ """Compute style relationship loss."""
187
+ batch_size = targets.size(0)
188
+
189
+ # Get predicted probabilities
190
+ probs = F.softmax(logits, dim=1)
191
+
192
+ # Compute relationship loss
193
+ relationship_loss = 0.0
194
+
195
+ for i in range(batch_size):
196
+ target = targets[i].item()
197
+
198
+ # Get relationship scores for the target style
199
+ target_relationships = self.style_relationships[target]
200
+
201
+ # Compute expected vs actual similarities
202
+ for j in range(batch_size):
203
+ if i != j:
204
+ other_target = targets[j].item()
205
+ expected_similarity = target_relationships[other_target]
206
+
207
+ # Compute actual similarity between predictions
208
+ actual_similarity = F.cosine_similarity(
209
+ probs[i].unsqueeze(0),
210
+ probs[j].unsqueeze(0)
211
+ )
212
+
213
+ # Relationship loss
214
+ relationship_loss += F.mse_loss(
215
+ actual_similarity,
216
+ torch.tensor(expected_similarity, device=logits.device)
217
+ )
218
+
219
+ # Normalize by number of pairs
220
+ num_pairs = batch_size * (batch_size - 1)
221
+ relationship_loss = relationship_loss / num_pairs if num_pairs > 0 else 0
222
+
223
+ return self.relationship_weight * relationship_loss
224
+
225
+
226
+ class MultiStyleLoss(nn.Module):
227
+ """Loss for multi-style detection and classification."""
228
+
229
+ def __init__(self, mixture_weight: float = 0.2):
230
+ super().__init__()
231
+ self.mixture_weight = mixture_weight
232
+ self.bce_loss = nn.BCELoss()
233
+ self.ce_loss = nn.CrossEntropyLoss()
234
+
235
+ def forward(self, style_probs: torch.Tensor, mixture_prob: torch.Tensor,
236
+ targets: torch.Tensor, is_mixture: torch.Tensor) -> torch.Tensor:
237
+ """Compute multi-style loss."""
238
+ batch_size = targets.size(0)
239
+
240
+ # Style classification loss
241
+ style_loss = self.ce_loss(style_probs, targets)
242
+
243
+ # Mixture detection loss
244
+ mixture_loss = self.bce_loss(mixture_prob, is_mixture.float())
245
+
246
+ # Multi-label loss for mixtures
247
+ multi_label_loss = 0.0
248
+ for i in range(batch_size):
249
+ if is_mixture[i]:
250
+ # For mixtures, encourage multiple style predictions
251
+ target_probs = style_probs[i]
252
+ # Encourage diversity in predictions
253
+ entropy = -torch.sum(target_probs * torch.log(target_probs + 1e-8))
254
+ multi_label_loss += -entropy # Maximize entropy for mixtures
255
+
256
+ multi_label_loss = multi_label_loss / batch_size if batch_size > 0 else 0
257
+
258
+ # Total loss
259
+ total_loss = style_loss + self.mixture_weight * mixture_loss + 0.1 * multi_label_loss
260
+
261
+ return total_loss
262
+
263
+
264
+ class FocalLoss(nn.Module):
265
+ """Focal loss for handling class imbalance."""
266
+
267
+ def __init__(self, alpha: float = 1.0, gamma: float = 2.0):
268
+ super().__init__()
269
+ self.alpha = alpha
270
+ self.gamma = gamma
271
+
272
+ def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
273
+ """Compute focal loss."""
274
+ ce_loss = F.cross_entropy(logits, targets, reduction='none')
275
+ pt = torch.exp(-ce_loss)
276
+ focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
277
+ return focal_loss.mean()
278
+
279
+
280
+ class LabelSmoothingLoss(nn.Module):
281
+ """Label smoothing loss for better generalization."""
282
+
283
+ def __init__(self, smoothing: float = 0.1, num_classes: int = 25):
284
+ super().__init__()
285
+ self.smoothing = smoothing
286
+ self.num_classes = num_classes
287
+
288
+ def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
289
+ """Compute label smoothing loss."""
290
+ # Create smoothed labels
291
+ batch_size = targets.size(0)
292
+ smoothed_labels = torch.zeros(batch_size, self.num_classes, device=logits.device)
293
+ smoothed_labels.fill_(self.smoothing / (self.num_classes - 1))
294
+ smoothed_labels.scatter_(1, targets.unsqueeze(1), 1 - self.smoothing)
295
+
296
+ # Compute loss
297
+ log_probs = F.log_softmax(logits, dim=1)
298
+ loss = -torch.sum(smoothed_labels * log_probs, dim=1)
299
+ return loss.mean()
300
+
301
+
302
+ class CombinedLoss(nn.Module):
303
+ """Combined loss function with multiple components."""
304
+
305
+ def __init__(self,
306
+ use_hierarchical: bool = True,
307
+ use_contrastive: bool = False,
308
+ use_style_relationship: bool = True,
309
+ use_focal: bool = False,
310
+ use_label_smoothing: bool = True,
311
+ weights: Dict[str, float] = None):
312
+ super().__init__()
313
+
314
+ self.use_hierarchical = use_hierarchical
315
+ self.use_contrastive = use_contrastive
316
+ self.use_style_relationship = use_style_relationship
317
+ self.use_focal = use_focal
318
+ self.use_label_smoothing = use_label_smoothing
319
+
320
+ # Initialize loss functions
321
+ self.hierarchical_loss = HierarchicalLoss() if use_hierarchical else None
322
+ self.contrastive_loss = ContrastiveLoss() if use_contrastive else None
323
+ self.style_relationship_loss = StyleRelationshipLoss() if use_style_relationship else None
324
+ self.focal_loss = FocalLoss() if use_focal else None
325
+ self.label_smoothing_loss = LabelSmoothingLoss() if use_label_smoothing else None
326
+ self.ce_loss = nn.CrossEntropyLoss()
327
+
328
+ # Loss weights
329
+ self.weights = weights or {
330
+ 'ce': 1.0,
331
+ 'hierarchical': 0.5,
332
+ 'contrastive': 0.1,
333
+ 'style_relationship': 0.1,
334
+ 'focal': 1.0,
335
+ 'label_smoothing': 1.0
336
+ }
337
+
338
+ def forward(self, outputs: Dict[str, torch.Tensor],
339
+ targets: torch.Tensor) -> Dict[str, torch.Tensor]:
340
+ """Compute combined loss."""
341
+ total_loss = 0.0
342
+ loss_dict = {}
343
+
344
+ # Classification loss
345
+ if 'fine_logits' in outputs:
346
+ if self.use_focal:
347
+ ce_loss = self.focal_loss(outputs['fine_logits'], targets)
348
+ elif self.use_label_smoothing:
349
+ ce_loss = self.label_smoothing_loss(outputs['fine_logits'], targets)
350
+ else:
351
+ ce_loss = self.ce_loss(outputs['fine_logits'], targets)
352
+
353
+ total_loss += self.weights['ce'] * ce_loss
354
+ loss_dict['ce_loss'] = ce_loss
355
+
356
+ # Hierarchical loss
357
+ if self.use_hierarchical and self.hierarchical_loss and 'broad_logits' in outputs:
358
+ hierarchical_loss = self.hierarchical_loss(
359
+ outputs['broad_logits'],
360
+ outputs['fine_logits'],
361
+ targets
362
+ )
363
+ total_loss += self.weights['hierarchical'] * hierarchical_loss
364
+ loss_dict['hierarchical_loss'] = hierarchical_loss
365
+
366
+ # Style relationship loss
367
+ if self.use_style_relationship and self.style_relationship_loss and 'fine_logits' in outputs:
368
+ relationship_loss = self.style_relationship_loss(
369
+ outputs['fine_logits'],
370
+ targets
371
+ )
372
+ total_loss += self.weights['style_relationship'] * relationship_loss
373
+ loss_dict['style_relationship_loss'] = relationship_loss
374
+
375
+ # Contrastive loss
376
+ if self.use_contrastive and self.contrastive_loss and 'projections' in outputs:
377
+ contrastive_loss = self.contrastive_loss(
378
+ outputs['projections'],
379
+ targets
380
+ )
381
+ total_loss += self.weights['contrastive'] * contrastive_loss
382
+ loss_dict['contrastive_loss'] = contrastive_loss
383
+
384
+ loss_dict['total_loss'] = total_loss
385
+
386
+ return loss_dict