"""Frozen sentence encoder + three task heads (capability, difficulty, length).""" from __future__ import annotations from dataclasses import dataclass DEFAULT_ENCODER = "BAAI/bge-small-en-v1.5" @dataclass class ModelSpec: encoder_name: str = DEFAULT_ENCODER embedding_dim: int = 384 hidden_dim: int = 256 n_capabilities: int = 8 n_length_buckets: int = 3 dropout: float = 0.1 max_seq_len: int = 256 def build_head(spec: ModelSpec): import torch.nn as nn class HeadStack(nn.Module): def __init__(self, s: ModelSpec): super().__init__() self.shared = nn.Sequential( nn.Linear(s.embedding_dim, s.hidden_dim), nn.GELU(), nn.Dropout(s.dropout), nn.Linear(s.hidden_dim, s.hidden_dim), nn.GELU(), nn.Dropout(s.dropout), ) self.cap_head = nn.Linear(s.hidden_dim, s.n_capabilities) self.diff_head = nn.Linear(s.hidden_dim, 1) self.len_head = nn.Linear(s.hidden_dim, s.n_length_buckets) def forward(self, embeddings): h = self.shared(embeddings) return { "cap_logits": self.cap_head(h), "diff": self.diff_head(h).squeeze(-1), "len_logits": self.len_head(h), } return HeadStack(spec) class Encoder: """Lazy wrapper around a HuggingFace sentence encoder, mean-pooled and L2-normalized.""" def __init__(self, encoder_name: str = DEFAULT_ENCODER, max_seq_len: int = 256): self.encoder_name = encoder_name self.max_seq_len = max_seq_len self._tokenizer = None self._model = None self._device = None def _ensure_loaded(self): if self._model is not None: return import torch from transformers import AutoModel, AutoTokenizer self._tokenizer = AutoTokenizer.from_pretrained(self.encoder_name) self._model = AutoModel.from_pretrained(self.encoder_name) self._device = "cuda" if torch.cuda.is_available() else "cpu" self._model.to(self._device).eval() for p in self._model.parameters(): p.requires_grad = False @property def device(self) -> str: self._ensure_loaded() return self._device def embed(self, texts: list[str]): import torch import torch.nn.functional as F self._ensure_loaded() enc = self._tokenizer( texts, padding=True, truncation=True, max_length=self.max_seq_len, return_tensors="pt", ).to(self._device) with torch.no_grad(): out = self._model(**enc) mask = enc["attention_mask"].unsqueeze(-1).float() pooled = (out.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) return F.normalize(pooled, dim=-1)