PromptComplexityEstimator / complexity_estimator /modeling_prompt_complexity.py
ilya-kolchinsky's picture
Upload 3 files
c3c903e verified
import torch
import torch.nn as nn
from transformers import AutoModel, PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
from hf.complexity_estimator.configuration_prompt_complexity import PromptComplexityConfig
class PromptComplexityModel(PreTrainedModel):
config_class = PromptComplexityConfig
def __init__(self, config: PromptComplexityConfig):
super().__init__(config)
self.encoder = AutoModel.from_pretrained(config.base_model_name)
h = self.encoder.config.hidden_size
self.post_ln = nn.LayerNorm(h) if config.layernorm_after_pool else nn.Identity()
if config.use_projection:
ph = int(h * config.proj_hidden_ratio)
self.proj = nn.Sequential(
nn.Dropout(config.dropout),
nn.Linear(h, ph),
nn.ReLU(),
nn.Linear(ph, h),
nn.ReLU(),
)
else:
self.proj = nn.Identity()
hidden = config.hidden if config.hidden is not None else max(h // 2, 128)
layers = [
nn.Dropout(config.dropout),
nn.Linear(h, hidden),
nn.ReLU(),
nn.Linear(hidden, 1),
]
if config.output_sigmoid:
layers.append(nn.Sigmoid())
self.head = nn.Sequential(*layers)
self.post_init()
def _mean_pool(self, last_hidden, attention_mask):
mask = attention_mask.unsqueeze(-1).to(last_hidden.dtype)
summed = (last_hidden * mask).sum(dim=1)
denom = mask.sum(dim=1).clamp_min(1e-6)
return summed / denom
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
pooled = self._mean_pool(out.last_hidden_state, attention_mask)
pooled = self.post_ln(pooled)
pooled = self.proj(pooled)
scores = self.head(pooled).squeeze(-1) # [B] in [0,1]
loss = None
if labels is not None:
labels = labels.to(scores.dtype).view(-1)
loss = torch.nn.functional.mse_loss(scores, labels)
# We’ll store scores inside logits for compatibility (shape [B,1]).
return SequenceClassifierOutput(loss=loss, logits=scores.unsqueeze(-1))
@torch.no_grad()
def predict(self, texts, tokenizer, device=None):
if isinstance(texts, str):
texts = [texts]
inputs = tokenizer(
texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.config.max_length,
)
if device is not None:
self.to(device)
inputs = {k: v.to(device) for k, v in inputs.items()}
self.eval()
scores = self(**inputs).logits.squeeze(-1)
out = scores.detach().cpu().tolist()
return out[0] if len(out) == 1 else out