ReignShad0 commited on
Commit
a836d31
·
verified ·
1 Parent(s): 879b44d

Update app/utils/model_loader.py

Browse files

def predict function not being defined causes a syntax error when attempting to load the .safetensors model instead of the default onnx model

Files changed (1) hide show
  1. app/utils/model_loader.py +379 -379
app/utils/model_loader.py CHANGED
@@ -1,379 +1,379 @@
1
- import torch
2
- import torch.nn as nn
3
- from torch.nn import GroupNorm, LayerNorm
4
- import torch.nn.functional as F
5
- import torch.utils.checkpoint as checkpoint
6
- import timm
7
-
8
- class ViTWrapper(nn.Module):
9
- """Wrapper to make ViT compatible with feature extraction for ImageTagger"""
10
- def __init__(self, vit_model):
11
- super().__init__()
12
- self.vit = vit_model
13
- self.out_indices = (-1,) # mimic timm.features_only
14
-
15
- # Get patch size and embedding dim from the model
16
- self.patch_size = vit_model.patch_embed.patch_size[0]
17
- self.embed_dim = vit_model.embed_dim
18
-
19
- def forward(self, x):
20
- B = x.size(0)
21
-
22
- # ➊ patch tokens
23
- x = self.vit.patch_embed(x) # (B, N, C)
24
-
25
- # ➋ prepend CLS
26
- cls_tok = self.vit.cls_token.expand(B, -1, -1) # (B, 1, C)
27
- x = torch.cat((cls_tok, x), dim=1) # (B, 1+N, C)
28
-
29
- # ➌ add positional encodings (full, incl. CLS)
30
- if self.vit.pos_embed is not None:
31
- x = x + self.vit.pos_embed[:, : x.size(1), :]
32
-
33
- x = self.vit.pos_drop(x)
34
-
35
- for blk in self.vit.blocks:
36
- x = blk(x)
37
-
38
- x = self.vit.norm(x) # (B, 1+N, C)
39
-
40
- # ➍ split back out
41
- cls_final = x[:, 0] # (B, C)
42
- patch_tokens = x[:, 1:] # (B, N, C)
43
-
44
- # ➎ reshape patches to (B, C, H, W)
45
- B, N, C = patch_tokens.shape
46
- h = w = int(N ** 0.5) # square assumption
47
- patch_features = patch_tokens.permute(0, 2, 1).reshape(B, C, h, w)
48
-
49
- # Return **both**: (patch map, CLS)
50
- return patch_features, cls_final
51
-
52
- def set_grad_checkpointing(self, enable=True):
53
- """Enable gradient checkpointing if supported"""
54
- if hasattr(self.vit, 'set_grad_checkpointing'):
55
- self.vit.set_grad_checkpointing(enable)
56
- return True
57
- return False
58
-
59
- class ImageTagger(nn.Module):
60
- """
61
- ImageTagger with Vision Transformer backbone
62
- """
63
- def __init__(self, total_tags, dataset, model_name='vit_base_patch16_224',
64
- num_heads=16, dropout=0.1, pretrained=True, tag_context_size=256,
65
- use_gradient_checkpointing=False, img_size=224):
66
- super().__init__()
67
-
68
- # Store checkpointing config
69
- self.use_gradient_checkpointing = use_gradient_checkpointing
70
- self.model_name = model_name
71
- self.img_size = img_size
72
-
73
- # Debug and stats flags
74
- self._flags = {
75
- 'debug': False,
76
- 'model_stats': True
77
- }
78
-
79
- # Core model config
80
- self.dataset = dataset
81
- self.tag_context_size = tag_context_size
82
- self.total_tags = total_tags
83
-
84
- print(f"🏗️ Building ImageTagger with ViT backbone and {total_tags} tags")
85
- print(f" Backbone: {model_name}")
86
- print(f" Image size: {img_size}x{img_size}")
87
- print(f" Tag context size: {tag_context_size}")
88
- print(f" Gradient checkpointing: {use_gradient_checkpointing}")
89
- print(f" 🎯 Custom embeddings, PyTorch native attention, no ground truth inclusion")
90
-
91
- # 1. Vision Transformer Backbone
92
- print("📦 Loading Vision Transformer backbone...")
93
- self._load_vit_backbone()
94
-
95
- # Get backbone dimensions by running a test forward pass
96
- self._determine_backbone_dimensions()
97
-
98
- self.embedding_dim = self.backbone.embed_dim
99
-
100
- # 2. Custom Tag Embeddings (no CLIP)
101
- print("🎯 Using custom tag embeddings (no CLIP)")
102
- self.tag_embedding = nn.Embedding(total_tags, self.embedding_dim)
103
-
104
- # 3. Shared weights approach - tag bias for initial predictions
105
- print("🔗 Using shared weights between initial head and tag embeddings")
106
- self.tag_bias = nn.Parameter(torch.zeros(total_tags))
107
-
108
-
109
- # 4. Image token extraction (for attention AND global pooling)
110
- self.image_token_proj = nn.Identity()
111
-
112
- # 5. Tags-as-queries cross-attention (using PyTorch's optimized implementation)
113
- self.cross_attention = nn.MultiheadAttention(
114
- embed_dim=self.embedding_dim,
115
- num_heads=num_heads,
116
- dropout=dropout,
117
- batch_first=True # Use (batch, seq, feature) format
118
- )
119
- self.cross_norm = nn.LayerNorm(self.embedding_dim)
120
-
121
- # Initialize weights
122
- self._init_weights()
123
-
124
- # Enable gradient checkpointing
125
- if self.use_gradient_checkpointing:
126
- self._enable_gradient_checkpointing()
127
-
128
- print(f"✅ ImageTagger with ViT initialized!")
129
- self._print_parameter_count()
130
-
131
- def _load_vit_backbone(self):
132
- """Load Vision Transformer model from timm"""
133
- print(f" Loading from timm: {self.model_name}")
134
-
135
- # Load the ViT model (not features_only, we want the full model for token extraction)
136
- vit_model = timm.create_model(
137
- self.model_name,
138
- pretrained=True,
139
- img_size=self.img_size,
140
- num_classes=0 # Remove classification head
141
- )
142
-
143
- # Wrap it in our compatibility layer
144
- self.backbone = ViTWrapper(vit_model)
145
-
146
- print(f" ✅ ViT loaded successfully")
147
- print(f" Patch size: {self.backbone.patch_size}x{self.backbone.patch_size}")
148
- print(f" Embed dim: {self.backbone.embed_dim}")
149
-
150
- def _determine_backbone_dimensions(self):
151
- """Determine backbone output dimensions"""
152
- print(" 🔍 Determining backbone dimensions...")
153
-
154
- with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
155
- # Create a dummy input
156
- dummy_input = torch.randn(1, 3, self.img_size, self.img_size)
157
-
158
- # Get features
159
- backbone_features, cls_dummy = self.backbone(dummy_input)
160
- feature_tensor = backbone_features
161
-
162
- self.backbone_dim = feature_tensor.shape[1]
163
- self.feature_map_size = feature_tensor.shape[2]
164
-
165
- print(f" Backbone output: {self.backbone_dim}D, {self.feature_map_size}x{self.feature_map_size} spatial")
166
- print(f" Total patch tokens: {self.feature_map_size * self.feature_map_size}")
167
-
168
- def _enable_gradient_checkpointing(self):
169
- """Enable gradient checkpointing for memory efficiency"""
170
- print("🔄 Enabling gradient checkpointing...")
171
-
172
- # Enable checkpointing for ViT backbone
173
- if self.backbone.set_grad_checkpointing(True):
174
- print(" ✅ ViT backbone checkpointing enabled")
175
- else:
176
- print(" ⚠️ ViT backbone doesn't support built-in checkpointing, will checkpoint manually")
177
-
178
- def _checkpoint_backbone(self, x):
179
- """Wrapper for backbone with gradient checkpointing"""
180
- if self.use_gradient_checkpointing and self.training:
181
- return checkpoint.checkpoint(self.backbone, x, use_reentrant=False)
182
- else:
183
- return self.backbone(x)
184
-
185
- def _checkpoint_image_proj(self, x):
186
- """Wrapper for image projection with gradient checkpointing"""
187
- if self.use_gradient_checkpointing and self.training:
188
- return checkpoint.checkpoint(self.image_token_proj, x, use_reentrant=False)
189
- else:
190
- return self.image_token_proj(x)
191
-
192
- def _checkpoint_cross_attention(self, query, key, value):
193
- """Wrapper for cross attention with gradient checkpointing"""
194
- def _attention_forward(q, k, v):
195
- attended_features, _ = self.cross_attention(query=q, key=k, value=v)
196
- return self.cross_norm(attended_features)
197
-
198
- if self.use_gradient_checkpointing and self.training:
199
- return checkpoint.checkpoint(_attention_forward, query, key, value, use_reentrant=False)
200
- else:
201
- return _attention_forward(query, key, value)
202
-
203
- def _checkpoint_candidate_selection(self, initial_logits):
204
- """Wrapper for candidate selection with gradient checkpointing"""
205
- def _candidate_forward(logits):
206
- return self._get_candidate_tags(logits)
207
-
208
- if self.use_gradient_checkpointing and self.training:
209
- return checkpoint.checkpoint(_candidate_forward, initial_logits, use_reentrant=False)
210
- else:
211
- return _candidate_forward(initial_logits)
212
-
213
- def _checkpoint_final_scoring(self, attended_features, candidate_indices):
214
- """Wrapper for final scoring with gradient checkpointing"""
215
- def _scoring_forward(features, indices):
216
- emb = self.tag_embedding(indices)
217
- # BF16 in, BF16 out
218
- return (features * emb).sum(dim=-1)
219
-
220
- if self.use_gradient_checkpointing and self.training:
221
- return checkpoint.checkpoint(_scoring_forward, attended_features, candidate_indices, use_reentrant=False)
222
- else:
223
- return _scoring_forward(attended_features, candidate_indices)
224
-
225
- def _init_weights(self):
226
- """Initialize weights for new modules"""
227
- def _init_layer(layer):
228
- if isinstance(layer, nn.Linear):
229
- nn.init.xavier_uniform_(layer.weight)
230
- if layer.bias is not None:
231
- nn.init.zeros_(layer.bias)
232
- elif isinstance(layer, nn.Conv2d):
233
- nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
234
- if layer.bias is not None:
235
- nn.init.zeros_(layer.bias)
236
- elif isinstance(layer, nn.Embedding):
237
- nn.init.normal_(layer.weight, mean=0, std=0.02)
238
-
239
- # Initialize new components
240
- self.image_token_proj.apply(_init_layer)
241
-
242
- # Initialize tag embeddings with normal distribution
243
- nn.init.normal_(self.tag_embedding.weight, mean=0, std=0.02)
244
-
245
- # Initialize tag bias
246
- nn.init.zeros_(self.tag_bias)
247
-
248
- def _print_parameter_count(self):
249
- """Print parameter statistics"""
250
- total_params = sum(p.numel() for p in self.parameters())
251
- trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
252
- backbone_params = sum(p.numel() for p in self.backbone.parameters())
253
-
254
- print(f"📊 Parameter Statistics:")
255
- print(f" Total parameters: {total_params/1e6:.1f}M")
256
- print(f" Trainable parameters: {trainable_params/1e6:.1f}M")
257
- print(f" Frozen parameters: {(total_params-trainable_params)/1e6:.1f}M")
258
- print(f" Backbone parameters: {backbone_params/1e6:.1f}M")
259
-
260
- if self.use_gradient_checkpointing:
261
- print(f" 🔄 Gradient checkpointing enabled for memory efficiency")
262
-
263
- @property
264
- def debug(self):
265
- return self._flags['debug']
266
-
267
- @property
268
- def model_stats(self):
269
- return self._flags['model_stats']
270
-
271
- def _get_candidate_tags(self, initial_logits, target_tags=None, hard_negatives=None):
272
- """Select candidate tags - no ground truth inclusion"""
273
- batch_size = initial_logits.size(0)
274
-
275
- # Simply select top K candidates based on initial predictions
276
- top_probs, top_indices = torch.topk(
277
- torch.sigmoid(initial_logits),
278
- k=min(self.tag_context_size, self.total_tags),
279
- dim=1, largest=True, sorted=True
280
- )
281
-
282
- return top_indices
283
-
284
- def _analyze_predictions(self, predictions, tag_indices):
285
- """Analyze prediction patterns"""
286
- if not self.model_stats:
287
- return {}
288
-
289
- if torch._dynamo.is_compiling():
290
- return {}
291
-
292
- with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
293
- probs = torch.sigmoid(predictions)
294
- relevant_probs = torch.gather(probs, 1, tag_indices)
295
-
296
- return {
297
- 'prediction_confidence': relevant_probs.mean().item(),
298
- 'prediction_entropy': -(relevant_probs * torch.log(relevant_probs + 1e-9)).mean().item(),
299
- 'high_confidence_ratio': (relevant_probs > 0.7).float().mean().item(),
300
- 'above_threshold_ratio': (relevant_probs > 0.5).float().mean().item(),
301
- }
302
-
303
- def forward(self, x, targets=None, hard_negatives=None):
304
- """
305
- Forward pass with ViT backbone, CLS token support and gradient-checkpointing.
306
- All arithmetic tensors stay in the backbone’s dtype (BF16 under autocast,
307
- FP32 otherwise). Anything that must mix dtypes is cast to match.
308
- """
309
- batch_size = x.size(0)
310
- model_stats = {} if self.model_stats else {}
311
-
312
- # ------------------------------------------------------------------
313
- # 1. Backbone → patch map + CLS token
314
- # ------------------------------------------------------------------
315
- patch_map, cls_token = self._checkpoint_backbone(x) # patch_map: [B, C, H, W]
316
- # cls_token: [B, C]
317
-
318
- # ------------------------------------------------------------------
319
- # 2. Tokens → global image vector
320
- # ------------------------------------------------------------------
321
- image_tokens_4d = self._checkpoint_image_proj(patch_map) # [B, C, H, W]
322
- image_tokens = image_tokens_4d.flatten(2).transpose(1, 2) # [B, N, C]
323
-
324
- # “Dual-pool”: mean-pool patches ⊕ CLS
325
- global_features = 0.5 * (image_tokens.mean(dim=1, dtype=image_tokens.dtype) + cls_token) # [B, C]
326
-
327
- compute_dtype = global_features.dtype # BF16 or FP32
328
-
329
- # ------------------------------------------------------------------
330
- # 3. Initial logits (shared weights)
331
- # ------------------------------------------------------------------
332
- tag_weights = self.tag_embedding.weight.to(compute_dtype) # [T, C]
333
- tag_bias = self.tag_bias.to(compute_dtype) # [T]
334
-
335
- initial_logits = global_features @ tag_weights.t() + tag_bias # [B, T]
336
- initial_logits = initial_logits.to(compute_dtype) # keep dtype uniform
337
- initial_preds = initial_logits # alias
338
-
339
- # ------------------------------------------------------------------
340
- # 4. Candidate set
341
- # ------------------------------------------------------------------
342
- candidate_indices = self._checkpoint_candidate_selection(initial_logits) # [B, K]
343
-
344
- tag_embeddings = self.tag_embedding(candidate_indices).to(compute_dtype) # [B, K, C]
345
-
346
- attended_features = self._checkpoint_cross_attention( # [B, K, C]
347
- tag_embeddings, image_tokens, image_tokens
348
- )
349
-
350
- # ------------------------------------------------------------------
351
- # 5. Score candidates & scatter back
352
- # ------------------------------------------------------------------
353
- candidate_logits = self._checkpoint_final_scoring(attended_features, candidate_indices) # [B, K]
354
-
355
- # --- align dtypes so scatter never throws ---
356
- if candidate_logits.dtype != initial_logits.dtype:
357
- candidate_logits = candidate_logits.to(initial_logits.dtype)
358
-
359
- refined_logits = initial_logits.clone()
360
- refined_logits.scatter_(1, candidate_indices, candidate_logits)
361
- refined_preds = refined_logits
362
-
363
- # ------------------------------------------------------------------
364
- # 6. Optional stats
365
- # ------------------------------------------------------------------
366
- if self.model_stats and targets is not None and not torch._dynamo.is_compiling():
367
- model_stats['initial_prediction_stats'] = self._analyze_predictions(initial_preds,
368
- candidate_indices)
369
- model_stats['refined_prediction_stats'] = self._analyze_predictions(refined_preds,
370
- candidate_indices)
371
-
372
- return {
373
- 'initial_predictions': initial_preds,
374
- 'refined_predictions': refined_preds,
375
- 'selected_candidates': candidate_indices,
376
- 'model_stats': model_stats
377
- }
378
-
379
- def predict
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import GroupNorm, LayerNorm
4
+ import torch.nn.functional as F
5
+ import torch.utils.checkpoint as checkpoint
6
+ import timm
7
+
8
+ class ViTWrapper(nn.Module):
9
+ """Wrapper to make ViT compatible with feature extraction for ImageTagger"""
10
+ def __init__(self, vit_model):
11
+ super().__init__()
12
+ self.vit = vit_model
13
+ self.out_indices = (-1,) # mimic timm.features_only
14
+
15
+ # Get patch size and embedding dim from the model
16
+ self.patch_size = vit_model.patch_embed.patch_size[0]
17
+ self.embed_dim = vit_model.embed_dim
18
+
19
+ def forward(self, x):
20
+ B = x.size(0)
21
+
22
+ # ➊ patch tokens
23
+ x = self.vit.patch_embed(x) # (B, N, C)
24
+
25
+ # ➋ prepend CLS
26
+ cls_tok = self.vit.cls_token.expand(B, -1, -1) # (B, 1, C)
27
+ x = torch.cat((cls_tok, x), dim=1) # (B, 1+N, C)
28
+
29
+ # ➌ add positional encodings (full, incl. CLS)
30
+ if self.vit.pos_embed is not None:
31
+ x = x + self.vit.pos_embed[:, : x.size(1), :]
32
+
33
+ x = self.vit.pos_drop(x)
34
+
35
+ for blk in self.vit.blocks:
36
+ x = blk(x)
37
+
38
+ x = self.vit.norm(x) # (B, 1+N, C)
39
+
40
+ # ➍ split back out
41
+ cls_final = x[:, 0] # (B, C)
42
+ patch_tokens = x[:, 1:] # (B, N, C)
43
+
44
+ # ➎ reshape patches to (B, C, H, W)
45
+ B, N, C = patch_tokens.shape
46
+ h = w = int(N ** 0.5) # square assumption
47
+ patch_features = patch_tokens.permute(0, 2, 1).reshape(B, C, h, w)
48
+
49
+ # Return **both**: (patch map, CLS)
50
+ return patch_features, cls_final
51
+
52
+ def set_grad_checkpointing(self, enable=True):
53
+ """Enable gradient checkpointing if supported"""
54
+ if hasattr(self.vit, 'set_grad_checkpointing'):
55
+ self.vit.set_grad_checkpointing(enable)
56
+ return True
57
+ return False
58
+
59
+ class ImageTagger(nn.Module):
60
+ """
61
+ ImageTagger with Vision Transformer backbone
62
+ """
63
+ def __init__(self, total_tags, dataset, model_name='vit_base_patch16_224',
64
+ num_heads=16, dropout=0.1, pretrained=True, tag_context_size=256,
65
+ use_gradient_checkpointing=False, img_size=224):
66
+ super().__init__()
67
+
68
+ # Store checkpointing config
69
+ self.use_gradient_checkpointing = use_gradient_checkpointing
70
+ self.model_name = model_name
71
+ self.img_size = img_size
72
+
73
+ # Debug and stats flags
74
+ self._flags = {
75
+ 'debug': False,
76
+ 'model_stats': True
77
+ }
78
+
79
+ # Core model config
80
+ self.dataset = dataset
81
+ self.tag_context_size = tag_context_size
82
+ self.total_tags = total_tags
83
+
84
+ print(f"🏗️ Building ImageTagger with ViT backbone and {total_tags} tags")
85
+ print(f" Backbone: {model_name}")
86
+ print(f" Image size: {img_size}x{img_size}")
87
+ print(f" Tag context size: {tag_context_size}")
88
+ print(f" Gradient checkpointing: {use_gradient_checkpointing}")
89
+ print(f" 🎯 Custom embeddings, PyTorch native attention, no ground truth inclusion")
90
+
91
+ # 1. Vision Transformer Backbone
92
+ print("📦 Loading Vision Transformer backbone...")
93
+ self._load_vit_backbone()
94
+
95
+ # Get backbone dimensions by running a test forward pass
96
+ self._determine_backbone_dimensions()
97
+
98
+ self.embedding_dim = self.backbone.embed_dim
99
+
100
+ # 2. Custom Tag Embeddings (no CLIP)
101
+ print("🎯 Using custom tag embeddings (no CLIP)")
102
+ self.tag_embedding = nn.Embedding(total_tags, self.embedding_dim)
103
+
104
+ # 3. Shared weights approach - tag bias for initial predictions
105
+ print("🔗 Using shared weights between initial head and tag embeddings")
106
+ self.tag_bias = nn.Parameter(torch.zeros(total_tags))
107
+
108
+
109
+ # 4. Image token extraction (for attention AND global pooling)
110
+ self.image_token_proj = nn.Identity()
111
+
112
+ # 5. Tags-as-queries cross-attention (using PyTorch's optimized implementation)
113
+ self.cross_attention = nn.MultiheadAttention(
114
+ embed_dim=self.embedding_dim,
115
+ num_heads=num_heads,
116
+ dropout=dropout,
117
+ batch_first=True # Use (batch, seq, feature) format
118
+ )
119
+ self.cross_norm = nn.LayerNorm(self.embedding_dim)
120
+
121
+ # Initialize weights
122
+ self._init_weights()
123
+
124
+ # Enable gradient checkpointing
125
+ if self.use_gradient_checkpointing:
126
+ self._enable_gradient_checkpointing()
127
+
128
+ print(f"✅ ImageTagger with ViT initialized!")
129
+ self._print_parameter_count()
130
+
131
+ def _load_vit_backbone(self):
132
+ """Load Vision Transformer model from timm"""
133
+ print(f" Loading from timm: {self.model_name}")
134
+
135
+ # Load the ViT model (not features_only, we want the full model for token extraction)
136
+ vit_model = timm.create_model(
137
+ self.model_name,
138
+ pretrained=True,
139
+ img_size=self.img_size,
140
+ num_classes=0 # Remove classification head
141
+ )
142
+
143
+ # Wrap it in our compatibility layer
144
+ self.backbone = ViTWrapper(vit_model)
145
+
146
+ print(f" ✅ ViT loaded successfully")
147
+ print(f" Patch size: {self.backbone.patch_size}x{self.backbone.patch_size}")
148
+ print(f" Embed dim: {self.backbone.embed_dim}")
149
+
150
+ def _determine_backbone_dimensions(self):
151
+ """Determine backbone output dimensions"""
152
+ print(" 🔍 Determining backbone dimensions...")
153
+
154
+ with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
155
+ # Create a dummy input
156
+ dummy_input = torch.randn(1, 3, self.img_size, self.img_size)
157
+
158
+ # Get features
159
+ backbone_features, cls_dummy = self.backbone(dummy_input)
160
+ feature_tensor = backbone_features
161
+
162
+ self.backbone_dim = feature_tensor.shape[1]
163
+ self.feature_map_size = feature_tensor.shape[2]
164
+
165
+ print(f" Backbone output: {self.backbone_dim}D, {self.feature_map_size}x{self.feature_map_size} spatial")
166
+ print(f" Total patch tokens: {self.feature_map_size * self.feature_map_size}")
167
+
168
+ def _enable_gradient_checkpointing(self):
169
+ """Enable gradient checkpointing for memory efficiency"""
170
+ print("🔄 Enabling gradient checkpointing...")
171
+
172
+ # Enable checkpointing for ViT backbone
173
+ if self.backbone.set_grad_checkpointing(True):
174
+ print(" ✅ ViT backbone checkpointing enabled")
175
+ else:
176
+ print(" ⚠️ ViT backbone doesn't support built-in checkpointing, will checkpoint manually")
177
+
178
+ def _checkpoint_backbone(self, x):
179
+ """Wrapper for backbone with gradient checkpointing"""
180
+ if self.use_gradient_checkpointing and self.training:
181
+ return checkpoint.checkpoint(self.backbone, x, use_reentrant=False)
182
+ else:
183
+ return self.backbone(x)
184
+
185
+ def _checkpoint_image_proj(self, x):
186
+ """Wrapper for image projection with gradient checkpointing"""
187
+ if self.use_gradient_checkpointing and self.training:
188
+ return checkpoint.checkpoint(self.image_token_proj, x, use_reentrant=False)
189
+ else:
190
+ return self.image_token_proj(x)
191
+
192
+ def _checkpoint_cross_attention(self, query, key, value):
193
+ """Wrapper for cross attention with gradient checkpointing"""
194
+ def _attention_forward(q, k, v):
195
+ attended_features, _ = self.cross_attention(query=q, key=k, value=v)
196
+ return self.cross_norm(attended_features)
197
+
198
+ if self.use_gradient_checkpointing and self.training:
199
+ return checkpoint.checkpoint(_attention_forward, query, key, value, use_reentrant=False)
200
+ else:
201
+ return _attention_forward(query, key, value)
202
+
203
+ def _checkpoint_candidate_selection(self, initial_logits):
204
+ """Wrapper for candidate selection with gradient checkpointing"""
205
+ def _candidate_forward(logits):
206
+ return self._get_candidate_tags(logits)
207
+
208
+ if self.use_gradient_checkpointing and self.training:
209
+ return checkpoint.checkpoint(_candidate_forward, initial_logits, use_reentrant=False)
210
+ else:
211
+ return _candidate_forward(initial_logits)
212
+
213
+ def _checkpoint_final_scoring(self, attended_features, candidate_indices):
214
+ """Wrapper for final scoring with gradient checkpointing"""
215
+ def _scoring_forward(features, indices):
216
+ emb = self.tag_embedding(indices)
217
+ # BF16 in, BF16 out
218
+ return (features * emb).sum(dim=-1)
219
+
220
+ if self.use_gradient_checkpointing and self.training:
221
+ return checkpoint.checkpoint(_scoring_forward, attended_features, candidate_indices, use_reentrant=False)
222
+ else:
223
+ return _scoring_forward(attended_features, candidate_indices)
224
+
225
+ def _init_weights(self):
226
+ """Initialize weights for new modules"""
227
+ def _init_layer(layer):
228
+ if isinstance(layer, nn.Linear):
229
+ nn.init.xavier_uniform_(layer.weight)
230
+ if layer.bias is not None:
231
+ nn.init.zeros_(layer.bias)
232
+ elif isinstance(layer, nn.Conv2d):
233
+ nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
234
+ if layer.bias is not None:
235
+ nn.init.zeros_(layer.bias)
236
+ elif isinstance(layer, nn.Embedding):
237
+ nn.init.normal_(layer.weight, mean=0, std=0.02)
238
+
239
+ # Initialize new components
240
+ self.image_token_proj.apply(_init_layer)
241
+
242
+ # Initialize tag embeddings with normal distribution
243
+ nn.init.normal_(self.tag_embedding.weight, mean=0, std=0.02)
244
+
245
+ # Initialize tag bias
246
+ nn.init.zeros_(self.tag_bias)
247
+
248
+ def _print_parameter_count(self):
249
+ """Print parameter statistics"""
250
+ total_params = sum(p.numel() for p in self.parameters())
251
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
252
+ backbone_params = sum(p.numel() for p in self.backbone.parameters())
253
+
254
+ print(f"📊 Parameter Statistics:")
255
+ print(f" Total parameters: {total_params/1e6:.1f}M")
256
+ print(f" Trainable parameters: {trainable_params/1e6:.1f}M")
257
+ print(f" Frozen parameters: {(total_params-trainable_params)/1e6:.1f}M")
258
+ print(f" Backbone parameters: {backbone_params/1e6:.1f}M")
259
+
260
+ if self.use_gradient_checkpointing:
261
+ print(f" 🔄 Gradient checkpointing enabled for memory efficiency")
262
+
263
+ @property
264
+ def debug(self):
265
+ return self._flags['debug']
266
+
267
+ @property
268
+ def model_stats(self):
269
+ return self._flags['model_stats']
270
+
271
+ def _get_candidate_tags(self, initial_logits, target_tags=None, hard_negatives=None):
272
+ """Select candidate tags - no ground truth inclusion"""
273
+ batch_size = initial_logits.size(0)
274
+
275
+ # Simply select top K candidates based on initial predictions
276
+ top_probs, top_indices = torch.topk(
277
+ torch.sigmoid(initial_logits),
278
+ k=min(self.tag_context_size, self.total_tags),
279
+ dim=1, largest=True, sorted=True
280
+ )
281
+
282
+ return top_indices
283
+
284
+ def _analyze_predictions(self, predictions, tag_indices):
285
+ """Analyze prediction patterns"""
286
+ if not self.model_stats:
287
+ return {}
288
+
289
+ if torch._dynamo.is_compiling():
290
+ return {}
291
+
292
+ with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
293
+ probs = torch.sigmoid(predictions)
294
+ relevant_probs = torch.gather(probs, 1, tag_indices)
295
+
296
+ return {
297
+ 'prediction_confidence': relevant_probs.mean().item(),
298
+ 'prediction_entropy': -(relevant_probs * torch.log(relevant_probs + 1e-9)).mean().item(),
299
+ 'high_confidence_ratio': (relevant_probs > 0.7).float().mean().item(),
300
+ 'above_threshold_ratio': (relevant_probs > 0.5).float().mean().item(),
301
+ }
302
+
303
+ def forward(self, x, targets=None, hard_negatives=None):
304
+ """
305
+ Forward pass with ViT backbone, CLS token support and gradient-checkpointing.
306
+ All arithmetic tensors stay in the backbone’s dtype (BF16 under autocast,
307
+ FP32 otherwise). Anything that must mix dtypes is cast to match.
308
+ """
309
+ batch_size = x.size(0)
310
+ model_stats = {} if self.model_stats else {}
311
+
312
+ # ------------------------------------------------------------------
313
+ # 1. Backbone → patch map + CLS token
314
+ # ------------------------------------------------------------------
315
+ patch_map, cls_token = self._checkpoint_backbone(x) # patch_map: [B, C, H, W]
316
+ # cls_token: [B, C]
317
+
318
+ # ------------------------------------------------------------------
319
+ # 2. Tokens → global image vector
320
+ # ------------------------------------------------------------------
321
+ image_tokens_4d = self._checkpoint_image_proj(patch_map) # [B, C, H, W]
322
+ image_tokens = image_tokens_4d.flatten(2).transpose(1, 2) # [B, N, C]
323
+
324
+ # “Dual-pool”: mean-pool patches ⊕ CLS
325
+ global_features = 0.5 * (image_tokens.mean(dim=1, dtype=image_tokens.dtype) + cls_token) # [B, C]
326
+
327
+ compute_dtype = global_features.dtype # BF16 or FP32
328
+
329
+ # ------------------------------------------------------------------
330
+ # 3. Initial logits (shared weights)
331
+ # ------------------------------------------------------------------
332
+ tag_weights = self.tag_embedding.weight.to(compute_dtype) # [T, C]
333
+ tag_bias = self.tag_bias.to(compute_dtype) # [T]
334
+
335
+ initial_logits = global_features @ tag_weights.t() + tag_bias # [B, T]
336
+ initial_logits = initial_logits.to(compute_dtype) # keep dtype uniform
337
+ initial_preds = initial_logits # alias
338
+
339
+ # ------------------------------------------------------------------
340
+ # 4. Candidate set
341
+ # ------------------------------------------------------------------
342
+ candidate_indices = self._checkpoint_candidate_selection(initial_logits) # [B, K]
343
+
344
+ tag_embeddings = self.tag_embedding(candidate_indices).to(compute_dtype) # [B, K, C]
345
+
346
+ attended_features = self._checkpoint_cross_attention( # [B, K, C]
347
+ tag_embeddings, image_tokens, image_tokens
348
+ )
349
+
350
+ # ------------------------------------------------------------------
351
+ # 5. Score candidates & scatter back
352
+ # ------------------------------------------------------------------
353
+ candidate_logits = self._checkpoint_final_scoring(attended_features, candidate_indices) # [B, K]
354
+
355
+ # --- align dtypes so scatter never throws ---
356
+ if candidate_logits.dtype != initial_logits.dtype:
357
+ candidate_logits = candidate_logits.to(initial_logits.dtype)
358
+
359
+ refined_logits = initial_logits.clone()
360
+ refined_logits.scatter_(1, candidate_indices, candidate_logits)
361
+ refined_preds = refined_logits
362
+
363
+ # ------------------------------------------------------------------
364
+ # 6. Optional stats
365
+ # ------------------------------------------------------------------
366
+ if self.model_stats and targets is not None and not torch._dynamo.is_compiling():
367
+ model_stats['initial_prediction_stats'] = self._analyze_predictions(initial_preds,
368
+ candidate_indices)
369
+ model_stats['refined_prediction_stats'] = self._analyze_predictions(refined_preds,
370
+ candidate_indices)
371
+
372
+ return {
373
+ 'initial_predictions': initial_preds,
374
+ 'refined_predictions': refined_preds,
375
+ 'selected_candidates': candidate_indices,
376
+ 'model_stats': model_stats
377
+ }
378
+
379
+ # def predict