ilya-kolchinsky commited on
Commit
c3c903e
·
verified ·
1 Parent(s): 9856d09

Upload 3 files

Browse files
complexity_estimator/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .configuration_prompt_complexity import PromptComplexityConfig
2
+ from .modeling_prompt_complexity import PromptComplexityModel
3
+
4
+ __all__ = ["PromptComplexityConfig", "PromptComplexityModel"]
complexity_estimator/configuration_prompt_complexity.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class PromptComplexityConfig(PretrainedConfig):
5
+ model_type = "prompt-complexity"
6
+
7
+ def __init__(
8
+ self,
9
+ base_model_name: str = "microsoft/deberta-v3-base",
10
+ max_length: int = 512,
11
+ dropout: float = 0.1,
12
+ hidden: int | None = None,
13
+ layernorm_after_pool: bool = True,
14
+ use_projection: bool = False,
15
+ proj_hidden_ratio: float = 1.0,
16
+ output_sigmoid: bool = True,
17
+ **kwargs,
18
+ ):
19
+ super().__init__(**kwargs)
20
+ self.base_model_name = base_model_name
21
+ self.max_length = max_length
22
+ self.dropout = dropout
23
+ self.hidden = hidden
24
+ self.layernorm_after_pool = layernorm_after_pool
25
+ self.use_projection = use_projection
26
+ self.proj_hidden_ratio = proj_hidden_ratio
27
+ self.output_sigmoid = output_sigmoid
complexity_estimator/modeling_prompt_complexity.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoModel, PreTrainedModel
4
+ from transformers.modeling_outputs import SequenceClassifierOutput
5
+
6
+ from hf.complexity_estimator.configuration_prompt_complexity import PromptComplexityConfig
7
+
8
+
9
+ class PromptComplexityModel(PreTrainedModel):
10
+ config_class = PromptComplexityConfig
11
+
12
+ def __init__(self, config: PromptComplexityConfig):
13
+ super().__init__(config)
14
+
15
+ self.encoder = AutoModel.from_pretrained(config.base_model_name)
16
+ h = self.encoder.config.hidden_size
17
+
18
+ self.post_ln = nn.LayerNorm(h) if config.layernorm_after_pool else nn.Identity()
19
+
20
+ if config.use_projection:
21
+ ph = int(h * config.proj_hidden_ratio)
22
+ self.proj = nn.Sequential(
23
+ nn.Dropout(config.dropout),
24
+ nn.Linear(h, ph),
25
+ nn.ReLU(),
26
+ nn.Linear(ph, h),
27
+ nn.ReLU(),
28
+ )
29
+ else:
30
+ self.proj = nn.Identity()
31
+
32
+ hidden = config.hidden if config.hidden is not None else max(h // 2, 128)
33
+
34
+ layers = [
35
+ nn.Dropout(config.dropout),
36
+ nn.Linear(h, hidden),
37
+ nn.ReLU(),
38
+ nn.Linear(hidden, 1),
39
+ ]
40
+ if config.output_sigmoid:
41
+ layers.append(nn.Sigmoid())
42
+ self.head = nn.Sequential(*layers)
43
+
44
+ self.post_init()
45
+
46
+ def _mean_pool(self, last_hidden, attention_mask):
47
+ mask = attention_mask.unsqueeze(-1).to(last_hidden.dtype)
48
+ summed = (last_hidden * mask).sum(dim=1)
49
+ denom = mask.sum(dim=1).clamp_min(1e-6)
50
+ return summed / denom
51
+
52
+ def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
53
+ out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
54
+ pooled = self._mean_pool(out.last_hidden_state, attention_mask)
55
+ pooled = self.post_ln(pooled)
56
+ pooled = self.proj(pooled)
57
+
58
+ scores = self.head(pooled).squeeze(-1) # [B] in [0,1]
59
+
60
+ loss = None
61
+ if labels is not None:
62
+ labels = labels.to(scores.dtype).view(-1)
63
+ loss = torch.nn.functional.mse_loss(scores, labels)
64
+
65
+ # We’ll store scores inside logits for compatibility (shape [B,1]).
66
+ return SequenceClassifierOutput(loss=loss, logits=scores.unsqueeze(-1))
67
+
68
+ @torch.no_grad()
69
+ def predict(self, texts, tokenizer, device=None):
70
+ if isinstance(texts, str):
71
+ texts = [texts]
72
+ inputs = tokenizer(
73
+ texts,
74
+ return_tensors="pt",
75
+ padding=True,
76
+ truncation=True,
77
+ max_length=self.config.max_length,
78
+ )
79
+ if device is not None:
80
+ self.to(device)
81
+ inputs = {k: v.to(device) for k, v in inputs.items()}
82
+ self.eval()
83
+ scores = self(**inputs).logits.squeeze(-1)
84
+ out = scores.detach().cpu().tolist()
85
+ return out[0] if len(out) == 1 else out