File size: 2,958 Bytes
c3c903e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
import torch.nn as nn
from transformers import AutoModel, PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput

from hf.complexity_estimator.configuration_prompt_complexity import PromptComplexityConfig


class PromptComplexityModel(PreTrainedModel):
    config_class = PromptComplexityConfig

    def __init__(self, config: PromptComplexityConfig):
        super().__init__(config)

        self.encoder = AutoModel.from_pretrained(config.base_model_name)
        h = self.encoder.config.hidden_size

        self.post_ln = nn.LayerNorm(h) if config.layernorm_after_pool else nn.Identity()

        if config.use_projection:
            ph = int(h * config.proj_hidden_ratio)
            self.proj = nn.Sequential(
                nn.Dropout(config.dropout),
                nn.Linear(h, ph),
                nn.ReLU(),
                nn.Linear(ph, h),
                nn.ReLU(),
            )
        else:
            self.proj = nn.Identity()

        hidden = config.hidden if config.hidden is not None else max(h // 2, 128)

        layers = [
            nn.Dropout(config.dropout),
            nn.Linear(h, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1),
        ]
        if config.output_sigmoid:
            layers.append(nn.Sigmoid())
        self.head = nn.Sequential(*layers)

        self.post_init()

    def _mean_pool(self, last_hidden, attention_mask):
        mask = attention_mask.unsqueeze(-1).to(last_hidden.dtype)
        summed = (last_hidden * mask).sum(dim=1)
        denom = mask.sum(dim=1).clamp_min(1e-6)
        return summed / denom

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
        pooled = self._mean_pool(out.last_hidden_state, attention_mask)
        pooled = self.post_ln(pooled)
        pooled = self.proj(pooled)

        scores = self.head(pooled).squeeze(-1)  # [B] in [0,1]

        loss = None
        if labels is not None:
            labels = labels.to(scores.dtype).view(-1)
            loss = torch.nn.functional.mse_loss(scores, labels)

        # We’ll store scores inside logits for compatibility (shape [B,1]).
        return SequenceClassifierOutput(loss=loss, logits=scores.unsqueeze(-1))

    @torch.no_grad()
    def predict(self, texts, tokenizer, device=None):
        if isinstance(texts, str):
            texts = [texts]
        inputs = tokenizer(
            texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=self.config.max_length,
        )
        if device is not None:
            self.to(device)
            inputs = {k: v.to(device) for k, v in inputs.items()}
        self.eval()
        scores = self(**inputs).logits.squeeze(-1)
        out = scores.detach().cpu().tolist()
        return out[0] if len(out) == 1 else out