AbstractPhil commited on
Commit
41cc7d5
·
verified ·
1 Parent(s): dccfee9

Create model_manager.py

Browse files
Files changed (1) hide show
  1. model_manager.py +918 -0
model_manager.py ADDED
@@ -0,0 +1,918 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pentachora batch generation and model creation.
3
+ Assumes vocab is already loaded as 'vocab'.
4
+ Assumes PentachoronStabilizer is already loaded.
5
+ Assumes BaselineViT is already loaded.
6
+ """
7
+
8
+ import torch
9
+ import numpy as np
10
+
11
+ # CIFAR-100 class names
12
+ CIFAR100_CLASSES = [
13
+ 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
14
+ 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
15
+ 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
16
+ 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
17
+ 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
18
+ 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
19
+ 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
20
+ 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
21
+ 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
22
+ 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
23
+ 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
24
+ 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table',
25
+ 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
26
+ 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
27
+ ]
28
+
29
+ #config = {
30
+ # 'head_type': 'roseface', # 'roseface' | 'legacy'
31
+ # 'prototype_mode': 'centroid', # 'centroid' | 'rose5' | 'max_vertex'
32
+ # 'margin_type': 'cosface', # 'arcface' | 'cosface' | 'sphereface'
33
+ # 'margin_m': 0.30,
34
+ # 'scale_s': 30.0,
35
+ # 'apply_margin_train_only': False,
36
+ # 'norm_type': 'l1', # 'l1' | 'l2' normalization
37
+ # 'similarity_mode': 'rose', # legacy
38
+ #}
39
+
40
+ # Model variant configurations
41
+ MODEL_CONFIGS = {
42
+ # Ultra-light
43
+
44
+ 'vit_beatrix_shaper': {
45
+ 'embed_dim': 256,
46
+ 'vocab_dim': 256,
47
+ 'depth': 16,
48
+ 'num_heads': 8,
49
+ 'mlp_ratio': 1.0,
50
+ #'norm_type': 'l1',
51
+ 'margin_type': 'cosface',
52
+ 'margin_m': 0.30,
53
+ 'scale_s': 30.0,
54
+ },
55
+ 'vit_beatrix_arc_shaper': {
56
+ 'embed_dim': 256,
57
+ 'vocab_dim': 256,
58
+ 'depth': 16,
59
+ 'num_heads': 8,
60
+ 'mlp_ratio': 2.0,
61
+ #'norm_type': 'l1',
62
+ 'margin_type': 'arcface',
63
+ 'margin_m': 0.2914,
64
+ 'scale_s': 30.0,
65
+ },
66
+ 'vit_beatrix_nano_arc': {
67
+ 'embed_dim': 64,
68
+ 'vocab_dim': 64,
69
+ 'depth': 25,
70
+ 'num_heads': 8,
71
+ 'mlp_ratio': 8.0,
72
+ #'norm_type': 'l1',
73
+ 'margin_type': 'arcface',
74
+ 'margin_m': 0.2914,
75
+ 'scale_s': 30.0,
76
+ },
77
+ 'vit_beatrix_nano_cos': {
78
+ 'embed_dim': 64,
79
+ 'vocab_dim': 64,
80
+ 'depth': 25,
81
+ 'num_heads': 8,
82
+ 'mlp_ratio': 8.0,
83
+ #'norm_type': 'l1',
84
+ 'margin_type': 'cosface',
85
+ 'margin_m': 0.2914,
86
+ 'scale_s': 30.0,
87
+ },
88
+ 'vit_beatrix_nano_128_cos': {
89
+ 'embed_dim': 128,
90
+ 'vocab_dim': 128,
91
+ 'depth': 25,
92
+ 'num_heads': 8,
93
+ 'mlp_ratio': 8.0,
94
+ #'norm_type': 'l1',
95
+ 'margin_type': 'cosface',
96
+ 'margin_m': 0.2914,
97
+ 'scale_s': 30.0,
98
+ },
99
+ 'vit_beatrix_mini_cos': {
100
+ 'embed_dim': 256,
101
+ 'vocab_dim': 256,
102
+ 'depth': 25,
103
+ 'num_heads': 8,
104
+ 'mlp_ratio': 8.0,
105
+ #'norm_type': 'l1',
106
+ 'margin_type': 'cosface',
107
+ 'margin_m': 0.2914,
108
+ 'scale_s': 30.0,
109
+ },
110
+ 'vit_beatrix_mini_cos_large_margin': {
111
+ 'embed_dim': 256,
112
+ 'vocab_dim': 256,
113
+ 'depth': 25,
114
+ 'num_heads': 8,
115
+ 'mlp_ratio': 8.0,
116
+ #'norm_type': 'l1',
117
+ 'margin_type': 'cosface',
118
+ 'margin_m': 0.7086,
119
+ 'scale_s': 30.0,
120
+ },
121
+ 'vit_zana_nano': {
122
+ 'embed_dim': 128,
123
+ 'vocab_dim': 128,
124
+ 'depth': 4,
125
+ 'num_heads': 2,
126
+ 'mlp_ratio': 2.0
127
+ },
128
+ 'vit_beatrix_base_cos': {
129
+ 'embed_dim': 512,
130
+ 'vocab_dim': 512,
131
+ 'depth': 25,
132
+ 'num_heads': 16,
133
+ 'mlp_ratio': 8.0,
134
+ #'norm_type': 'l1',
135
+ 'margin_type': 'cosface',
136
+ 'margin_m': 0.2914,
137
+ 'scale_s': 30.0,
138
+ },
139
+ 'vit_zana_nano_deep': {
140
+ 'embed_dim': 128,
141
+ 'vocab_dim': 128,
142
+ 'depth': 8,
143
+ 'num_heads': 4,
144
+ 'mlp_ratio': 2.0
145
+ },
146
+ 'vit_zana_shaper': {
147
+ 'embed_dim': 256,
148
+ 'vocab_dim': 256,
149
+ 'depth': 32,
150
+ 'num_heads': 8,
151
+ 'mlp_ratio': 4.0
152
+ },
153
+ 'vit_zana_nano_thicc': {
154
+ 'embed_dim': 128,
155
+ 'vocab_dim': 128,
156
+ 'depth': 4,
157
+ 'num_heads': 8,
158
+ 'mlp_ratio': 4.0
159
+ },
160
+ 'vit_zana_micro': {
161
+ 'embed_dim': 500,
162
+ 'vocab_dim': 25,
163
+ 'depth': 6,
164
+ 'num_heads': 2,
165
+ 'mlp_ratio': 2.0
166
+ },
167
+ 'vit_zana_micro_500': {
168
+ 'embed_dim': 500,
169
+ 'vocab_dim': 25,
170
+ 'depth': 6,
171
+ 'num_heads': 5,
172
+ 'mlp_ratio': 2.0
173
+ },
174
+
175
+ 'vit_zana_base': {
176
+ 'embed_dim': 512,
177
+ 'vocab_dim': 512,
178
+ 'depth': 16,
179
+ 'num_heads': 4,
180
+ 'mlp_ratio': 4.0
181
+ },
182
+ 'vit_ursula_nano_1000': {
183
+ 'embed_dim': 1000,
184
+ 'vocab_dim': 500,
185
+ 'depth': 4,
186
+ 'num_heads': 50,
187
+ 'mlp_ratio': 4.0
188
+ },
189
+ 'vit_ursula_nano': {
190
+ 'embed_dim': 1000,
191
+ 'vocab_dim': 25,
192
+ 'depth': 4,
193
+ 'num_heads': 10,
194
+ 'mlp_ratio': 4.0
195
+ },
196
+
197
+ # Lightweight
198
+ 'tiny': {
199
+ 'embed_dim': 192,
200
+ 'vocab_dim': 192,
201
+ 'depth': 12,
202
+ 'num_heads': 3,
203
+ 'mlp_ratio': 4.0
204
+ },
205
+
206
+ 'vit_ursula_mini': {
207
+ 'embed_dim': 256,
208
+ 'vocab_dim': 256,
209
+ 'depth': 12,
210
+ 'num_heads': 4,
211
+ 'mlp_ratio': 4.0
212
+ },
213
+
214
+ # Standard
215
+ 'small': {
216
+ 'embed_dim': 384,
217
+ 'vocab_dim': 384,
218
+ 'depth': 12,
219
+ 'num_heads': 6,
220
+ 'mlp_ratio': 4.0
221
+ },
222
+
223
+ 'base': {
224
+ 'embed_dim': 768,
225
+ 'vocab_dim': 768,
226
+ 'depth': 12,
227
+ 'num_heads': 12,
228
+ 'mlp_ratio': 4.0
229
+ },
230
+
231
+ # Experimental
232
+ 'wide_shallow': {
233
+ 'embed_dim': 1024,
234
+ 'vocab_dim': 1024,
235
+ 'depth': 4,
236
+ 'num_heads': 16,
237
+ 'mlp_ratio': 2.0
238
+ },
239
+
240
+ 'narrow_deep': {
241
+ 'embed_dim': 192,
242
+ 'vocab_dim': 192,
243
+ 'depth': 24,
244
+ 'num_heads': 3,
245
+ 'mlp_ratio': 4.0
246
+ },
247
+ }
248
+
249
+
250
+ """
251
+ Updated pentachora batch generation and model creation for L1 norm.
252
+ Add this modification to your existing build_model function.
253
+ """
254
+
255
+ def build_model(variant='small', **override_params):
256
+ """
257
+ Build model with explicit parameter handling - no hidden kwargs.
258
+
259
+ Args:
260
+ variant: Model variant name from MODEL_CONFIGS
261
+ **override_params: Individual parameter overrides
262
+
263
+ Returns:
264
+ model: BaselineViT model with frozen pentachora
265
+ """
266
+ assert variant in MODEL_CONFIGS, f"Unknown variant: {variant}. Choose from {list(MODEL_CONFIGS.keys())}"
267
+ base_config = MODEL_CONFIGS[variant].copy()
268
+
269
+ # EXPLICIT parameter extraction with defaults
270
+ # Core architecture parameters
271
+ embed_dim = override_params.get('embed_dim', base_config.get('embed_dim', 512))
272
+ vocab_dim = override_params.get('vocab_dim', base_config.get('vocab_dim', 512))
273
+ depth = override_params.get('depth', base_config.get('depth', 12))
274
+ num_heads = override_params.get('num_heads', base_config.get('num_heads', 8))
275
+ mlp_ratio = override_params.get('mlp_ratio', base_config.get('mlp_ratio', 4.0))
276
+
277
+ # Image and patch parameters
278
+ img_size = override_params.get('img_size', base_config.get('img_size', 32))
279
+ patch_size = override_params.get('patch_size', base_config.get('patch_size', 4))
280
+
281
+ # Regularization parameters
282
+ dropout = override_params.get('dropout', base_config.get('dropout', 0.0))
283
+ attn_dropout = override_params.get('attn_dropout', base_config.get('attn_dropout', 0.0))
284
+
285
+ # Pentachora geometry parameters
286
+ similarity_mode = override_params.get('similarity_mode', base_config.get('similarity_mode', 'rose'))
287
+ norm_type = override_params.get('norm_type', base_config.get('norm_type', 'l1'))
288
+
289
+ # RoseFace head parameters
290
+ head_type = override_params.get('head_type', base_config.get('head_type', 'roseface'))
291
+ prototype_mode = override_params.get('prototype_mode', base_config.get('prototype_mode', 'centroid'))
292
+ margin_type = override_params.get('margin_type', base_config.get('margin_type', 'cosface'))
293
+ margin_m = float(override_params.get('margin_m', base_config.get('margin_m', 0.30)))
294
+ scale_s = float(override_params.get('scale_s', base_config.get('scale_s', 30.0)))
295
+ apply_margin_train_only = override_params.get('apply_margin_train_only',
296
+ base_config.get('apply_margin_train_only', False))
297
+
298
+ # Dataset configuration
299
+ num_classes = len(CIFAR100_CLASSES)
300
+
301
+ # Print what we're building
302
+ print(f"Building {variant}:")
303
+ print(f" Architecture: embed={embed_dim}, vocab={vocab_dim}, depth={depth}, heads={num_heads}")
304
+ print(f" Image: {img_size}x{img_size}, patch={patch_size}x{patch_size}")
305
+ print(f" RoseFace: {margin_type}, m={margin_m:.4f}, s={scale_s:.1f}")
306
+ print(f" Norm: {norm_type}, Similarity: {similarity_mode}")
307
+
308
+ # Generate pentachora from vocab
309
+ print(f"Generating {num_classes} pentachora from vocabulary...")
310
+ class_names = CIFAR100_CLASSES[:num_classes]
311
+
312
+ # vocab.encode_batch returns List[np.ndarray] where each is (5, vocab_dim)
313
+ pentachora_np_list = vocab.encode_batch(class_names, generate=True)
314
+
315
+ # Convert to torch tensors
316
+ raw_penta_list = [torch.tensor(penta, dtype=torch.float32) for penta in pentachora_np_list]
317
+
318
+ # Handle dimension mismatch if needed
319
+ pentachora_list = []
320
+ for i, penta in enumerate(raw_penta_list):
321
+ if penta.shape[-1] != vocab_dim:
322
+ current_dim = penta.shape[-1]
323
+
324
+ if current_dim > vocab_dim:
325
+ # Downsample via linear interpolation
326
+ resized_vertices = []
327
+ for v in range(penta.shape[0]):
328
+ indices = torch.linspace(0, current_dim - 1, vocab_dim)
329
+ vertex = penta[v]
330
+ left_idx = indices.floor().long()
331
+ right_idx = (left_idx + 1).clamp(max=current_dim - 1)
332
+ alpha = indices - left_idx.float()
333
+ interpolated = vertex[left_idx] * (1 - alpha) + vertex[right_idx] * alpha
334
+ resized_vertices.append(interpolated)
335
+ penta_resized = torch.stack(resized_vertices)
336
+ if i == 0: # Only print once
337
+ print(f" Downsampling pentachora from {current_dim} to {vocab_dim}")
338
+ else:
339
+ # Upsample via linear interpolation
340
+ resized_vertices = []
341
+ for v in range(penta.shape[0]):
342
+ vertex = penta[v]
343
+ x = torch.linspace(0, current_dim - 1, vocab_dim)
344
+ interpolated = torch.zeros(vocab_dim, dtype=vertex.dtype, device=vertex.device)
345
+ for j in range(vocab_dim):
346
+ if x[j] <= 0:
347
+ interpolated[j] = vertex[0]
348
+ elif x[j] >= current_dim - 1:
349
+ interpolated[j] = vertex[-1]
350
+ else:
351
+ left = int(x[j])
352
+ alpha = x[j] - left
353
+ interpolated[j] = vertex[left] * (1 - alpha) + vertex[left + 1] * alpha
354
+ resized_vertices.append(interpolated)
355
+ penta_resized = torch.stack(resized_vertices)
356
+ if i == 0: # Only print once
357
+ print(f" Upsampling pentachora from {current_dim} to {vocab_dim}")
358
+
359
+ pentachora_list.append(penta_resized)
360
+ else:
361
+ pentachora_list.append(penta.detach().clone().to(get_default_device()))
362
+
363
+ print(f"Using {num_classes} L1-normalized pentachora")
364
+
365
+ # Create model with EXPLICIT parameters - no **kwargs
366
+ model = BaselineViT(
367
+ pentachora_list=pentachora_list,
368
+ vocab_dim=vocab_dim,
369
+ img_size=img_size,
370
+ patch_size=patch_size,
371
+ embed_dim=embed_dim,
372
+ depth=depth,
373
+ num_heads=num_heads,
374
+ mlp_ratio=mlp_ratio,
375
+ dropout=dropout,
376
+ attn_dropout=attn_dropout,
377
+ similarity_mode=similarity_mode,
378
+ norm_type=norm_type,
379
+ head_type=head_type,
380
+ prototype_mode=prototype_mode,
381
+ margin_type=margin_type,
382
+ margin_m=margin_m,
383
+ scale_s=scale_s,
384
+ apply_margin_train_only=apply_margin_train_only
385
+ )
386
+
387
+ # Store complete config for checkpoint saving
388
+ model.config = {
389
+ 'variant': variant,
390
+ 'vocab_dim': vocab_dim,
391
+ 'embed_dim': embed_dim,
392
+ 'depth': depth,
393
+ 'num_heads': num_heads,
394
+ 'mlp_ratio': mlp_ratio,
395
+ 'img_size': img_size,
396
+ 'patch_size': patch_size,
397
+ 'dropout': dropout,
398
+ 'attn_dropout': attn_dropout,
399
+ 'similarity_mode': similarity_mode,
400
+ 'norm_type': norm_type,
401
+ 'head_type': head_type,
402
+ 'prototype_mode': prototype_mode,
403
+ 'margin_type': margin_type,
404
+ 'margin_m': margin_m,
405
+ 'scale_s': scale_s,
406
+ 'apply_margin_train_only': apply_margin_train_only,
407
+ 'num_classes': num_classes,
408
+ }
409
+
410
+ # Print model statistics
411
+ total_params = sum(p.numel() for p in model.parameters())
412
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
413
+ frozen_params = total_params - trainable_params
414
+
415
+ # After creating model, before returning
416
+ print("\nDiagnostic: Checking pentachora status...")
417
+ for i, penta in enumerate(model.class_pentachora[:3]): # Check first 3
418
+ print(f"Pentachora {i}:")
419
+ print(f" vertices requires_grad: {penta.vertices.requires_grad}")
420
+ print(f" vertices mean: {penta.vertices.mean().item():.6f}")
421
+ print(f" vertices std: {penta.vertices.std().item():.6f}")
422
+
423
+ # Check a main model parameter
424
+ print("\nMain model parameters:")
425
+ if hasattr(model, 'patch_embed'):
426
+ print(f" patch_embed.weight mean: {model.patch_embed.weight.mean().item():.6f}")
427
+ print(f" patch_embed.weight std: {model.patch_embed.weight.std().item():.6f}")
428
+
429
+ print(f"\nModel: {variant}")
430
+ print(f" Classes: {num_classes}")
431
+ print(f" Normalization: {norm_type.upper()}")
432
+ print(f" Total params: {total_params:,}")
433
+ print(f" Trainable params: {trainable_params:,}")
434
+ print(f" Frozen pentachora params: {frozen_params:,}")
435
+
436
+ return model
437
+
438
+ # =========================
439
+ # Minimal load/save helpers
440
+ # =========================
441
+ import os, json, math
442
+ from pathlib import Path
443
+ import torch
444
+ import numpy as np
445
+
446
+ try:
447
+ from safetensors.torch import save_file, load_file
448
+ except Exception as e:
449
+ raise RuntimeError("safetensors is required: pip install safetensors") from e
450
+
451
+ def _get_device():
452
+ return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
453
+
454
+ def _jsonify_obj(obj) -> dict:
455
+ """Turn a config object or dict into a JSON-safe dict."""
456
+ if obj is None:
457
+ return {}
458
+ if isinstance(obj, dict):
459
+ return obj
460
+ out = {}
461
+ for k in dir(obj):
462
+ if k.startswith('_'):
463
+ continue
464
+ v = getattr(obj, k)
465
+ if callable(v):
466
+ continue
467
+ if isinstance(v, torch.Tensor):
468
+ v = v.tolist()
469
+ elif isinstance(v, np.ndarray):
470
+ v = v.tolist()
471
+ out[k] = v
472
+ return out
473
+
474
+ def _ensure_model_config_dict(model):
475
+ """Guarantee model.config is a dict describing the head + geometry relevant fields."""
476
+ if hasattr(model, "config") and isinstance(model.config, dict):
477
+ return model.config
478
+ cfg = {
479
+ "arch": type(model).__name__,
480
+ "num_classes": getattr(model, "num_classes", None),
481
+ "embed_dim": getattr(model, "embed_dim", None),
482
+ "pentachora_dim": getattr(model, "pentachora_dim", None),
483
+ "img_size": getattr(model, "img_size", 32),
484
+ "patch_size": getattr(model, "patch_size", 4),
485
+ "norm_type": getattr(model, "norm_type", None),
486
+ "similarity_mode": getattr(model, "similarity_mode", None),
487
+ "head_type": getattr(model, "head_type", None),
488
+ "prototype_mode": getattr(model, "prototype_mode", None),
489
+ "margin_type": getattr(model, "margin_type", None),
490
+ "margin_m": float(getattr(model, "margin_m", 0.0)) if hasattr(model, "margin_m") else None,
491
+ "scale_s": float(getattr(model, "scale_s", 1.0)) if hasattr(model, "scale_s") else None,
492
+ }
493
+ model.config = cfg
494
+ return cfg
495
+
496
+ def _collect_state_tensors(state_dict):
497
+ return {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
498
+
499
+ def _session_dir(paths: dict) -> Path:
500
+ root = Path(paths["save_dir"])
501
+ return root / f"{paths['model_variant']}_{paths['session_timestamp']}"
502
+
503
+ def _find_local_checkpoint(paths: dict) -> tuple[Path, Path | None, Path | None]:
504
+ """
505
+ Return (weights_path, model_config_path, vocab_path) from the session dir.
506
+ Prefer 'best_*.safetensors'; fall back to most recent '*.safetensors'.
507
+ """
508
+ sdir = _session_dir(paths)
509
+ if not sdir.exists():
510
+ return None, None, None
511
+ safes = sorted(sdir.glob("*.safetensors"), key=lambda p: p.stat().st_mtime)
512
+ if not safes:
513
+ return None, None, None
514
+ # prefer 'best_' if present
515
+ bests = [p for p in safes if p.name.startswith("best_")]
516
+ w = bests[-1] if bests else safes[-1]
517
+ model_cfg = sdir / w.name.replace(".safetensors", "_model_config.json")
518
+ vocab = sdir / w.name.replace(".safetensors", "_vocabulary.json")
519
+ return w, (model_cfg if model_cfg.exists() else None), (vocab if vocab.exists() else None)
520
+
521
+ def _load_saved_vocabulary(vocab_json_path: Path) -> list[torch.Tensor]:
522
+ """Return list of [5,D] tensors from saved crystal JSON."""
523
+ with open(vocab_json_path, "r") as f:
524
+ data = json.load(f)
525
+ crystals = data.get("crystal_to_token", [])
526
+ # crystals[i]['crystal'] is [5,D] list
527
+ penta_list = []
528
+ for item in crystals:
529
+ arr = torch.tensor(item["crystal"], dtype=torch.float32)
530
+ penta_list.append(arr)
531
+ return penta_list
532
+
533
+ # =========================================
534
+ # SAVE: weights + model/training/vocabulary
535
+ # =========================================
536
+ def save_existing_model(
537
+ model,
538
+ paths: dict,
539
+ model_config=None,
540
+ training_config=None,
541
+ *,
542
+ filename_base: str | None = None,
543
+ save_vocabulary: bool = True,
544
+ push_to_hub: bool | None = None
545
+ ):
546
+ """
547
+ Save the model to disk, and optionally upload to the HF Hub.
548
+
549
+ Args:
550
+ model: BaselineViT instance
551
+ paths: {
552
+ 'save_dir': str,
553
+ 'model_variant': str,
554
+ 'session_timestamp': str,
555
+ # (optional for naming)
556
+ 'epoch': int,
557
+ 'val_acc': float,
558
+ 'is_best': bool,
559
+ # hub
560
+ 'hub_repo': str,
561
+ 'hub_token': str|None,
562
+ }
563
+ model_config: dict or object (optional; if None, built from model)
564
+ training_config: TrainingConfig or dict (optional; saved to JSON)
565
+ filename_base: override the base filename; if None, derived from epoch/acc/best
566
+ save_vocabulary: write *_vocabulary.json from model.class_pentachora
567
+ push_to_hub: override paths.get('push_to_hub')
568
+ """
569
+ device = _get_device()
570
+ sess_dir = _session_dir(paths)
571
+ sess_dir.mkdir(parents=True, exist_ok=True)
572
+
573
+ # ---- filename base
574
+ if filename_base is None:
575
+ ep = paths.get("epoch")
576
+ acc = paths.get("val_acc")
577
+ is_best = bool(paths.get("is_best", False))
578
+ tag = f"epoch{int(ep):03d}_acc{float(acc):.2f}" if (ep is not None and acc is not None) else "snapshot"
579
+ filename_base = f"{'best_' if is_best else 'checkpoint_'}{tag}"
580
+
581
+ # ---- weights
582
+ weights_path = sess_dir / f"{filename_base}.safetensors"
583
+ state = _collect_state_tensors(model.state_dict())
584
+ save_file(state, str(weights_path))
585
+
586
+ # ---- model config
587
+ cfg_dict = _jsonify_obj(model_config) or _ensure_model_config_dict(model)
588
+ model_cfg_path = sess_dir / f"{filename_base}_model_config.json"
589
+ with open(model_cfg_path, "w") as f:
590
+ json.dump(cfg_dict, f, indent=2, default=str)
591
+
592
+ # ---- training config (metadata)
593
+ if training_config is not None:
594
+ train_cfg_dict = _jsonify_obj(training_config)
595
+ train_cfg_path = sess_dir / f"{filename_base}_training_config.json"
596
+ with open(train_cfg_path, "w") as f:
597
+ json.dump(train_cfg_dict, f, indent=2, default=str)
598
+ else:
599
+ train_cfg_path = None
600
+
601
+ # ---- vocabulary
602
+ vocab_path = None
603
+ if save_vocabulary and hasattr(model, "class_pentachora") and model.class_pentachora is not None:
604
+ crystals = torch.stack([p.vertices for p in model.class_pentachora], dim=0).detach().cpu().numpy().tolist()
605
+ vocab_data = {
606
+ "vocab_dim": getattr(model, "pentachora_dim", None),
607
+ "num_classes": len(model.class_pentachora),
608
+ "num_vertices": 5,
609
+ "tokens": CIFAR100_CLASSES[: len(crystals)],
610
+ "crystal_to_token": [
611
+ {"index": i, "token": CIFAR100_CLASSES[i], "crystal": crystals[i]}
612
+ for i in range(len(crystals))
613
+ ],
614
+ }
615
+ vocab_path = sess_dir / f"{filename_base}_vocabulary.json"
616
+ with open(vocab_path, "w") as f:
617
+ json.dump(vocab_data, f, indent=2)
618
+
619
+ print(f"✓ Saved weights: {weights_path.name}")
620
+ print(f"✓ Saved model config: {model_cfg_path.name}")
621
+ if train_cfg_path:
622
+ print(f"✓ Saved training config: {train_cfg_path.name}")
623
+ if vocab_path:
624
+ print(f"✓ Saved vocabulary: {vocab_path.name}")
625
+
626
+ # ---- optional hub upload
627
+ do_push = push_to_hub if push_to_hub is not None else paths.get("push_to_hub", False)
628
+ if do_push:
629
+ try:
630
+ from huggingface_hub import HfApi, create_repo
631
+ hub_repo = paths["hub_repo"]
632
+ hub_token = paths.get("hub_token")
633
+ subfolder = f"models/{paths['model_variant']}/{paths['session_timestamp']}"
634
+
635
+ api = HfApi(token=hub_token)
636
+ try:
637
+ create_repo(hub_repo, token=hub_token, private=True, exist_ok=True)
638
+ except Exception:
639
+ pass
640
+
641
+ def _up(p: Path):
642
+ api.upload_file(
643
+ path_or_fileobj=str(p),
644
+ path_in_repo=f"{subfolder}/{p.name}",
645
+ repo_id=hub_repo,
646
+ repo_type="model"
647
+ )
648
+
649
+ _up(weights_path); _up(model_cfg_path)
650
+ if train_cfg_path: _up(train_cfg_path)
651
+ if vocab_path: _up(vocab_path)
652
+ print(f"✓ Pushed to hub: {hub_repo}/{subfolder}")
653
+ except Exception as e:
654
+ print(f"⚠ Hub upload failed: {e}")
655
+
656
+ return {
657
+ "weights": weights_path,
658
+ "model_config": model_cfg_path,
659
+ "training_config": train_cfg_path,
660
+ "vocabulary": vocab_path,
661
+ "session_dir": sess_dir
662
+ }
663
+
664
+ # =========================================
665
+ # LOAD: from disk or hub subfolder
666
+ # =========================================
667
+ def load_existing_model(
668
+ model_path: str | Path | None,
669
+ paths: dict | None,
670
+ model_config=None,
671
+ training_config=None,
672
+ *,
673
+ from_hub: bool = False,
674
+ prefer_best: bool = True,
675
+ map_location: str | torch.device | None = None
676
+ ):
677
+ """
678
+ Load a saved model (weights + config), reconstruct the architecture via build_model,
679
+ and return a ready-to-use model. If a saved vocabulary is present, reuse it.
680
+
681
+ Args:
682
+ model_path: explicit path to a .safetensors file; if None, resolve from `paths`
683
+ paths: {
684
+ 'save_dir': str, 'model_variant': str, 'session_timestamp': str,
685
+ # (for hub)
686
+ 'hub_repo': str, 'hub_token': str|None
687
+ }
688
+ from_hub: if True, pull from HF Hub subfolder models/{variant}/{session}/
689
+ prefer_best: when scanning a folder, pick 'best_*.safetensors' if available
690
+ map_location: optional torch map_location
691
+
692
+ Returns:
693
+ model (on default device), resolved_paths dict
694
+ """
695
+ device = _get_device() if map_location is None else map_location
696
+
697
+ # ---------- resolve source files ----------
698
+ if model_path is not None:
699
+ weights_path = Path(model_path)
700
+ base = weights_path.name.replace(".safetensors", "")
701
+ session_dir = weights_path.parent
702
+ model_cfg_path = session_dir / f"{base}_model_config.json"
703
+ vocab_path = session_dir / f"{base}_vocabulary.json"
704
+ elif from_hub:
705
+ try:
706
+ from huggingface_hub import hf_hub_download
707
+ except Exception as e:
708
+ raise RuntimeError("huggingface_hub is required for from_hub=True") from e
709
+ hub_repo = paths["hub_repo"]
710
+ subfolder = f"models/{paths['model_variant']}/{paths['session_timestamp']}"
711
+ # Download index (weights); prefer 'best_' by asking caller to pass the exact name or we try both
712
+ # We will download repo file list is not available here; caller should pass model_path if you want a specific file.
713
+ # Fallback: try canonical 'best_' name; else 'checkpoint_'.
714
+ candidates = ["best", "checkpoint"]
715
+ weights_path = None
716
+ for pref in candidates:
717
+ try:
718
+ fname = None
719
+ # look for any .safetensors in subfolder; require caller to provide exact file if multiple
720
+ # Here we try a common name; if it fails, raise with guidance
721
+ # (You can extend to list_repo_files if needed.)
722
+ # Attempt pattern-less download will fail; so require explicit file or local resolution.
723
+ # Safer approach: user supplies explicit model_path for hub.
724
+ pass
725
+ except Exception:
726
+ pass
727
+ raise RuntimeError(
728
+ "When loading from Hub, please supply the explicit .safetensors filename in model_path "
729
+ "(e.g., '.../best_epoch010_acc30.30.safetensors') or download locally first."
730
+ )
731
+ else:
732
+ # resolve from local session dir
733
+ weights_path, model_cfg_path, vocab_path = _find_local_checkpoint(paths)
734
+ if weights_path is None:
735
+ raise FileNotFoundError("No checkpoint found in session folder")
736
+
737
+ # ---------- read model config ----------
738
+ # prefer on-disk config; else use provided model_config; else minimal override dict
739
+ if model_cfg_path and model_cfg_path.exists():
740
+ with open(model_cfg_path, "r") as f:
741
+ cfg = json.load(f)
742
+ else:
743
+ cfg = _jsonify_obj(model_config)
744
+
745
+ # variant + overrides to rebuild the model
746
+ variant = cfg.get("variant", paths.get("model_variant") if paths else None)
747
+ if variant is None:
748
+ raise ValueError("Model variant not found in config; pass paths['model_variant'] or include 'variant'.")
749
+
750
+ overrides = {}
751
+ # allow restoring head settings if present
752
+ for k in ("embed_dim","vocab_dim","depth","num_heads","mlp_ratio",
753
+ "img_size","patch_size","dropout","attn_dropout",
754
+ "norm_type","similarity_mode",
755
+ "head_type","prototype_mode","margin_type","margin_m","scale_s",
756
+ "apply_margin_train_only"):
757
+ if k in cfg and cfg[k] is not None:
758
+ overrides[k] = cfg[k]
759
+
760
+ # ---------- rebuild model via your factory ----------
761
+ # IMPORTANT: if a saved vocabulary exists, load it to reproduce exact pentachora
762
+ if 'vocabulary' in overrides: # just in case
763
+ overrides.pop('vocabulary')
764
+ if 'num_classes' in cfg:
765
+ overrides['num_classes'] = cfg['num_classes'] # not used directly by build_model but okay to keep
766
+
767
+ if 'vocab' in globals() and (not ('pentachora_list' in overrides)):
768
+ # build_model will use vocab.encode_batch; if we have a saved vocab JSON, override afterwards
769
+ model = build_model(variant=variant, **overrides).to(device)
770
+ if 'get_default_device' in globals():
771
+ model = model.to(get_default_device())
772
+ else:
773
+ model = build_model(variant=variant, **overrides).to(device)
774
+
775
+ # if a vocabulary JSON exists, replace model.class_pentachora with saved crystals
776
+ if 'vocab' in globals() and vocab_path and vocab_path.exists():
777
+ saved_penta = _load_saved_vocabulary(vocab_path) # list of [5,D]
778
+ if hasattr(model, "class_pentachora") and len(saved_penta) == len(model.class_pentachora):
779
+ # swap in the exact saved pentachora
780
+ new_list = []
781
+ for p in saved_penta:
782
+ new_list.append(type(model.class_pentachora[0])(p, norm_type=getattr(model, "norm_type", "l1")))
783
+ # rebuild ModuleList
784
+ import torch.nn as nn
785
+ model.class_pentachora = nn.ModuleList(new_list)
786
+ # update normalized buffers inside PentachoraEmbedding if needed (constructor already handles it)
787
+
788
+ # ---------- load weights ----------
789
+ sd = load_file(str(weights_path), device='cpu')
790
+ print(f"\nCheckpoint contains {len(sd)} keys")
791
+ print(f"First 5 keys: {list(sd.keys())[:5]}")
792
+
793
+ # Check for compiled model prefix
794
+ has_orig_mod = any(k.startswith("_orig_mod.") for k in sd.keys())
795
+ if has_orig_mod:
796
+ print("Detected compiled model checkpoint (_orig_mod. prefix)")
797
+
798
+ # Strip _orig_mod. if present
799
+ fixed = {}
800
+ for k, v in sd.items():
801
+ new_key = k[10:] if k.startswith("_orig_mod.") else k
802
+ fixed[new_key] = v
803
+
804
+ # Get model state dict for comparison
805
+ model_state = model.state_dict()
806
+ print(f"\nModel expects {len(model_state)} keys")
807
+ print(f"First 5 expected: {list(model_state.keys())[:5]}")
808
+
809
+ # Find mismatches
810
+ checkpoint_keys = set(fixed.keys())
811
+ model_keys = set(model_state.keys())
812
+
813
+ missing_in_checkpoint = model_keys - checkpoint_keys
814
+ unexpected_in_checkpoint = checkpoint_keys - model_keys
815
+
816
+ print(f"\nKeys in model but not in checkpoint: {len(missing_in_checkpoint)}")
817
+ if missing_in_checkpoint and len(missing_in_checkpoint) < 10:
818
+ print(f" Missing: {list(missing_in_checkpoint)[:10]}")
819
+
820
+ print(f"Keys in checkpoint but not in model: {len(unexpected_in_checkpoint)}")
821
+ if unexpected_in_checkpoint and len(unexpected_in_checkpoint) < 10:
822
+ print(f" Unexpected: {list(unexpected_in_checkpoint)[:10]}")
823
+
824
+ # Load with strict=True to see the actual error
825
+ try:
826
+ model.load_state_dict(fixed, strict=True)
827
+ print("✓ Strict load successful - all weights loaded")
828
+ except RuntimeError as e:
829
+ print(f"⚠ Strict load failed: {e}")
830
+ # Fall back to non-strict
831
+ incompatible = model.load_state_dict(fixed, strict=False)
832
+ print(f"Loaded with strict=False")
833
+ print(f" Missing keys: {len(incompatible.missing_keys)}")
834
+ print(f" Unexpected keys: {len(incompatible.unexpected_keys)}")
835
+
836
+ # Check if critical weights are missing
837
+ critical_missing = [k for k in incompatible.missing_keys if 'weight' in k or 'bias' in k]
838
+ if critical_missing:
839
+ print(f" ⚠ Critical missing weights: {critical_missing[:5]}")
840
+
841
+ # Verify weights aren't zero
842
+ sample_weight = next(iter(model.parameters()))
843
+ print(f"\nFirst parameter stats:")
844
+ print(f" Shape: {sample_weight.shape}")
845
+ print(f" Mean: {sample_weight.mean().item():.6f}")
846
+ print(f" Std: {sample_weight.std().item():.6f}")
847
+ print(f" Min: {sample_weight.min().item():.6f}")
848
+ print(f" Max: {sample_weight.max().item():.6f}")
849
+
850
+ model.eval()
851
+ return model, {
852
+ "weights": weights_path,
853
+ "model_config": model_cfg_path,
854
+ "vocabulary": vocab_path,
855
+ "session_dir": weights_path.parent
856
+ }
857
+
858
+ def get_parameter_groups(model, weight_decay):
859
+ """Create parameter groups with weight decay handling"""
860
+ no_decay = ['bias', 'LayerNorm.weight', 'norm']
861
+ params_decay = []
862
+ params_no_decay = []
863
+
864
+ for name, param in model.named_parameters():
865
+ if param.requires_grad:
866
+ if any(nd in name for nd in no_decay):
867
+ params_no_decay.append(param)
868
+ else:
869
+ params_decay.append(param)
870
+
871
+ return [
872
+ {'params': params_decay, 'weight_decay': weight_decay},
873
+ {'params': params_no_decay, 'weight_decay': 0.0}
874
+ ]
875
+
876
+ def create_scheduler(optimizer, config, start_epoch=0):
877
+ """Create cosine scheduler with warmup"""
878
+ def lr_lambda(epoch):
879
+ if epoch < config.warmup_epochs:
880
+ return epoch / config.warmup_epochs
881
+ if config.epochs <= config.warmup_epochs:
882
+ return 1.0
883
+ return 0.5 * (1 + np.cos(np.pi * (epoch - config.warmup_epochs) /
884
+ (config.epochs - config.warmup_epochs)))
885
+
886
+ scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
887
+
888
+ # Fast-forward to correct epoch if resuming
889
+ for _ in range(start_epoch):
890
+ scheduler.step()
891
+
892
+ return scheduler
893
+
894
+ def count_parameters(model):
895
+ """Count model parameters"""
896
+ total = sum(p.numel() for p in model.parameters())
897
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
898
+ return {'total': total, 'trainable': trainable}
899
+
900
+ # Test loading
901
+ if __name__ == "__main__":
902
+ print("Testing model loader...")
903
+ print("=" * 50)
904
+
905
+ # Test load a small model
906
+ model = build_model('vit_beatrix_shaper').to(get_default_device())
907
+ #model = load_exisiting_model(
908
+
909
+ # Test forward pass
910
+ x = torch.randn(4, 3, 32, 32).to(get_default_device())
911
+ output = model(x)
912
+
913
+ print(f"\nForward pass successful!")
914
+ print(f" Input shape: {x.shape}")
915
+ print(f" Logits shape: {output['logits'].shape}")
916
+ print(f" Similarities shape: {output['similarities'].shape}")
917
+
918
+ print("\n✓ Model loader working correctly!")