File size: 6,474 Bytes
ed99b9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
"""
OutfitTransformerCIR - Complementary Item Retrieval Model
==========================================================

Architecture based on Sarkar et al. with modifications:
- LaBSE instead of BERT for multilingual text encoding
- Set-wise Outfit Ranking Loss instead of InfoNCE

Usage:
    from model import OutfitTransformerCIR
    
    model = OutfitTransformerCIR()
    model.load_state_dict(torch.load("pytorch_model.bin"))
    model.eval()
    
    # context_images: (B, S, 512) - ResNet-18 features
    # context_texts: (B, S, 768) - LaBSE embeddings
    predicted = model(context_images, context_texts)
    # predicted: (B, 128) - Missing item embedding
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class OutfitTransformerCIR(nn.Module):
    """
    Complementary Item Retrieval Transformer
    
    Given context items (partial outfit), predicts the embedding of a missing item
    that would complete the outfit harmoniously.
    
    Architecture:
        - Visual projection: 512 (ResNet-18) → 64
        - Text projection: 768 (LaBSE) → 64  
        - Combined: 64 + 64 = 128 dim item embedding
        - Transformer Encoder: 6 layers, 16 heads
        - Learnable [QUERY] token for missing item prediction
    
    Args:
        embedding_dim (int): Final embedding dimension (default: 128)
        nhead (int): Number of attention heads (default: 16)
        num_layers (int): Number of transformer layers (default: 6)
        use_projection (bool): Whether to apply projection layers.
            - True: Input is raw features (512 + 768)
            - False: Input is pre-projected features (64 + 64)
    """
    
    def __init__(self, embedding_dim=128, nhead=16, num_layers=6, use_projection=True):
        super(OutfitTransformerCIR, self).__init__()
        
        self.use_projection = use_projection
        self.embedding_dim = embedding_dim
        
        # Projection layers (trained, not frozen)
        self.visual_proj = nn.Linear(512, 64)
        self.text_proj = nn.Linear(768, 64)
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead=nhead,
            dim_feedforward=512,
            batch_first=True,
            dropout=0.1
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Learnable query token (represents the missing item)
        self.query_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
        
        # Output projection with normalization
        self.output_proj = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim),
            nn.LayerNorm(embedding_dim)
        )
    
    def encode_items(self, images, texts):
        """
        Encode multiple items (for context).
        
        Args:
            images: (B, S, D_img) where D_img=512 (raw) or 64 (projected)
            texts: (B, S, D_txt) where D_txt=768 (raw) or 64 (projected)
        
        Returns:
            (B, S, 128) - Unified item embeddings
        """
        if self.use_projection:
            img_emb = self.visual_proj(images)
            txt_emb = self.text_proj(texts)
        else:
            img_emb = images
            txt_emb = texts
        
        return torch.cat((img_emb, txt_emb), dim=-1)
    
    def encode_single_item(self, image, text):
        """
        Encode a single item (for candidate scoring).
        
        Args:
            image: (B, D_img)
            text: (B, D_txt)
        
        Returns:
            (B, 128) - Item embedding
        """
        if self.use_projection:
            img_emb = self.visual_proj(image)
            txt_emb = self.text_proj(text)
        else:
            img_emb = image
            txt_emb = text
        
        return torch.cat((img_emb, txt_emb), dim=-1)
    
    def forward(self, context_images, context_texts, padding_mask=None):
        """
        Predict the embedding of a missing item.
        
        Args:
            context_images: (B, S, 512) - ResNet-18 features of context items
            context_texts: (B, S, 768) - LaBSE embeddings of context items
            padding_mask: (B, S) - True indicates padding positions
        
        Returns:
            (B, 128) - Predicted embedding for the missing item
        
        Example:
            >>> model = OutfitTransformerCIR()
            >>> # Outfit with 3 items: t-shirt, jeans, watch
            >>> img_features = torch.randn(1, 3, 512)  # ResNet-18 outputs
            >>> txt_features = torch.randn(1, 3, 768)  # LaBSE outputs
            >>> predicted = model(img_features, txt_features)
            >>> # predicted: (1, 128) - embedding for ideal 4th item (e.g., shoes)
        """
        batch_size = context_images.size(0)
        device = context_images.device
        
        # 1. Encode context items
        item_embeddings = self.encode_items(context_images, context_texts)
        
        # 2. Prepend learnable query token
        query = self.query_token.expand(batch_size, -1, -1)
        x = torch.cat([query, item_embeddings], dim=1)
        
        # 3. Build attention mask (query always attends, padding positions masked)
        if padding_mask is not None:
            query_mask = torch.zeros(batch_size, 1, dtype=torch.bool, device=device)
            full_mask = torch.cat([query_mask, padding_mask], dim=1)
        else:
            full_mask = None
        
        # 4. Transformer forward
        out = self.transformer(x, src_key_padding_mask=full_mask)
        
        # 5. Extract query output (first position)
        query_out = out[:, 0, :]
        
        # 6. Project and L2 normalize
        predicted = self.output_proj(query_out)
        predicted = F.normalize(predicted, p=2, dim=-1)
        
        return predicted


# Convenience function for loading
def load_model(checkpoint_path, device="cpu"):
    """
    Load a trained OutfitTransformerCIR model.
    
    Args:
        checkpoint_path: Path to pytorch_model.bin
        device: "cpu" or "cuda"
    
    Returns:
        Loaded model in eval mode
    """
    model = OutfitTransformerCIR(
        embedding_dim=128,
        nhead=16,
        num_layers=6,
        use_projection=True
    )
    
    state_dict = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    
    return model