File size: 3,797 Bytes
275cef9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc58dce
275cef9
 
 
 
 
 
 
 
 
 
40b1944
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275cef9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from __future__ import annotations
import os
import torch
import torch.nn as nn
import numpy as np

# Inlined so Space has no dependency on the ofa package.
STUDENT_BASE = "Qwen/Qwen2.5-0.5B-Instruct"
ADAPTER_REPO = "build-small-hackathon/deku"
STUDENT_HIDDEN_DIM = 896  # Qwen2.5-0.5B hidden size
N_TEACHERS = 5


class GatingNetwork(nn.Module):
    def __init__(self, hidden_dim: int, n_teachers: int):
        super().__init__()
        self.fc = nn.Linear(hidden_dim, n_teachers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.softmax(self.fc(x), dim=-1)


def _masked_mean(hidden: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    m = mask.unsqueeze(-1).to(hidden.dtype)
    return (hidden * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0)


def load_student(hf_token: str | None = None):
    """Load student (base + LoRA) and gating network from HF Hub.

    Returns (tokenizer, student_model, gating_network).
    """
    from transformers import AutoModelForCausalLM, AutoTokenizer
    from peft import PeftModel
    from huggingface_hub import hf_hub_download

    token = hf_token or os.environ.get("HF_TOKEN")

    tok = AutoTokenizer.from_pretrained(STUDENT_BASE)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    base = AutoModelForCausalLM.from_pretrained(
        STUDENT_BASE,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        output_hidden_states=True,
    )
    student = PeftModel.from_pretrained(base, ADAPTER_REPO, token=token)
    student.eval()

    gating_path = hf_hub_download(
        repo_id=ADAPTER_REPO,
        filename="gating.pt",
        repo_type="model",
        token=token,
    )
    gating = GatingNetwork(STUDENT_HIDDEN_DIM, N_TEACHERS)
    gating.load_state_dict(torch.load(gating_path, map_location="cpu"))
    gating.eval()

    return tok, student, gating


def generate_response(
    text: str,
    student: nn.Module,
    tok,
    max_new_tokens: int = 200,
) -> str:
    """Run student generation and return decoded answer text."""
    device = next(student.parameters()).device
    try:
        prompt = tok.apply_chat_template(
            [{"role": "user", "content": text}],
            tokenize=False,
            add_generation_prompt=True,
        )
    except Exception:
        prompt = text
    enc = tok(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        out = student.generate(
            input_ids=enc["input_ids"],
            attention_mask=enc["attention_mask"],
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tok.eos_token_id,
        )
    new_tokens = out[0][enc["input_ids"].shape[1]:]
    return tok.decode(new_tokens, skip_special_tokens=True)


def run_probe(
    text: str,
    student: nn.Module,
    tok,
    gating: GatingNetwork,
    reducer,
    max_len: int = 512,
) -> tuple[dict, list[float]]:
    """Single student forward. Returns (new_umap_point, gate_weights).

    new_umap_point : {"x": float, "y": float, "z": float, "label": str}
    gate_weights   : list of N_TEACHERS floats summing to 1.0
    """
    enc = tok([text], return_tensors="pt", truncation=True, max_length=max_len)
    device = next(student.parameters()).device
    enc = {k: v.to(device) for k, v in enc.items()}

    with torch.no_grad():
        out = student(**enc, output_hidden_states=True, use_cache=False)

    pooled = _masked_mean(out.hidden_states[-1], enc["attention_mask"]).float()
    gate_weights: list[float] = gating(pooled).squeeze(0).tolist()

    coords3d = reducer.transform(pooled.cpu().numpy())  # (1, 3)
    return {
        "x": float(coords3d[0, 0]),
        "y": float(coords3d[0, 1]),
        "z": float(coords3d[0, 2]),
        "label": "probe",
    }, gate_weights