|
|
""" |
|
|
Hierarchy model for learning clothing category-aligned embeddings. |
|
|
This file contains the hierarchy model that learns to encode images and texts |
|
|
in an embedding space specialized for representing clothing categories (dress, shirt, etc.). |
|
|
It includes a regex pattern-based hierarchy extractor, a ResNet image encoder, |
|
|
a hierarchy embedding encoder, and loss functions for training. |
|
|
""" |
|
|
|
|
|
import pandas as pd |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torchvision import transforms, models |
|
|
from PIL import Image |
|
|
from tqdm import tqdm |
|
|
from sklearn.model_selection import train_test_split |
|
|
import re |
|
|
import requests |
|
|
from io import BytesIO |
|
|
import config |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HierarchyDataset(Dataset): |
|
|
""" |
|
|
Dataset class for hierarchy embedding training. |
|
|
|
|
|
Handles loading images from local paths or URLs, extracting hierarchy information |
|
|
from text descriptions, and applying appropriate transformations for training. |
|
|
""" |
|
|
|
|
|
def __init__(self, dataframe, use_local_images=True, image_size=224): |
|
|
""" |
|
|
Initialize the hierarchy dataset. |
|
|
|
|
|
Args: |
|
|
dataframe: DataFrame with columns for image paths/URLs, text descriptions, and hierarchy labels |
|
|
use_local_images: Whether to prefer local images over URLs (default: True) |
|
|
image_size: Size of images after resizing (default: 224) |
|
|
""" |
|
|
self.dataframe = dataframe |
|
|
self.use_local_images = use_local_images |
|
|
self.image_size = image_size |
|
|
|
|
|
|
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize((image_size, image_size)), |
|
|
transforms.RandomHorizontalFlip(p=0.3), |
|
|
transforms.RandomRotation(10), |
|
|
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
|
|
|
self.val_transform = transforms.Compose([ |
|
|
transforms.Resize((image_size, image_size)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
|
|
|
if use_local_images: |
|
|
if config.column_local_image_path not in dataframe.columns: |
|
|
print(f"β οΈ Column {config.column_local_image_path} not found. Using URLs.") |
|
|
self.use_local_images = False |
|
|
else: |
|
|
local_available = dataframe[config.column_local_image_path].notna().sum() |
|
|
total = len(dataframe) |
|
|
print(f"π Local images available: {local_available}/{total} ({local_available/total*100:.1f}%)") |
|
|
|
|
|
|
|
|
def set_training_mode(self, training=True): |
|
|
""" |
|
|
Switch between training and validation transforms. |
|
|
|
|
|
Args: |
|
|
training: If True, use training transforms with augmentation; if False, use validation transforms |
|
|
""" |
|
|
self.training_mode = training |
|
|
|
|
|
def __len__(self): |
|
|
"""Return the number of samples in the dataset.""" |
|
|
return len(self.dataframe) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
""" |
|
|
Get a sample from the dataset. |
|
|
|
|
|
Args: |
|
|
idx: Index of the sample |
|
|
|
|
|
Returns: |
|
|
Tuple of (image_tensor, description_text, hierarchy_label) |
|
|
""" |
|
|
row = self.dataframe.iloc[idx] |
|
|
|
|
|
|
|
|
if self.use_local_images and pd.notna(row.get(config.column_local_image_path, '')): |
|
|
local_path = row[config.column_local_image_path] |
|
|
image = Image.open(local_path).convert("RGB") |
|
|
|
|
|
elif isinstance(row[config.column_url_image], dict): |
|
|
image = Image.open(BytesIO(row[config.column_url_image]['bytes'])).convert('RGB') |
|
|
|
|
|
else: |
|
|
image = self._download_image(row[config.column_url_image]) |
|
|
|
|
|
|
|
|
if hasattr(self, 'training_mode') and not self.training_mode: |
|
|
image = self.val_transform(image) |
|
|
else: |
|
|
image = self.transform(image) |
|
|
|
|
|
description = row[config.text_column] |
|
|
hierarchy = row[config.hierarchy_column] |
|
|
|
|
|
return image, description, hierarchy |
|
|
|
|
|
def _download_image(self, img_url): |
|
|
""" |
|
|
Download an image from a URL with timeout. |
|
|
|
|
|
Args: |
|
|
img_url: URL of the image to download |
|
|
|
|
|
Returns: |
|
|
PIL Image object |
|
|
""" |
|
|
response = requests.get(img_url, timeout=10) |
|
|
response.raise_for_status() |
|
|
image = Image.open(BytesIO(response.content)).convert("RGB") |
|
|
return image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HierarchyExtractor: |
|
|
""" |
|
|
Extract hierarchy categories directly from text using pattern matching. |
|
|
|
|
|
This class uses regex patterns to identify clothing categories (e.g., shirt, dress) |
|
|
from text descriptions, handling variations, plurals, and common fashion terms. |
|
|
""" |
|
|
|
|
|
def __init__(self, hierarchy_classes, verbose=False): |
|
|
""" |
|
|
Initialize the hierarchy extractor. |
|
|
|
|
|
Args: |
|
|
hierarchy_classes: List of hierarchy class names |
|
|
verbose: Whether to print initialization information (default: False) |
|
|
""" |
|
|
self.hierarchy_classes = sorted(hierarchy_classes) |
|
|
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.hierarchy_classes)} |
|
|
self.idx_to_class = {idx: cls for idx, cls in enumerate(self.hierarchy_classes)} |
|
|
|
|
|
|
|
|
self.patterns = self._create_patterns() |
|
|
|
|
|
if verbose: |
|
|
print(f"π― Hierarchy extractor initialized with {len(self.hierarchy_classes)} classes") |
|
|
print(f"π Classes: {self.hierarchy_classes}") |
|
|
|
|
|
def _create_patterns(self): |
|
|
""" |
|
|
Create regex patterns for each hierarchy class. |
|
|
|
|
|
Creates patterns that match variations, plurals, and common fashion terms |
|
|
for each hierarchy class. |
|
|
|
|
|
Returns: |
|
|
Dictionary mapping hierarchy classes to regex patterns |
|
|
""" |
|
|
patterns = {} |
|
|
|
|
|
for hierarchy in self.hierarchy_classes: |
|
|
|
|
|
variations = [hierarchy.lower()] |
|
|
|
|
|
|
|
|
if '-' in hierarchy: |
|
|
variations.append(hierarchy.replace('-', ' ')) |
|
|
variations.append(hierarchy.replace('-', '')) |
|
|
|
|
|
|
|
|
if not hierarchy.endswith('s'): |
|
|
variations.append(hierarchy + 's') |
|
|
|
|
|
|
|
|
fashion_terms = { |
|
|
'shirt': ['shirt', 'shirts', 'tee', 't-shirt', 'tshirt'], |
|
|
'jacket': ['jacket', 'jackets', 'coat', 'coats'], |
|
|
'pant': ['pant', 'pants', 'trouser', 'trousers', 'jean', 'jeans'], |
|
|
'dress': ['dress', 'dresses'], |
|
|
'skirt': ['skirt', 'skirts'], |
|
|
'shoe': ['shoe', 'shoes', 'boot', 'boots', 'sneaker', 'sneakers'], |
|
|
'bag': ['bag', 'bags', 'handbag', 'handbags', 'purse', 'purses'], |
|
|
'hat': ['hat', 'hats', 'cap', 'caps'], |
|
|
'scarf': ['scarf', 'scarves'], |
|
|
'belt': ['belt', 'belts'], |
|
|
'sock': ['sock', 'socks'], |
|
|
'underwear': ['underwear', 'underpant', 'underpants'], |
|
|
'sweater': ['sweater', 'sweaters', 'jumper', 'jumpers'], |
|
|
'blouse': ['blouse', 'blouses'], |
|
|
'vest': ['vest', 'vests'], |
|
|
'short': ['short', 'shorts'], |
|
|
'legging': ['legging', 'leggings'], |
|
|
'suit': ['suit', 'suits'], |
|
|
'tie': ['tie', 'ties'], |
|
|
'glove': ['glove', 'gloves'], |
|
|
'sandal': ['sandal', 'sandals'] |
|
|
} |
|
|
|
|
|
|
|
|
for key, terms in fashion_terms.items(): |
|
|
if key in hierarchy.lower(): |
|
|
variations.extend(terms) |
|
|
|
|
|
|
|
|
pattern = r'\b(' + '|'.join(re.escape(v) for v in variations) + r')\b' |
|
|
patterns[hierarchy] = pattern |
|
|
|
|
|
return patterns |
|
|
|
|
|
def extract_hierarchy(self, text): |
|
|
""" |
|
|
Extract hierarchy category from text using pattern matching. |
|
|
|
|
|
Args: |
|
|
text: Input text string |
|
|
|
|
|
Returns: |
|
|
Hierarchy class name if found, None otherwise |
|
|
""" |
|
|
text_lower = text.lower() |
|
|
|
|
|
|
|
|
for hierarchy in self.hierarchy_classes: |
|
|
if hierarchy.lower() in text_lower: |
|
|
return hierarchy |
|
|
|
|
|
|
|
|
for hierarchy, pattern in self.patterns.items(): |
|
|
if re.search(pattern, text_lower): |
|
|
return hierarchy |
|
|
|
|
|
|
|
|
return None |
|
|
|
|
|
def extract_hierarchy_idx(self, text): |
|
|
""" |
|
|
Extract hierarchy index from text. |
|
|
|
|
|
Args: |
|
|
text: Input text string |
|
|
|
|
|
Returns: |
|
|
Hierarchy index if found, None otherwise |
|
|
""" |
|
|
hierarchy = self.extract_hierarchy(text) |
|
|
if hierarchy: |
|
|
return self.class_to_idx[hierarchy] |
|
|
return None |
|
|
|
|
|
def get_hierarchy_embedding(self, text, embed_dim=config.hierarchy_emb_dim): |
|
|
""" |
|
|
Create embedding from hierarchy index extracted from text. |
|
|
|
|
|
Args: |
|
|
text: Input text string |
|
|
embed_dim: Dimension of the embedding (default: hierarchy_emb_dim) |
|
|
|
|
|
Returns: |
|
|
Embedding tensor of shape (embed_dim,) |
|
|
""" |
|
|
hierarchy_idx = self.extract_hierarchy_idx(text) |
|
|
if hierarchy_idx is not None: |
|
|
|
|
|
embedding = torch.zeros(embed_dim) |
|
|
|
|
|
start_idx = (hierarchy_idx * 3) % embed_dim |
|
|
embedding[start_idx] = 1.0 |
|
|
embedding[(start_idx + 1) % embed_dim] = 0.5 |
|
|
embedding[(start_idx + 2) % embed_dim] = 0.3 |
|
|
return embedding |
|
|
else: |
|
|
|
|
|
return torch.zeros(embed_dim) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PretrainedImageEncoder(nn.Module): |
|
|
""" |
|
|
Image encoder based on pretrained ResNet18 for extracting image embeddings. |
|
|
|
|
|
Uses a pretrained ResNet18 backbone and freezes early layers to prevent overfitting. |
|
|
Adds a custom projection head to output embeddings of the specified dimension. |
|
|
""" |
|
|
|
|
|
def __init__(self, embed_dim, dropout=0.3): |
|
|
""" |
|
|
Initialize the pretrained image encoder. |
|
|
|
|
|
Args: |
|
|
embed_dim: Dimension of the output embedding |
|
|
dropout: Dropout rate for regularization (default: 0.3) |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.backbone = models.resnet18(pretrained=True) |
|
|
backbone_dim = 512 |
|
|
|
|
|
|
|
|
self.backbone = nn.Sequential(*list(self.backbone.children())[:-1]) |
|
|
|
|
|
|
|
|
self.projection = nn.Sequential( |
|
|
nn.Flatten(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(backbone_dim, embed_dim * 2), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout(dropout * 0.5), |
|
|
nn.Linear(embed_dim * 2, embed_dim), |
|
|
nn.LayerNorm(embed_dim) |
|
|
) |
|
|
|
|
|
|
|
|
self._freeze_backbone_layers() |
|
|
|
|
|
def _freeze_backbone_layers(self): |
|
|
""" |
|
|
Freeze early layers to prevent overfitting. |
|
|
|
|
|
Freezes the first 70% of backbone layers, allowing only the last layers |
|
|
to be fine-tuned during training. |
|
|
""" |
|
|
if hasattr(self.backbone, 'children'): |
|
|
layers = list(self.backbone.children()) |
|
|
freeze_until = int(len(layers) * 0.7) |
|
|
for i, layer in enumerate(layers): |
|
|
if i < freeze_until: |
|
|
for param in layer.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Forward pass through the image encoder. |
|
|
|
|
|
Args: |
|
|
x: Image tensor [batch_size, channels, height, width] |
|
|
|
|
|
Returns: |
|
|
Image embeddings [batch_size, embed_dim] |
|
|
""" |
|
|
features = self.backbone(x) |
|
|
return self.projection(features) |
|
|
|
|
|
class HierarchyEncoder(nn.Module): |
|
|
""" |
|
|
Encoder that takes hierarchy indices directly. |
|
|
|
|
|
Uses an embedding layer to convert hierarchy indices to embeddings, |
|
|
followed by a projection head to output embeddings of the specified dimension. |
|
|
""" |
|
|
|
|
|
def __init__(self, num_hierarchies, embed_dim, dropout=0.3): |
|
|
""" |
|
|
Initialize the hierarchy encoder. |
|
|
|
|
|
Args: |
|
|
num_hierarchies: Number of hierarchy classes |
|
|
embed_dim: Dimension of the output embedding |
|
|
dropout: Dropout rate for regularization (default: 0.3) |
|
|
""" |
|
|
super().__init__() |
|
|
self.num_hierarchies = num_hierarchies |
|
|
self.embed_dim = embed_dim |
|
|
|
|
|
|
|
|
self.embedding = nn.Embedding(num_hierarchies, embed_dim) |
|
|
|
|
|
|
|
|
self.projection = nn.Sequential( |
|
|
nn.Linear(embed_dim, embed_dim * 2), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(embed_dim * 2, embed_dim), |
|
|
nn.LayerNorm(embed_dim) |
|
|
) |
|
|
|
|
|
|
|
|
self._init_weights() |
|
|
|
|
|
def _init_weights(self): |
|
|
""" |
|
|
Initialize weights properly using Xavier uniform initialization. |
|
|
""" |
|
|
nn.init.xavier_uniform_(self.embedding.weight) |
|
|
for module in self.projection.modules(): |
|
|
if isinstance(module, nn.Linear): |
|
|
nn.init.xavier_uniform_(module.weight) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
def forward(self, hierarchy_indices): |
|
|
""" |
|
|
Forward pass through the hierarchy encoder. |
|
|
|
|
|
Args: |
|
|
hierarchy_indices: Tensor of hierarchy indices [batch_size] |
|
|
|
|
|
Returns: |
|
|
Hierarchy embeddings [batch_size, embed_dim] |
|
|
|
|
|
Note: |
|
|
Includes workaround for MPS device: embedding layers don't work well with MPS, |
|
|
so embedding lookup is done on CPU and results are moved back to device. |
|
|
""" |
|
|
|
|
|
|
|
|
device = next(self.parameters()).device |
|
|
if device.type == 'mps': |
|
|
|
|
|
indices_cpu = hierarchy_indices.cpu() |
|
|
|
|
|
emb_weight = self.embedding.weight.cpu() |
|
|
emb = F.embedding(indices_cpu, emb_weight) |
|
|
|
|
|
emb = emb.contiguous().to(device) |
|
|
else: |
|
|
emb = self.embedding(hierarchy_indices) |
|
|
|
|
|
return self.projection(emb) |
|
|
|
|
|
class HierarchyClassifierHead(nn.Module): |
|
|
""" |
|
|
Classifier head for hierarchy classification. |
|
|
|
|
|
Multi-layer perceptron that takes embeddings as input and outputs |
|
|
classification logits for hierarchy classes. |
|
|
""" |
|
|
|
|
|
def __init__(self, in_dim, num_classes, hidden_dim=None, dropout=0.3): |
|
|
""" |
|
|
Initialize the hierarchy classifier head. |
|
|
|
|
|
Args: |
|
|
in_dim: Input embedding dimension |
|
|
num_classes: Number of hierarchy classes |
|
|
hidden_dim: Hidden layer dimension (default: max(in_dim // 2, num_classes * 2)) |
|
|
dropout: Dropout rate for regularization (default: 0.3) |
|
|
""" |
|
|
super().__init__() |
|
|
if hidden_dim is None: |
|
|
hidden_dim = max(in_dim // 2, num_classes * 2) |
|
|
|
|
|
self.classifier = nn.Sequential( |
|
|
nn.Linear(in_dim, hidden_dim), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(hidden_dim, hidden_dim // 2), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout(dropout * 0.5), |
|
|
nn.Linear(hidden_dim // 2, num_classes) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Forward pass through the classifier head. |
|
|
|
|
|
Args: |
|
|
x: Input embeddings [batch_size, in_dim] |
|
|
|
|
|
Returns: |
|
|
Classification logits [batch_size, num_classes] |
|
|
""" |
|
|
return self.classifier(x) |
|
|
|
|
|
class Model(nn.Module): |
|
|
""" |
|
|
Main hierarchy model for learning clothing category-aligned embeddings. |
|
|
|
|
|
Combines image encoder, hierarchy encoder, and classifier heads to learn |
|
|
aligned embeddings for images and text descriptions based on clothing categories. |
|
|
""" |
|
|
|
|
|
def __init__(self, num_hierarchy_classes, embed_dim, dropout=0.3): |
|
|
""" |
|
|
Initialize the hierarchy model. |
|
|
|
|
|
Args: |
|
|
num_hierarchy_classes: Number of hierarchy classes |
|
|
embed_dim: Dimension of the embedding space |
|
|
dropout: Dropout rate for regularization (default: 0.3) |
|
|
""" |
|
|
super().__init__() |
|
|
self.img_enc = PretrainedImageEncoder(embed_dim, dropout) |
|
|
self.hierarchy_enc = HierarchyEncoder(num_hierarchy_classes, embed_dim, dropout) |
|
|
self.hierarchy_head_img = HierarchyClassifierHead(embed_dim, num_hierarchy_classes, dropout=dropout) |
|
|
self.hierarchy_head_txt = HierarchyClassifierHead(embed_dim, num_hierarchy_classes, dropout=dropout) |
|
|
self.num_hierarchy_classes = num_hierarchy_classes |
|
|
|
|
|
def forward(self, image=None, hierarchy_indices=None): |
|
|
""" |
|
|
Forward pass through the model. |
|
|
|
|
|
Args: |
|
|
image: Optional image tensor [batch_size, channels, height, width] |
|
|
hierarchy_indices: Optional hierarchy indices tensor [batch_size] |
|
|
|
|
|
Returns: |
|
|
Dictionary containing: |
|
|
- 'z_img': Image embeddings [batch_size, embed_dim] (if image provided) |
|
|
- 'z_txt': Text embeddings [batch_size, embed_dim] (if hierarchy_indices provided) |
|
|
- 'hierarchy_logits_img': Image classification logits [batch_size, num_classes] (if image provided) |
|
|
- 'hierarchy_logits_txt': Text classification logits [batch_size, num_classes] (if hierarchy_indices provided) |
|
|
""" |
|
|
out = {} |
|
|
if image is not None: |
|
|
z_img = self.img_enc(image) |
|
|
z_img = F.normalize(z_img, p=2, dim=1) |
|
|
hierarchy_logits_img = self.hierarchy_head_img(z_img) |
|
|
out['hierarchy_logits_img'] = hierarchy_logits_img |
|
|
out['z_img'] = z_img |
|
|
|
|
|
if hierarchy_indices is not None: |
|
|
z_txt = self.hierarchy_enc(hierarchy_indices) |
|
|
z_txt = F.normalize(z_txt, p=2, dim=1) |
|
|
hierarchy_logits_txt = self.hierarchy_head_txt(z_txt) |
|
|
out['hierarchy_logits_txt'] = hierarchy_logits_txt |
|
|
out['z_txt'] = z_txt |
|
|
|
|
|
return out |
|
|
|
|
|
def set_hierarchy_extractor(self, hierarchy_extractor): |
|
|
""" |
|
|
Set the hierarchy extractor for text processing. |
|
|
|
|
|
Args: |
|
|
hierarchy_extractor: HierarchyExtractor instance |
|
|
""" |
|
|
self.hierarchy_extractor = hierarchy_extractor |
|
|
|
|
|
def get_text_embeddings(self, text): |
|
|
""" |
|
|
Get text embeddings for a given text string or list of strings. |
|
|
|
|
|
Args: |
|
|
text: Text string or list of text strings |
|
|
|
|
|
Returns: |
|
|
Text embeddings tensor [batch_size, embed_dim] |
|
|
|
|
|
Raises: |
|
|
ValueError: If hierarchy cannot be extracted from text |
|
|
""" |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
model_device = next(self.parameters()).device |
|
|
|
|
|
|
|
|
if isinstance(text, (list, tuple)): |
|
|
|
|
|
hierarchy_indices = [] |
|
|
for hierarchy_text in text: |
|
|
if isinstance(hierarchy_text, str): |
|
|
hierarchy_idx = self.hierarchy_extractor.extract_hierarchy_idx(hierarchy_text) |
|
|
if hierarchy_idx is None: |
|
|
raise ValueError(f"Could not extract hierarchy for text: '{hierarchy_text}'. Available classes: {self.hierarchy_extractor.hierarchy_classes}") |
|
|
hierarchy_indices.append(hierarchy_idx) |
|
|
else: |
|
|
raise ValueError(f"Expected string, got {type(hierarchy_text)}: {hierarchy_text}") |
|
|
|
|
|
|
|
|
hierarchy_indices = torch.tensor(hierarchy_indices, device=model_device) |
|
|
|
|
|
|
|
|
output = self.forward(hierarchy_indices=hierarchy_indices) |
|
|
return output['z_txt'] |
|
|
|
|
|
|
|
|
elif isinstance(text, str): |
|
|
|
|
|
hierarchy_idx = self.hierarchy_extractor.extract_hierarchy_idx(text) |
|
|
if hierarchy_idx is None: |
|
|
raise ValueError(f"Could not extract hierarchy for text: '{text}'. Available classes: {self.hierarchy_extractor.hierarchy_classes}") |
|
|
|
|
|
|
|
|
hierarchy_indices = torch.tensor([hierarchy_idx], device=model_device) |
|
|
|
|
|
|
|
|
output = self.forward(hierarchy_indices=hierarchy_indices) |
|
|
return output['z_txt'] |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Expected string or list/tuple of strings, got {type(text)}: {text}") |
|
|
|
|
|
def get_image_embeddings(self, image): |
|
|
""" |
|
|
Get image embeddings for a given image tensor. |
|
|
|
|
|
Args: |
|
|
image: Image tensor [channels, height, width] or [batch_size, channels, height, width] |
|
|
|
|
|
Returns: |
|
|
Image embeddings tensor [batch_size, embed_dim] |
|
|
|
|
|
Raises: |
|
|
ValueError: If image is not a torch.Tensor |
|
|
""" |
|
|
with torch.no_grad(): |
|
|
if not isinstance(image, torch.Tensor): |
|
|
raise ValueError("Image must be a torch.Tensor") |
|
|
|
|
|
|
|
|
device = next(self.parameters()).device |
|
|
if image.device != device: |
|
|
image = image.to(device) |
|
|
|
|
|
|
|
|
if image.dim() == 3: |
|
|
image = image.unsqueeze(0) |
|
|
|
|
|
|
|
|
output = self.forward(image=image) |
|
|
return output['z_img'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Loss(nn.Module): |
|
|
""" |
|
|
Combined loss function for hierarchy model training. |
|
|
|
|
|
Combines classification loss, contrastive loss, and consistency loss |
|
|
to learn aligned embeddings while maintaining classification accuracy. |
|
|
""" |
|
|
|
|
|
def __init__(self, hierarchy_classes, classification_weight=1.0, |
|
|
consistency_weight=0.3, contrastive_weight=0.2, |
|
|
temperature=0.07, label_smoothing=0.1): |
|
|
""" |
|
|
Initialize the loss function. |
|
|
|
|
|
Args: |
|
|
hierarchy_classes: List of hierarchy class names |
|
|
classification_weight: Weight for classification loss (default: 1.0) |
|
|
consistency_weight: Weight for consistency loss (default: 0.3) |
|
|
contrastive_weight: Weight for contrastive loss (default: 0.2) |
|
|
temperature: Temperature scaling for contrastive loss (default: 0.07) |
|
|
label_smoothing: Label smoothing parameter (default: 0.1) |
|
|
""" |
|
|
super().__init__() |
|
|
self.classification_weight = classification_weight |
|
|
self.consistency_weight = consistency_weight |
|
|
self.contrastive_weight = contrastive_weight |
|
|
self.temperature = temperature |
|
|
|
|
|
self.hierarchy_classes = sorted(list(set(hierarchy_classes))) |
|
|
self.num_classes = len(self.hierarchy_classes) |
|
|
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.hierarchy_classes)} |
|
|
|
|
|
|
|
|
self.ce = nn.CrossEntropyLoss(label_smoothing=label_smoothing) |
|
|
self.mse = nn.MSELoss() |
|
|
|
|
|
def contrastive_loss(self, img_emb, txt_emb): |
|
|
""" |
|
|
InfoNCE contrastive loss for aligning image and text embeddings. |
|
|
|
|
|
Args: |
|
|
img_emb: Image embeddings [batch_size, embed_dim] |
|
|
txt_emb: Text embeddings [batch_size, embed_dim] |
|
|
|
|
|
Returns: |
|
|
Contrastive loss value |
|
|
""" |
|
|
sim_matrix = torch.matmul(img_emb, txt_emb.T) / self.temperature |
|
|
labels = torch.arange(img_emb.size(0), device=img_emb.device) |
|
|
|
|
|
loss_i2t = F.cross_entropy(sim_matrix, labels) |
|
|
loss_t2i = F.cross_entropy(sim_matrix.T, labels) |
|
|
|
|
|
return (loss_i2t + loss_t2i) / 2 |
|
|
|
|
|
def forward(self, img_logits, txt_logits, img_embeddings, txt_embeddings, target_hierarchies): |
|
|
""" |
|
|
Forward pass through the loss function. |
|
|
|
|
|
Args: |
|
|
img_logits: Image classification logits [batch_size, num_classes] |
|
|
txt_logits: Text classification logits [batch_size, num_classes] |
|
|
img_embeddings: Image embeddings [batch_size, embed_dim] |
|
|
txt_embeddings: Text embeddings [batch_size, embed_dim] |
|
|
target_hierarchies: List of target hierarchy class names [batch_size] |
|
|
|
|
|
Returns: |
|
|
Combined loss value |
|
|
""" |
|
|
device = img_embeddings.device |
|
|
|
|
|
|
|
|
target_classes = torch.tensor([ |
|
|
self.class_to_idx.get(hierarchy, 0) for hierarchy in target_hierarchies |
|
|
], device=device) |
|
|
|
|
|
|
|
|
classification_loss = (self.ce(img_logits, target_classes) + |
|
|
self.ce(txt_logits, target_classes)) / 2 |
|
|
|
|
|
|
|
|
contrastive_loss = self.contrastive_loss(img_embeddings, txt_embeddings) |
|
|
|
|
|
|
|
|
consistency_loss = self.mse(img_embeddings, txt_embeddings) |
|
|
|
|
|
|
|
|
total_loss = (self.classification_weight * classification_loss + |
|
|
self.contrastive_weight * contrastive_loss + |
|
|
self.consistency_weight * consistency_loss) |
|
|
|
|
|
return total_loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def collate_fn(batch, hierarchy_extractor): |
|
|
""" |
|
|
Collate function for DataLoader that processes batches and extracts hierarchy indices. |
|
|
|
|
|
Args: |
|
|
batch: List of (image, description, hierarchy) tuples |
|
|
hierarchy_extractor: HierarchyExtractor instance |
|
|
|
|
|
Returns: |
|
|
Dictionary containing: |
|
|
- 'image': Stacked image tensors [batch_size, channels, height, width] |
|
|
- 'hierarchy_indices': Hierarchy indices tensor [batch_size] |
|
|
- hierarchy_column: List of hierarchy class names [batch_size] |
|
|
""" |
|
|
images = torch.stack([b[0] for b in batch], dim=0) |
|
|
texts = [b[1] for b in batch] |
|
|
hierarchies = [b[2] for b in batch] |
|
|
|
|
|
|
|
|
hierarchy_indices = [] |
|
|
for text in texts: |
|
|
idx = hierarchy_extractor.extract_hierarchy_idx(text) |
|
|
if idx is not None: |
|
|
hierarchy_indices.append(idx) |
|
|
else: |
|
|
|
|
|
target_hierarchy = hierarchies[len(hierarchy_indices)] |
|
|
idx = hierarchy_extractor.class_to_idx.get(target_hierarchy, 0) |
|
|
hierarchy_indices.append(idx) |
|
|
|
|
|
hierarchy_indices = torch.tensor(hierarchy_indices, dtype=torch.long) |
|
|
|
|
|
return { |
|
|
'image': images, |
|
|
'hierarchy_indices': hierarchy_indices, |
|
|
config.hierarchy_column: hierarchies |
|
|
} |
|
|
|
|
|
def calculate_accuracy(logits, target_hierarchies, hierarchy_classes): |
|
|
""" |
|
|
Calculate classification accuracy. |
|
|
|
|
|
Args: |
|
|
logits: Classification logits [batch_size, num_classes] |
|
|
target_hierarchies: List of target hierarchy class names [batch_size] |
|
|
hierarchy_classes: List of hierarchy class names |
|
|
|
|
|
Returns: |
|
|
Accuracy score (float between 0 and 1) |
|
|
""" |
|
|
batch_size = logits.size(0) |
|
|
correct = 0 |
|
|
pred_indices = torch.argmax(logits, dim=1).cpu().numpy() |
|
|
|
|
|
for i in range(batch_size): |
|
|
pred_class = hierarchy_classes[pred_indices[i]] if pred_indices[i] < len(hierarchy_classes) else "" |
|
|
target_class = target_hierarchies[i] |
|
|
if pred_class == target_class: |
|
|
correct += 1 |
|
|
|
|
|
return correct / batch_size |
|
|
|
|
|
def train_one_epoch(model, dataloader, optimizer, device, hierarchy_classes, scheduler=None): |
|
|
""" |
|
|
Train the model for one epoch. |
|
|
|
|
|
Args: |
|
|
model: Model instance to train |
|
|
dataloader: DataLoader for training data |
|
|
optimizer: Optimizer instance |
|
|
device: Device to train on |
|
|
hierarchy_classes: List of hierarchy class names |
|
|
scheduler: Optional learning rate scheduler |
|
|
|
|
|
Returns: |
|
|
Dictionary containing training metrics: |
|
|
- 'loss': Average training loss |
|
|
- 'acc_img': Average image classification accuracy |
|
|
- 'acc_txt': Average text classification accuracy |
|
|
""" |
|
|
model.train() |
|
|
total_loss = 0.0 |
|
|
total_acc_img = 0.0 |
|
|
total_acc_txt = 0.0 |
|
|
num_batches = 0 |
|
|
|
|
|
loss_fn = Loss( |
|
|
hierarchy_classes, |
|
|
classification_weight=1.0, |
|
|
consistency_weight=0.3, |
|
|
contrastive_weight=0.2, |
|
|
label_smoothing=0.1 |
|
|
).to(device) |
|
|
|
|
|
pbar = tqdm(dataloader, desc="Training", leave=False) |
|
|
for batch in pbar: |
|
|
images = batch['image'].to(device) |
|
|
hierarchy_indices = batch['hierarchy_indices'].to(device) |
|
|
target_hierarchies = batch[config.hierarchy_column] |
|
|
|
|
|
|
|
|
if hasattr(dataloader.dataset, 'set_training_mode'): |
|
|
dataloader.dataset.set_training_mode(True) |
|
|
|
|
|
out = model(image=images, hierarchy_indices=hierarchy_indices) |
|
|
hierarchy_logits_img = out['hierarchy_logits_img'] |
|
|
hierarchy_logits_txt = out['hierarchy_logits_txt'] |
|
|
z_img, z_txt = out['z_img'], out['z_txt'] |
|
|
|
|
|
|
|
|
loss = loss_fn(hierarchy_logits_img, hierarchy_logits_txt, z_img, z_txt, target_hierarchies) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
loss.backward() |
|
|
|
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
|
|
|
|
|
optimizer.step() |
|
|
|
|
|
if scheduler is not None: |
|
|
scheduler.step() |
|
|
|
|
|
|
|
|
acc_img = calculate_accuracy(hierarchy_logits_img, target_hierarchies, hierarchy_classes) |
|
|
acc_txt = calculate_accuracy(hierarchy_logits_txt, target_hierarchies, hierarchy_classes) |
|
|
|
|
|
total_loss += loss.item() |
|
|
total_acc_img += acc_img |
|
|
total_acc_txt += acc_txt |
|
|
num_batches += 1 |
|
|
|
|
|
pbar.set_postfix({ |
|
|
'loss': f'{loss.item():.4f}', |
|
|
'acc_img': f'{acc_img:.3f}', |
|
|
'acc_txt': f'{acc_txt:.3f}', |
|
|
}) |
|
|
|
|
|
return { |
|
|
'loss': total_loss / num_batches, |
|
|
'acc_img': total_acc_img / num_batches, |
|
|
'acc_txt': total_acc_txt / num_batches |
|
|
} |
|
|
|
|
|
def validate(model, dataloader, device, hierarchy_classes): |
|
|
""" |
|
|
Validate the model on validation data. |
|
|
|
|
|
Args: |
|
|
model: Model instance to validate |
|
|
dataloader: DataLoader for validation data |
|
|
device: Device to validate on |
|
|
hierarchy_classes: List of hierarchy class names |
|
|
|
|
|
Returns: |
|
|
Dictionary containing validation metrics: |
|
|
- 'loss': Average validation loss |
|
|
- 'acc_img': Average image classification accuracy |
|
|
- 'acc_txt': Average text classification accuracy |
|
|
""" |
|
|
model.eval() |
|
|
total_loss = 0.0 |
|
|
total_acc_img = 0.0 |
|
|
total_acc_txt = 0.0 |
|
|
num_batches = 0 |
|
|
|
|
|
loss_fn = Loss( |
|
|
hierarchy_classes, |
|
|
classification_weight=1.0, |
|
|
consistency_weight=0.3, |
|
|
contrastive_weight=0.2 |
|
|
).to(device) |
|
|
|
|
|
pbar = tqdm(dataloader, desc="Validation", leave=False) |
|
|
with torch.no_grad(): |
|
|
for batch in pbar: |
|
|
images = batch['image'].to(device) |
|
|
hierarchy_indices = batch['hierarchy_indices'].to(device) |
|
|
target_hierarchies = batch[config.hierarchy_column] |
|
|
|
|
|
|
|
|
if hasattr(dataloader.dataset, 'set_training_mode'): |
|
|
dataloader.dataset.set_training_mode(False) |
|
|
|
|
|
out = model(image=images, hierarchy_indices=hierarchy_indices) |
|
|
hierarchy_logits_img = out['hierarchy_logits_img'] |
|
|
hierarchy_logits_txt = out['hierarchy_logits_txt'] |
|
|
z_img, z_txt = out['z_img'], out['z_txt'] |
|
|
|
|
|
|
|
|
loss = loss_fn(hierarchy_logits_img, hierarchy_logits_txt, z_img, z_txt, target_hierarchies) |
|
|
|
|
|
|
|
|
acc_img = calculate_accuracy(hierarchy_logits_img, target_hierarchies, hierarchy_classes) |
|
|
acc_txt = calculate_accuracy(hierarchy_logits_txt, target_hierarchies, hierarchy_classes) |
|
|
|
|
|
total_loss += loss.item() |
|
|
total_acc_img += acc_img |
|
|
total_acc_txt += acc_txt |
|
|
num_batches += 1 |
|
|
|
|
|
pbar.set_postfix({ |
|
|
'loss': f'{loss.item():.4f}', |
|
|
'acc_img': f'{acc_img:.3f}', |
|
|
'acc_txt': f'{acc_txt:.3f}', |
|
|
}) |
|
|
|
|
|
return { |
|
|
'loss': total_loss / num_batches, |
|
|
'acc_img': total_acc_img / num_batches, |
|
|
'acc_txt': total_acc_txt / num_batches |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
device = config.device |
|
|
batch_size = 16 |
|
|
lr = 5e-5 |
|
|
epochs = 20 |
|
|
val_split = 0.2 |
|
|
dropout = 0.4 |
|
|
weight_decay = 1e-3 |
|
|
|
|
|
print(f"π Starting hierarchical training on device: {device}") |
|
|
print(f"π Config: {epochs} epochs, batch={batch_size}, lr={lr}, embed_dim={config.hierarchy_emb_dim}") |
|
|
|
|
|
|
|
|
print(f"π Using dataset: { config.local_dataset_path}") |
|
|
df = pd.read_csv(config.local_dataset_path) |
|
|
print(f"π Loaded {len(df)} samples") |
|
|
|
|
|
|
|
|
hierarchy_classes = sorted(df[config.hierarchy_column].unique().tolist()) |
|
|
print(f"π Found {len(hierarchy_classes)} hierarchy classes") |
|
|
|
|
|
|
|
|
hierarchy_extractor = HierarchyExtractor(hierarchy_classes, verbose=True) |
|
|
|
|
|
|
|
|
train_df, val_df = train_test_split( |
|
|
df, |
|
|
test_size=val_split, |
|
|
random_state=42, |
|
|
stratify=df[config.hierarchy_column] |
|
|
) |
|
|
train_df = train_df.reset_index(drop=True) |
|
|
val_df = val_df.reset_index(drop=True) |
|
|
|
|
|
print(f"π Train: {len(train_df)}, Validation: {len(val_df)}") |
|
|
|
|
|
|
|
|
train_ds = HierarchyDataset(train_df, image_size=224) |
|
|
val_ds = HierarchyDataset(val_df, image_size=224) |
|
|
|
|
|
|
|
|
train_dl = DataLoader( |
|
|
train_ds, |
|
|
batch_size=batch_size, |
|
|
shuffle=True, |
|
|
collate_fn=lambda batch: collate_fn(batch, hierarchy_extractor) |
|
|
) |
|
|
val_dl = DataLoader( |
|
|
val_ds, |
|
|
batch_size=batch_size, |
|
|
shuffle=False, |
|
|
collate_fn=lambda batch: collate_fn(batch, hierarchy_extractor) |
|
|
) |
|
|
|
|
|
|
|
|
model = Model( |
|
|
num_hierarchy_classes=len(hierarchy_classes), |
|
|
embed_dim=config.hierarchy_emb_dim, |
|
|
dropout=dropout |
|
|
).to(device) |
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) |
|
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2, eta_min=lr/10) |
|
|
|
|
|
print(f"π― Model parameters: {sum(p.numel() for p in model.parameters()):,}") |
|
|
print("\n" + "="*80) |
|
|
|
|
|
|
|
|
best_val_loss = float('inf') |
|
|
training_history = {'train_loss': [], 'val_loss': [], 'val_acc_img': [], 'val_acc_txt': []} |
|
|
|
|
|
for e in range(epochs): |
|
|
print(f"\nπ Epoch {e+1}/{epochs}") |
|
|
print("-" * 50) |
|
|
|
|
|
|
|
|
train_metrics = train_one_epoch(model, train_dl, optimizer, device, hierarchy_classes, scheduler) |
|
|
|
|
|
|
|
|
val_metrics = validate(model, val_dl, device, hierarchy_classes) |
|
|
|
|
|
|
|
|
training_history['train_loss'].append(train_metrics['loss']) |
|
|
training_history['val_loss'].append(val_metrics['loss']) |
|
|
training_history['val_acc_img'].append(val_metrics['acc_img']) |
|
|
training_history['val_acc_txt'].append(val_metrics['acc_txt']) |
|
|
|
|
|
|
|
|
print(f"π TRAIN - Loss: {train_metrics['loss']:.6f} | " |
|
|
f"Img Acc: {train_metrics['acc_img']:.3f} | " |
|
|
f"Txt Acc: {train_metrics['acc_txt']:.3f}") |
|
|
|
|
|
print(f"β
VAL - Loss: {val_metrics['loss']:.6f} | " |
|
|
f"Img Acc: {val_metrics['acc_img']:.3f} | " |
|
|
f"Txt Acc: {val_metrics['acc_txt']:.3f}") |
|
|
|
|
|
|
|
|
if val_metrics['loss'] < best_val_loss: |
|
|
best_val_loss = val_metrics['loss'] |
|
|
print(f"πΎ New best validation loss! Saving model...") |
|
|
|
|
|
torch.save({ |
|
|
'model_state': model.state_dict(), |
|
|
'hierarchy_classes': hierarchy_classes, |
|
|
'epoch': e+1, |
|
|
'config': { |
|
|
'embed_dim': config.hierarchy_emb_dim, |
|
|
'dropout': dropout |
|
|
} |
|
|
}, config.hierarchy_model_path) |
|
|
|
|
|
|
|
|
if (e + 1) % 2 == 0: |
|
|
print(f"πΎ Saving checkpoint at epoch {e+1}...") |
|
|
|
|
|
torch.save({ |
|
|
'model_state': model.state_dict(), |
|
|
'hierarchy_classes': hierarchy_classes, |
|
|
'epoch': e+1, |
|
|
'config': { |
|
|
'embed_dim': config.hierarchy_emb_dim, |
|
|
'dropout': dropout |
|
|
} |
|
|
}, f"model_checkpoint_epoch_{e+1}.pth") |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("π Training completed!") |
|
|
print(f"π Best validation loss: {best_val_loss:.6f}") |
|
|
|
|
|
print(f"\nπ Final validation accuracy: Image={training_history['val_acc_img'][-1]:.3f}, Text={training_history['val_acc_txt'][-1]:.3f}") |
|
|
|