Leacb4 commited on
Commit
22ab4d2
·
verified ·
1 Parent(s): 1b41378

Delete models/hierarchy_model.py

Browse files
Files changed (1) hide show
  1. models/hierarchy_model.py +0 -1085
models/hierarchy_model.py DELETED
@@ -1,1085 +0,0 @@
1
- """
2
- Hierarchy model for learning clothing category-aligned embeddings.
3
- This file contains the hierarchy model that learns to encode images and texts
4
- in an embedding space specialized for representing clothing categories (dress, shirt, etc.).
5
- It includes a regex pattern-based hierarchy extractor, a ResNet image encoder,
6
- a hierarchy embedding encoder, and loss functions for training.
7
- """
8
-
9
- import pandas as pd
10
- import torch
11
- import torch.nn as nn
12
- import torch.nn.functional as F
13
- from torch.utils.data import Dataset, DataLoader
14
- from torchvision import transforms, models
15
- from PIL import Image
16
- from tqdm import tqdm
17
- from sklearn.model_selection import train_test_split
18
- import re
19
- import requests
20
- from io import BytesIO
21
- import config
22
-
23
- # -------------------------
24
- # 1) Dataset
25
- # -------------------------
26
-
27
- class HierarchyDataset(Dataset):
28
- """
29
- Dataset class for hierarchy embedding training.
30
-
31
- Handles loading images from local paths or URLs, extracting hierarchy information
32
- from text descriptions, and applying appropriate transformations for training.
33
- """
34
-
35
- def __init__(self, dataframe, use_local_images=True, image_size=224):
36
- """
37
- Initialize the hierarchy dataset.
38
-
39
- Args:
40
- dataframe: DataFrame with columns for image paths/URLs, text descriptions, and hierarchy labels
41
- use_local_images: Whether to prefer local images over URLs (default: True)
42
- image_size: Size of images after resizing (default: 224)
43
- """
44
- self.dataframe = dataframe
45
- self.use_local_images = use_local_images
46
- self.image_size = image_size
47
-
48
- # transforms with data augmentation for training
49
- self.transform = transforms.Compose([
50
- transforms.Resize((image_size, image_size)),
51
- transforms.RandomHorizontalFlip(p=0.3),
52
- transforms.RandomRotation(10),
53
- transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
54
- transforms.ToTensor(),
55
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
56
- ])
57
-
58
- # Validation transforms (no augmentation)
59
- self.val_transform = transforms.Compose([
60
- transforms.Resize((image_size, image_size)),
61
- transforms.ToTensor(),
62
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
63
- ])
64
-
65
- # Check local image availability
66
- if use_local_images:
67
- if config.column_local_image_path not in dataframe.columns:
68
- print(f"⚠️ Column {config.column_local_image_path} not found. Using URLs.")
69
- self.use_local_images = False
70
- else:
71
- local_available = dataframe[config.column_local_image_path].notna().sum()
72
- total = len(dataframe)
73
- print(f"📁 Local images available: {local_available}/{total} ({local_available/total*100:.1f}%)")
74
-
75
-
76
- def set_training_mode(self, training=True):
77
- """
78
- Switch between training and validation transforms.
79
-
80
- Args:
81
- training: If True, use training transforms with augmentation; if False, use validation transforms
82
- """
83
- self.training_mode = training
84
-
85
- def __len__(self):
86
- """Return the number of samples in the dataset."""
87
- return len(self.dataframe)
88
-
89
- def __getitem__(self, idx):
90
- """
91
- Get a sample from the dataset.
92
-
93
- Args:
94
- idx: Index of the sample
95
-
96
- Returns:
97
- Tuple of (image_tensor, description_text, hierarchy_label)
98
- """
99
- row = self.dataframe.iloc[idx]
100
-
101
- # Try to load local image first
102
- if self.use_local_images and pd.notna(row.get(config.column_local_image_path, '')):
103
- local_path = row[config.column_local_image_path]
104
- image = Image.open(local_path).convert("RGB")
105
- # Check if image is a dictionary of bytes
106
- elif isinstance(row[config.column_url_image], dict):
107
- image = Image.open(BytesIO(row[config.column_url_image]['bytes'])).convert('RGB')
108
- # Otherwise, try to download from URL
109
- else:
110
- image = self._download_image(row[config.column_url_image])
111
-
112
- # Apply transforms
113
- if hasattr(self, 'training_mode') and not self.training_mode:
114
- image = self.val_transform(image)
115
- else:
116
- image = self.transform(image)
117
-
118
- description = row[config.text_column]
119
- hierarchy = row[config.hierarchy_column]
120
-
121
- return image, description, hierarchy
122
-
123
- def _download_image(self, img_url):
124
- """
125
- Download an image from a URL with timeout.
126
-
127
- Args:
128
- img_url: URL of the image to download
129
-
130
- Returns:
131
- PIL Image object
132
- """
133
- response = requests.get(img_url, timeout=10)
134
- response.raise_for_status()
135
- image = Image.open(BytesIO(response.content)).convert("RGB")
136
- return image
137
-
138
- # -------------------------
139
- # 2) Hierarchy Extractor
140
- # -------------------------
141
-
142
- class HierarchyExtractor:
143
- """
144
- Extract hierarchy categories directly from text using pattern matching.
145
-
146
- This class uses regex patterns to identify clothing categories (e.g., shirt, dress)
147
- from text descriptions, handling variations, plurals, and common fashion terms.
148
- """
149
-
150
- def __init__(self, hierarchy_classes, verbose=False):
151
- """
152
- Initialize the hierarchy extractor.
153
-
154
- Args:
155
- hierarchy_classes: List of hierarchy class names
156
- verbose: Whether to print initialization information (default: False)
157
- """
158
- self.hierarchy_classes = sorted(hierarchy_classes)
159
- self.class_to_idx = {cls: idx for idx, cls in enumerate(self.hierarchy_classes)}
160
- self.idx_to_class = {idx: cls for idx, cls in enumerate(self.hierarchy_classes)}
161
-
162
- # Create patterns for each hierarchy
163
- self.patterns = self._create_patterns()
164
-
165
- if verbose:
166
- print(f"🎯 Hierarchy extractor initialized with {len(self.hierarchy_classes)} classes")
167
- print(f"📋 Classes: {self.hierarchy_classes}")
168
-
169
- def _create_patterns(self):
170
- """
171
- Create regex patterns for each hierarchy class.
172
-
173
- Creates patterns that match variations, plurals, and common fashion terms
174
- for each hierarchy class.
175
-
176
- Returns:
177
- Dictionary mapping hierarchy classes to regex patterns
178
- """
179
- patterns = {}
180
-
181
- for hierarchy in self.hierarchy_classes:
182
- # Create variations of the hierarchy name
183
- variations = [hierarchy.lower()]
184
-
185
- # Add common variations
186
- if '-' in hierarchy:
187
- variations.append(hierarchy.replace('-', ' '))
188
- variations.append(hierarchy.replace('-', ''))
189
-
190
- # Add plural forms
191
- if not hierarchy.endswith('s'):
192
- variations.append(hierarchy + 's')
193
-
194
- # Add common fashion terms
195
- fashion_terms = {
196
- 'shirt': ['shirt', 'shirts', 'tee', 't-shirt', 'tshirt'],
197
- 'jacket': ['jacket', 'jackets', 'coat', 'coats'],
198
- 'pant': ['pant', 'pants', 'trouser', 'trousers', 'jean', 'jeans'],
199
- 'dress': ['dress', 'dresses'],
200
- 'skirt': ['skirt', 'skirts'],
201
- 'shoe': ['shoe', 'shoes', 'boot', 'boots', 'sneaker', 'sneakers'],
202
- 'bag': ['bag', 'bags', 'handbag', 'handbags', 'purse', 'purses'],
203
- 'hat': ['hat', 'hats', 'cap', 'caps'],
204
- 'scarf': ['scarf', 'scarves'],
205
- 'belt': ['belt', 'belts'],
206
- 'sock': ['sock', 'socks'],
207
- 'underwear': ['underwear', 'underpant', 'underpants'],
208
- 'sweater': ['sweater', 'sweaters', 'jumper', 'jumpers'],
209
- 'blouse': ['blouse', 'blouses'],
210
- 'vest': ['vest', 'vests'],
211
- 'short': ['short', 'shorts'],
212
- 'legging': ['legging', 'leggings'],
213
- 'suit': ['suit', 'suits'],
214
- 'tie': ['tie', 'ties'],
215
- 'glove': ['glove', 'gloves'],
216
- 'sandal': ['sandal', 'sandals']
217
- }
218
-
219
- # Add fashion terms if hierarchy matches
220
- for key, terms in fashion_terms.items():
221
- if key in hierarchy.lower():
222
- variations.extend(terms)
223
-
224
- # Create regex pattern
225
- pattern = r'\b(' + '|'.join(re.escape(v) for v in variations) + r')\b'
226
- patterns[hierarchy] = pattern
227
-
228
- return patterns
229
-
230
- def extract_hierarchy(self, text):
231
- """
232
- Extract hierarchy category from text using pattern matching.
233
-
234
- Args:
235
- text: Input text string
236
-
237
- Returns:
238
- Hierarchy class name if found, None otherwise
239
- """
240
- text_lower = text.lower()
241
-
242
- # Try exact match first
243
- for hierarchy in self.hierarchy_classes:
244
- if hierarchy.lower() in text_lower:
245
- return hierarchy
246
-
247
- # Try pattern matching
248
- for hierarchy, pattern in self.patterns.items():
249
- if re.search(pattern, text_lower):
250
- return hierarchy
251
-
252
- # If no match found, return the most common hierarchy or None
253
- return None
254
-
255
- def extract_hierarchy_idx(self, text):
256
- """
257
- Extract hierarchy index from text.
258
-
259
- Args:
260
- text: Input text string
261
-
262
- Returns:
263
- Hierarchy index if found, None otherwise
264
- """
265
- hierarchy = self.extract_hierarchy(text)
266
- if hierarchy:
267
- return self.class_to_idx[hierarchy]
268
- return None
269
-
270
- def get_hierarchy_embedding(self, text, embed_dim=config.hierarchy_emb_dim):
271
- """
272
- Create embedding from hierarchy index extracted from text.
273
-
274
- Args:
275
- text: Input text string
276
- embed_dim: Dimension of the embedding (default: hierarchy_emb_dim)
277
-
278
- Returns:
279
- Embedding tensor of shape (embed_dim,)
280
- """
281
- hierarchy_idx = self.extract_hierarchy_idx(text)
282
- if hierarchy_idx is not None:
283
- # Create one-hot encoding
284
- embedding = torch.zeros(embed_dim)
285
- # Use the hierarchy index to set some values
286
- start_idx = (hierarchy_idx * 3) % embed_dim
287
- embedding[start_idx] = 1.0
288
- embedding[(start_idx + 1) % embed_dim] = 0.5
289
- embedding[(start_idx + 2) % embed_dim] = 0.3
290
- return embedding
291
- else:
292
- # Return zero embedding for unknown hierarchy
293
- return torch.zeros(embed_dim)
294
-
295
- # -------------------------
296
- # 3) Models
297
- # -------------------------
298
-
299
- class PretrainedImageEncoder(nn.Module):
300
- """
301
- Image encoder based on pretrained ResNet18 for extracting image embeddings.
302
-
303
- Uses a pretrained ResNet18 backbone and freezes early layers to prevent overfitting.
304
- Adds a custom projection head to output embeddings of the specified dimension.
305
- """
306
-
307
- def __init__(self, embed_dim, dropout=0.3):
308
- """
309
- Initialize the pretrained image encoder.
310
-
311
- Args:
312
- embed_dim: Dimension of the output embedding
313
- dropout: Dropout rate for regularization (default: 0.3)
314
- """
315
- super().__init__()
316
-
317
- self.backbone = models.resnet18(pretrained=True)
318
- backbone_dim = 512
319
-
320
- # Remove the final classification layer
321
- self.backbone = nn.Sequential(*list(self.backbone.children())[:-1])
322
-
323
- # Add custom projection head
324
- self.projection = nn.Sequential(
325
- nn.Flatten(),
326
- nn.Dropout(dropout),
327
- nn.Linear(backbone_dim, embed_dim * 2),
328
- nn.ReLU(inplace=True),
329
- nn.Dropout(dropout * 0.5),
330
- nn.Linear(embed_dim * 2, embed_dim),
331
- nn.LayerNorm(embed_dim)
332
- )
333
-
334
- # Fine-tune only the last few layers
335
- self._freeze_backbone_layers()
336
-
337
- def _freeze_backbone_layers(self):
338
- """
339
- Freeze early layers to prevent overfitting.
340
-
341
- Freezes the first 70% of backbone layers, allowing only the last layers
342
- to be fine-tuned during training.
343
- """
344
- if hasattr(self.backbone, 'children'):
345
- layers = list(self.backbone.children())
346
- freeze_until = int(len(layers) * 0.7)
347
- for i, layer in enumerate(layers):
348
- if i < freeze_until:
349
- for param in layer.parameters():
350
- param.requires_grad = False
351
-
352
- def forward(self, x):
353
- """
354
- Forward pass through the image encoder.
355
-
356
- Args:
357
- x: Image tensor [batch_size, channels, height, width]
358
-
359
- Returns:
360
- Image embeddings [batch_size, embed_dim]
361
- """
362
- features = self.backbone(x)
363
- return self.projection(features)
364
-
365
- class HierarchyEncoder(nn.Module):
366
- """
367
- Encoder that takes hierarchy indices directly.
368
-
369
- Uses an embedding layer to convert hierarchy indices to embeddings,
370
- followed by a projection head to output embeddings of the specified dimension.
371
- """
372
-
373
- def __init__(self, num_hierarchies, embed_dim, dropout=0.3):
374
- """
375
- Initialize the hierarchy encoder.
376
-
377
- Args:
378
- num_hierarchies: Number of hierarchy classes
379
- embed_dim: Dimension of the output embedding
380
- dropout: Dropout rate for regularization (default: 0.3)
381
- """
382
- super().__init__()
383
- self.num_hierarchies = num_hierarchies
384
- self.embed_dim = embed_dim
385
-
386
- # Embedding layer
387
- self.embedding = nn.Embedding(num_hierarchies, embed_dim)
388
-
389
- # Projection layer
390
- self.projection = nn.Sequential(
391
- nn.Linear(embed_dim, embed_dim * 2),
392
- nn.ReLU(inplace=True),
393
- nn.Dropout(dropout),
394
- nn.Linear(embed_dim * 2, embed_dim),
395
- nn.LayerNorm(embed_dim)
396
- )
397
-
398
- # Initialize weights
399
- self._init_weights()
400
-
401
- def _init_weights(self):
402
- """
403
- Initialize weights properly using Xavier uniform initialization.
404
- """
405
- nn.init.xavier_uniform_(self.embedding.weight)
406
- for module in self.projection.modules():
407
- if isinstance(module, nn.Linear):
408
- nn.init.xavier_uniform_(module.weight)
409
- if module.bias is not None:
410
- nn.init.zeros_(module.bias)
411
-
412
- def forward(self, hierarchy_indices):
413
- """
414
- Forward pass through the hierarchy encoder.
415
-
416
- Args:
417
- hierarchy_indices: Tensor of hierarchy indices [batch_size]
418
-
419
- Returns:
420
- Hierarchy embeddings [batch_size, embed_dim]
421
-
422
- Note:
423
- Includes workaround for MPS device: embedding layers don't work well with MPS,
424
- so embedding lookup is done on CPU and results are moved back to device.
425
- """
426
- # hierarchy_indices: (B,) - batch of hierarchy indices
427
- # Workaround for MPS: embedding layers don't work well with MPS, so do lookup on CPU
428
- device = next(self.parameters()).device
429
- if device.type == 'mps':
430
- # Move indices to CPU for embedding lookup
431
- indices_cpu = hierarchy_indices.cpu()
432
- # Use functional embedding with explicit weight handling for MPS compatibility
433
- emb_weight = self.embedding.weight.cpu()
434
- emb = F.embedding(indices_cpu, emb_weight)
435
- # Move result back to model device (MPS) - ensure it's contiguous
436
- emb = emb.contiguous().to(device)
437
- else:
438
- emb = self.embedding(hierarchy_indices)
439
- # Ensure emb is on the same device as projection before calling it
440
- return self.projection(emb)
441
-
442
- class HierarchyClassifierHead(nn.Module):
443
- """
444
- Classifier head for hierarchy classification.
445
-
446
- Multi-layer perceptron that takes embeddings as input and outputs
447
- classification logits for hierarchy classes.
448
- """
449
-
450
- def __init__(self, in_dim, num_classes, hidden_dim=None, dropout=0.3):
451
- """
452
- Initialize the hierarchy classifier head.
453
-
454
- Args:
455
- in_dim: Input embedding dimension
456
- num_classes: Number of hierarchy classes
457
- hidden_dim: Hidden layer dimension (default: max(in_dim // 2, num_classes * 2))
458
- dropout: Dropout rate for regularization (default: 0.3)
459
- """
460
- super().__init__()
461
- if hidden_dim is None:
462
- hidden_dim = max(in_dim // 2, num_classes * 2)
463
-
464
- self.classifier = nn.Sequential(
465
- nn.Linear(in_dim, hidden_dim),
466
- nn.ReLU(inplace=True),
467
- nn.Dropout(dropout),
468
- nn.Linear(hidden_dim, hidden_dim // 2),
469
- nn.ReLU(inplace=True),
470
- nn.Dropout(dropout * 0.5),
471
- nn.Linear(hidden_dim // 2, num_classes)
472
- )
473
-
474
- def forward(self, x):
475
- """
476
- Forward pass through the classifier head.
477
-
478
- Args:
479
- x: Input embeddings [batch_size, in_dim]
480
-
481
- Returns:
482
- Classification logits [batch_size, num_classes]
483
- """
484
- return self.classifier(x)
485
-
486
- class Model(nn.Module):
487
- """
488
- Main hierarchy model for learning clothing category-aligned embeddings.
489
-
490
- Combines image encoder, hierarchy encoder, and classifier heads to learn
491
- aligned embeddings for images and text descriptions based on clothing categories.
492
- """
493
-
494
- def __init__(self, num_hierarchy_classes, embed_dim, dropout=0.3):
495
- """
496
- Initialize the hierarchy model.
497
-
498
- Args:
499
- num_hierarchy_classes: Number of hierarchy classes
500
- embed_dim: Dimension of the embedding space
501
- dropout: Dropout rate for regularization (default: 0.3)
502
- """
503
- super().__init__()
504
- self.img_enc = PretrainedImageEncoder(embed_dim, dropout)
505
- self.hierarchy_enc = HierarchyEncoder(num_hierarchy_classes, embed_dim, dropout)
506
- self.hierarchy_head_img = HierarchyClassifierHead(embed_dim, num_hierarchy_classes, dropout=dropout)
507
- self.hierarchy_head_txt = HierarchyClassifierHead(embed_dim, num_hierarchy_classes, dropout=dropout)
508
- self.num_hierarchy_classes = num_hierarchy_classes
509
-
510
- def forward(self, image=None, hierarchy_indices=None):
511
- """
512
- Forward pass through the model.
513
-
514
- Args:
515
- image: Optional image tensor [batch_size, channels, height, width]
516
- hierarchy_indices: Optional hierarchy indices tensor [batch_size]
517
-
518
- Returns:
519
- Dictionary containing:
520
- - 'z_img': Image embeddings [batch_size, embed_dim] (if image provided)
521
- - 'z_txt': Text embeddings [batch_size, embed_dim] (if hierarchy_indices provided)
522
- - 'hierarchy_logits_img': Image classification logits [batch_size, num_classes] (if image provided)
523
- - 'hierarchy_logits_txt': Text classification logits [batch_size, num_classes] (if hierarchy_indices provided)
524
- """
525
- out = {}
526
- if image is not None:
527
- z_img = self.img_enc(image)
528
- z_img = F.normalize(z_img, p=2, dim=1)
529
- hierarchy_logits_img = self.hierarchy_head_img(z_img)
530
- out['hierarchy_logits_img'] = hierarchy_logits_img
531
- out['z_img'] = z_img
532
-
533
- if hierarchy_indices is not None:
534
- z_txt = self.hierarchy_enc(hierarchy_indices)
535
- z_txt = F.normalize(z_txt, p=2, dim=1)
536
- hierarchy_logits_txt = self.hierarchy_head_txt(z_txt)
537
- out['hierarchy_logits_txt'] = hierarchy_logits_txt
538
- out['z_txt'] = z_txt
539
-
540
- return out
541
-
542
- def set_hierarchy_extractor(self, hierarchy_extractor):
543
- """
544
- Set the hierarchy extractor for text processing.
545
-
546
- Args:
547
- hierarchy_extractor: HierarchyExtractor instance
548
- """
549
- self.hierarchy_extractor = hierarchy_extractor
550
-
551
- def get_text_embeddings(self, text):
552
- """
553
- Get text embeddings for a given text string or list of strings.
554
-
555
- Args:
556
- text: Text string or list of text strings
557
-
558
- Returns:
559
- Text embeddings tensor [batch_size, embed_dim]
560
-
561
- Raises:
562
- ValueError: If hierarchy cannot be extracted from text
563
- """
564
-
565
- with torch.no_grad():
566
- # Get the device of the model
567
- model_device = next(self.parameters()).device
568
-
569
- # Handle case where text is a list/tuple of hierarchies
570
- if isinstance(text, (list, tuple)):
571
- # Process multiple hierarchies
572
- hierarchy_indices = []
573
- for hierarchy_text in text:
574
- if isinstance(hierarchy_text, str):
575
- hierarchy_idx = self.hierarchy_extractor.extract_hierarchy_idx(hierarchy_text)
576
- if hierarchy_idx is None:
577
- raise ValueError(f"Could not extract hierarchy for text: '{hierarchy_text}'. Available classes: {self.hierarchy_extractor.hierarchy_classes}")
578
- hierarchy_indices.append(hierarchy_idx)
579
- else:
580
- raise ValueError(f"Expected string, got {type(hierarchy_text)}: {hierarchy_text}")
581
-
582
- # Convert to tensor and move to device
583
- hierarchy_indices = torch.tensor(hierarchy_indices, device=model_device)
584
-
585
- # Get text embeddings for all hierarchies
586
- output = self.forward(hierarchy_indices=hierarchy_indices)
587
- return output['z_txt']
588
-
589
- # Handle single string case
590
- elif isinstance(text, str):
591
- # Extract hierarchy index from text
592
- hierarchy_idx = self.hierarchy_extractor.extract_hierarchy_idx(text)
593
- if hierarchy_idx is None:
594
- raise ValueError(f"Could not extract hierarchy for text: '{text}'. Available classes: {self.hierarchy_extractor.hierarchy_classes}")
595
-
596
- # Convert to tensor and move to device
597
- hierarchy_indices = torch.tensor([hierarchy_idx], device=model_device)
598
-
599
- # Get text embeddings
600
- output = self.forward(hierarchy_indices=hierarchy_indices)
601
- return output['z_txt']
602
-
603
- else:
604
- raise ValueError(f"Expected string or list/tuple of strings, got {type(text)}: {text}")
605
-
606
- def get_image_embeddings(self, image):
607
- """
608
- Get image embeddings for a given image tensor.
609
-
610
- Args:
611
- image: Image tensor [channels, height, width] or [batch_size, channels, height, width]
612
-
613
- Returns:
614
- Image embeddings tensor [batch_size, embed_dim]
615
-
616
- Raises:
617
- ValueError: If image is not a torch.Tensor
618
- """
619
- with torch.no_grad():
620
- if not isinstance(image, torch.Tensor):
621
- raise ValueError("Image must be a torch.Tensor")
622
-
623
- # Ensure image is on the same device as model
624
- device = next(self.parameters()).device
625
- if image.device != device:
626
- image = image.to(device)
627
-
628
- # Add batch dimension if needed
629
- if image.dim() == 3:
630
- image = image.unsqueeze(0)
631
-
632
- # Get image embeddings
633
- output = self.forward(image=image)
634
- return output['z_img']
635
-
636
- # -------------------------
637
- # 4) Loss functions
638
- # -------------------------
639
-
640
- class Loss(nn.Module):
641
- """
642
- Combined loss function for hierarchy model training.
643
-
644
- Combines classification loss, contrastive loss, and consistency loss
645
- to learn aligned embeddings while maintaining classification accuracy.
646
- """
647
-
648
- def __init__(self, hierarchy_classes, classification_weight=1.0,
649
- consistency_weight=0.3, contrastive_weight=0.2,
650
- temperature=0.07, label_smoothing=0.1):
651
- """
652
- Initialize the loss function.
653
-
654
- Args:
655
- hierarchy_classes: List of hierarchy class names
656
- classification_weight: Weight for classification loss (default: 1.0)
657
- consistency_weight: Weight for consistency loss (default: 0.3)
658
- contrastive_weight: Weight for contrastive loss (default: 0.2)
659
- temperature: Temperature scaling for contrastive loss (default: 0.07)
660
- label_smoothing: Label smoothing parameter (default: 0.1)
661
- """
662
- super().__init__()
663
- self.classification_weight = classification_weight
664
- self.consistency_weight = consistency_weight
665
- self.contrastive_weight = contrastive_weight
666
- self.temperature = temperature
667
-
668
- self.hierarchy_classes = sorted(list(set(hierarchy_classes)))
669
- self.num_classes = len(self.hierarchy_classes)
670
- self.class_to_idx = {cls: idx for idx, cls in enumerate(self.hierarchy_classes)}
671
-
672
- # Loss functions with label smoothing
673
- self.ce = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
674
- self.mse = nn.MSELoss()
675
-
676
- def contrastive_loss(self, img_emb, txt_emb):
677
- """
678
- InfoNCE contrastive loss for aligning image and text embeddings.
679
-
680
- Args:
681
- img_emb: Image embeddings [batch_size, embed_dim]
682
- txt_emb: Text embeddings [batch_size, embed_dim]
683
-
684
- Returns:
685
- Contrastive loss value
686
- """
687
- sim_matrix = torch.matmul(img_emb, txt_emb.T) / self.temperature
688
- labels = torch.arange(img_emb.size(0), device=img_emb.device)
689
-
690
- loss_i2t = F.cross_entropy(sim_matrix, labels)
691
- loss_t2i = F.cross_entropy(sim_matrix.T, labels)
692
-
693
- return (loss_i2t + loss_t2i) / 2
694
-
695
- def forward(self, img_logits, txt_logits, img_embeddings, txt_embeddings, target_hierarchies):
696
- """
697
- Forward pass through the loss function.
698
-
699
- Args:
700
- img_logits: Image classification logits [batch_size, num_classes]
701
- txt_logits: Text classification logits [batch_size, num_classes]
702
- img_embeddings: Image embeddings [batch_size, embed_dim]
703
- txt_embeddings: Text embeddings [batch_size, embed_dim]
704
- target_hierarchies: List of target hierarchy class names [batch_size]
705
-
706
- Returns:
707
- Combined loss value
708
- """
709
- device = img_embeddings.device
710
-
711
- # Convert hierarchy names to indices
712
- target_classes = torch.tensor([
713
- self.class_to_idx.get(hierarchy, 0) for hierarchy in target_hierarchies
714
- ], device=device)
715
-
716
- # 1. Classification loss
717
- classification_loss = (self.ce(img_logits, target_classes) +
718
- self.ce(txt_logits, target_classes)) / 2
719
-
720
- # 2. Contrastive loss for alignment
721
- contrastive_loss = self.contrastive_loss(img_embeddings, txt_embeddings)
722
-
723
- # 3. Consistency loss between modalities
724
- consistency_loss = self.mse(img_embeddings, txt_embeddings)
725
-
726
- # Combined loss
727
- total_loss = (self.classification_weight * classification_loss +
728
- self.contrastive_weight * contrastive_loss +
729
- self.consistency_weight * consistency_loss)
730
-
731
- return total_loss
732
-
733
- # -------------------------
734
- # 5) Training
735
- # -------------------------
736
-
737
- def collate_fn(batch, hierarchy_extractor):
738
- """
739
- Collate function for DataLoader that processes batches and extracts hierarchy indices.
740
-
741
- Args:
742
- batch: List of (image, description, hierarchy) tuples
743
- hierarchy_extractor: HierarchyExtractor instance
744
-
745
- Returns:
746
- Dictionary containing:
747
- - 'image': Stacked image tensors [batch_size, channels, height, width]
748
- - 'hierarchy_indices': Hierarchy indices tensor [batch_size]
749
- - hierarchy_column: List of hierarchy class names [batch_size]
750
- """
751
- images = torch.stack([b[0] for b in batch], dim=0)
752
- texts = [b[1] for b in batch]
753
- hierarchies = [b[2] for b in batch]
754
-
755
- # Extract hierarchy indices from texts
756
- hierarchy_indices = []
757
- for text in texts:
758
- idx = hierarchy_extractor.extract_hierarchy_idx(text)
759
- if idx is not None:
760
- hierarchy_indices.append(idx)
761
- else:
762
- # If no hierarchy found, use the target hierarchy
763
- target_hierarchy = hierarchies[len(hierarchy_indices)]
764
- idx = hierarchy_extractor.class_to_idx.get(target_hierarchy, 0)
765
- hierarchy_indices.append(idx)
766
-
767
- hierarchy_indices = torch.tensor(hierarchy_indices, dtype=torch.long)
768
-
769
- return {
770
- 'image': images,
771
- 'hierarchy_indices': hierarchy_indices,
772
- config.hierarchy_column: hierarchies
773
- }
774
-
775
- def calculate_accuracy(logits, target_hierarchies, hierarchy_classes):
776
- """
777
- Calculate classification accuracy.
778
-
779
- Args:
780
- logits: Classification logits [batch_size, num_classes]
781
- target_hierarchies: List of target hierarchy class names [batch_size]
782
- hierarchy_classes: List of hierarchy class names
783
-
784
- Returns:
785
- Accuracy score (float between 0 and 1)
786
- """
787
- batch_size = logits.size(0)
788
- correct = 0
789
- pred_indices = torch.argmax(logits, dim=1).cpu().numpy()
790
-
791
- for i in range(batch_size):
792
- pred_class = hierarchy_classes[pred_indices[i]] if pred_indices[i] < len(hierarchy_classes) else ""
793
- target_class = target_hierarchies[i]
794
- if pred_class == target_class:
795
- correct += 1
796
-
797
- return correct / batch_size
798
-
799
- def train_one_epoch(model, dataloader, optimizer, device, hierarchy_classes, scheduler=None):
800
- """
801
- Train the model for one epoch.
802
-
803
- Args:
804
- model: Model instance to train
805
- dataloader: DataLoader for training data
806
- optimizer: Optimizer instance
807
- device: Device to train on
808
- hierarchy_classes: List of hierarchy class names
809
- scheduler: Optional learning rate scheduler
810
-
811
- Returns:
812
- Dictionary containing training metrics:
813
- - 'loss': Average training loss
814
- - 'acc_img': Average image classification accuracy
815
- - 'acc_txt': Average text classification accuracy
816
- """
817
- model.train()
818
- total_loss = 0.0
819
- total_acc_img = 0.0
820
- total_acc_txt = 0.0
821
- num_batches = 0
822
-
823
- loss_fn = Loss(
824
- hierarchy_classes,
825
- classification_weight=1.0,
826
- consistency_weight=0.3,
827
- contrastive_weight=0.2,
828
- label_smoothing=0.1
829
- ).to(device)
830
-
831
- pbar = tqdm(dataloader, desc="Training", leave=False)
832
- for batch in pbar:
833
- images = batch['image'].to(device)
834
- hierarchy_indices = batch['hierarchy_indices'].to(device)
835
- target_hierarchies = batch[config.hierarchy_column]
836
-
837
- # Set dataset to training mode
838
- if hasattr(dataloader.dataset, 'set_training_mode'):
839
- dataloader.dataset.set_training_mode(True)
840
-
841
- out = model(image=images, hierarchy_indices=hierarchy_indices)
842
- hierarchy_logits_img = out['hierarchy_logits_img']
843
- hierarchy_logits_txt = out['hierarchy_logits_txt']
844
- z_img, z_txt = out['z_img'], out['z_txt']
845
-
846
- # Calculate loss
847
- loss = loss_fn(hierarchy_logits_img, hierarchy_logits_txt, z_img, z_txt, target_hierarchies)
848
-
849
- optimizer.zero_grad()
850
- loss.backward()
851
-
852
- # Gradient clipping
853
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
854
-
855
- optimizer.step()
856
-
857
- if scheduler is not None:
858
- scheduler.step()
859
-
860
- # Calculate accuracies
861
- acc_img = calculate_accuracy(hierarchy_logits_img, target_hierarchies, hierarchy_classes)
862
- acc_txt = calculate_accuracy(hierarchy_logits_txt, target_hierarchies, hierarchy_classes)
863
-
864
- total_loss += loss.item()
865
- total_acc_img += acc_img
866
- total_acc_txt += acc_txt
867
- num_batches += 1
868
-
869
- pbar.set_postfix({
870
- 'loss': f'{loss.item():.4f}',
871
- 'acc_img': f'{acc_img:.3f}',
872
- 'acc_txt': f'{acc_txt:.3f}',
873
- })
874
-
875
- return {
876
- 'loss': total_loss / num_batches,
877
- 'acc_img': total_acc_img / num_batches,
878
- 'acc_txt': total_acc_txt / num_batches
879
- }
880
-
881
- def validate(model, dataloader, device, hierarchy_classes):
882
- """
883
- Validate the model on validation data.
884
-
885
- Args:
886
- model: Model instance to validate
887
- dataloader: DataLoader for validation data
888
- device: Device to validate on
889
- hierarchy_classes: List of hierarchy class names
890
-
891
- Returns:
892
- Dictionary containing validation metrics:
893
- - 'loss': Average validation loss
894
- - 'acc_img': Average image classification accuracy
895
- - 'acc_txt': Average text classification accuracy
896
- """
897
- model.eval()
898
- total_loss = 0.0
899
- total_acc_img = 0.0
900
- total_acc_txt = 0.0
901
- num_batches = 0
902
-
903
- loss_fn = Loss(
904
- hierarchy_classes,
905
- classification_weight=1.0,
906
- consistency_weight=0.3,
907
- contrastive_weight=0.2
908
- ).to(device)
909
-
910
- pbar = tqdm(dataloader, desc="Validation", leave=False)
911
- with torch.no_grad():
912
- for batch in pbar:
913
- images = batch['image'].to(device)
914
- hierarchy_indices = batch['hierarchy_indices'].to(device)
915
- target_hierarchies = batch[config.hierarchy_column]
916
-
917
- # Set dataset to validation mode
918
- if hasattr(dataloader.dataset, 'set_training_mode'):
919
- dataloader.dataset.set_training_mode(False)
920
-
921
- out = model(image=images, hierarchy_indices=hierarchy_indices)
922
- hierarchy_logits_img = out['hierarchy_logits_img']
923
- hierarchy_logits_txt = out['hierarchy_logits_txt']
924
- z_img, z_txt = out['z_img'], out['z_txt']
925
-
926
- # Calculate loss
927
- loss = loss_fn(hierarchy_logits_img, hierarchy_logits_txt, z_img, z_txt, target_hierarchies)
928
-
929
- # Calculate accuracies
930
- acc_img = calculate_accuracy(hierarchy_logits_img, target_hierarchies, hierarchy_classes)
931
- acc_txt = calculate_accuracy(hierarchy_logits_txt, target_hierarchies, hierarchy_classes)
932
-
933
- total_loss += loss.item()
934
- total_acc_img += acc_img
935
- total_acc_txt += acc_txt
936
- num_batches += 1
937
-
938
- pbar.set_postfix({
939
- 'loss': f'{loss.item():.4f}',
940
- 'acc_img': f'{acc_img:.3f}',
941
- 'acc_txt': f'{acc_txt:.3f}',
942
- })
943
-
944
- return {
945
- 'loss': total_loss / num_batches,
946
- 'acc_img': total_acc_img / num_batches,
947
- 'acc_txt': total_acc_txt / num_batches
948
- }
949
-
950
- # -------------------------
951
- # 6) Main training script
952
- # -------------------------
953
-
954
- if __name__ == "__main__":
955
- # Configuration
956
- device = config.device
957
- batch_size = 16
958
- lr = 5e-5
959
- epochs = 20
960
- val_split = 0.2
961
- dropout = 0.4
962
- weight_decay = 1e-3
963
-
964
- print(f"🚀 Starting hierarchical training on device: {device}")
965
- print(f"📊 Config: {epochs} epochs, batch={batch_size}, lr={lr}, embed_dim={config.hierarchy_emb_dim}")
966
-
967
- # Load dataset
968
- print(f"📁 Using dataset: { config.local_dataset_path}")
969
- df = pd.read_csv(config.local_dataset_path)
970
- print(f"📁 Loaded {len(df)} samples")
971
-
972
- # Get unique hierarchy classes
973
- hierarchy_classes = sorted(df[config.hierarchy_column].unique().tolist())
974
- print(f"📋 Found {len(hierarchy_classes)} hierarchy classes")
975
-
976
- # Create hierarchy extractor
977
- hierarchy_extractor = HierarchyExtractor(hierarchy_classes, verbose=True)
978
-
979
- # Train/validation split
980
- train_df, val_df = train_test_split(
981
- df,
982
- test_size=val_split,
983
- random_state=42,
984
- stratify=df[config.hierarchy_column]
985
- )
986
- train_df = train_df.reset_index(drop=True)
987
- val_df = val_df.reset_index(drop=True)
988
-
989
- print(f"📈 Train: {len(train_df)}, Validation: {len(val_df)}")
990
-
991
- # Create datasets
992
- train_ds = HierarchyDataset(train_df, image_size=224)
993
- val_ds = HierarchyDataset(val_df, image_size=224)
994
-
995
- # Create data loaders
996
- train_dl = DataLoader(
997
- train_ds,
998
- batch_size=batch_size,
999
- shuffle=True,
1000
- collate_fn=lambda batch: collate_fn(batch, hierarchy_extractor)
1001
- )
1002
- val_dl = DataLoader(
1003
- val_ds,
1004
- batch_size=batch_size,
1005
- shuffle=False,
1006
- collate_fn=lambda batch: collate_fn(batch, hierarchy_extractor)
1007
- )
1008
-
1009
- # Create model
1010
- model = Model(
1011
- num_hierarchy_classes=len(hierarchy_classes),
1012
- embed_dim=config.hierarchy_emb_dim,
1013
- dropout=dropout
1014
- ).to(device)
1015
-
1016
- # Optimizer and scheduler
1017
- optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
1018
- scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2, eta_min=lr/10)
1019
-
1020
- print(f"🎯 Model parameters: {sum(p.numel() for p in model.parameters()):,}")
1021
- print("\n" + "="*80)
1022
-
1023
- # Training loop
1024
- best_val_loss = float('inf')
1025
- training_history = {'train_loss': [], 'val_loss': [], 'val_acc_img': [], 'val_acc_txt': []}
1026
-
1027
- for e in range(epochs):
1028
- print(f"\n🔄 Epoch {e+1}/{epochs}")
1029
- print("-" * 50)
1030
-
1031
- # Training
1032
- train_metrics = train_one_epoch(model, train_dl, optimizer, device, hierarchy_classes, scheduler)
1033
-
1034
- # Validation
1035
- val_metrics = validate(model, val_dl, device, hierarchy_classes)
1036
-
1037
- # Track history
1038
- training_history['train_loss'].append(train_metrics['loss'])
1039
- training_history['val_loss'].append(val_metrics['loss'])
1040
- training_history['val_acc_img'].append(val_metrics['acc_img'])
1041
- training_history['val_acc_txt'].append(val_metrics['acc_txt'])
1042
-
1043
- # Display results
1044
- print(f"📊 TRAIN - Loss: {train_metrics['loss']:.6f} | "
1045
- f"Img Acc: {train_metrics['acc_img']:.3f} | "
1046
- f"Txt Acc: {train_metrics['acc_txt']:.3f}")
1047
-
1048
- print(f"✅ VAL - Loss: {val_metrics['loss']:.6f} | "
1049
- f"Img Acc: {val_metrics['acc_img']:.3f} | "
1050
- f"Txt Acc: {val_metrics['acc_txt']:.3f}")
1051
-
1052
- # Save best model
1053
- if val_metrics['loss'] < best_val_loss:
1054
- best_val_loss = val_metrics['loss']
1055
- print(f"💾 New best validation loss! Saving model...")
1056
-
1057
- torch.save({
1058
- 'model_state': model.state_dict(),
1059
- 'hierarchy_classes': hierarchy_classes,
1060
- 'epoch': e+1,
1061
- 'config': {
1062
- 'embed_dim': config.hierarchy_emb_dim,
1063
- 'dropout': dropout
1064
- }
1065
- }, config.hierarchy_model_path)
1066
-
1067
- # Save model every 2 epochs
1068
- if (e + 1) % 2 == 0:
1069
- print(f"💾 Saving checkpoint at epoch {e+1}...")
1070
-
1071
- torch.save({
1072
- 'model_state': model.state_dict(),
1073
- 'hierarchy_classes': hierarchy_classes,
1074
- 'epoch': e+1,
1075
- 'config': {
1076
- 'embed_dim': config.hierarchy_emb_dim,
1077
- 'dropout': dropout
1078
- }
1079
- }, f"model_checkpoint_epoch_{e+1}.pth")
1080
-
1081
- print("\n" + "="*80)
1082
- print("🎉 Training completed!")
1083
- print(f"🏆 Best validation loss: {best_val_loss:.6f}")
1084
-
1085
- print(f"\n📈 Final validation accuracy: Image={training_history['val_acc_img'][-1]:.3f}, Text={training_history['val_acc_txt'][-1]:.3f}")