File size: 2,954 Bytes
6f0ff99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""Frozen sentence encoder + three task heads (capability, difficulty, length)."""

from __future__ import annotations

from dataclasses import dataclass

DEFAULT_ENCODER = "BAAI/bge-small-en-v1.5"


@dataclass
class ModelSpec:
    encoder_name: str = DEFAULT_ENCODER
    embedding_dim: int = 384
    hidden_dim: int = 256
    n_capabilities: int = 8
    n_length_buckets: int = 3
    dropout: float = 0.1
    max_seq_len: int = 256


def build_head(spec: ModelSpec):
    import torch.nn as nn

    class HeadStack(nn.Module):
        def __init__(self, s: ModelSpec):
            super().__init__()
            self.shared = nn.Sequential(
                nn.Linear(s.embedding_dim, s.hidden_dim),
                nn.GELU(),
                nn.Dropout(s.dropout),
                nn.Linear(s.hidden_dim, s.hidden_dim),
                nn.GELU(),
                nn.Dropout(s.dropout),
            )
            self.cap_head = nn.Linear(s.hidden_dim, s.n_capabilities)
            self.diff_head = nn.Linear(s.hidden_dim, 1)
            self.len_head = nn.Linear(s.hidden_dim, s.n_length_buckets)

        def forward(self, embeddings):
            h = self.shared(embeddings)
            return {
                "cap_logits": self.cap_head(h),
                "diff": self.diff_head(h).squeeze(-1),
                "len_logits": self.len_head(h),
            }

    return HeadStack(spec)


class Encoder:
    """Lazy wrapper around a HuggingFace sentence encoder, mean-pooled and L2-normalized."""

    def __init__(self, encoder_name: str = DEFAULT_ENCODER, max_seq_len: int = 256):
        self.encoder_name = encoder_name
        self.max_seq_len = max_seq_len
        self._tokenizer = None
        self._model = None
        self._device = None

    def _ensure_loaded(self):
        if self._model is not None:
            return
        import torch
        from transformers import AutoModel, AutoTokenizer

        self._tokenizer = AutoTokenizer.from_pretrained(self.encoder_name)
        self._model = AutoModel.from_pretrained(self.encoder_name)
        self._device = "cuda" if torch.cuda.is_available() else "cpu"
        self._model.to(self._device).eval()
        for p in self._model.parameters():
            p.requires_grad = False

    @property
    def device(self) -> str:
        self._ensure_loaded()
        return self._device

    def embed(self, texts: list[str]):
        import torch
        import torch.nn.functional as F

        self._ensure_loaded()
        enc = self._tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=self.max_seq_len,
            return_tensors="pt",
        ).to(self._device)
        with torch.no_grad():
            out = self._model(**enc)
        mask = enc["attention_mask"].unsqueeze(-1).float()
        pooled = (out.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
        return F.normalize(pooled, dim=-1)