spectralman's picture
Initial deploy: classifier + FastAPI router
6f0ff99 verified
"""Frozen sentence encoder + three task heads (capability, difficulty, length)."""
from __future__ import annotations
from dataclasses import dataclass
DEFAULT_ENCODER = "BAAI/bge-small-en-v1.5"
@dataclass
class ModelSpec:
encoder_name: str = DEFAULT_ENCODER
embedding_dim: int = 384
hidden_dim: int = 256
n_capabilities: int = 8
n_length_buckets: int = 3
dropout: float = 0.1
max_seq_len: int = 256
def build_head(spec: ModelSpec):
import torch.nn as nn
class HeadStack(nn.Module):
def __init__(self, s: ModelSpec):
super().__init__()
self.shared = nn.Sequential(
nn.Linear(s.embedding_dim, s.hidden_dim),
nn.GELU(),
nn.Dropout(s.dropout),
nn.Linear(s.hidden_dim, s.hidden_dim),
nn.GELU(),
nn.Dropout(s.dropout),
)
self.cap_head = nn.Linear(s.hidden_dim, s.n_capabilities)
self.diff_head = nn.Linear(s.hidden_dim, 1)
self.len_head = nn.Linear(s.hidden_dim, s.n_length_buckets)
def forward(self, embeddings):
h = self.shared(embeddings)
return {
"cap_logits": self.cap_head(h),
"diff": self.diff_head(h).squeeze(-1),
"len_logits": self.len_head(h),
}
return HeadStack(spec)
class Encoder:
"""Lazy wrapper around a HuggingFace sentence encoder, mean-pooled and L2-normalized."""
def __init__(self, encoder_name: str = DEFAULT_ENCODER, max_seq_len: int = 256):
self.encoder_name = encoder_name
self.max_seq_len = max_seq_len
self._tokenizer = None
self._model = None
self._device = None
def _ensure_loaded(self):
if self._model is not None:
return
import torch
from transformers import AutoModel, AutoTokenizer
self._tokenizer = AutoTokenizer.from_pretrained(self.encoder_name)
self._model = AutoModel.from_pretrained(self.encoder_name)
self._device = "cuda" if torch.cuda.is_available() else "cpu"
self._model.to(self._device).eval()
for p in self._model.parameters():
p.requires_grad = False
@property
def device(self) -> str:
self._ensure_loaded()
return self._device
def embed(self, texts: list[str]):
import torch
import torch.nn.functional as F
self._ensure_loaded()
enc = self._tokenizer(
texts,
padding=True,
truncation=True,
max_length=self.max_seq_len,
return_tensors="pt",
).to(self._device)
with torch.no_grad():
out = self._model(**enc)
mask = enc["attention_mask"].unsqueeze(-1).float()
pooled = (out.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
return F.normalize(pooled, dim=-1)