Leacb4 commited on
Commit
ec5c397
·
verified ·
1 Parent(s): fc411a2

Upload training/hierarchy_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training/hierarchy_model.py +451 -780
training/hierarchy_model.py CHANGED
@@ -1,9 +1,20 @@
1
  """
2
  Hierarchy model for learning clothing category-aligned embeddings.
3
- This file contains the hierarchy model that learns to encode images and texts
4
- in an embedding space specialized for representing clothing categories (dress, shirt, etc.).
5
- It includes a regex pattern-based hierarchy extractor, a ResNet image encoder,
6
- a hierarchy embedding encoder, and loss functions for training.
 
 
 
 
 
 
 
 
 
 
 
7
  """
8
 
9
  import pandas as pd
@@ -11,146 +22,28 @@ import torch
11
  import torch.nn as nn
12
  import torch.nn.functional as F
13
  from torch.utils.data import Dataset, DataLoader
14
- from torchvision import transforms, models
15
  from PIL import Image
16
  from tqdm import tqdm
17
  from sklearn.model_selection import train_test_split
18
  import re
19
- import requests
20
- from io import BytesIO
21
  import config
22
 
23
  # -------------------------
24
- # 1) Dataset
25
- # -------------------------
26
-
27
- class HierarchyDataset(Dataset):
28
- """
29
- Dataset class for hierarchy embedding training.
30
-
31
- Handles loading images from local paths or URLs, extracting hierarchy information
32
- from text descriptions, and applying appropriate transformations for training.
33
- """
34
-
35
- def __init__(self, dataframe, use_local_images=True, image_size=224):
36
- """
37
- Initialize the hierarchy dataset.
38
-
39
- Args:
40
- dataframe: DataFrame with columns for image paths/URLs, text descriptions, and hierarchy labels
41
- use_local_images: Whether to prefer local images over URLs (default: True)
42
- image_size: Size of images after resizing (default: 224)
43
- """
44
- self.dataframe = dataframe
45
- self.use_local_images = use_local_images
46
- self.image_size = image_size
47
-
48
- # transforms with data augmentation for training
49
- self.transform = transforms.Compose([
50
- transforms.Resize((image_size, image_size)),
51
- transforms.RandomHorizontalFlip(p=0.3),
52
- transforms.RandomRotation(10),
53
- transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
54
- transforms.ToTensor(),
55
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
56
- ])
57
-
58
- # Validation transforms (no augmentation)
59
- self.val_transform = transforms.Compose([
60
- transforms.Resize((image_size, image_size)),
61
- transforms.ToTensor(),
62
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
63
- ])
64
-
65
- # Check local image availability
66
- if use_local_images:
67
- if config.column_local_image_path not in dataframe.columns:
68
- print(f"⚠️ Column {config.column_local_image_path} not found. Using URLs.")
69
- self.use_local_images = False
70
- else:
71
- local_available = dataframe[config.column_local_image_path].notna().sum()
72
- total = len(dataframe)
73
- print(f"📁 Local images available: {local_available}/{total} ({local_available/total*100:.1f}%)")
74
-
75
-
76
- def set_training_mode(self, training=True):
77
- """
78
- Switch between training and validation transforms.
79
-
80
- Args:
81
- training: If True, use training transforms with augmentation; if False, use validation transforms
82
- """
83
- self.training_mode = training
84
-
85
- def __len__(self):
86
- """Return the number of samples in the dataset."""
87
- return len(self.dataframe)
88
-
89
- def __getitem__(self, idx):
90
- """
91
- Get a sample from the dataset.
92
-
93
- Args:
94
- idx: Index of the sample
95
-
96
- Returns:
97
- Tuple of (image_tensor, description_text, hierarchy_label)
98
- """
99
- row = self.dataframe.iloc[idx]
100
-
101
- # Try to load local image first
102
- if self.use_local_images and pd.notna(row.get(config.column_local_image_path, '')):
103
- local_path = row[config.column_local_image_path]
104
- image = Image.open(local_path).convert("RGB")
105
- # Check if image is a dictionary of bytes
106
- elif isinstance(row[config.column_url_image], dict):
107
- image = Image.open(BytesIO(row[config.column_url_image]['bytes'])).convert('RGB')
108
- # Otherwise, try to download from URL
109
- else:
110
- image = self._download_image(row[config.column_url_image])
111
-
112
- # Apply transforms
113
- if hasattr(self, 'training_mode') and not self.training_mode:
114
- image = self.val_transform(image)
115
- else:
116
- image = self.transform(image)
117
-
118
- description = row[config.text_column]
119
- hierarchy = row[config.hierarchy_column]
120
-
121
- return image, description, hierarchy
122
-
123
- def _download_image(self, img_url):
124
- """
125
- Download an image from a URL with timeout.
126
-
127
- Args:
128
- img_url: URL of the image to download
129
-
130
- Returns:
131
- PIL Image object
132
- """
133
- response = requests.get(img_url, timeout=10)
134
- response.raise_for_status()
135
- image = Image.open(BytesIO(response.content)).convert("RGB")
136
- return image
137
-
138
- # -------------------------
139
- # 2) Hierarchy Extractor
140
  # -------------------------
141
 
142
  class HierarchyExtractor:
143
  """
144
  Extract hierarchy categories directly from text using pattern matching.
145
-
146
  This class uses regex patterns to identify clothing categories (e.g., shirt, dress)
147
  from text descriptions, handling variations, plurals, and common fashion terms.
148
  """
149
-
150
  def __init__(self, hierarchy_classes, verbose=False):
151
  """
152
  Initialize the hierarchy extractor.
153
-
154
  Args:
155
  hierarchy_classes: List of hierarchy class names
156
  verbose: Whether to print initialization information (default: False)
@@ -158,39 +51,39 @@ class HierarchyExtractor:
158
  self.hierarchy_classes = sorted(hierarchy_classes)
159
  self.class_to_idx = {cls: idx for idx, cls in enumerate(self.hierarchy_classes)}
160
  self.idx_to_class = {idx: cls for idx, cls in enumerate(self.hierarchy_classes)}
161
-
162
  # Create patterns for each hierarchy
163
  self.patterns = self._create_patterns()
164
-
165
  if verbose:
166
- print(f"🎯 Hierarchy extractor initialized with {len(self.hierarchy_classes)} classes")
167
- print(f"📋 Classes: {self.hierarchy_classes}")
168
-
169
  def _create_patterns(self):
170
  """
171
  Create regex patterns for each hierarchy class.
172
-
173
  Creates patterns that match variations, plurals, and common fashion terms
174
  for each hierarchy class.
175
-
176
  Returns:
177
  Dictionary mapping hierarchy classes to regex patterns
178
  """
179
  patterns = {}
180
-
181
  for hierarchy in self.hierarchy_classes:
182
  # Create variations of the hierarchy name
183
  variations = [hierarchy.lower()]
184
-
185
  # Add common variations
186
  if '-' in hierarchy:
187
  variations.append(hierarchy.replace('-', ' '))
188
  variations.append(hierarchy.replace('-', ''))
189
-
190
  # Add plural forms
191
  if not hierarchy.endswith('s'):
192
  variations.append(hierarchy + 's')
193
-
194
  # Add common fashion terms
195
  fashion_terms = {
196
  'shirt': ['shirt', 'shirts', 'tee', 't-shirt', 'tshirt'],
@@ -215,50 +108,50 @@ class HierarchyExtractor:
215
  'glove': ['glove', 'gloves'],
216
  'sandal': ['sandal', 'sandals']
217
  }
218
-
219
  # Add fashion terms if hierarchy matches
220
  for key, terms in fashion_terms.items():
221
  if key in hierarchy.lower():
222
  variations.extend(terms)
223
-
224
  # Create regex pattern
225
  pattern = r'\b(' + '|'.join(re.escape(v) for v in variations) + r')\b'
226
  patterns[hierarchy] = pattern
227
-
228
  return patterns
229
-
230
  def extract_hierarchy(self, text):
231
  """
232
  Extract hierarchy category from text using pattern matching.
233
-
234
  Args:
235
  text: Input text string
236
-
237
  Returns:
238
  Hierarchy class name if found, None otherwise
239
  """
240
  text_lower = text.lower()
241
-
242
  # Try exact match first
243
  for hierarchy in self.hierarchy_classes:
244
  if hierarchy.lower() in text_lower:
245
  return hierarchy
246
-
247
  # Try pattern matching
248
  for hierarchy, pattern in self.patterns.items():
249
  if re.search(pattern, text_lower):
250
  return hierarchy
251
-
252
- # If no match found, return the most common hierarchy or None
253
  return None
254
-
255
  def extract_hierarchy_idx(self, text):
256
  """
257
  Extract hierarchy index from text.
258
-
259
  Args:
260
  text: Input text string
261
-
262
  Returns:
263
  Hierarchy index if found, None otherwise
264
  """
@@ -266,15 +159,15 @@ class HierarchyExtractor:
266
  if hierarchy:
267
  return self.class_to_idx[hierarchy]
268
  return None
269
-
270
  def get_hierarchy_embedding(self, text, embed_dim=config.hierarchy_emb_dim):
271
  """
272
  Create embedding from hierarchy index extracted from text.
273
-
274
  Args:
275
  text: Input text string
276
  embed_dim: Dimension of the embedding (default: hierarchy_emb_dim)
277
-
278
  Returns:
279
  Embedding tensor of shape (embed_dim,)
280
  """
@@ -293,164 +186,21 @@ class HierarchyExtractor:
293
  return torch.zeros(embed_dim)
294
 
295
  # -------------------------
296
- # 3) Models
297
  # -------------------------
298
 
299
- class PretrainedImageEncoder(nn.Module):
300
- """
301
- Image encoder based on pretrained ResNet18 for extracting image embeddings.
302
-
303
- Uses a pretrained ResNet18 backbone and freezes early layers to prevent overfitting.
304
- Adds a custom projection head to output embeddings of the specified dimension.
305
- """
306
-
307
- def __init__(self, embed_dim, dropout=0.3):
308
- """
309
- Initialize the pretrained image encoder.
310
-
311
- Args:
312
- embed_dim: Dimension of the output embedding
313
- dropout: Dropout rate for regularization (default: 0.3)
314
- """
315
- super().__init__()
316
-
317
- self.backbone = models.resnet18(pretrained=True)
318
- backbone_dim = 512
319
-
320
- # Remove the final classification layer
321
- self.backbone = nn.Sequential(*list(self.backbone.children())[:-1])
322
-
323
- # Add custom projection head
324
- self.projection = nn.Sequential(
325
- nn.Flatten(),
326
- nn.Dropout(dropout),
327
- nn.Linear(backbone_dim, embed_dim * 2),
328
- nn.ReLU(inplace=True),
329
- nn.Dropout(dropout * 0.5),
330
- nn.Linear(embed_dim * 2, embed_dim),
331
- nn.LayerNorm(embed_dim)
332
- )
333
-
334
- # Fine-tune only the last few layers
335
- self._freeze_backbone_layers()
336
-
337
- def _freeze_backbone_layers(self):
338
- """
339
- Freeze early layers to prevent overfitting.
340
-
341
- Freezes the first 70% of backbone layers, allowing only the last layers
342
- to be fine-tuned during training.
343
- """
344
- if hasattr(self.backbone, 'children'):
345
- layers = list(self.backbone.children())
346
- freeze_until = int(len(layers) * 0.7)
347
- for i, layer in enumerate(layers):
348
- if i < freeze_until:
349
- for param in layer.parameters():
350
- param.requires_grad = False
351
-
352
- def forward(self, x):
353
- """
354
- Forward pass through the image encoder.
355
-
356
- Args:
357
- x: Image tensor [batch_size, channels, height, width]
358
-
359
- Returns:
360
- Image embeddings [batch_size, embed_dim]
361
- """
362
- features = self.backbone(x)
363
- return self.projection(features)
364
-
365
- class HierarchyEncoder(nn.Module):
366
- """
367
- Encoder that takes hierarchy indices directly.
368
-
369
- Uses an embedding layer to convert hierarchy indices to embeddings,
370
- followed by a projection head to output embeddings of the specified dimension.
371
- """
372
-
373
- def __init__(self, num_hierarchies, embed_dim, dropout=0.3):
374
- """
375
- Initialize the hierarchy encoder.
376
-
377
- Args:
378
- num_hierarchies: Number of hierarchy classes
379
- embed_dim: Dimension of the output embedding
380
- dropout: Dropout rate for regularization (default: 0.3)
381
- """
382
- super().__init__()
383
- self.num_hierarchies = num_hierarchies
384
- self.embed_dim = embed_dim
385
-
386
- # Embedding layer
387
- self.embedding = nn.Embedding(num_hierarchies, embed_dim)
388
-
389
- # Projection layer
390
- self.projection = nn.Sequential(
391
- nn.Linear(embed_dim, embed_dim * 2),
392
- nn.ReLU(inplace=True),
393
- nn.Dropout(dropout),
394
- nn.Linear(embed_dim * 2, embed_dim),
395
- nn.LayerNorm(embed_dim)
396
- )
397
-
398
- # Initialize weights
399
- self._init_weights()
400
-
401
- def _init_weights(self):
402
- """
403
- Initialize weights properly using Xavier uniform initialization.
404
- """
405
- nn.init.xavier_uniform_(self.embedding.weight)
406
- for module in self.projection.modules():
407
- if isinstance(module, nn.Linear):
408
- nn.init.xavier_uniform_(module.weight)
409
- if module.bias is not None:
410
- nn.init.zeros_(module.bias)
411
-
412
- def forward(self, hierarchy_indices):
413
- """
414
- Forward pass through the hierarchy encoder.
415
-
416
- Args:
417
- hierarchy_indices: Tensor of hierarchy indices [batch_size]
418
-
419
- Returns:
420
- Hierarchy embeddings [batch_size, embed_dim]
421
-
422
- Note:
423
- Includes workaround for MPS device: embedding layers don't work well with MPS,
424
- so embedding lookup is done on CPU and results are moved back to device.
425
- """
426
- # hierarchy_indices: (B,) - batch of hierarchy indices
427
- # Workaround for MPS: embedding layers don't work well with MPS, so do lookup on CPU
428
- device = next(self.parameters()).device
429
- if device.type == 'mps':
430
- # Move indices to CPU for embedding lookup
431
- indices_cpu = hierarchy_indices.cpu()
432
- # Use functional embedding with explicit weight handling for MPS compatibility
433
- emb_weight = self.embedding.weight.cpu()
434
- emb = F.embedding(indices_cpu, emb_weight)
435
- # Move result back to model device (MPS) - ensure it's contiguous
436
- emb = emb.contiguous().to(device)
437
- else:
438
- emb = self.embedding(hierarchy_indices)
439
- # Ensure emb is on the same device as projection before calling it
440
- return self.projection(emb)
441
-
442
  class HierarchyClassifierHead(nn.Module):
443
  """
444
  Classifier head for hierarchy classification.
445
-
446
  Multi-layer perceptron that takes embeddings as input and outputs
447
  classification logits for hierarchy classes.
448
  """
449
-
450
  def __init__(self, in_dim, num_classes, hidden_dim=None, dropout=0.3):
451
  """
452
  Initialize the hierarchy classifier head.
453
-
454
  Args:
455
  in_dim: Input embedding dimension
456
  num_classes: Number of hierarchy classes
@@ -460,7 +210,7 @@ class HierarchyClassifierHead(nn.Module):
460
  super().__init__()
461
  if hidden_dim is None:
462
  hidden_dim = max(in_dim // 2, num_classes * 2)
463
-
464
  self.classifier = nn.Sequential(
465
  nn.Linear(in_dim, hidden_dim),
466
  nn.ReLU(inplace=True),
@@ -470,168 +220,222 @@ class HierarchyClassifierHead(nn.Module):
470
  nn.Dropout(dropout * 0.5),
471
  nn.Linear(hidden_dim // 2, num_classes)
472
  )
473
-
474
  def forward(self, x):
475
  """
476
  Forward pass through the classifier head.
477
-
478
  Args:
479
  x: Input embeddings [batch_size, in_dim]
480
-
481
  Returns:
482
  Classification logits [batch_size, num_classes]
483
  """
484
  return self.classifier(x)
485
 
486
- class Model(nn.Module):
 
487
  """
488
- Main hierarchy model for learning clothing category-aligned embeddings.
489
-
490
- Combines image encoder, hierarchy encoder, and classifier heads to learn
491
- aligned embeddings for images and text descriptions based on clothing categories.
492
  """
493
-
494
- def __init__(self, num_hierarchy_classes, embed_dim, dropout=0.3):
495
- """
496
- Initialize the hierarchy model.
497
-
498
- Args:
499
- num_hierarchy_classes: Number of hierarchy classes
500
- embed_dim: Dimension of the embedding space
501
- dropout: Dropout rate for regularization (default: 0.3)
502
- """
503
  super().__init__()
504
- self.img_enc = PretrainedImageEncoder(embed_dim, dropout)
505
- self.hierarchy_enc = HierarchyEncoder(num_hierarchy_classes, embed_dim, dropout)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
  self.hierarchy_head_img = HierarchyClassifierHead(embed_dim, num_hierarchy_classes, dropout=dropout)
507
  self.hierarchy_head_txt = HierarchyClassifierHead(embed_dim, num_hierarchy_classes, dropout=dropout)
508
- self.num_hierarchy_classes = num_hierarchy_classes
509
 
510
- def forward(self, image=None, hierarchy_indices=None):
511
- """
512
- Forward pass through the model.
513
-
514
- Args:
515
- image: Optional image tensor [batch_size, channels, height, width]
516
- hierarchy_indices: Optional hierarchy indices tensor [batch_size]
517
-
518
- Returns:
519
- Dictionary containing:
520
- - 'z_img': Image embeddings [batch_size, embed_dim] (if image provided)
521
- - 'z_txt': Text embeddings [batch_size, embed_dim] (if hierarchy_indices provided)
522
- - 'hierarchy_logits_img': Image classification logits [batch_size, num_classes] (if image provided)
523
- - 'hierarchy_logits_txt': Text classification logits [batch_size, num_classes] (if hierarchy_indices provided)
524
- """
525
- out = {}
526
- if image is not None:
527
- z_img = self.img_enc(image)
528
- z_img = F.normalize(z_img, p=2, dim=1)
529
- hierarchy_logits_img = self.hierarchy_head_img(z_img)
530
- out['hierarchy_logits_img'] = hierarchy_logits_img
531
- out['z_img'] = z_img
532
-
533
- if hierarchy_indices is not None:
534
- z_txt = self.hierarchy_enc(hierarchy_indices)
535
- z_txt = F.normalize(z_txt, p=2, dim=1)
536
- hierarchy_logits_txt = self.hierarchy_head_txt(z_txt)
537
- out['hierarchy_logits_txt'] = hierarchy_logits_txt
538
- out['z_txt'] = z_txt
539
-
540
- return out
541
-
542
  def set_hierarchy_extractor(self, hierarchy_extractor):
543
- """
544
- Set the hierarchy extractor for text processing.
545
-
546
- Args:
547
- hierarchy_extractor: HierarchyExtractor instance
548
- """
549
  self.hierarchy_extractor = hierarchy_extractor
550
-
551
- def get_text_embeddings(self, text):
552
- """
553
- Get text embeddings for a given text string or list of strings.
554
-
555
- Args:
556
- text: Text string or list of text strings
557
-
558
- Returns:
559
- Text embeddings tensor [batch_size, embed_dim]
560
-
561
- Raises:
562
- ValueError: If hierarchy cannot be extracted from text
563
- """
564
-
565
  with torch.no_grad():
566
- # Get the device of the model
567
- model_device = next(self.parameters()).device
568
-
569
- # Handle case where text is a list/tuple of hierarchies
570
- if isinstance(text, (list, tuple)):
571
- # Process multiple hierarchies
572
- hierarchy_indices = []
573
- for hierarchy_text in text:
574
- if isinstance(hierarchy_text, str):
575
- hierarchy_idx = self.hierarchy_extractor.extract_hierarchy_idx(hierarchy_text)
576
- if hierarchy_idx is None:
577
- raise ValueError(f"Could not extract hierarchy for text: '{hierarchy_text}'. Available classes: {self.hierarchy_extractor.hierarchy_classes}")
578
- hierarchy_indices.append(hierarchy_idx)
579
- else:
580
- raise ValueError(f"Expected string, got {type(hierarchy_text)}: {hierarchy_text}")
581
-
582
- # Convert to tensor and move to device
583
- hierarchy_indices = torch.tensor(hierarchy_indices, device=model_device)
584
-
585
- # Get text embeddings for all hierarchies
586
- output = self.forward(hierarchy_indices=hierarchy_indices)
587
- return output['z_txt']
588
-
589
- # Handle single string case
590
- elif isinstance(text, str):
591
- # Extract hierarchy index from text
592
- hierarchy_idx = self.hierarchy_extractor.extract_hierarchy_idx(text)
593
- if hierarchy_idx is None:
594
- raise ValueError(f"Could not extract hierarchy for text: '{text}'. Available classes: {self.hierarchy_extractor.hierarchy_classes}")
595
-
596
- # Convert to tensor and move to device
597
- hierarchy_indices = torch.tensor([hierarchy_idx], device=model_device)
598
-
599
- # Get text embeddings
600
- output = self.forward(hierarchy_indices=hierarchy_indices)
601
- return output['z_txt']
602
-
603
- else:
604
- raise ValueError(f"Expected string or list/tuple of strings, got {type(text)}: {text}")
605
-
606
- def get_image_embeddings(self, image):
607
- """
608
- Get image embeddings for a given image tensor.
609
-
610
- Args:
611
- image: Image tensor [channels, height, width] or [batch_size, channels, height, width]
612
-
613
- Returns:
614
- Image embeddings tensor [batch_size, embed_dim]
615
-
616
- Raises:
617
- ValueError: If image is not a torch.Tensor
618
- """
619
  with torch.no_grad():
620
- if not isinstance(image, torch.Tensor):
621
- raise ValueError("Image must be a torch.Tensor")
622
-
623
- # Ensure image is on the same device as model
624
- device = next(self.parameters()).device
625
- if image.device != device:
626
- image = image.to(device)
627
-
628
- # Add batch dimension if needed
629
- if image.dim() == 3:
630
- image = image.unsqueeze(0)
631
-
632
- # Get image embeddings
633
- output = self.forward(image=image)
634
- return output['z_img']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
635
 
636
  # -------------------------
637
  # 4) Loss functions
@@ -640,17 +444,17 @@ class Model(nn.Module):
640
  class Loss(nn.Module):
641
  """
642
  Combined loss function for hierarchy model training.
643
-
644
  Combines classification loss, contrastive loss, and consistency loss
645
  to learn aligned embeddings while maintaining classification accuracy.
646
  """
647
-
648
- def __init__(self, hierarchy_classes, classification_weight=1.0,
649
- consistency_weight=0.3, contrastive_weight=0.2,
650
  temperature=0.07, label_smoothing=0.1):
651
  """
652
  Initialize the loss function.
653
-
654
  Args:
655
  hierarchy_classes: List of hierarchy class names
656
  classification_weight: Weight for classification loss (default: 1.0)
@@ -664,422 +468,289 @@ class Loss(nn.Module):
664
  self.consistency_weight = consistency_weight
665
  self.contrastive_weight = contrastive_weight
666
  self.temperature = temperature
667
-
668
  self.hierarchy_classes = sorted(list(set(hierarchy_classes)))
669
  self.num_classes = len(self.hierarchy_classes)
670
  self.class_to_idx = {cls: idx for idx, cls in enumerate(self.hierarchy_classes)}
671
-
672
  # Loss functions with label smoothing
673
  self.ce = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
674
  self.mse = nn.MSELoss()
675
-
676
  def contrastive_loss(self, img_emb, txt_emb):
677
  """
678
  InfoNCE contrastive loss for aligning image and text embeddings.
679
-
680
  Args:
681
  img_emb: Image embeddings [batch_size, embed_dim]
682
  txt_emb: Text embeddings [batch_size, embed_dim]
683
-
684
  Returns:
685
  Contrastive loss value
686
  """
687
  sim_matrix = torch.matmul(img_emb, txt_emb.T) / self.temperature
688
  labels = torch.arange(img_emb.size(0), device=img_emb.device)
689
-
690
  loss_i2t = F.cross_entropy(sim_matrix, labels)
691
  loss_t2i = F.cross_entropy(sim_matrix.T, labels)
692
-
693
  return (loss_i2t + loss_t2i) / 2
694
-
695
  def forward(self, img_logits, txt_logits, img_embeddings, txt_embeddings, target_hierarchies):
696
  """
697
  Forward pass through the loss function.
698
-
699
  Args:
700
  img_logits: Image classification logits [batch_size, num_classes]
701
  txt_logits: Text classification logits [batch_size, num_classes]
702
  img_embeddings: Image embeddings [batch_size, embed_dim]
703
  txt_embeddings: Text embeddings [batch_size, embed_dim]
704
  target_hierarchies: List of target hierarchy class names [batch_size]
705
-
706
  Returns:
707
  Combined loss value
708
  """
709
  device = img_embeddings.device
710
-
711
  # Convert hierarchy names to indices
712
  target_classes = torch.tensor([
713
  self.class_to_idx.get(hierarchy, 0) for hierarchy in target_hierarchies
714
  ], device=device)
715
-
716
  # 1. Classification loss
717
- classification_loss = (self.ce(img_logits, target_classes) +
718
  self.ce(txt_logits, target_classes)) / 2
719
-
720
  # 2. Contrastive loss for alignment
721
  contrastive_loss = self.contrastive_loss(img_embeddings, txt_embeddings)
722
-
723
  # 3. Consistency loss between modalities
724
  consistency_loss = self.mse(img_embeddings, txt_embeddings)
725
-
726
  # Combined loss
727
  total_loss = (self.classification_weight * classification_loss +
728
  self.contrastive_weight * contrastive_loss +
729
  self.consistency_weight * consistency_loss)
730
-
731
  return total_loss
732
 
733
  # -------------------------
734
- # 5) Training
735
  # -------------------------
736
 
737
- def collate_fn(batch, hierarchy_extractor):
738
- """
739
- Collate function for DataLoader that processes batches and extracts hierarchy indices.
740
-
741
- Args:
742
- batch: List of (image, description, hierarchy) tuples
743
- hierarchy_extractor: HierarchyExtractor instance
744
-
745
- Returns:
746
- Dictionary containing:
747
- - 'image': Stacked image tensors [batch_size, channels, height, width]
748
- - 'hierarchy_indices': Hierarchy indices tensor [batch_size]
749
- - hierarchy_column: List of hierarchy class names [batch_size]
750
- """
751
- images = torch.stack([b[0] for b in batch], dim=0)
752
- texts = [b[1] for b in batch]
753
- hierarchies = [b[2] for b in batch]
754
-
755
- # Extract hierarchy indices from texts
756
- hierarchy_indices = []
757
- for text in texts:
758
- idx = hierarchy_extractor.extract_hierarchy_idx(text)
759
- if idx is not None:
760
- hierarchy_indices.append(idx)
761
- else:
762
- # If no hierarchy found, use the target hierarchy
763
- target_hierarchy = hierarchies[len(hierarchy_indices)]
764
- idx = hierarchy_extractor.class_to_idx.get(target_hierarchy, 0)
765
- hierarchy_indices.append(idx)
766
-
767
- hierarchy_indices = torch.tensor(hierarchy_indices, dtype=torch.long)
768
-
769
- return {
770
- 'image': images,
771
- 'hierarchy_indices': hierarchy_indices,
772
- config.hierarchy_column: hierarchies
773
- }
774
-
775
  def calculate_accuracy(logits, target_hierarchies, hierarchy_classes):
776
  """
777
  Calculate classification accuracy.
778
-
779
  Args:
780
  logits: Classification logits [batch_size, num_classes]
781
  target_hierarchies: List of target hierarchy class names [batch_size]
782
  hierarchy_classes: List of hierarchy class names
783
-
784
  Returns:
785
  Accuracy score (float between 0 and 1)
786
  """
787
  batch_size = logits.size(0)
788
  correct = 0
789
  pred_indices = torch.argmax(logits, dim=1).cpu().numpy()
790
-
791
  for i in range(batch_size):
792
  pred_class = hierarchy_classes[pred_indices[i]] if pred_indices[i] < len(hierarchy_classes) else ""
793
  target_class = target_hierarchies[i]
794
  if pred_class == target_class:
795
  correct += 1
796
-
797
- return correct / batch_size
798
 
799
- def train_one_epoch(model, dataloader, optimizer, device, hierarchy_classes, scheduler=None):
800
- """
801
- Train the model for one epoch.
802
-
803
- Args:
804
- model: Model instance to train
805
- dataloader: DataLoader for training data
806
- optimizer: Optimizer instance
807
- device: Device to train on
808
- hierarchy_classes: List of hierarchy class names
809
- scheduler: Optional learning rate scheduler
810
-
811
- Returns:
812
- Dictionary containing training metrics:
813
- - 'loss': Average training loss
814
- - 'acc_img': Average image classification accuracy
815
- - 'acc_txt': Average text classification accuracy
816
- """
817
- model.train()
818
- total_loss = 0.0
819
- total_acc_img = 0.0
820
- total_acc_txt = 0.0
821
- num_batches = 0
822
-
823
- loss_fn = Loss(
824
- hierarchy_classes,
825
- classification_weight=1.0,
826
- consistency_weight=0.3,
827
- contrastive_weight=0.2,
828
- label_smoothing=0.1
829
- ).to(device)
830
-
831
- pbar = tqdm(dataloader, desc="Training", leave=False)
832
- for batch in pbar:
833
- images = batch['image'].to(device)
834
- hierarchy_indices = batch['hierarchy_indices'].to(device)
835
- target_hierarchies = batch[config.hierarchy_column]
836
-
837
- # Set dataset to training mode
838
- if hasattr(dataloader.dataset, 'set_training_mode'):
839
- dataloader.dataset.set_training_mode(True)
840
-
841
- out = model(image=images, hierarchy_indices=hierarchy_indices)
842
- hierarchy_logits_img = out['hierarchy_logits_img']
843
- hierarchy_logits_txt = out['hierarchy_logits_txt']
844
- z_img, z_txt = out['z_img'], out['z_txt']
845
-
846
- # Calculate loss
847
- loss = loss_fn(hierarchy_logits_img, hierarchy_logits_txt, z_img, z_txt, target_hierarchies)
848
-
849
- optimizer.zero_grad()
850
- loss.backward()
851
-
852
- # Gradient clipping
853
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
854
-
855
- optimizer.step()
856
-
857
- if scheduler is not None:
858
- scheduler.step()
859
-
860
- # Calculate accuracies
861
- acc_img = calculate_accuracy(hierarchy_logits_img, target_hierarchies, hierarchy_classes)
862
- acc_txt = calculate_accuracy(hierarchy_logits_txt, target_hierarchies, hierarchy_classes)
863
-
864
- total_loss += loss.item()
865
- total_acc_img += acc_img
866
- total_acc_txt += acc_txt
867
- num_batches += 1
868
-
869
- pbar.set_postfix({
870
- 'loss': f'{loss.item():.4f}',
871
- 'acc_img': f'{acc_img:.3f}',
872
- 'acc_txt': f'{acc_txt:.3f}',
873
- })
874
-
875
- return {
876
- 'loss': total_loss / num_batches,
877
- 'acc_img': total_acc_img / num_batches,
878
- 'acc_txt': total_acc_txt / num_batches
879
- }
880
-
881
- def validate(model, dataloader, device, hierarchy_classes):
882
- """
883
- Validate the model on validation data.
884
-
885
- Args:
886
- model: Model instance to validate
887
- dataloader: DataLoader for validation data
888
- device: Device to validate on
889
- hierarchy_classes: List of hierarchy class names
890
-
891
- Returns:
892
- Dictionary containing validation metrics:
893
- - 'loss': Average validation loss
894
- - 'acc_img': Average image classification accuracy
895
- - 'acc_txt': Average text classification accuracy
896
- """
897
- model.eval()
898
- total_loss = 0.0
899
- total_acc_img = 0.0
900
- total_acc_txt = 0.0
901
- num_batches = 0
902
-
903
- loss_fn = Loss(
904
- hierarchy_classes,
905
- classification_weight=1.0,
906
- consistency_weight=0.3,
907
- contrastive_weight=0.2
908
- ).to(device)
909
-
910
- pbar = tqdm(dataloader, desc="Validation", leave=False)
911
- with torch.no_grad():
912
- for batch in pbar:
913
- images = batch['image'].to(device)
914
- hierarchy_indices = batch['hierarchy_indices'].to(device)
915
- target_hierarchies = batch[config.hierarchy_column]
916
-
917
- # Set dataset to validation mode
918
- if hasattr(dataloader.dataset, 'set_training_mode'):
919
- dataloader.dataset.set_training_mode(False)
920
-
921
- out = model(image=images, hierarchy_indices=hierarchy_indices)
922
- hierarchy_logits_img = out['hierarchy_logits_img']
923
- hierarchy_logits_txt = out['hierarchy_logits_txt']
924
- z_img, z_txt = out['z_img'], out['z_txt']
925
-
926
- # Calculate loss
927
- loss = loss_fn(hierarchy_logits_img, hierarchy_logits_txt, z_img, z_txt, target_hierarchies)
928
-
929
- # Calculate accuracies
930
- acc_img = calculate_accuracy(hierarchy_logits_img, target_hierarchies, hierarchy_classes)
931
- acc_txt = calculate_accuracy(hierarchy_logits_txt, target_hierarchies, hierarchy_classes)
932
-
933
- total_loss += loss.item()
934
- total_acc_img += acc_img
935
- total_acc_txt += acc_txt
936
- num_batches += 1
937
-
938
- pbar.set_postfix({
939
- 'loss': f'{loss.item():.4f}',
940
- 'acc_img': f'{acc_img:.3f}',
941
- 'acc_txt': f'{acc_txt:.3f}',
942
- })
943
-
944
- return {
945
- 'loss': total_loss / num_batches,
946
- 'acc_img': total_acc_img / num_batches,
947
- 'acc_txt': total_acc_txt / num_batches
948
- }
949
 
950
  # -------------------------
951
  # 6) Main training script
952
  # -------------------------
953
 
954
- if __name__ == "__main__":
955
- # Configuration
 
 
 
 
 
 
 
956
  device = config.device
957
- batch_size = 16
958
- lr = 5e-5
959
- epochs = 20
960
- val_split = 0.2
961
- dropout = 0.4
962
- weight_decay = 1e-3
963
-
964
- print(f"🚀 Starting hierarchical training on device: {device}")
965
- print(f"📊 Config: {epochs} epochs, batch={batch_size}, lr={lr}, embed_dim={config.hierarchy_emb_dim}")
966
-
967
- # Load dataset
968
- print(f"📁 Using dataset: { config.local_dataset_path}")
969
- df = pd.read_csv(config.local_dataset_path)
970
- print(f"📁 Loaded {len(df)} samples")
971
-
972
- # Get unique hierarchy classes
 
 
 
 
 
 
 
 
 
 
 
 
973
  hierarchy_classes = sorted(df[config.hierarchy_column].unique().tolist())
974
- print(f"📋 Found {len(hierarchy_classes)} hierarchy classes")
975
-
976
- # Create hierarchy extractor
977
- hierarchy_extractor = HierarchyExtractor(hierarchy_classes, verbose=True)
978
-
979
- # Train/validation split
980
  train_df, val_df = train_test_split(
981
- df,
982
- test_size=val_split,
983
- random_state=42,
984
- stratify=df[config.hierarchy_column]
985
  )
986
  train_df = train_df.reset_index(drop=True)
987
  val_df = val_df.reset_index(drop=True)
988
-
989
- print(f"📈 Train: {len(train_df)}, Validation: {len(val_df)}")
990
-
991
- # Create datasets
992
- train_ds = HierarchyDataset(train_df, image_size=224)
993
- val_ds = HierarchyDataset(val_df, image_size=224)
994
-
995
- # Create data loaders
996
- train_dl = DataLoader(
997
- train_ds,
998
- batch_size=batch_size,
999
- shuffle=True,
1000
- collate_fn=lambda batch: collate_fn(batch, hierarchy_extractor)
1001
  )
1002
- val_dl = DataLoader(
1003
- val_ds,
1004
- batch_size=batch_size,
1005
- shuffle=False,
1006
- collate_fn=lambda batch: collate_fn(batch, hierarchy_extractor)
1007
  )
1008
-
1009
- # Create model
1010
- model = Model(
1011
- num_hierarchy_classes=len(hierarchy_classes),
1012
- embed_dim=config.hierarchy_emb_dim,
1013
- dropout=dropout
 
 
 
 
 
 
 
 
 
 
1014
  ).to(device)
1015
-
1016
- # Optimizer and scheduler
1017
- optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
1018
- scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2, eta_min=lr/10)
1019
-
1020
- print(f"🎯 Model parameters: {sum(p.numel() for p in model.parameters()):,}")
1021
- print("\n" + "="*80)
1022
-
1023
- # Training loop
1024
- best_val_loss = float('inf')
1025
- training_history = {'train_loss': [], 'val_loss': [], 'val_acc_img': [], 'val_acc_txt': []}
1026
-
1027
- for e in range(epochs):
1028
- print(f"\n🔄 Epoch {e+1}/{epochs}")
1029
- print("-" * 50)
1030
-
1031
- # Training
1032
- train_metrics = train_one_epoch(model, train_dl, optimizer, device, hierarchy_classes, scheduler)
1033
-
1034
- # Validation
1035
- val_metrics = validate(model, val_dl, device, hierarchy_classes)
1036
-
1037
- # Track history
1038
- training_history['train_loss'].append(train_metrics['loss'])
1039
- training_history['val_loss'].append(val_metrics['loss'])
1040
- training_history['val_acc_img'].append(val_metrics['acc_img'])
1041
- training_history['val_acc_txt'].append(val_metrics['acc_txt'])
1042
-
1043
- # Display results
1044
- print(f"📊 TRAIN - Loss: {train_metrics['loss']:.6f} | "
1045
- f"Img Acc: {train_metrics['acc_img']:.3f} | "
1046
- f"Txt Acc: {train_metrics['acc_txt']:.3f}")
1047
-
1048
- print(f"✅ VAL - Loss: {val_metrics['loss']:.6f} | "
1049
- f"Img Acc: {val_metrics['acc_img']:.3f} | "
1050
- f"Txt Acc: {val_metrics['acc_txt']:.3f}")
1051
-
1052
- # Save best model
1053
- if val_metrics['loss'] < best_val_loss:
1054
- best_val_loss = val_metrics['loss']
1055
- print(f"💾 New best validation loss! Saving model...")
1056
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1057
  torch.save({
1058
- 'model_state': model.state_dict(),
1059
- 'hierarchy_classes': hierarchy_classes,
1060
- 'epoch': e+1,
1061
- 'config': {
1062
- 'embed_dim': config.hierarchy_emb_dim,
1063
- 'dropout': dropout
1064
- }
 
 
1065
  }, config.hierarchy_model_path)
1066
-
1067
- # Save model every 2 epochs
1068
- if (e + 1) % 2 == 0:
1069
- print(f"💾 Saving checkpoint at epoch {e+1}...")
1070
-
1071
- torch.save({
1072
- 'model_state': model.state_dict(),
1073
- 'hierarchy_classes': hierarchy_classes,
1074
- 'epoch': e+1,
1075
- 'config': {
1076
- 'embed_dim': config.hierarchy_emb_dim,
1077
- 'dropout': dropout
1078
- }
1079
- }, f"model_checkpoint_epoch_{e+1}.pth")
1080
-
1081
- print("\n" + "="*80)
1082
- print("🎉 Training completed!")
1083
- print(f"🏆 Best validation loss: {best_val_loss:.6f}")
1084
-
1085
- print(f"\n📈 Final validation accuracy: Image={training_history['val_acc_img'][-1]:.3f}, Text={training_history['val_acc_txt'][-1]:.3f}")
 
1
  """
2
  Hierarchy model for learning clothing category-aligned embeddings.
3
+
4
+ Architecture: frozen CLIP (ViT-B/32) encoders with trainable MLP projections
5
+ to a 64-dimensional embedding space, plus classifier heads for hierarchy
6
+ category prediction. The CLIP backbone provides strong image and text
7
+ understanding while the lightweight projection heads learn a compact,
8
+ category-aligned representation suitable for fast nearest-neighbor search.
9
+
10
+ Components:
11
+ - HierarchyExtractor: regex pattern-based text-to-category mapper
12
+ - HierarchyClassifierHead: MLP classifier on top of projected embeddings
13
+ - HierarchyModel: frozen CLIP + trainable projections + classifier heads
14
+ - HierarchyDataset: CLIP-preprocessed images + raw text for training
15
+ - PrecomputedHierarchyDataset: pre-computed CLIP features for fast training
16
+ - Loss: combined classification, contrastive, and consistency loss
17
+ - train_hierarchy: end-to-end training loop using pre-computed features
18
  """
19
 
20
  import pandas as pd
 
22
  import torch.nn as nn
23
  import torch.nn.functional as F
24
  from torch.utils.data import Dataset, DataLoader
 
25
  from PIL import Image
26
  from tqdm import tqdm
27
  from sklearn.model_selection import train_test_split
28
  import re
 
 
29
  import config
30
 
31
  # -------------------------
32
+ # 1) Hierarchy Extractor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  # -------------------------
34
 
35
  class HierarchyExtractor:
36
  """
37
  Extract hierarchy categories directly from text using pattern matching.
38
+
39
  This class uses regex patterns to identify clothing categories (e.g., shirt, dress)
40
  from text descriptions, handling variations, plurals, and common fashion terms.
41
  """
42
+
43
  def __init__(self, hierarchy_classes, verbose=False):
44
  """
45
  Initialize the hierarchy extractor.
46
+
47
  Args:
48
  hierarchy_classes: List of hierarchy class names
49
  verbose: Whether to print initialization information (default: False)
 
51
  self.hierarchy_classes = sorted(hierarchy_classes)
52
  self.class_to_idx = {cls: idx for idx, cls in enumerate(self.hierarchy_classes)}
53
  self.idx_to_class = {idx: cls for idx, cls in enumerate(self.hierarchy_classes)}
54
+
55
  # Create patterns for each hierarchy
56
  self.patterns = self._create_patterns()
57
+
58
  if verbose:
59
+ print(f"Hierarchy extractor initialized with {len(self.hierarchy_classes)} classes")
60
+ print(f"Classes: {self.hierarchy_classes}")
61
+
62
  def _create_patterns(self):
63
  """
64
  Create regex patterns for each hierarchy class.
65
+
66
  Creates patterns that match variations, plurals, and common fashion terms
67
  for each hierarchy class.
68
+
69
  Returns:
70
  Dictionary mapping hierarchy classes to regex patterns
71
  """
72
  patterns = {}
73
+
74
  for hierarchy in self.hierarchy_classes:
75
  # Create variations of the hierarchy name
76
  variations = [hierarchy.lower()]
77
+
78
  # Add common variations
79
  if '-' in hierarchy:
80
  variations.append(hierarchy.replace('-', ' '))
81
  variations.append(hierarchy.replace('-', ''))
82
+
83
  # Add plural forms
84
  if not hierarchy.endswith('s'):
85
  variations.append(hierarchy + 's')
86
+
87
  # Add common fashion terms
88
  fashion_terms = {
89
  'shirt': ['shirt', 'shirts', 'tee', 't-shirt', 'tshirt'],
 
108
  'glove': ['glove', 'gloves'],
109
  'sandal': ['sandal', 'sandals']
110
  }
111
+
112
  # Add fashion terms if hierarchy matches
113
  for key, terms in fashion_terms.items():
114
  if key in hierarchy.lower():
115
  variations.extend(terms)
116
+
117
  # Create regex pattern
118
  pattern = r'\b(' + '|'.join(re.escape(v) for v in variations) + r')\b'
119
  patterns[hierarchy] = pattern
120
+
121
  return patterns
122
+
123
  def extract_hierarchy(self, text):
124
  """
125
  Extract hierarchy category from text using pattern matching.
126
+
127
  Args:
128
  text: Input text string
129
+
130
  Returns:
131
  Hierarchy class name if found, None otherwise
132
  """
133
  text_lower = text.lower()
134
+
135
  # Try exact match first
136
  for hierarchy in self.hierarchy_classes:
137
  if hierarchy.lower() in text_lower:
138
  return hierarchy
139
+
140
  # Try pattern matching
141
  for hierarchy, pattern in self.patterns.items():
142
  if re.search(pattern, text_lower):
143
  return hierarchy
144
+
145
+ # If no match found, return None
146
  return None
147
+
148
  def extract_hierarchy_idx(self, text):
149
  """
150
  Extract hierarchy index from text.
151
+
152
  Args:
153
  text: Input text string
154
+
155
  Returns:
156
  Hierarchy index if found, None otherwise
157
  """
 
159
  if hierarchy:
160
  return self.class_to_idx[hierarchy]
161
  return None
162
+
163
  def get_hierarchy_embedding(self, text, embed_dim=config.hierarchy_emb_dim):
164
  """
165
  Create embedding from hierarchy index extracted from text.
166
+
167
  Args:
168
  text: Input text string
169
  embed_dim: Dimension of the embedding (default: hierarchy_emb_dim)
170
+
171
  Returns:
172
  Embedding tensor of shape (embed_dim,)
173
  """
 
186
  return torch.zeros(embed_dim)
187
 
188
  # -------------------------
189
+ # 2) Models
190
  # -------------------------
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  class HierarchyClassifierHead(nn.Module):
193
  """
194
  Classifier head for hierarchy classification.
195
+
196
  Multi-layer perceptron that takes embeddings as input and outputs
197
  classification logits for hierarchy classes.
198
  """
199
+
200
  def __init__(self, in_dim, num_classes, hidden_dim=None, dropout=0.3):
201
  """
202
  Initialize the hierarchy classifier head.
203
+
204
  Args:
205
  in_dim: Input embedding dimension
206
  num_classes: Number of hierarchy classes
 
210
  super().__init__()
211
  if hidden_dim is None:
212
  hidden_dim = max(in_dim // 2, num_classes * 2)
213
+
214
  self.classifier = nn.Sequential(
215
  nn.Linear(in_dim, hidden_dim),
216
  nn.ReLU(inplace=True),
 
220
  nn.Dropout(dropout * 0.5),
221
  nn.Linear(hidden_dim // 2, num_classes)
222
  )
223
+
224
  def forward(self, x):
225
  """
226
  Forward pass through the classifier head.
227
+
228
  Args:
229
  x: Input embeddings [batch_size, in_dim]
230
+
231
  Returns:
232
  Classification logits [batch_size, num_classes]
233
  """
234
  return self.classifier(x)
235
 
236
+
237
+ class HierarchyModel(nn.Module):
238
  """
239
+ Hierarchy model: frozen CLIP encoders + trainable MLP projections to 64D.
240
+
241
+ Replaces ResNet18 image encoder and discrete embedding lookup with CLIP's
242
+ full encoders, giving CLIP-level understanding in 64 dimensions.
243
  """
244
+
245
+ CLIP_MODEL_NAME = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
246
+
247
+ def __init__(self, num_hierarchy_classes: int, embed_dim: int = config.hierarchy_emb_dim,
248
+ clip_model_name: str | None = None, dropout: float = 0.3):
 
 
 
 
 
249
  super().__init__()
250
+ from transformers import CLIPModel as _CLIPModel, CLIPProcessor as _CLIPProc
251
+
252
+ self.embed_dim = embed_dim
253
+ self.num_hierarchy_classes = num_hierarchy_classes
254
+ self.clip_model_name = clip_model_name or self.CLIP_MODEL_NAME
255
+
256
+ # Frozen CLIP backbone
257
+ self.clip = _CLIPModel.from_pretrained(self.clip_model_name)
258
+ self.processor = _CLIPProc.from_pretrained(self.clip_model_name)
259
+ for p in self.clip.parameters():
260
+ p.requires_grad = False
261
+
262
+ clip_dim = self.clip.config.projection_dim # 512
263
+
264
+ # Trainable MLP projections
265
+ self.image_projection = nn.Sequential(
266
+ nn.Linear(clip_dim, 128),
267
+ nn.ReLU(inplace=True),
268
+ nn.Dropout(dropout),
269
+ nn.Linear(128, embed_dim),
270
+ nn.LayerNorm(embed_dim),
271
+ )
272
+ self.text_projection = nn.Sequential(
273
+ nn.Linear(clip_dim, 128),
274
+ nn.ReLU(inplace=True),
275
+ nn.Dropout(dropout),
276
+ nn.Linear(128, embed_dim),
277
+ nn.LayerNorm(embed_dim),
278
+ )
279
+
280
+ # Classification heads
281
  self.hierarchy_head_img = HierarchyClassifierHead(embed_dim, num_hierarchy_classes, dropout=dropout)
282
  self.hierarchy_head_txt = HierarchyClassifierHead(embed_dim, num_hierarchy_classes, dropout=dropout)
 
283
 
284
+ # Will be set after init
285
+ self.hierarchy_extractor = None
286
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  def set_hierarchy_extractor(self, hierarchy_extractor):
 
 
 
 
 
 
288
  self.hierarchy_extractor = hierarchy_extractor
289
+
290
+ # ------ forward ------
291
+
292
+ def _clip_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
293
  with torch.no_grad():
294
+ return self.clip.get_image_features(pixel_values=pixel_values)
295
+
296
+ def _clip_text_features(self, texts: list[str]) -> torch.Tensor:
297
+ device = next(self.parameters()).device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  with torch.no_grad():
299
+ inputs = self.processor(text=texts, padding=True, truncation=True, return_tensors="pt")
300
+ inputs = {k: v.to(device) for k, v in inputs.items()}
301
+ return self.clip.get_text_features(**inputs)
302
+
303
+ def forward(self, pixel_values: torch.Tensor | None = None,
304
+ texts: list[str] | None = None):
305
+ """Forward pass. Accepts images and/or raw text strings."""
306
+ out = {}
307
+ if pixel_values is not None:
308
+ img_feat = self._clip_image_features(pixel_values)
309
+ z_img = F.normalize(self.image_projection(img_feat), p=2, dim=-1)
310
+ out["z_img"] = z_img
311
+ out["hierarchy_logits_img"] = self.hierarchy_head_img(z_img)
312
+
313
+ if texts is not None:
314
+ txt_feat = self._clip_text_features(texts)
315
+ z_txt = F.normalize(self.text_projection(txt_feat), p=2, dim=-1)
316
+ out["z_txt"] = z_txt
317
+ out["hierarchy_logits_txt"] = self.hierarchy_head_txt(z_txt)
318
+
319
+ return out
320
+
321
+ # ------ API expected by main_model.py ------
322
+
323
+ def get_text_embeddings(self, texts) -> torch.Tensor:
324
+ """Returns [B, 64] from text strings (hierarchy labels or descriptions)."""
325
+ if isinstance(texts, str):
326
+ texts = [texts]
327
+ with torch.no_grad():
328
+ txt_feat = self._clip_text_features(texts)
329
+ return F.normalize(self.text_projection(txt_feat), p=2, dim=-1)
330
+
331
+ def get_image_embeddings(self, pixel_values: torch.Tensor) -> torch.Tensor:
332
+ """Returns [B, 64] from preprocessed pixel_values."""
333
+ if pixel_values.dim() == 3:
334
+ pixel_values = pixel_values.unsqueeze(0)
335
+ with torch.no_grad():
336
+ img_feat = self._clip_image_features(pixel_values)
337
+ return F.normalize(self.image_projection(img_feat), p=2, dim=-1)
338
+
339
+ # ------ serialization ------
340
+
341
+ def save_checkpoint(self, path: str, hierarchy_classes: list[str], epoch: int = 0):
342
+ torch.save({
343
+ "model_version": "v2",
344
+ "embedding_dim": self.embed_dim,
345
+ "clip_model_name": self.clip_model_name,
346
+ "hierarchy_classes": hierarchy_classes,
347
+ "epoch": epoch,
348
+ "image_projection": self.image_projection.state_dict(),
349
+ "text_projection": self.text_projection.state_dict(),
350
+ "hierarchy_head_img": self.hierarchy_head_img.state_dict(),
351
+ "hierarchy_head_txt": self.hierarchy_head_txt.state_dict(),
352
+ }, path)
353
+
354
+ @classmethod
355
+ def from_checkpoint(cls, path: str, device: torch.device | str = "cpu"):
356
+ ckpt = torch.load(path, map_location=device)
357
+ hierarchy_classes = ckpt["hierarchy_classes"]
358
+ model = cls(
359
+ num_hierarchy_classes=len(hierarchy_classes),
360
+ embed_dim=ckpt["embedding_dim"],
361
+ clip_model_name=ckpt.get("clip_model_name", cls.CLIP_MODEL_NAME),
362
+ )
363
+ model.image_projection.load_state_dict(ckpt["image_projection"])
364
+ model.text_projection.load_state_dict(ckpt["text_projection"])
365
+ model.hierarchy_head_img.load_state_dict(ckpt["hierarchy_head_img"])
366
+ model.hierarchy_head_txt.load_state_dict(ckpt["hierarchy_head_txt"])
367
+ # Set up hierarchy extractor
368
+ extractor = HierarchyExtractor(hierarchy_classes, verbose=False)
369
+ model.set_hierarchy_extractor(extractor)
370
+ model.to(device)
371
+ model.eval()
372
+ return model
373
+
374
+
375
+ # -------------------------
376
+ # 3) Datasets
377
+ # -------------------------
378
+
379
+ class HierarchyDataset(Dataset):
380
+ """Dataset for HierarchyModel -- CLIP-preprocessed images + raw text."""
381
+
382
+ def __init__(self, dataframe, processor, hierarchy_extractor):
383
+ self.df = dataframe.reset_index(drop=True)
384
+ self.processor = processor
385
+ self.hierarchy_extractor = hierarchy_extractor
386
+
387
+ def __len__(self):
388
+ return len(self.df)
389
+
390
+ def __getitem__(self, idx):
391
+ row = self.df.iloc[idx]
392
+ try:
393
+ img = Image.open(row[config.column_local_image_path]).convert("RGB")
394
+ except Exception:
395
+ return None
396
+ pixel_values = self.processor(images=img, return_tensors="pt")["pixel_values"].squeeze(0)
397
+ text = str(row[config.text_column])
398
+ hierarchy = str(row[config.hierarchy_column])
399
+ return pixel_values, text, hierarchy
400
+
401
+
402
+ def collate_fn(batch):
403
+ """Collate for HierarchyDataset -- filters None, stacks images."""
404
+ batch = [b for b in batch if b is not None]
405
+ if len(batch) == 0:
406
+ return None
407
+ imgs, texts, hierarchies = zip(*batch)
408
+ return torch.stack(imgs, 0), list(texts), list(hierarchies)
409
+
410
+
411
+ class PrecomputedHierarchyDataset(Dataset):
412
+ """Dataset using pre-computed CLIP features for fast hierarchy training."""
413
+
414
+ def __init__(self, image_paths, hierarchies, image_features, text_features):
415
+ self.image_paths = image_paths
416
+ self.hierarchies = hierarchies
417
+ self.image_features = image_features
418
+ self.text_features = text_features
419
+
420
+ def __len__(self):
421
+ return len(self.image_paths)
422
+
423
+ def __getitem__(self, idx):
424
+ path = self.image_paths[idx]
425
+ hierarchy = self.hierarchies[idx]
426
+ img_feat = self.image_features.get(path)
427
+ txt_feat = self.text_features.get(hierarchy)
428
+ if img_feat is None or txt_feat is None:
429
+ return None
430
+ return img_feat, txt_feat, hierarchy
431
+
432
+ @staticmethod
433
+ def collate(batch):
434
+ batch = [b for b in batch if b is not None]
435
+ if not batch:
436
+ return None
437
+ imgs, txts, hierarchies = zip(*batch)
438
+ return torch.stack(imgs, 0), torch.stack(txts, 0), list(hierarchies)
439
 
440
  # -------------------------
441
  # 4) Loss functions
 
444
  class Loss(nn.Module):
445
  """
446
  Combined loss function for hierarchy model training.
447
+
448
  Combines classification loss, contrastive loss, and consistency loss
449
  to learn aligned embeddings while maintaining classification accuracy.
450
  """
451
+
452
+ def __init__(self, hierarchy_classes, classification_weight=1.0,
453
+ consistency_weight=0.3, contrastive_weight=0.2,
454
  temperature=0.07, label_smoothing=0.1):
455
  """
456
  Initialize the loss function.
457
+
458
  Args:
459
  hierarchy_classes: List of hierarchy class names
460
  classification_weight: Weight for classification loss (default: 1.0)
 
468
  self.consistency_weight = consistency_weight
469
  self.contrastive_weight = contrastive_weight
470
  self.temperature = temperature
471
+
472
  self.hierarchy_classes = sorted(list(set(hierarchy_classes)))
473
  self.num_classes = len(self.hierarchy_classes)
474
  self.class_to_idx = {cls: idx for idx, cls in enumerate(self.hierarchy_classes)}
475
+
476
  # Loss functions with label smoothing
477
  self.ce = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
478
  self.mse = nn.MSELoss()
479
+
480
  def contrastive_loss(self, img_emb, txt_emb):
481
  """
482
  InfoNCE contrastive loss for aligning image and text embeddings.
483
+
484
  Args:
485
  img_emb: Image embeddings [batch_size, embed_dim]
486
  txt_emb: Text embeddings [batch_size, embed_dim]
487
+
488
  Returns:
489
  Contrastive loss value
490
  """
491
  sim_matrix = torch.matmul(img_emb, txt_emb.T) / self.temperature
492
  labels = torch.arange(img_emb.size(0), device=img_emb.device)
493
+
494
  loss_i2t = F.cross_entropy(sim_matrix, labels)
495
  loss_t2i = F.cross_entropy(sim_matrix.T, labels)
496
+
497
  return (loss_i2t + loss_t2i) / 2
498
+
499
  def forward(self, img_logits, txt_logits, img_embeddings, txt_embeddings, target_hierarchies):
500
  """
501
  Forward pass through the loss function.
502
+
503
  Args:
504
  img_logits: Image classification logits [batch_size, num_classes]
505
  txt_logits: Text classification logits [batch_size, num_classes]
506
  img_embeddings: Image embeddings [batch_size, embed_dim]
507
  txt_embeddings: Text embeddings [batch_size, embed_dim]
508
  target_hierarchies: List of target hierarchy class names [batch_size]
509
+
510
  Returns:
511
  Combined loss value
512
  """
513
  device = img_embeddings.device
514
+
515
  # Convert hierarchy names to indices
516
  target_classes = torch.tensor([
517
  self.class_to_idx.get(hierarchy, 0) for hierarchy in target_hierarchies
518
  ], device=device)
519
+
520
  # 1. Classification loss
521
+ classification_loss = (self.ce(img_logits, target_classes) +
522
  self.ce(txt_logits, target_classes)) / 2
523
+
524
  # 2. Contrastive loss for alignment
525
  contrastive_loss = self.contrastive_loss(img_embeddings, txt_embeddings)
526
+
527
  # 3. Consistency loss between modalities
528
  consistency_loss = self.mse(img_embeddings, txt_embeddings)
529
+
530
  # Combined loss
531
  total_loss = (self.classification_weight * classification_loss +
532
  self.contrastive_weight * contrastive_loss +
533
  self.consistency_weight * consistency_loss)
534
+
535
  return total_loss
536
 
537
  # -------------------------
538
+ # 5) Training utilities
539
  # -------------------------
540
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
541
  def calculate_accuracy(logits, target_hierarchies, hierarchy_classes):
542
  """
543
  Calculate classification accuracy.
544
+
545
  Args:
546
  logits: Classification logits [batch_size, num_classes]
547
  target_hierarchies: List of target hierarchy class names [batch_size]
548
  hierarchy_classes: List of hierarchy class names
549
+
550
  Returns:
551
  Accuracy score (float between 0 and 1)
552
  """
553
  batch_size = logits.size(0)
554
  correct = 0
555
  pred_indices = torch.argmax(logits, dim=1).cpu().numpy()
556
+
557
  for i in range(batch_size):
558
  pred_class = hierarchy_classes[pred_indices[i]] if pred_indices[i] < len(hierarchy_classes) else ""
559
  target_class = target_hierarchies[i]
560
  if pred_class == target_class:
561
  correct += 1
 
 
562
 
563
+ return correct / batch_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
564
 
565
  # -------------------------
566
  # 6) Main training script
567
  # -------------------------
568
 
569
+ def train_hierarchy():
570
+ """Train HierarchyModel using pre-computed CLIP features (fast)."""
571
+ from pathlib import Path
572
+ batch_size = 256
573
+ lr = 5e-4
574
+ epochs = 30
575
+ val_split = 0.2
576
+ dropout = 0.3
577
+
578
  device = config.device
579
+ print(f"Starting HierarchyModel training on device: {device}")
580
+
581
+ # Load pre-computed features
582
+ feat_dir = Path(config.local_dataset_path).parent
583
+ img_feat_path = feat_dir / "clip_image_features.pt"
584
+ txt_feat_path = feat_dir / "clip_text_features.pt"
585
+
586
+ if not img_feat_path.exists() or not txt_feat_path.exists():
587
+ print("Pre-computed features not found. Run data/precompute_clip_features.py first.")
588
+ return
589
+
590
+ print("Loading pre-computed CLIP features...")
591
+ image_features = torch.load(img_feat_path, map_location="cpu")
592
+ text_features = torch.load(txt_feat_path, map_location="cpu")
593
+ print(f" Image features: {len(image_features)}, Text features: {len(text_features)}")
594
+
595
+ # Load data
596
+ df = pd.read_csv(config.local_dataset_path)
597
+ df = df.dropna(subset=[config.column_local_image_path, config.hierarchy_column])
598
+ df = df[df[config.column_local_image_path].isin(image_features.keys())]
599
+ df = df[df[config.hierarchy_column].isin(text_features.keys())]
600
+ # Filter out classes with fewer than 2 samples (required for stratified split)
601
+ class_counts = df[config.hierarchy_column].value_counts()
602
+ valid_classes = class_counts[class_counts >= 2].index
603
+ df = df[df[config.hierarchy_column].isin(valid_classes)]
604
+ df = df.reset_index(drop=True)
605
+ print(f"Valid samples: {len(df)} (dropped {len(class_counts) - len(valid_classes)} singleton classes)")
606
+
607
  hierarchy_classes = sorted(df[config.hierarchy_column].unique().tolist())
608
+ num_classes = len(hierarchy_classes)
609
+ class_to_idx = {cls: idx for idx, cls in enumerate(hierarchy_classes)}
610
+ print(f"Hierarchy classes ({num_classes}): {hierarchy_classes}")
611
+
612
+ # Split
 
613
  train_df, val_df = train_test_split(
614
+ df, test_size=val_split, random_state=42,
615
+ stratify=df[config.hierarchy_column],
 
 
616
  )
617
  train_df = train_df.reset_index(drop=True)
618
  val_df = val_df.reset_index(drop=True)
619
+ print(f"Train: {len(train_df)}, Val: {len(val_df)}")
620
+
621
+ train_ds = PrecomputedHierarchyDataset(
622
+ train_df[config.column_local_image_path].tolist(),
623
+ train_df[config.hierarchy_column].tolist(),
624
+ image_features, text_features,
 
 
 
 
 
 
 
625
  )
626
+ val_ds = PrecomputedHierarchyDataset(
627
+ val_df[config.column_local_image_path].tolist(),
628
+ val_df[config.hierarchy_column].tolist(),
629
+ image_features, text_features,
 
630
  )
631
+ train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
632
+ collate_fn=PrecomputedHierarchyDataset.collate, num_workers=0)
633
+ val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
634
+ collate_fn=PrecomputedHierarchyDataset.collate, num_workers=0)
635
+
636
+ # Trainable modules
637
+ clip_dim = 512
638
+ emb_dim = config.hierarchy_emb_dim
639
+
640
+ image_proj = nn.Sequential(
641
+ nn.Linear(clip_dim, 128), nn.ReLU(inplace=True), nn.Dropout(dropout),
642
+ nn.Linear(128, emb_dim), nn.LayerNorm(emb_dim),
643
+ ).to(device)
644
+ text_proj = nn.Sequential(
645
+ nn.Linear(clip_dim, 128), nn.ReLU(inplace=True), nn.Dropout(dropout),
646
+ nn.Linear(128, emb_dim), nn.LayerNorm(emb_dim),
647
  ).to(device)
648
+ head_img = HierarchyClassifierHead(emb_dim, num_classes, dropout=dropout).to(device)
649
+ head_txt = HierarchyClassifierHead(emb_dim, num_classes, dropout=dropout).to(device)
650
+
651
+ trainable_params = (
652
+ list(image_proj.parameters()) + list(text_proj.parameters()) +
653
+ list(head_img.parameters()) + list(head_txt.parameters())
654
+ )
655
+ optimizer = torch.optim.AdamW(trainable_params, lr=lr, weight_decay=1e-3)
656
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
657
+
658
+ loss_fn = Loss(hierarchy_classes, classification_weight=1.0,
659
+ consistency_weight=0.3, contrastive_weight=0.2,
660
+ label_smoothing=0.1).to(device)
661
+
662
+ best_val_loss = float("inf")
663
+ patience_counter = 0
664
+ patience = 10
665
+
666
+ for epoch in range(epochs):
667
+ image_proj.train(); text_proj.train()
668
+ head_img.train(); head_txt.train()
669
+
670
+ train_loss_sum, train_batches = 0.0, 0
671
+ for batch in tqdm(train_dl, desc=f"Epoch {epoch+1}/{epochs} train", leave=False):
672
+ if batch is None:
673
+ continue
674
+ img_feat, txt_feat, hierarchies = batch
675
+ img_feat, txt_feat = img_feat.to(device), txt_feat.to(device)
676
+
677
+ z_img = F.normalize(image_proj(img_feat), p=2, dim=-1)
678
+ z_txt = F.normalize(text_proj(txt_feat), p=2, dim=-1)
679
+ logits_img = head_img(z_img)
680
+ logits_txt = head_txt(z_txt)
681
+
682
+ optimizer.zero_grad()
683
+ loss = loss_fn(logits_img, logits_txt, z_img, z_txt, hierarchies)
684
+ loss.backward()
685
+ torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=1.0)
686
+ optimizer.step()
687
+
688
+ train_loss_sum += loss.item()
689
+ train_batches += 1
690
+
691
+ scheduler.step()
692
+ avg_train = train_loss_sum / max(train_batches, 1)
693
+
694
+ # Validate
695
+ image_proj.eval(); text_proj.eval()
696
+ head_img.eval(); head_txt.eval()
697
+ val_loss_sum, val_batches = 0.0, 0
698
+ val_correct_img, val_correct_txt, val_total = 0, 0, 0
699
+
700
+ with torch.no_grad():
701
+ for batch in val_dl:
702
+ if batch is None:
703
+ continue
704
+ img_feat, txt_feat, hierarchies = batch
705
+ img_feat, txt_feat = img_feat.to(device), txt_feat.to(device)
706
+
707
+ z_img = F.normalize(image_proj(img_feat), p=2, dim=-1)
708
+ z_txt = F.normalize(text_proj(txt_feat), p=2, dim=-1)
709
+ logits_img = head_img(z_img)
710
+ logits_txt = head_txt(z_txt)
711
+
712
+ loss = loss_fn(logits_img, logits_txt, z_img, z_txt, hierarchies)
713
+ val_loss_sum += loss.item()
714
+ val_batches += 1
715
+
716
+ acc_img = calculate_accuracy(logits_img, hierarchies, hierarchy_classes)
717
+ acc_txt = calculate_accuracy(logits_txt, hierarchies, hierarchy_classes)
718
+ val_correct_img += acc_img * len(hierarchies)
719
+ val_correct_txt += acc_txt * len(hierarchies)
720
+ val_total += len(hierarchies)
721
+
722
+ avg_val = val_loss_sum / max(val_batches, 1)
723
+ vacc_img = val_correct_img / max(val_total, 1)
724
+ vacc_txt = val_correct_txt / max(val_total, 1)
725
+
726
+ print(f"Epoch {epoch+1}/{epochs} train={avg_train:.4f} val={avg_val:.4f} "
727
+ f"img_acc={vacc_img:.3f} txt_acc={vacc_txt:.3f}")
728
+
729
+ if avg_val < best_val_loss:
730
+ best_val_loss = avg_val
731
+ patience_counter = 0
732
  torch.save({
733
+ "model_version": "v2",
734
+ "embedding_dim": emb_dim,
735
+ "clip_model_name": HierarchyModel.CLIP_MODEL_NAME,
736
+ "hierarchy_classes": hierarchy_classes,
737
+ "epoch": epoch + 1,
738
+ "image_projection": image_proj.state_dict(),
739
+ "text_projection": text_proj.state_dict(),
740
+ "hierarchy_head_img": head_img.state_dict(),
741
+ "hierarchy_head_txt": head_txt.state_dict(),
742
  }, config.hierarchy_model_path)
743
+ print(f" -> Saved best model (val_loss={avg_val:.4f})")
744
+ else:
745
+ patience_counter += 1
746
+ if patience_counter >= patience:
747
+ print(f"Early stopping at epoch {epoch+1}")
748
+ break
749
+
750
+ print(f"\nTraining complete. Best val loss: {best_val_loss:.4f}")
751
+ print(f"Model saved to: {config.hierarchy_model_path}")
752
+
753
+
754
+ if __name__ == "__main__":
755
+ import os
756
+ train_hierarchy()