File size: 3,322 Bytes
4c1c394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02c5c77
4c1c394
3529287
4c1c394
 
 
3529287
4c1c394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
Lazy loaders for the CLIP model and GPU backbone+SAE runner.

Both are loaded at most once, on first use.  Neither is required:
  - CLIP is needed only for free-text feature search
  - The GPU runner is needed only for live patch-activation inference
"""

import os
import sys

import numpy as np
import torch

from .args import args
from .state import _all_datasets


# ---------- CLIP ----------

_clip_handle = None   # (model, processor, device) once loaded


def get_clip():
    """Return (model, processor, device), loading CLIP on first call."""
    global _clip_handle
    if _clip_handle is None:
        sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src'))
        from clip_utils import load_clip
        dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        print(f"[CLIP] Loading {args.clip_model} on {dev} (first free-text query)...")
        model, processor = load_clip(dev, model_name=args.clip_model)
        _clip_handle = (model, processor, dev)
        print("[CLIP] Ready.")
    return _clip_handle


# ---------- GPU backbone + SAE ----------

_gpu_runner = None   # (fwd_fn, sae, transform_fn, n_reg, extract_fn, backbone_name, device)


def get_gpu_runner():
    """Return the runner tuple, loading on first call. Returns None if unavailable."""
    global _gpu_runner
    if _gpu_runner is not None:
        return _gpu_runner
    if not args.sae_path:
        return None

    src_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'src'))
    sys.path.insert(0, src_dir)
    from backbone_runners import load_batched_backbone
    from precompute_utils import load_sae, extract_tokens

    dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    d_model = _all_datasets[0]['d_model']   # SAE output dim fixed to primary dataset
    print(f"[GPU runner] Loading {args.backbone} layer {args.layer} + SAE on {dev} ...")
    fwd, d_hidden, n_reg, tfm = load_batched_backbone(args.backbone, args.layer, dev)
    sae = load_sae(args.sae_path, d_hidden, d_model, args.top_k, dev)
    _gpu_runner = (fwd, sae, tfm, n_reg, extract_tokens, args.backbone, dev)
    print("[GPU runner] Ready.")
    return _gpu_runner


def run_gpu_inference(pil_img) -> np.ndarray | None:
    """Run pil_img through backbone→SAE; return (n_patches, d_sae) float32 or None."""
    runner = get_gpu_runner()
    if runner is None:
        return None
    fwd, sae, tfm, n_reg, extract_tokens, backbone_name, dev = runner
    tensor = tfm(pil_img).unsqueeze(0).to(dev)
    with torch.inference_mode():
        hidden = fwd(tensor)
        tokens = extract_tokens(hidden, backbone_name, 'spatial', n_reg)
        flat   = tokens.reshape(-1, tokens.shape[-1])
        _, z, _ = sae(flat)
        print(f"[GPU runner] z shape={z.shape}, "
              f"nonzero={int((z > 0).sum())}, max={float(z.max()):.4f}")
    return z.cpu().float().numpy()


def warmup_gpu_runner():
    """Load the GPU runner in a background thread so the first patch request is fast."""
    import threading
    if args.sae_path:
        def _warmup():
            try:
                get_gpu_runner()
            except Exception as e:
                print(f"[GPU runner] Warmup failed: {e}")
        threading.Thread(target=_warmup, daemon=True).start()