| from __future__ import annotations |
|
|
| import re |
| from dataclasses import dataclass |
| from typing import Iterable |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import AutoTokenizer, PreTrainedModel |
| from transformers.modeling_outputs import ModelOutput |
| from transformers.models.deberta_v2.modeling_deberta_v2 import DebertaV2Model |
|
|
| try: |
| from .configuration_leg import LEGConfig |
| except ImportError: |
| from configuration_leg import LEGConfig |
|
|
|
|
| _PROMPT_CLEAN_PATTERN = r"[^\w\s-]|(?<!\w)['\"]|['\"](?!\w)" |
|
|
|
|
| @dataclass |
| class LEGModelOutput(ModelOutput): |
| prompt_logits: torch.FloatTensor | None = None |
| token_logits: torch.FloatTensor | None = None |
|
|
|
|
| class AttentionPooling(nn.Module): |
| def __init__(self, hidden_size: int): |
| super().__init__() |
| self.attn = nn.Linear(hidden_size, 1) |
|
|
| def forward( |
| self, |
| embeddings: torch.Tensor, |
| attention_mask: torch.Tensor | None, |
| ) -> torch.Tensor: |
| attn_scores = self.attn(embeddings).squeeze(-1) |
| if attention_mask is not None: |
| neg_inf = torch.finfo(attn_scores.dtype).min |
| attn_scores = attn_scores.masked_fill(attention_mask == 0, neg_inf) |
| attn_weights = torch.softmax(attn_scores, dim=-1) |
| return torch.sum(embeddings * attn_weights.unsqueeze(-1), dim=1) |
|
|
|
|
| class LEGForSafetyExplanation(PreTrainedModel): |
| config_class = LEGConfig |
| base_model_prefix = "bert" |
|
|
| def __init__(self, config: LEGConfig): |
| super().__init__(config) |
| self.bert = DebertaV2Model(config) |
| self.attention_pooling = AttentionPooling(config.hidden_size) |
| self.prompt_classifier = nn.Linear(config.hidden_size, 2) |
| self.token_classifier = nn.Linear(config.hidden_size, 2) |
|
|
| |
| self.log_sigma_prompt = nn.Parameter(torch.zeros(())) |
| self.log_sigma_token = nn.Parameter(torch.zeros(())) |
|
|
| self._cached_tokenizer = None |
| self._inference_tokenizer_source = None |
| self.post_init() |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
| model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) |
| model._inference_tokenizer_source = str(pretrained_model_name_or_path) |
| return model |
|
|
| def _get_tokenizer(self, tokenizer=None): |
| if tokenizer is not None: |
| return tokenizer |
|
|
| if self._cached_tokenizer is not None: |
| return self._cached_tokenizer |
|
|
| source = ( |
| self._inference_tokenizer_source |
| or getattr(self, "name_or_path", None) |
| or getattr(self.config, "_name_or_path", None) |
| ) |
| if not source: |
| raise ValueError( |
| "Tokenizer source could not be resolved automatically. " |
| "Pass tokenizer=AutoTokenizer.from_pretrained(...) explicitly." |
| ) |
|
|
| self._cached_tokenizer = AutoTokenizer.from_pretrained(source, use_fast=True) |
| return self._cached_tokenizer |
|
|
| @staticmethod |
| def _clean_and_split_prompt(prompt: str) -> list[str]: |
| cleaned_text = re.sub(_PROMPT_CLEAN_PATTERN, "", prompt) |
| return cleaned_text.split() |
|
|
| @staticmethod |
| def _normalize_prompts(prompts: str | Iterable[str]) -> tuple[list[str], bool]: |
| if isinstance(prompts, str): |
| return [prompts], True |
| prompt_list = list(prompts) |
| return prompt_list, False |
|
|
| def _predict_from_tokenized( |
| self, |
| encodings, |
| words_batch: list[list[str]], |
| prompt_threshold: float, |
| word_threshold: float, |
| ) -> list[dict]: |
| device = next(self.parameters()).device |
| model_inputs = { |
| "input_ids": encodings["input_ids"].to(device), |
| "attention_mask": encodings["attention_mask"].to(device), |
| } |
|
|
| with torch.inference_mode(): |
| outputs = self.forward(**model_inputs) |
|
|
| prompt_probs = torch.softmax(outputs.prompt_logits, dim=1).cpu() |
| token_probs = torch.softmax(outputs.token_logits, dim=2).cpu() |
|
|
| formatted_outputs = [] |
| for batch_index, words in enumerate(words_batch): |
| prompt_safe = prompt_probs[batch_index, 0].item() |
| prompt_unsafe = prompt_probs[batch_index, 1].item() |
| safety_label = int( |
| prompt_unsafe > prompt_safe and prompt_unsafe > prompt_threshold |
| ) |
|
|
| token_safe = token_probs[batch_index, :, 0].tolist() |
| token_unsafe = token_probs[batch_index, :, 1].tolist() |
| word_ids = encodings.word_ids(batch_index=batch_index) |
|
|
| word_id_to_label = {} |
| for token_index, word_id in enumerate(word_ids): |
| if word_id is None or token_index >= len(token_unsafe): |
| continue |
|
|
| predicted_label = int( |
| token_unsafe[token_index] > token_safe[token_index] |
| and token_unsafe[token_index] > word_threshold |
| ) |
| if word_id not in word_id_to_label: |
| word_id_to_label[word_id] = predicted_label |
|
|
| explanation = [ |
| (word, word_id_to_label.get(word_index, 0)) |
| for word_index, word in enumerate(words) |
| ] |
| formatted_outputs.append( |
| { |
| "safety_label": safety_label, |
| "explanation": explanation, |
| } |
| ) |
|
|
| return formatted_outputs |
|
|
| def predict_safety( |
| self, |
| prompts: str | Iterable[str], |
| tokenizer=None, |
| prompt_threshold: float | None = None, |
| word_threshold: float | None = None, |
| max_length: int | None = None, |
| batch_size: int | None = None, |
| ): |
| prompt_list, single_input = self._normalize_prompts(prompts) |
| tokenizer = self._get_tokenizer(tokenizer=tokenizer) |
|
|
| if not prompt_list: |
| return [] if not single_input else { |
| "safety_label": 0, |
| "explanation": [], |
| } |
|
|
| prompt_threshold = ( |
| self.config.prompt_threshold |
| if prompt_threshold is None |
| else prompt_threshold |
| ) |
| word_threshold = ( |
| self.config.word_threshold if word_threshold is None else word_threshold |
| ) |
| max_length = ( |
| self.config.inference_max_length |
| if max_length is None |
| else max_length |
| ) |
|
|
| effective_batch_size = len(prompt_list) |
| if batch_size is not None: |
| if batch_size <= 0: |
| raise ValueError("batch_size must be a positive integer when provided.") |
| effective_batch_size = batch_size |
|
|
| formatted_outputs = [] |
| for start_idx in range(0, len(prompt_list), effective_batch_size): |
| prompt_chunk = prompt_list[start_idx : start_idx + effective_batch_size] |
| words_batch = [ |
| self._clean_and_split_prompt(prompt_text or "") |
| for prompt_text in prompt_chunk |
| ] |
|
|
| encodings = tokenizer( |
| words_batch, |
| is_split_into_words=True, |
| max_length=max_length, |
| truncation=True, |
| padding="max_length", |
| return_tensors="pt", |
| ) |
|
|
| formatted_outputs.extend( |
| self._predict_from_tokenized( |
| encodings=encodings, |
| words_batch=words_batch, |
| prompt_threshold=prompt_threshold, |
| word_threshold=word_threshold, |
| ) |
| ) |
| return formatted_outputs[0] if single_input else formatted_outputs |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor | None = None, |
| attention_mask: torch.Tensor | None = None, |
| token_type_ids: torch.Tensor | None = None, |
| prompts: str | Iterable[str] | None = None, |
| prompt: str | None = None, |
| tokenizer=None, |
| prompt_threshold: float | None = None, |
| word_threshold: float | None = None, |
| max_length: int | None = None, |
| batch_size: int | None = None, |
| **kwargs, |
| ): |
| if prompts is None and prompt is not None: |
| prompts = prompt |
|
|
| if prompts is not None and input_ids is None: |
| return self.predict_safety( |
| prompts=prompts, |
| tokenizer=tokenizer, |
| prompt_threshold=prompt_threshold, |
| word_threshold=word_threshold, |
| max_length=max_length, |
| batch_size=batch_size, |
| ) |
|
|
| if input_ids is None: |
| raise ValueError( |
| "Provide either tokenized inputs (`input_ids`, `attention_mask`) or " |
| "raw `prompts`/`prompt` strings." |
| ) |
|
|
| encoder_kwargs = { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "return_dict": True, |
| } |
| if token_type_ids is not None: |
| encoder_kwargs["token_type_ids"] = token_type_ids |
|
|
| encoder_outputs = self.bert(**encoder_kwargs) |
| hidden_states = encoder_outputs.last_hidden_state |
| pooled_output = self.attention_pooling(hidden_states, attention_mask) |
|
|
| return LEGModelOutput( |
| prompt_logits=self.prompt_classifier(pooled_output), |
| token_logits=self.token_classifier(hidden_states), |
| ) |
|
|