""" Lazy loaders for the CLIP model and GPU backbone+SAE runner. Both are loaded at most once, on first use. Neither is required: - CLIP is needed only for free-text feature search - The GPU runner is needed only for live patch-activation inference """ import os import sys import numpy as np import torch from .args import args from .state import _all_datasets # ---------- CLIP ---------- _clip_handle = None # (model, processor, device) once loaded def get_clip(): """Return (model, processor, device), loading CLIP on first call.""" global _clip_handle if _clip_handle is None: sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src')) from clip_utils import load_clip dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(f"[CLIP] Loading {args.clip_model} on {dev} (first free-text query)...") model, processor = load_clip(dev, model_name=args.clip_model) _clip_handle = (model, processor, dev) print("[CLIP] Ready.") return _clip_handle # ---------- GPU backbone + SAE ---------- _gpu_runner = None # (fwd_fn, sae, transform_fn, n_reg, extract_fn, backbone_name, device) def get_gpu_runner(): """Return the runner tuple, loading on first call. Returns None if unavailable.""" global _gpu_runner if _gpu_runner is not None: return _gpu_runner if not args.sae_path: return None src_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'src')) sys.path.insert(0, src_dir) from backbone_runners import load_batched_backbone from precompute_utils import load_sae, extract_tokens dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") d_model = _all_datasets[0]['d_model'] # SAE output dim fixed to primary dataset print(f"[GPU runner] Loading {args.backbone} layer {args.layer} + SAE on {dev} ...") fwd, d_hidden, n_reg, tfm = load_batched_backbone(args.backbone, args.layer, dev) sae = load_sae(args.sae_path, d_hidden, d_model, args.top_k, dev) _gpu_runner = (fwd, sae, tfm, n_reg, extract_tokens, args.backbone, dev) print("[GPU runner] Ready.") return _gpu_runner def run_gpu_inference(pil_img) -> np.ndarray | None: """Run pil_img through backbone→SAE; return (n_patches, d_sae) float32 or None.""" runner = get_gpu_runner() if runner is None: return None fwd, sae, tfm, n_reg, extract_tokens, backbone_name, dev = runner tensor = tfm(pil_img).unsqueeze(0).to(dev) with torch.inference_mode(): hidden = fwd(tensor) tokens = extract_tokens(hidden, backbone_name, 'spatial', n_reg) flat = tokens.reshape(-1, tokens.shape[-1]) _, z, _ = sae(flat) print(f"[GPU runner] z shape={z.shape}, " f"nonzero={int((z > 0).sum())}, max={float(z.max()):.4f}") return z.cpu().float().numpy() def warmup_gpu_runner(): """Load the GPU runner in a background thread so the first patch request is fast.""" import threading if args.sae_path: def _warmup(): try: get_gpu_runner() except Exception as e: print(f"[GPU runner] Warmup failed: {e}") threading.Thread(target=_warmup, daemon=True).start()