| """ |
| Lazy loaders for the CLIP model and GPU backbone+SAE runner. |
| |
| Both are loaded at most once, on first use. Neither is required: |
| - CLIP is needed only for free-text feature search |
| - The GPU runner is needed only for live patch-activation inference |
| """ |
|
|
| import os |
| import sys |
|
|
| import numpy as np |
| import torch |
|
|
| from .args import args |
| from .state import _all_datasets |
|
|
|
|
| |
|
|
| _clip_handle = None |
|
|
|
|
| def get_clip(): |
| """Return (model, processor, device), loading CLIP on first call.""" |
| global _clip_handle |
| if _clip_handle is None: |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src')) |
| from clip_utils import load_clip |
| dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| print(f"[CLIP] Loading {args.clip_model} on {dev} (first free-text query)...") |
| model, processor = load_clip(dev, model_name=args.clip_model) |
| _clip_handle = (model, processor, dev) |
| print("[CLIP] Ready.") |
| return _clip_handle |
|
|
|
|
| |
|
|
| _gpu_runner = None |
|
|
|
|
| def get_gpu_runner(): |
| """Return the runner tuple, loading on first call. Returns None if unavailable.""" |
| global _gpu_runner |
| if _gpu_runner is not None: |
| return _gpu_runner |
| if not args.sae_path: |
| return None |
|
|
| src_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'src')) |
| sys.path.insert(0, src_dir) |
| from backbone_runners import load_batched_backbone |
| from precompute_utils import load_sae, extract_tokens |
|
|
| dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| d_model = _all_datasets[0]['d_model'] |
| print(f"[GPU runner] Loading {args.backbone} layer {args.layer} + SAE on {dev} ...") |
| fwd, d_hidden, n_reg, tfm = load_batched_backbone(args.backbone, args.layer, dev) |
| sae = load_sae(args.sae_path, d_hidden, d_model, args.top_k, dev) |
| _gpu_runner = (fwd, sae, tfm, n_reg, extract_tokens, args.backbone, dev) |
| print("[GPU runner] Ready.") |
| return _gpu_runner |
|
|
|
|
| def run_gpu_inference(pil_img) -> np.ndarray | None: |
| """Run pil_img through backbone→SAE; return (n_patches, d_sae) float32 or None.""" |
| runner = get_gpu_runner() |
| if runner is None: |
| return None |
| fwd, sae, tfm, n_reg, extract_tokens, backbone_name, dev = runner |
| tensor = tfm(pil_img).unsqueeze(0).to(dev) |
| with torch.inference_mode(): |
| hidden = fwd(tensor) |
| tokens = extract_tokens(hidden, backbone_name, 'spatial', n_reg) |
| flat = tokens.reshape(-1, tokens.shape[-1]) |
| _, z, _ = sae(flat) |
| print(f"[GPU runner] z shape={z.shape}, " |
| f"nonzero={int((z > 0).sum())}, max={float(z.max()):.4f}") |
| return z.cpu().float().numpy() |
|
|
|
|
| def warmup_gpu_runner(): |
| """Load the GPU runner in a background thread so the first patch request is fast.""" |
| import threading |
| if args.sae_path: |
| def _warmup(): |
| try: |
| get_gpu_runner() |
| except Exception as e: |
| print(f"[GPU runner] Warmup failed: {e}") |
| threading.Thread(target=_warmup, daemon=True).start() |
|
|