File size: 3,797 Bytes
275cef9 cc58dce 275cef9 40b1944 275cef9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | from __future__ import annotations
import os
import torch
import torch.nn as nn
import numpy as np
# Inlined so Space has no dependency on the ofa package.
STUDENT_BASE = "Qwen/Qwen2.5-0.5B-Instruct"
ADAPTER_REPO = "build-small-hackathon/deku"
STUDENT_HIDDEN_DIM = 896 # Qwen2.5-0.5B hidden size
N_TEACHERS = 5
class GatingNetwork(nn.Module):
def __init__(self, hidden_dim: int, n_teachers: int):
super().__init__()
self.fc = nn.Linear(hidden_dim, n_teachers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.softmax(self.fc(x), dim=-1)
def _masked_mean(hidden: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
m = mask.unsqueeze(-1).to(hidden.dtype)
return (hidden * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0)
def load_student(hf_token: str | None = None):
"""Load student (base + LoRA) and gating network from HF Hub.
Returns (tokenizer, student_model, gating_network).
"""
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from huggingface_hub import hf_hub_download
token = hf_token or os.environ.get("HF_TOKEN")
tok = AutoTokenizer.from_pretrained(STUDENT_BASE)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
base = AutoModelForCausalLM.from_pretrained(
STUDENT_BASE,
torch_dtype=torch.bfloat16,
device_map="auto",
output_hidden_states=True,
)
student = PeftModel.from_pretrained(base, ADAPTER_REPO, token=token)
student.eval()
gating_path = hf_hub_download(
repo_id=ADAPTER_REPO,
filename="gating.pt",
repo_type="model",
token=token,
)
gating = GatingNetwork(STUDENT_HIDDEN_DIM, N_TEACHERS)
gating.load_state_dict(torch.load(gating_path, map_location="cpu"))
gating.eval()
return tok, student, gating
def generate_response(
text: str,
student: nn.Module,
tok,
max_new_tokens: int = 200,
) -> str:
"""Run student generation and return decoded answer text."""
device = next(student.parameters()).device
try:
prompt = tok.apply_chat_template(
[{"role": "user", "content": text}],
tokenize=False,
add_generation_prompt=True,
)
except Exception:
prompt = text
enc = tok(prompt, return_tensors="pt").to(device)
with torch.no_grad():
out = student.generate(
input_ids=enc["input_ids"],
attention_mask=enc["attention_mask"],
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tok.eos_token_id,
)
new_tokens = out[0][enc["input_ids"].shape[1]:]
return tok.decode(new_tokens, skip_special_tokens=True)
def run_probe(
text: str,
student: nn.Module,
tok,
gating: GatingNetwork,
reducer,
max_len: int = 512,
) -> tuple[dict, list[float]]:
"""Single student forward. Returns (new_umap_point, gate_weights).
new_umap_point : {"x": float, "y": float, "z": float, "label": str}
gate_weights : list of N_TEACHERS floats summing to 1.0
"""
enc = tok([text], return_tensors="pt", truncation=True, max_length=max_len)
device = next(student.parameters()).device
enc = {k: v.to(device) for k, v in enc.items()}
with torch.no_grad():
out = student(**enc, output_hidden_states=True, use_cache=False)
pooled = _masked_mean(out.hidden_states[-1], enc["attention_mask"]).float()
gate_weights: list[float] = gating(pooled).squeeze(0).tolist()
coords3d = reducer.transform(pooled.cpu().numpy()) # (1, 3)
return {
"x": float(coords3d[0, 0]),
"y": float(coords3d[0, 1]),
"z": float(coords3d[0, 2]),
"label": "probe",
}, gate_weights
|