""" OutfitTransformerCIR - Complementary Item Retrieval Model ========================================================== Architecture based on Sarkar et al. with modifications: - LaBSE instead of BERT for multilingual text encoding - Set-wise Outfit Ranking Loss instead of InfoNCE Usage: from model import OutfitTransformerCIR model = OutfitTransformerCIR() model.load_state_dict(torch.load("pytorch_model.bin")) model.eval() # context_images: (B, S, 512) - ResNet-18 features # context_texts: (B, S, 768) - LaBSE embeddings predicted = model(context_images, context_texts) # predicted: (B, 128) - Missing item embedding """ import torch import torch.nn as nn import torch.nn.functional as F class OutfitTransformerCIR(nn.Module): """ Complementary Item Retrieval Transformer Given context items (partial outfit), predicts the embedding of a missing item that would complete the outfit harmoniously. Architecture: - Visual projection: 512 (ResNet-18) → 64 - Text projection: 768 (LaBSE) → 64 - Combined: 64 + 64 = 128 dim item embedding - Transformer Encoder: 6 layers, 16 heads - Learnable [QUERY] token for missing item prediction Args: embedding_dim (int): Final embedding dimension (default: 128) nhead (int): Number of attention heads (default: 16) num_layers (int): Number of transformer layers (default: 6) use_projection (bool): Whether to apply projection layers. - True: Input is raw features (512 + 768) - False: Input is pre-projected features (64 + 64) """ def __init__(self, embedding_dim=128, nhead=16, num_layers=6, use_projection=True): super(OutfitTransformerCIR, self).__init__() self.use_projection = use_projection self.embedding_dim = embedding_dim # Projection layers (trained, not frozen) self.visual_proj = nn.Linear(512, 64) self.text_proj = nn.Linear(768, 64) # Transformer encoder encoder_layer = nn.TransformerEncoderLayer( d_model=embedding_dim, nhead=nhead, dim_feedforward=512, batch_first=True, dropout=0.1 ) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) # Learnable query token (represents the missing item) self.query_token = nn.Parameter(torch.randn(1, 1, embedding_dim)) # Output projection with normalization self.output_proj = nn.Sequential( nn.Linear(embedding_dim, embedding_dim), nn.LayerNorm(embedding_dim) ) def encode_items(self, images, texts): """ Encode multiple items (for context). Args: images: (B, S, D_img) where D_img=512 (raw) or 64 (projected) texts: (B, S, D_txt) where D_txt=768 (raw) or 64 (projected) Returns: (B, S, 128) - Unified item embeddings """ if self.use_projection: img_emb = self.visual_proj(images) txt_emb = self.text_proj(texts) else: img_emb = images txt_emb = texts return torch.cat((img_emb, txt_emb), dim=-1) def encode_single_item(self, image, text): """ Encode a single item (for candidate scoring). Args: image: (B, D_img) text: (B, D_txt) Returns: (B, 128) - Item embedding """ if self.use_projection: img_emb = self.visual_proj(image) txt_emb = self.text_proj(text) else: img_emb = image txt_emb = text return torch.cat((img_emb, txt_emb), dim=-1) def forward(self, context_images, context_texts, padding_mask=None): """ Predict the embedding of a missing item. Args: context_images: (B, S, 512) - ResNet-18 features of context items context_texts: (B, S, 768) - LaBSE embeddings of context items padding_mask: (B, S) - True indicates padding positions Returns: (B, 128) - Predicted embedding for the missing item Example: >>> model = OutfitTransformerCIR() >>> # Outfit with 3 items: t-shirt, jeans, watch >>> img_features = torch.randn(1, 3, 512) # ResNet-18 outputs >>> txt_features = torch.randn(1, 3, 768) # LaBSE outputs >>> predicted = model(img_features, txt_features) >>> # predicted: (1, 128) - embedding for ideal 4th item (e.g., shoes) """ batch_size = context_images.size(0) device = context_images.device # 1. Encode context items item_embeddings = self.encode_items(context_images, context_texts) # 2. Prepend learnable query token query = self.query_token.expand(batch_size, -1, -1) x = torch.cat([query, item_embeddings], dim=1) # 3. Build attention mask (query always attends, padding positions masked) if padding_mask is not None: query_mask = torch.zeros(batch_size, 1, dtype=torch.bool, device=device) full_mask = torch.cat([query_mask, padding_mask], dim=1) else: full_mask = None # 4. Transformer forward out = self.transformer(x, src_key_padding_mask=full_mask) # 5. Extract query output (first position) query_out = out[:, 0, :] # 6. Project and L2 normalize predicted = self.output_proj(query_out) predicted = F.normalize(predicted, p=2, dim=-1) return predicted # Convenience function for loading def load_model(checkpoint_path, device="cpu"): """ Load a trained OutfitTransformerCIR model. Args: checkpoint_path: Path to pytorch_model.bin device: "cpu" or "cuda" Returns: Loaded model in eval mode """ model = OutfitTransformerCIR( embedding_dim=128, nhead=16, num_layers=6, use_projection=True ) state_dict = torch.load(checkpoint_path, map_location=device) model.load_state_dict(state_dict) model.to(device) model.eval() return model