LEG-1.0-aegis2.0-base / modeling_leg.py
asiful109's picture
Upload LEG model export
e18d152 verified
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Iterable
import torch
import torch.nn as nn
from transformers import AutoTokenizer, PreTrainedModel
from transformers.modeling_outputs import ModelOutput
from transformers.models.deberta_v2.modeling_deberta_v2 import DebertaV2Model
try:
from .configuration_leg import LEGConfig
except ImportError: # pragma: no cover
from configuration_leg import LEGConfig
_PROMPT_CLEAN_PATTERN = r"[^\w\s-]|(?<!\w)['\"]|['\"](?!\w)"
@dataclass
class LEGModelOutput(ModelOutput):
prompt_logits: torch.FloatTensor | None = None
token_logits: torch.FloatTensor | None = None
class AttentionPooling(nn.Module):
def __init__(self, hidden_size: int):
super().__init__()
self.attn = nn.Linear(hidden_size, 1)
def forward(
self,
embeddings: torch.Tensor,
attention_mask: torch.Tensor | None,
) -> torch.Tensor:
attn_scores = self.attn(embeddings).squeeze(-1)
if attention_mask is not None:
neg_inf = torch.finfo(attn_scores.dtype).min
attn_scores = attn_scores.masked_fill(attention_mask == 0, neg_inf)
attn_weights = torch.softmax(attn_scores, dim=-1)
return torch.sum(embeddings * attn_weights.unsqueeze(-1), dim=1)
class LEGForSafetyExplanation(PreTrainedModel):
config_class = LEGConfig
base_model_prefix = "bert"
def __init__(self, config: LEGConfig):
super().__init__(config)
self.bert = DebertaV2Model(config)
self.attention_pooling = AttentionPooling(config.hidden_size)
self.prompt_classifier = nn.Linear(config.hidden_size, 2)
self.token_classifier = nn.Linear(config.hidden_size, 2)
# Kept only because these parameters exist in the source checkpoint.
self.log_sigma_prompt = nn.Parameter(torch.zeros(()))
self.log_sigma_token = nn.Parameter(torch.zeros(()))
self._cached_tokenizer = None
self._inference_tokenizer_source = None
self.post_init()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
model._inference_tokenizer_source = str(pretrained_model_name_or_path)
return model
def _get_tokenizer(self, tokenizer=None):
if tokenizer is not None:
return tokenizer
if self._cached_tokenizer is not None:
return self._cached_tokenizer
source = (
self._inference_tokenizer_source
or getattr(self, "name_or_path", None)
or getattr(self.config, "_name_or_path", None)
)
if not source:
raise ValueError(
"Tokenizer source could not be resolved automatically. "
"Pass tokenizer=AutoTokenizer.from_pretrained(...) explicitly."
)
self._cached_tokenizer = AutoTokenizer.from_pretrained(source, use_fast=True)
return self._cached_tokenizer
@staticmethod
def _clean_and_split_prompt(prompt: str) -> list[str]:
cleaned_text = re.sub(_PROMPT_CLEAN_PATTERN, "", prompt)
return cleaned_text.split()
@staticmethod
def _normalize_prompts(prompts: str | Iterable[str]) -> tuple[list[str], bool]:
if isinstance(prompts, str):
return [prompts], True
prompt_list = list(prompts)
return prompt_list, False
def _predict_from_tokenized(
self,
encodings,
words_batch: list[list[str]],
prompt_threshold: float,
word_threshold: float,
) -> list[dict]:
device = next(self.parameters()).device
model_inputs = {
"input_ids": encodings["input_ids"].to(device),
"attention_mask": encodings["attention_mask"].to(device),
}
with torch.inference_mode():
outputs = self.forward(**model_inputs)
prompt_probs = torch.softmax(outputs.prompt_logits, dim=1).cpu()
token_probs = torch.softmax(outputs.token_logits, dim=2).cpu()
formatted_outputs = []
for batch_index, words in enumerate(words_batch):
prompt_safe = prompt_probs[batch_index, 0].item()
prompt_unsafe = prompt_probs[batch_index, 1].item()
safety_label = int(
prompt_unsafe > prompt_safe and prompt_unsafe > prompt_threshold
)
token_safe = token_probs[batch_index, :, 0].tolist()
token_unsafe = token_probs[batch_index, :, 1].tolist()
word_ids = encodings.word_ids(batch_index=batch_index)
word_id_to_label = {}
for token_index, word_id in enumerate(word_ids):
if word_id is None or token_index >= len(token_unsafe):
continue
predicted_label = int(
token_unsafe[token_index] > token_safe[token_index]
and token_unsafe[token_index] > word_threshold
)
if word_id not in word_id_to_label:
word_id_to_label[word_id] = predicted_label
explanation = [
(word, word_id_to_label.get(word_index, 0))
for word_index, word in enumerate(words)
]
formatted_outputs.append(
{
"safety_label": safety_label,
"explanation": explanation,
}
)
return formatted_outputs
def predict_safety(
self,
prompts: str | Iterable[str],
tokenizer=None,
prompt_threshold: float | None = None,
word_threshold: float | None = None,
max_length: int | None = None,
batch_size: int | None = None,
):
prompt_list, single_input = self._normalize_prompts(prompts)
tokenizer = self._get_tokenizer(tokenizer=tokenizer)
if not prompt_list:
return [] if not single_input else {
"safety_label": 0,
"explanation": [],
}
prompt_threshold = (
self.config.prompt_threshold
if prompt_threshold is None
else prompt_threshold
)
word_threshold = (
self.config.word_threshold if word_threshold is None else word_threshold
)
max_length = (
self.config.inference_max_length
if max_length is None
else max_length
)
effective_batch_size = len(prompt_list)
if batch_size is not None:
if batch_size <= 0:
raise ValueError("batch_size must be a positive integer when provided.")
effective_batch_size = batch_size
formatted_outputs = []
for start_idx in range(0, len(prompt_list), effective_batch_size):
prompt_chunk = prompt_list[start_idx : start_idx + effective_batch_size]
words_batch = [
self._clean_and_split_prompt(prompt_text or "")
for prompt_text in prompt_chunk
]
encodings = tokenizer(
words_batch,
is_split_into_words=True,
max_length=max_length,
truncation=True,
padding="max_length",
return_tensors="pt",
)
formatted_outputs.extend(
self._predict_from_tokenized(
encodings=encodings,
words_batch=words_batch,
prompt_threshold=prompt_threshold,
word_threshold=word_threshold,
)
)
return formatted_outputs[0] if single_input else formatted_outputs
def forward(
self,
input_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
token_type_ids: torch.Tensor | None = None,
prompts: str | Iterable[str] | None = None,
prompt: str | None = None,
tokenizer=None,
prompt_threshold: float | None = None,
word_threshold: float | None = None,
max_length: int | None = None,
batch_size: int | None = None,
**kwargs,
):
if prompts is None and prompt is not None:
prompts = prompt
if prompts is not None and input_ids is None:
return self.predict_safety(
prompts=prompts,
tokenizer=tokenizer,
prompt_threshold=prompt_threshold,
word_threshold=word_threshold,
max_length=max_length,
batch_size=batch_size,
)
if input_ids is None:
raise ValueError(
"Provide either tokenized inputs (`input_ids`, `attention_mask`) or "
"raw `prompts`/`prompt` strings."
)
encoder_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"return_dict": True,
}
if token_type_ids is not None:
encoder_kwargs["token_type_ids"] = token_type_ids
encoder_outputs = self.bert(**encoder_kwargs)
hidden_states = encoder_outputs.last_hidden_state
pooled_output = self.attention_pooling(hidden_states, attention_mask)
return LEGModelOutput(
prompt_logits=self.prompt_classifier(pooled_output),
token_logits=self.token_classifier(hidden_states),
)