| from __future__ import annotations |
| import os |
| import torch |
| import torch.nn as nn |
| import numpy as np |
|
|
| |
| STUDENT_BASE = "Qwen/Qwen2.5-0.5B-Instruct" |
| ADAPTER_REPO = "build-small-hackathon/deku" |
| STUDENT_HIDDEN_DIM = 896 |
| N_TEACHERS = 5 |
|
|
|
|
| class GatingNetwork(nn.Module): |
| def __init__(self, hidden_dim: int, n_teachers: int): |
| super().__init__() |
| self.fc = nn.Linear(hidden_dim, n_teachers) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return torch.softmax(self.fc(x), dim=-1) |
|
|
|
|
| def _masked_mean(hidden: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| m = mask.unsqueeze(-1).to(hidden.dtype) |
| return (hidden * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0) |
|
|
|
|
| def load_student(hf_token: str | None = None): |
| """Load student (base + LoRA) and gating network from HF Hub. |
| |
| Returns (tokenizer, student_model, gating_network). |
| """ |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from peft import PeftModel |
| from huggingface_hub import hf_hub_download |
|
|
| token = hf_token or os.environ.get("HF_TOKEN") |
|
|
| tok = AutoTokenizer.from_pretrained(STUDENT_BASE) |
| if tok.pad_token is None: |
| tok.pad_token = tok.eos_token |
|
|
| base = AutoModelForCausalLM.from_pretrained( |
| STUDENT_BASE, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| output_hidden_states=True, |
| ) |
| student = PeftModel.from_pretrained(base, ADAPTER_REPO, token=token) |
| student.eval() |
|
|
| gating_path = hf_hub_download( |
| repo_id=ADAPTER_REPO, |
| filename="gating.pt", |
| repo_type="model", |
| token=token, |
| ) |
| gating = GatingNetwork(STUDENT_HIDDEN_DIM, N_TEACHERS) |
| gating.load_state_dict(torch.load(gating_path, map_location="cpu")) |
| gating.eval() |
|
|
| return tok, student, gating |
|
|
|
|
| def generate_response( |
| text: str, |
| student: nn.Module, |
| tok, |
| max_new_tokens: int = 200, |
| ) -> str: |
| """Run student generation and return decoded answer text.""" |
| device = next(student.parameters()).device |
| try: |
| prompt = tok.apply_chat_template( |
| [{"role": "user", "content": text}], |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
| except Exception: |
| prompt = text |
| enc = tok(prompt, return_tensors="pt").to(device) |
| with torch.no_grad(): |
| out = student.generate( |
| input_ids=enc["input_ids"], |
| attention_mask=enc["attention_mask"], |
| max_new_tokens=max_new_tokens, |
| do_sample=False, |
| pad_token_id=tok.eos_token_id, |
| ) |
| new_tokens = out[0][enc["input_ids"].shape[1]:] |
| return tok.decode(new_tokens, skip_special_tokens=True) |
|
|
|
|
| def run_probe( |
| text: str, |
| student: nn.Module, |
| tok, |
| gating: GatingNetwork, |
| reducer, |
| max_len: int = 512, |
| ) -> tuple[dict, list[float]]: |
| """Single student forward. Returns (new_umap_point, gate_weights). |
| |
| new_umap_point : {"x": float, "y": float, "z": float, "label": str} |
| gate_weights : list of N_TEACHERS floats summing to 1.0 |
| """ |
| enc = tok([text], return_tensors="pt", truncation=True, max_length=max_len) |
| device = next(student.parameters()).device |
| enc = {k: v.to(device) for k, v in enc.items()} |
|
|
| with torch.no_grad(): |
| out = student(**enc, output_hidden_states=True, use_cache=False) |
|
|
| pooled = _masked_mean(out.hidden_states[-1], enc["attention_mask"]).float() |
| gate_weights: list[float] = gating(pooled).squeeze(0).tolist() |
|
|
| coords3d = reducer.transform(pooled.cpu().numpy()) |
| return { |
| "x": float(coords3d[0, 0]), |
| "y": float(coords3d[0, 1]), |
| "z": float(coords3d[0, 2]), |
| "label": "probe", |
| }, gate_weights |
|
|