one-for-all / _probe.py
frankyy03's picture
fix: iframe Three.js, model response output, UI labels
40b1944 verified
from __future__ import annotations
import os
import torch
import torch.nn as nn
import numpy as np
# Inlined so Space has no dependency on the ofa package.
STUDENT_BASE = "Qwen/Qwen2.5-0.5B-Instruct"
ADAPTER_REPO = "build-small-hackathon/deku"
STUDENT_HIDDEN_DIM = 896 # Qwen2.5-0.5B hidden size
N_TEACHERS = 5
class GatingNetwork(nn.Module):
def __init__(self, hidden_dim: int, n_teachers: int):
super().__init__()
self.fc = nn.Linear(hidden_dim, n_teachers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.softmax(self.fc(x), dim=-1)
def _masked_mean(hidden: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
m = mask.unsqueeze(-1).to(hidden.dtype)
return (hidden * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0)
def load_student(hf_token: str | None = None):
"""Load student (base + LoRA) and gating network from HF Hub.
Returns (tokenizer, student_model, gating_network).
"""
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from huggingface_hub import hf_hub_download
token = hf_token or os.environ.get("HF_TOKEN")
tok = AutoTokenizer.from_pretrained(STUDENT_BASE)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
base = AutoModelForCausalLM.from_pretrained(
STUDENT_BASE,
torch_dtype=torch.bfloat16,
device_map="auto",
output_hidden_states=True,
)
student = PeftModel.from_pretrained(base, ADAPTER_REPO, token=token)
student.eval()
gating_path = hf_hub_download(
repo_id=ADAPTER_REPO,
filename="gating.pt",
repo_type="model",
token=token,
)
gating = GatingNetwork(STUDENT_HIDDEN_DIM, N_TEACHERS)
gating.load_state_dict(torch.load(gating_path, map_location="cpu"))
gating.eval()
return tok, student, gating
def generate_response(
text: str,
student: nn.Module,
tok,
max_new_tokens: int = 200,
) -> str:
"""Run student generation and return decoded answer text."""
device = next(student.parameters()).device
try:
prompt = tok.apply_chat_template(
[{"role": "user", "content": text}],
tokenize=False,
add_generation_prompt=True,
)
except Exception:
prompt = text
enc = tok(prompt, return_tensors="pt").to(device)
with torch.no_grad():
out = student.generate(
input_ids=enc["input_ids"],
attention_mask=enc["attention_mask"],
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tok.eos_token_id,
)
new_tokens = out[0][enc["input_ids"].shape[1]:]
return tok.decode(new_tokens, skip_special_tokens=True)
def run_probe(
text: str,
student: nn.Module,
tok,
gating: GatingNetwork,
reducer,
max_len: int = 512,
) -> tuple[dict, list[float]]:
"""Single student forward. Returns (new_umap_point, gate_weights).
new_umap_point : {"x": float, "y": float, "z": float, "label": str}
gate_weights : list of N_TEACHERS floats summing to 1.0
"""
enc = tok([text], return_tensors="pt", truncation=True, max_length=max_len)
device = next(student.parameters()).device
enc = {k: v.to(device) for k, v in enc.items()}
with torch.no_grad():
out = student(**enc, output_hidden_states=True, use_cache=False)
pooled = _masked_mean(out.hidden_states[-1], enc["attention_mask"]).float()
gate_weights: list[float] = gating(pooled).squeeze(0).tolist()
coords3d = reducer.transform(pooled.cpu().numpy()) # (1, 3)
return {
"x": float(coords3d[0, 0]),
"y": float(coords3d[0, 1]),
"z": float(coords3d[0, 2]),
"label": "probe",
}, gate_weights