AbstractPhil commited on
Commit
dbffff8
Β·
verified Β·
1 Parent(s): ea86ebc

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +67 -18
model.py CHANGED
@@ -10,9 +10,9 @@ Input: (B, 8, 16, 16) β€” adapted latent patches
10
  Output: gate_vectors (B, 64, 17), patch_features (B, 64, 256), logits
11
 
12
  Usage:
13
- from geometric_model import SuperpositionPatchClassifier, load_from_hub
14
 
15
- model = load_from_hub() # downloads from AbstractPhil/geovocab-patch-maker
16
  out = model(patches)
17
 
18
  # Gate vectors: explicit geometric properties per patch
@@ -324,30 +324,66 @@ class SuperpositionPatchClassifier(nn.Module):
324
  # Hub Loading
325
  # ══════════════════════════════════════════════════════════════════════════════
326
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
  def load_from_hub(
328
  repo_id="AbstractPhil/geovocab-patch-maker",
329
- filename="model.pt",
 
330
  device="cuda" if torch.cuda.is_available() else "cpu",
331
  ):
332
- """Load pretrained model from HuggingFace Hub."""
 
 
 
 
 
333
  from huggingface_hub import hf_hub_download
334
 
335
- path = hf_hub_download(repo_id=repo_id, filename=filename)
336
- ckpt = torch.load(path, map_location=device, weights_only=False)
337
- cfg = ckpt["config"]
 
 
 
 
 
 
 
338
 
339
- model = SuperpositionPatchClassifier(
340
- embed_dim=cfg["embed_dim"],
341
- patch_dim=cfg["patch_dim"],
342
- n_bootstrap=cfg["n_bootstrap"],
343
- n_geometric=cfg["n_geometric"],
344
- n_heads=cfg["n_heads"],
345
- dropout=0.0,
346
- ).to(device).eval()
347
 
 
348
  model.load_state_dict(ckpt["model_state_dict"])
349
- print(f"βœ“ Loaded {repo_id} (epoch {ckpt.get('epoch', '?')})")
350
- return model
 
 
 
 
351
 
352
 
353
  @torch.no_grad()
@@ -395,6 +431,9 @@ def extract_features(model, patches, batch_size=256):
395
  # ══════════════════════════════════════════════════════════════════════════════
396
 
397
  if __name__ == "__main__":
 
 
 
398
  model = SuperpositionPatchClassifier()
399
  n_params = sum(p.numel() for p in model.parameters())
400
  print(f"SuperpositionPatchClassifier: {n_params:,} parameters")
@@ -406,4 +445,14 @@ if __name__ == "__main__":
406
  print(f" local_dim: {out['local_dim_logits'].shape}")
407
  print(f" struct_topo: {out['struct_topo_logits'].shape}")
408
  print(f" patch_shapes: {out['patch_shape_logits'].shape}")
409
- print(f" global_features: {out['global_features'].shape}")
 
 
 
 
 
 
 
 
 
 
 
10
  Output: gate_vectors (B, 64, 17), patch_features (B, 64, 256), logits
11
 
12
  Usage:
13
+ from geometric_model import load_from_hub, extract_features
14
 
15
+ model, config = load_from_hub() # reads config.json + model.pt from Hub
16
  out = model(patches)
17
 
18
  # Gate vectors: explicit geometric properties per patch
 
324
  # Hub Loading
325
  # ══════════════════════════════════════════════════════════════════════════════
326
 
327
+ def load_config(repo_id="AbstractPhil/geovocab-patch-maker", config_file="config.json"):
328
+ """Load model config from HuggingFace Hub."""
329
+ import json
330
+ from huggingface_hub import hf_hub_download
331
+
332
+ path = hf_hub_download(repo_id=repo_id, filename=config_file)
333
+ with open(path, "r") as f:
334
+ return json.load(f)
335
+
336
+
337
+ def from_config(config, device="cpu"):
338
+ """Instantiate model from config dict (no weights)."""
339
+ return SuperpositionPatchClassifier(
340
+ embed_dim=config["embed_dim"],
341
+ patch_dim=config["patch_dim"],
342
+ n_bootstrap=config["n_bootstrap"],
343
+ n_geometric=config["n_geometric"],
344
+ n_heads=config["n_heads"],
345
+ dropout=config.get("dropout", 0.0),
346
+ ).to(device)
347
+
348
+
349
  def load_from_hub(
350
  repo_id="AbstractPhil/geovocab-patch-maker",
351
+ weights_file="model.pt",
352
+ config_file="config.json",
353
  device="cuda" if torch.cuda.is_available() else "cpu",
354
  ):
355
+ """
356
+ Load pretrained model from HuggingFace Hub.
357
+
358
+ Reads config.json for architecture, model.pt for weights.
359
+ Falls back to config embedded in checkpoint if config.json missing.
360
+ """
361
  from huggingface_hub import hf_hub_download
362
 
363
+ # Load config
364
+ try:
365
+ config = load_config(repo_id, config_file)
366
+ print(f"βœ“ Config loaded from {config_file}")
367
+ except Exception:
368
+ config = None
369
+
370
+ # Load weights
371
+ weights_path = hf_hub_download(repo_id=repo_id, filename=weights_file)
372
+ ckpt = torch.load(weights_path, map_location=device, weights_only=False)
373
 
374
+ # Config priority: config.json > checkpoint config
375
+ if config is None:
376
+ config = ckpt["config"]
377
+ print(f" Config from checkpoint (no {config_file} found)")
 
 
 
 
378
 
379
+ model = from_config(config, device=device)
380
  model.load_state_dict(ckpt["model_state_dict"])
381
+ model.eval()
382
+
383
+ epoch = ckpt.get("epoch", "?")
384
+ n_params = sum(p.numel() for p in model.parameters())
385
+ print(f"βœ“ Loaded {repo_id} (epoch {epoch}, {n_params:,} params)")
386
+ return model, config
387
 
388
 
389
  @torch.no_grad()
 
431
  # ══════════════════════════════════════════════════════════════════════════════
432
 
433
  if __name__ == "__main__":
434
+ import json
435
+
436
+ # Test 1: Direct instantiation
437
  model = SuperpositionPatchClassifier()
438
  n_params = sum(p.numel() for p in model.parameters())
439
  print(f"SuperpositionPatchClassifier: {n_params:,} parameters")
 
445
  print(f" local_dim: {out['local_dim_logits'].shape}")
446
  print(f" struct_topo: {out['struct_topo_logits'].shape}")
447
  print(f" patch_shapes: {out['patch_shape_logits'].shape}")
448
+ print(f" global_features: {out['global_features'].shape}")
449
+
450
+ # Test 2: From config
451
+ import os
452
+ cfg_path = os.path.join(os.path.dirname(__file__), "config.json")
453
+ if os.path.exists(cfg_path):
454
+ with open(cfg_path) as f:
455
+ config = json.load(f)
456
+ model2 = from_config(config)
457
+ print(f"\n from_config: {sum(p.numel() for p in model2.parameters()):,} params")
458
+ print(f" config: {config['model_type']} embed={config['embed_dim']} patches={config['num_patches']}")