Spaces:
Sleeping
Sleeping
| """Frozen sentence encoder + three task heads (capability, difficulty, length).""" | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| DEFAULT_ENCODER = "BAAI/bge-small-en-v1.5" | |
| class ModelSpec: | |
| encoder_name: str = DEFAULT_ENCODER | |
| embedding_dim: int = 384 | |
| hidden_dim: int = 256 | |
| n_capabilities: int = 8 | |
| n_length_buckets: int = 3 | |
| dropout: float = 0.1 | |
| max_seq_len: int = 256 | |
| def build_head(spec: ModelSpec): | |
| import torch.nn as nn | |
| class HeadStack(nn.Module): | |
| def __init__(self, s: ModelSpec): | |
| super().__init__() | |
| self.shared = nn.Sequential( | |
| nn.Linear(s.embedding_dim, s.hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(s.dropout), | |
| nn.Linear(s.hidden_dim, s.hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(s.dropout), | |
| ) | |
| self.cap_head = nn.Linear(s.hidden_dim, s.n_capabilities) | |
| self.diff_head = nn.Linear(s.hidden_dim, 1) | |
| self.len_head = nn.Linear(s.hidden_dim, s.n_length_buckets) | |
| def forward(self, embeddings): | |
| h = self.shared(embeddings) | |
| return { | |
| "cap_logits": self.cap_head(h), | |
| "diff": self.diff_head(h).squeeze(-1), | |
| "len_logits": self.len_head(h), | |
| } | |
| return HeadStack(spec) | |
| class Encoder: | |
| """Lazy wrapper around a HuggingFace sentence encoder, mean-pooled and L2-normalized.""" | |
| def __init__(self, encoder_name: str = DEFAULT_ENCODER, max_seq_len: int = 256): | |
| self.encoder_name = encoder_name | |
| self.max_seq_len = max_seq_len | |
| self._tokenizer = None | |
| self._model = None | |
| self._device = None | |
| def _ensure_loaded(self): | |
| if self._model is not None: | |
| return | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer | |
| self._tokenizer = AutoTokenizer.from_pretrained(self.encoder_name) | |
| self._model = AutoModel.from_pretrained(self.encoder_name) | |
| self._device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self._model.to(self._device).eval() | |
| for p in self._model.parameters(): | |
| p.requires_grad = False | |
| def device(self) -> str: | |
| self._ensure_loaded() | |
| return self._device | |
| def embed(self, texts: list[str]): | |
| import torch | |
| import torch.nn.functional as F | |
| self._ensure_loaded() | |
| enc = self._tokenizer( | |
| texts, | |
| padding=True, | |
| truncation=True, | |
| max_length=self.max_seq_len, | |
| return_tensors="pt", | |
| ).to(self._device) | |
| with torch.no_grad(): | |
| out = self._model(**enc) | |
| mask = enc["attention_mask"].unsqueeze(-1).float() | |
| pooled = (out.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) | |
| return F.normalize(pooled, dim=-1) | |