Leacb4 commited on
Commit
dd11813
·
verified ·
1 Parent(s): 1994248

Upload hierarchy_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. hierarchy_model.py +338 -14
hierarchy_model.py CHANGED
@@ -1,3 +1,11 @@
 
 
 
 
 
 
 
 
1
  import pandas as pd
2
  import torch
3
  import torch.nn as nn
@@ -17,7 +25,22 @@ import config
17
  # -------------------------
18
 
19
  class HierarchyDataset(Dataset):
 
 
 
 
 
 
 
20
  def __init__(self, dataframe, use_local_images=True, image_size=224):
 
 
 
 
 
 
 
 
21
  self.dataframe = dataframe
22
  self.use_local_images = use_local_images
23
  self.image_size = image_size
@@ -51,13 +74,28 @@ class HierarchyDataset(Dataset):
51
 
52
 
53
  def set_training_mode(self, training=True):
54
- """Switch between training and validation transforms"""
 
 
 
 
 
55
  self.training_mode = training
56
 
57
  def __len__(self):
 
58
  return len(self.dataframe)
59
 
60
  def __getitem__(self, idx):
 
 
 
 
 
 
 
 
 
61
  row = self.dataframe.iloc[idx]
62
 
63
  # Try to load local image first
@@ -83,7 +121,15 @@ class HierarchyDataset(Dataset):
83
  return image, description, hierarchy
84
 
85
  def _download_image(self, img_url):
86
- """Download an image from a URL with timeout"""
 
 
 
 
 
 
 
 
87
  response = requests.get(img_url, timeout=10)
88
  response.raise_for_status()
89
  image = Image.open(BytesIO(response.content)).convert("RGB")
@@ -94,9 +140,21 @@ class HierarchyDataset(Dataset):
94
  # -------------------------
95
 
96
  class HierarchyExtractor:
97
- """Extract hierarchy directly from text using matching"""
 
 
 
 
 
98
 
99
  def __init__(self, hierarchy_classes, verbose=False):
 
 
 
 
 
 
 
100
  self.hierarchy_classes = sorted(hierarchy_classes)
101
  self.class_to_idx = {cls: idx for idx, cls in enumerate(self.hierarchy_classes)}
102
  self.idx_to_class = {idx: cls for idx, cls in enumerate(self.hierarchy_classes)}
@@ -109,7 +167,15 @@ class HierarchyExtractor:
109
  print(f"📋 Classes: {self.hierarchy_classes}")
110
 
111
  def _create_patterns(self):
112
- """Create regex patterns for each hierarchy"""
 
 
 
 
 
 
 
 
113
  patterns = {}
114
 
115
  for hierarchy in self.hierarchy_classes:
@@ -162,7 +228,15 @@ class HierarchyExtractor:
162
  return patterns
163
 
164
  def extract_hierarchy(self, text):
165
- """Extract hierarchy from text using pattern matching"""
 
 
 
 
 
 
 
 
166
  text_lower = text.lower()
167
 
168
  # Try exact match first
@@ -179,14 +253,31 @@ class HierarchyExtractor:
179
  return None
180
 
181
  def extract_hierarchy_idx(self, text):
182
- """Extract hierarchy index from text"""
 
 
 
 
 
 
 
 
183
  hierarchy = self.extract_hierarchy(text)
184
  if hierarchy:
185
  return self.class_to_idx[hierarchy]
186
  return None
187
 
188
  def get_hierarchy_embedding(self, text, embed_dim=config.hierarchy_emb_dim):
189
- """Create embedding from hierarchy index"""
 
 
 
 
 
 
 
 
 
190
  hierarchy_idx = self.extract_hierarchy_idx(text)
191
  if hierarchy_idx is not None:
192
  # Create one-hot encoding
@@ -206,7 +297,21 @@ class HierarchyExtractor:
206
  # -------------------------
207
 
208
  class PretrainedImageEncoder(nn.Module):
 
 
 
 
 
 
 
209
  def __init__(self, embed_dim, dropout=0.3):
 
 
 
 
 
 
 
210
  super().__init__()
211
 
212
  self.backbone = models.resnet18(pretrained=True)
@@ -230,7 +335,12 @@ class PretrainedImageEncoder(nn.Module):
230
  self._freeze_backbone_layers()
231
 
232
  def _freeze_backbone_layers(self):
233
- """Freeze early layers to prevent overfitting"""
 
 
 
 
 
234
  if hasattr(self.backbone, 'children'):
235
  layers = list(self.backbone.children())
236
  freeze_until = int(len(layers) * 0.7)
@@ -240,13 +350,35 @@ class PretrainedImageEncoder(nn.Module):
240
  param.requires_grad = False
241
 
242
  def forward(self, x):
 
 
 
 
 
 
 
 
 
243
  features = self.backbone(x)
244
  return self.projection(features)
245
 
246
  class HierarchyEncoder(nn.Module):
247
- """Encoder that takes hierarchy index directly"""
 
 
 
 
 
248
 
249
  def __init__(self, num_hierarchies, embed_dim, dropout=0.3):
 
 
 
 
 
 
 
 
250
  super().__init__()
251
  self.num_hierarchies = num_hierarchies
252
  self.embed_dim = embed_dim
@@ -267,7 +399,9 @@ class HierarchyEncoder(nn.Module):
267
  self._init_weights()
268
 
269
  def _init_weights(self):
270
- """Initialize weights properly"""
 
 
271
  nn.init.xavier_uniform_(self.embedding.weight)
272
  for module in self.projection.modules():
273
  if isinstance(module, nn.Linear):
@@ -276,6 +410,19 @@ class HierarchyEncoder(nn.Module):
276
  nn.init.zeros_(module.bias)
277
 
278
  def forward(self, hierarchy_indices):
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  # hierarchy_indices: (B,) - batch of hierarchy indices
280
  # Workaround for MPS: embedding layers don't work well with MPS, so do lookup on CPU
281
  device = next(self.parameters()).device
@@ -293,7 +440,23 @@ class HierarchyEncoder(nn.Module):
293
  return self.projection(emb)
294
 
295
  class HierarchyClassifierHead(nn.Module):
 
 
 
 
 
 
 
296
  def __init__(self, in_dim, num_classes, hidden_dim=None, dropout=0.3):
 
 
 
 
 
 
 
 
 
297
  super().__init__()
298
  if hidden_dim is None:
299
  hidden_dim = max(in_dim // 2, num_classes * 2)
@@ -309,10 +472,34 @@ class HierarchyClassifierHead(nn.Module):
309
  )
310
 
311
  def forward(self, x):
 
 
 
 
 
 
 
 
 
312
  return self.classifier(x)
313
 
314
  class Model(nn.Module):
 
 
 
 
 
 
 
315
  def __init__(self, num_hierarchy_classes, embed_dim, dropout=0.3):
 
 
 
 
 
 
 
 
316
  super().__init__()
317
  self.img_enc = PretrainedImageEncoder(embed_dim, dropout)
318
  self.hierarchy_enc = HierarchyEncoder(num_hierarchy_classes, embed_dim, dropout)
@@ -321,6 +508,20 @@ class Model(nn.Module):
321
  self.num_hierarchy_classes = num_hierarchy_classes
322
 
323
  def forward(self, image=None, hierarchy_indices=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  out = {}
325
  if image is not None:
326
  z_img = self.img_enc(image)
@@ -339,11 +540,27 @@ class Model(nn.Module):
339
  return out
340
 
341
  def set_hierarchy_extractor(self, hierarchy_extractor):
342
- """Set the hierarchy extractor for text processing"""
 
 
 
 
 
343
  self.hierarchy_extractor = hierarchy_extractor
344
 
345
  def get_text_embeddings(self, text):
346
- """Get text embeddings for a given text string or list of strings"""
 
 
 
 
 
 
 
 
 
 
 
347
 
348
  with torch.no_grad():
349
  # Get the device of the model
@@ -387,7 +604,18 @@ class Model(nn.Module):
387
  raise ValueError(f"Expected string or list/tuple of strings, got {type(text)}: {text}")
388
 
389
  def get_image_embeddings(self, image):
390
- """Get image embeddings for a given image tensor"""
 
 
 
 
 
 
 
 
 
 
 
391
  with torch.no_grad():
392
  if not isinstance(image, torch.Tensor):
393
  raise ValueError("Image must be a torch.Tensor")
@@ -410,9 +638,27 @@ class Model(nn.Module):
410
  # -------------------------
411
 
412
  class Loss(nn.Module):
 
 
 
 
 
 
 
413
  def __init__(self, hierarchy_classes, classification_weight=1.0,
414
  consistency_weight=0.3, contrastive_weight=0.2,
415
  temperature=0.07, label_smoothing=0.1):
 
 
 
 
 
 
 
 
 
 
 
416
  super().__init__()
417
  self.classification_weight = classification_weight
418
  self.consistency_weight = consistency_weight
@@ -428,7 +674,16 @@ class Loss(nn.Module):
428
  self.mse = nn.MSELoss()
429
 
430
  def contrastive_loss(self, img_emb, txt_emb):
431
- """InfoNCE contrastive loss"""
 
 
 
 
 
 
 
 
 
432
  sim_matrix = torch.matmul(img_emb, txt_emb.T) / self.temperature
433
  labels = torch.arange(img_emb.size(0), device=img_emb.device)
434
 
@@ -438,6 +693,19 @@ class Loss(nn.Module):
438
  return (loss_i2t + loss_t2i) / 2
439
 
440
  def forward(self, img_logits, txt_logits, img_embeddings, txt_embeddings, target_hierarchies):
 
 
 
 
 
 
 
 
 
 
 
 
 
441
  device = img_embeddings.device
442
 
443
  # Convert hierarchy names to indices
@@ -467,6 +735,19 @@ class Loss(nn.Module):
467
  # -------------------------
468
 
469
  def collate_fn(batch, hierarchy_extractor):
 
 
 
 
 
 
 
 
 
 
 
 
 
470
  images = torch.stack([b[0] for b in batch], dim=0)
471
  texts = [b[1] for b in batch]
472
  hierarchies = [b[2] for b in batch]
@@ -492,6 +773,17 @@ def collate_fn(batch, hierarchy_extractor):
492
  }
493
 
494
  def calculate_accuracy(logits, target_hierarchies, hierarchy_classes):
 
 
 
 
 
 
 
 
 
 
 
495
  batch_size = logits.size(0)
496
  correct = 0
497
  pred_indices = torch.argmax(logits, dim=1).cpu().numpy()
@@ -505,6 +797,23 @@ def calculate_accuracy(logits, target_hierarchies, hierarchy_classes):
505
  return correct / batch_size
506
 
507
  def train_one_epoch(model, dataloader, optimizer, device, hierarchy_classes, scheduler=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
  model.train()
509
  total_loss = 0.0
510
  total_acc_img = 0.0
@@ -570,6 +879,21 @@ def train_one_epoch(model, dataloader, optimizer, device, hierarchy_classes, sch
570
  }
571
 
572
  def validate(model, dataloader, device, hierarchy_classes):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
573
  model.eval()
574
  total_loss = 0.0
575
  total_acc_img = 0.0
 
1
+ """
2
+ Hierarchy model for learning clothing category-aligned embeddings.
3
+ This file contains the hierarchy model that learns to encode images and texts
4
+ in an embedding space specialized for representing clothing categories (dress, shirt, etc.).
5
+ It includes a regex pattern-based hierarchy extractor, a ResNet image encoder,
6
+ a hierarchy embedding encoder, and loss functions for training.
7
+ """
8
+
9
  import pandas as pd
10
  import torch
11
  import torch.nn as nn
 
25
  # -------------------------
26
 
27
  class HierarchyDataset(Dataset):
28
+ """
29
+ Dataset class for hierarchy embedding training.
30
+
31
+ Handles loading images from local paths or URLs, extracting hierarchy information
32
+ from text descriptions, and applying appropriate transformations for training.
33
+ """
34
+
35
  def __init__(self, dataframe, use_local_images=True, image_size=224):
36
+ """
37
+ Initialize the hierarchy dataset.
38
+
39
+ Args:
40
+ dataframe: DataFrame with columns for image paths/URLs, text descriptions, and hierarchy labels
41
+ use_local_images: Whether to prefer local images over URLs (default: True)
42
+ image_size: Size of images after resizing (default: 224)
43
+ """
44
  self.dataframe = dataframe
45
  self.use_local_images = use_local_images
46
  self.image_size = image_size
 
74
 
75
 
76
  def set_training_mode(self, training=True):
77
+ """
78
+ Switch between training and validation transforms.
79
+
80
+ Args:
81
+ training: If True, use training transforms with augmentation; if False, use validation transforms
82
+ """
83
  self.training_mode = training
84
 
85
  def __len__(self):
86
+ """Return the number of samples in the dataset."""
87
  return len(self.dataframe)
88
 
89
  def __getitem__(self, idx):
90
+ """
91
+ Get a sample from the dataset.
92
+
93
+ Args:
94
+ idx: Index of the sample
95
+
96
+ Returns:
97
+ Tuple of (image_tensor, description_text, hierarchy_label)
98
+ """
99
  row = self.dataframe.iloc[idx]
100
 
101
  # Try to load local image first
 
121
  return image, description, hierarchy
122
 
123
  def _download_image(self, img_url):
124
+ """
125
+ Download an image from a URL with timeout.
126
+
127
+ Args:
128
+ img_url: URL of the image to download
129
+
130
+ Returns:
131
+ PIL Image object
132
+ """
133
  response = requests.get(img_url, timeout=10)
134
  response.raise_for_status()
135
  image = Image.open(BytesIO(response.content)).convert("RGB")
 
140
  # -------------------------
141
 
142
  class HierarchyExtractor:
143
+ """
144
+ Extract hierarchy categories directly from text using pattern matching.
145
+
146
+ This class uses regex patterns to identify clothing categories (e.g., shirt, dress)
147
+ from text descriptions, handling variations, plurals, and common fashion terms.
148
+ """
149
 
150
  def __init__(self, hierarchy_classes, verbose=False):
151
+ """
152
+ Initialize the hierarchy extractor.
153
+
154
+ Args:
155
+ hierarchy_classes: List of hierarchy class names
156
+ verbose: Whether to print initialization information (default: False)
157
+ """
158
  self.hierarchy_classes = sorted(hierarchy_classes)
159
  self.class_to_idx = {cls: idx for idx, cls in enumerate(self.hierarchy_classes)}
160
  self.idx_to_class = {idx: cls for idx, cls in enumerate(self.hierarchy_classes)}
 
167
  print(f"📋 Classes: {self.hierarchy_classes}")
168
 
169
  def _create_patterns(self):
170
+ """
171
+ Create regex patterns for each hierarchy class.
172
+
173
+ Creates patterns that match variations, plurals, and common fashion terms
174
+ for each hierarchy class.
175
+
176
+ Returns:
177
+ Dictionary mapping hierarchy classes to regex patterns
178
+ """
179
  patterns = {}
180
 
181
  for hierarchy in self.hierarchy_classes:
 
228
  return patterns
229
 
230
  def extract_hierarchy(self, text):
231
+ """
232
+ Extract hierarchy category from text using pattern matching.
233
+
234
+ Args:
235
+ text: Input text string
236
+
237
+ Returns:
238
+ Hierarchy class name if found, None otherwise
239
+ """
240
  text_lower = text.lower()
241
 
242
  # Try exact match first
 
253
  return None
254
 
255
  def extract_hierarchy_idx(self, text):
256
+ """
257
+ Extract hierarchy index from text.
258
+
259
+ Args:
260
+ text: Input text string
261
+
262
+ Returns:
263
+ Hierarchy index if found, None otherwise
264
+ """
265
  hierarchy = self.extract_hierarchy(text)
266
  if hierarchy:
267
  return self.class_to_idx[hierarchy]
268
  return None
269
 
270
  def get_hierarchy_embedding(self, text, embed_dim=config.hierarchy_emb_dim):
271
+ """
272
+ Create embedding from hierarchy index extracted from text.
273
+
274
+ Args:
275
+ text: Input text string
276
+ embed_dim: Dimension of the embedding (default: hierarchy_emb_dim)
277
+
278
+ Returns:
279
+ Embedding tensor of shape (embed_dim,)
280
+ """
281
  hierarchy_idx = self.extract_hierarchy_idx(text)
282
  if hierarchy_idx is not None:
283
  # Create one-hot encoding
 
297
  # -------------------------
298
 
299
  class PretrainedImageEncoder(nn.Module):
300
+ """
301
+ Image encoder based on pretrained ResNet18 for extracting image embeddings.
302
+
303
+ Uses a pretrained ResNet18 backbone and freezes early layers to prevent overfitting.
304
+ Adds a custom projection head to output embeddings of the specified dimension.
305
+ """
306
+
307
  def __init__(self, embed_dim, dropout=0.3):
308
+ """
309
+ Initialize the pretrained image encoder.
310
+
311
+ Args:
312
+ embed_dim: Dimension of the output embedding
313
+ dropout: Dropout rate for regularization (default: 0.3)
314
+ """
315
  super().__init__()
316
 
317
  self.backbone = models.resnet18(pretrained=True)
 
335
  self._freeze_backbone_layers()
336
 
337
  def _freeze_backbone_layers(self):
338
+ """
339
+ Freeze early layers to prevent overfitting.
340
+
341
+ Freezes the first 70% of backbone layers, allowing only the last layers
342
+ to be fine-tuned during training.
343
+ """
344
  if hasattr(self.backbone, 'children'):
345
  layers = list(self.backbone.children())
346
  freeze_until = int(len(layers) * 0.7)
 
350
  param.requires_grad = False
351
 
352
  def forward(self, x):
353
+ """
354
+ Forward pass through the image encoder.
355
+
356
+ Args:
357
+ x: Image tensor [batch_size, channels, height, width]
358
+
359
+ Returns:
360
+ Image embeddings [batch_size, embed_dim]
361
+ """
362
  features = self.backbone(x)
363
  return self.projection(features)
364
 
365
  class HierarchyEncoder(nn.Module):
366
+ """
367
+ Encoder that takes hierarchy indices directly.
368
+
369
+ Uses an embedding layer to convert hierarchy indices to embeddings,
370
+ followed by a projection head to output embeddings of the specified dimension.
371
+ """
372
 
373
  def __init__(self, num_hierarchies, embed_dim, dropout=0.3):
374
+ """
375
+ Initialize the hierarchy encoder.
376
+
377
+ Args:
378
+ num_hierarchies: Number of hierarchy classes
379
+ embed_dim: Dimension of the output embedding
380
+ dropout: Dropout rate for regularization (default: 0.3)
381
+ """
382
  super().__init__()
383
  self.num_hierarchies = num_hierarchies
384
  self.embed_dim = embed_dim
 
399
  self._init_weights()
400
 
401
  def _init_weights(self):
402
+ """
403
+ Initialize weights properly using Xavier uniform initialization.
404
+ """
405
  nn.init.xavier_uniform_(self.embedding.weight)
406
  for module in self.projection.modules():
407
  if isinstance(module, nn.Linear):
 
410
  nn.init.zeros_(module.bias)
411
 
412
  def forward(self, hierarchy_indices):
413
+ """
414
+ Forward pass through the hierarchy encoder.
415
+
416
+ Args:
417
+ hierarchy_indices: Tensor of hierarchy indices [batch_size]
418
+
419
+ Returns:
420
+ Hierarchy embeddings [batch_size, embed_dim]
421
+
422
+ Note:
423
+ Includes workaround for MPS device: embedding layers don't work well with MPS,
424
+ so embedding lookup is done on CPU and results are moved back to device.
425
+ """
426
  # hierarchy_indices: (B,) - batch of hierarchy indices
427
  # Workaround for MPS: embedding layers don't work well with MPS, so do lookup on CPU
428
  device = next(self.parameters()).device
 
440
  return self.projection(emb)
441
 
442
  class HierarchyClassifierHead(nn.Module):
443
+ """
444
+ Classifier head for hierarchy classification.
445
+
446
+ Multi-layer perceptron that takes embeddings as input and outputs
447
+ classification logits for hierarchy classes.
448
+ """
449
+
450
  def __init__(self, in_dim, num_classes, hidden_dim=None, dropout=0.3):
451
+ """
452
+ Initialize the hierarchy classifier head.
453
+
454
+ Args:
455
+ in_dim: Input embedding dimension
456
+ num_classes: Number of hierarchy classes
457
+ hidden_dim: Hidden layer dimension (default: max(in_dim // 2, num_classes * 2))
458
+ dropout: Dropout rate for regularization (default: 0.3)
459
+ """
460
  super().__init__()
461
  if hidden_dim is None:
462
  hidden_dim = max(in_dim // 2, num_classes * 2)
 
472
  )
473
 
474
  def forward(self, x):
475
+ """
476
+ Forward pass through the classifier head.
477
+
478
+ Args:
479
+ x: Input embeddings [batch_size, in_dim]
480
+
481
+ Returns:
482
+ Classification logits [batch_size, num_classes]
483
+ """
484
  return self.classifier(x)
485
 
486
  class Model(nn.Module):
487
+ """
488
+ Main hierarchy model for learning clothing category-aligned embeddings.
489
+
490
+ Combines image encoder, hierarchy encoder, and classifier heads to learn
491
+ aligned embeddings for images and text descriptions based on clothing categories.
492
+ """
493
+
494
  def __init__(self, num_hierarchy_classes, embed_dim, dropout=0.3):
495
+ """
496
+ Initialize the hierarchy model.
497
+
498
+ Args:
499
+ num_hierarchy_classes: Number of hierarchy classes
500
+ embed_dim: Dimension of the embedding space
501
+ dropout: Dropout rate for regularization (default: 0.3)
502
+ """
503
  super().__init__()
504
  self.img_enc = PretrainedImageEncoder(embed_dim, dropout)
505
  self.hierarchy_enc = HierarchyEncoder(num_hierarchy_classes, embed_dim, dropout)
 
508
  self.num_hierarchy_classes = num_hierarchy_classes
509
 
510
  def forward(self, image=None, hierarchy_indices=None):
511
+ """
512
+ Forward pass through the model.
513
+
514
+ Args:
515
+ image: Optional image tensor [batch_size, channels, height, width]
516
+ hierarchy_indices: Optional hierarchy indices tensor [batch_size]
517
+
518
+ Returns:
519
+ Dictionary containing:
520
+ - 'z_img': Image embeddings [batch_size, embed_dim] (if image provided)
521
+ - 'z_txt': Text embeddings [batch_size, embed_dim] (if hierarchy_indices provided)
522
+ - 'hierarchy_logits_img': Image classification logits [batch_size, num_classes] (if image provided)
523
+ - 'hierarchy_logits_txt': Text classification logits [batch_size, num_classes] (if hierarchy_indices provided)
524
+ """
525
  out = {}
526
  if image is not None:
527
  z_img = self.img_enc(image)
 
540
  return out
541
 
542
  def set_hierarchy_extractor(self, hierarchy_extractor):
543
+ """
544
+ Set the hierarchy extractor for text processing.
545
+
546
+ Args:
547
+ hierarchy_extractor: HierarchyExtractor instance
548
+ """
549
  self.hierarchy_extractor = hierarchy_extractor
550
 
551
  def get_text_embeddings(self, text):
552
+ """
553
+ Get text embeddings for a given text string or list of strings.
554
+
555
+ Args:
556
+ text: Text string or list of text strings
557
+
558
+ Returns:
559
+ Text embeddings tensor [batch_size, embed_dim]
560
+
561
+ Raises:
562
+ ValueError: If hierarchy cannot be extracted from text
563
+ """
564
 
565
  with torch.no_grad():
566
  # Get the device of the model
 
604
  raise ValueError(f"Expected string or list/tuple of strings, got {type(text)}: {text}")
605
 
606
  def get_image_embeddings(self, image):
607
+ """
608
+ Get image embeddings for a given image tensor.
609
+
610
+ Args:
611
+ image: Image tensor [channels, height, width] or [batch_size, channels, height, width]
612
+
613
+ Returns:
614
+ Image embeddings tensor [batch_size, embed_dim]
615
+
616
+ Raises:
617
+ ValueError: If image is not a torch.Tensor
618
+ """
619
  with torch.no_grad():
620
  if not isinstance(image, torch.Tensor):
621
  raise ValueError("Image must be a torch.Tensor")
 
638
  # -------------------------
639
 
640
  class Loss(nn.Module):
641
+ """
642
+ Combined loss function for hierarchy model training.
643
+
644
+ Combines classification loss, contrastive loss, and consistency loss
645
+ to learn aligned embeddings while maintaining classification accuracy.
646
+ """
647
+
648
  def __init__(self, hierarchy_classes, classification_weight=1.0,
649
  consistency_weight=0.3, contrastive_weight=0.2,
650
  temperature=0.07, label_smoothing=0.1):
651
+ """
652
+ Initialize the loss function.
653
+
654
+ Args:
655
+ hierarchy_classes: List of hierarchy class names
656
+ classification_weight: Weight for classification loss (default: 1.0)
657
+ consistency_weight: Weight for consistency loss (default: 0.3)
658
+ contrastive_weight: Weight for contrastive loss (default: 0.2)
659
+ temperature: Temperature scaling for contrastive loss (default: 0.07)
660
+ label_smoothing: Label smoothing parameter (default: 0.1)
661
+ """
662
  super().__init__()
663
  self.classification_weight = classification_weight
664
  self.consistency_weight = consistency_weight
 
674
  self.mse = nn.MSELoss()
675
 
676
  def contrastive_loss(self, img_emb, txt_emb):
677
+ """
678
+ InfoNCE contrastive loss for aligning image and text embeddings.
679
+
680
+ Args:
681
+ img_emb: Image embeddings [batch_size, embed_dim]
682
+ txt_emb: Text embeddings [batch_size, embed_dim]
683
+
684
+ Returns:
685
+ Contrastive loss value
686
+ """
687
  sim_matrix = torch.matmul(img_emb, txt_emb.T) / self.temperature
688
  labels = torch.arange(img_emb.size(0), device=img_emb.device)
689
 
 
693
  return (loss_i2t + loss_t2i) / 2
694
 
695
  def forward(self, img_logits, txt_logits, img_embeddings, txt_embeddings, target_hierarchies):
696
+ """
697
+ Forward pass through the loss function.
698
+
699
+ Args:
700
+ img_logits: Image classification logits [batch_size, num_classes]
701
+ txt_logits: Text classification logits [batch_size, num_classes]
702
+ img_embeddings: Image embeddings [batch_size, embed_dim]
703
+ txt_embeddings: Text embeddings [batch_size, embed_dim]
704
+ target_hierarchies: List of target hierarchy class names [batch_size]
705
+
706
+ Returns:
707
+ Combined loss value
708
+ """
709
  device = img_embeddings.device
710
 
711
  # Convert hierarchy names to indices
 
735
  # -------------------------
736
 
737
  def collate_fn(batch, hierarchy_extractor):
738
+ """
739
+ Collate function for DataLoader that processes batches and extracts hierarchy indices.
740
+
741
+ Args:
742
+ batch: List of (image, description, hierarchy) tuples
743
+ hierarchy_extractor: HierarchyExtractor instance
744
+
745
+ Returns:
746
+ Dictionary containing:
747
+ - 'image': Stacked image tensors [batch_size, channels, height, width]
748
+ - 'hierarchy_indices': Hierarchy indices tensor [batch_size]
749
+ - hierarchy_column: List of hierarchy class names [batch_size]
750
+ """
751
  images = torch.stack([b[0] for b in batch], dim=0)
752
  texts = [b[1] for b in batch]
753
  hierarchies = [b[2] for b in batch]
 
773
  }
774
 
775
  def calculate_accuracy(logits, target_hierarchies, hierarchy_classes):
776
+ """
777
+ Calculate classification accuracy.
778
+
779
+ Args:
780
+ logits: Classification logits [batch_size, num_classes]
781
+ target_hierarchies: List of target hierarchy class names [batch_size]
782
+ hierarchy_classes: List of hierarchy class names
783
+
784
+ Returns:
785
+ Accuracy score (float between 0 and 1)
786
+ """
787
  batch_size = logits.size(0)
788
  correct = 0
789
  pred_indices = torch.argmax(logits, dim=1).cpu().numpy()
 
797
  return correct / batch_size
798
 
799
  def train_one_epoch(model, dataloader, optimizer, device, hierarchy_classes, scheduler=None):
800
+ """
801
+ Train the model for one epoch.
802
+
803
+ Args:
804
+ model: Model instance to train
805
+ dataloader: DataLoader for training data
806
+ optimizer: Optimizer instance
807
+ device: Device to train on
808
+ hierarchy_classes: List of hierarchy class names
809
+ scheduler: Optional learning rate scheduler
810
+
811
+ Returns:
812
+ Dictionary containing training metrics:
813
+ - 'loss': Average training loss
814
+ - 'acc_img': Average image classification accuracy
815
+ - 'acc_txt': Average text classification accuracy
816
+ """
817
  model.train()
818
  total_loss = 0.0
819
  total_acc_img = 0.0
 
879
  }
880
 
881
  def validate(model, dataloader, device, hierarchy_classes):
882
+ """
883
+ Validate the model on validation data.
884
+
885
+ Args:
886
+ model: Model instance to validate
887
+ dataloader: DataLoader for validation data
888
+ device: Device to validate on
889
+ hierarchy_classes: List of hierarchy class names
890
+
891
+ Returns:
892
+ Dictionary containing validation metrics:
893
+ - 'loss': Average validation loss
894
+ - 'acc_img': Average image classification accuracy
895
+ - 'acc_txt': Average text classification accuracy
896
+ """
897
  model.eval()
898
  total_loss = 0.0
899
  total_acc_img = 0.0