|
|
from transformers import AutoTokenizer |
|
|
from transformers.modeling_outputs import ModelOutput |
|
|
from typing import List, Dict, Optional, Union, Tuple |
|
|
from dataclasses import dataclass |
|
|
import torch |
|
|
|
|
|
from gigacheck.model.mistral_ai_detector import MistralAIDetectorForSequenceClassification |
|
|
from gigacheck.model.src.interval_detector.span_utils import span_cxw_to_xx |
|
|
|
|
|
from .configuration_gigacheck import GigaCheckConfig |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class GigaCheckOutput(ModelOutput): |
|
|
""" |
|
|
Output type for GigaCheck model. |
|
|
|
|
|
Args: |
|
|
pred_label_ids (torch.Tensor): [Batch] Indices of the predicted classes (Human/AI/Mixed). |
|
|
classification_head_probs (torch.Tensor): [Batch, Num_Classes] Softmax probabilities. |
|
|
""" |
|
|
pred_label_ids: Optional[torch.Tensor] = None |
|
|
classification_head_probs: Optional[torch.Tensor] = None |
|
|
|
|
|
|
|
|
class GigaCheckForSequenceClassification(MistralAIDetectorForSequenceClassification): |
|
|
config_class = GigaCheckConfig |
|
|
|
|
|
def __init__(self, config: GigaCheckConfig): |
|
|
super().__init__( |
|
|
config, |
|
|
with_detr = False, |
|
|
detr_config = None, |
|
|
ce_weights = None, |
|
|
freeze_backbone = False, |
|
|
id2label = config.id2label, |
|
|
) |
|
|
self.trained_classification_head = True |
|
|
self._max_len = self.config.max_length |
|
|
self.tokenizer = None |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs): |
|
|
"""Loads a pretrained GigaCheck model from a local path or the Hugging Face Hub. |
|
|
|
|
|
Args: |
|
|
pretrained_model_name_or_path (str): The name or path of the pretrained model. |
|
|
model_args: Additional positional arguments passed to parent class. |
|
|
kwargs: Additional keyword arguments passed to parent class. |
|
|
|
|
|
Returns: |
|
|
GigaCheckForSequenceClassification: The initialized model with loaded weights and initialized tokenizer. |
|
|
""" |
|
|
|
|
|
model = super().from_pretrained( |
|
|
pretrained_model_name_or_path, |
|
|
*model_args, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
if model.config.to_dict().get("trained_classification_head", True) is False: |
|
|
|
|
|
model.trained_classification_head = False |
|
|
|
|
|
model.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) |
|
|
|
|
|
|
|
|
model.config.pad_token_id = model.tokenizer.pad_token_id \ |
|
|
if model.tokenizer.pad_token_id is not None else model.tokenizer.unk_token_id |
|
|
if model.tokenizer.pad_token_id is None: |
|
|
model.tokenizer.pad_token_id = model.tokenizer.unk_token_id |
|
|
|
|
|
model.config.bos_token_id = model.tokenizer.bos_token_id |
|
|
model.config.eos_token_id = model.tokenizer.eos_token_id |
|
|
model.config.unk_token_id = model.tokenizer.unk_token_id |
|
|
|
|
|
return model |
|
|
|
|
|
def _get_inputs(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: |
|
|
""" |
|
|
Tokenizes a batch of texts handling specific truncation logic to preserve exact text length mapping. |
|
|
""" |
|
|
assert self._max_len is not None and self.tokenizer is not None, "Model must be initialized" |
|
|
|
|
|
|
|
|
raw_encodings = self.tokenizer(texts, add_special_tokens=False) |
|
|
|
|
|
batch_features = [] |
|
|
text_lens = [] |
|
|
|
|
|
content_max_len = self._max_len - 2 |
|
|
bos_id = self.tokenizer.bos_token_id |
|
|
eos_id = self.tokenizer.eos_token_id |
|
|
|
|
|
for i, tokens in enumerate(raw_encodings.input_ids): |
|
|
if len(tokens) > content_max_len: |
|
|
tokens = tokens[:content_max_len] |
|
|
|
|
|
cur_text = self.tokenizer.decode(tokens, skip_special_tokens=True) |
|
|
text_len = len(cur_text) |
|
|
else: |
|
|
|
|
|
text_len = len(texts[i]) |
|
|
|
|
|
|
|
|
final_tokens = [bos_id] + tokens + [eos_id] |
|
|
|
|
|
|
|
|
batch_features.append({"input_ids": final_tokens}) |
|
|
text_lens.append(text_len) |
|
|
|
|
|
|
|
|
padded_output = self.tokenizer.pad( |
|
|
batch_features, |
|
|
padding=True, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
input_ids = padded_output["input_ids"].to(self.device) |
|
|
attention_mask = padded_output["attention_mask"].to(self.device) |
|
|
|
|
|
return input_ids, attention_mask, text_lens |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
text: Union[str, List[str]], |
|
|
return_dict: Optional[bool] = None, |
|
|
) -> Union[Tuple, GigaCheckOutput]: |
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if isinstance(text, str): |
|
|
text = [text] |
|
|
|
|
|
input_ids, attention_mask, text_lens = self._get_inputs(text) |
|
|
|
|
|
output = super().forward( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
return_dict=True, |
|
|
return_detr_output=self.config.with_detr, |
|
|
) |
|
|
|
|
|
|
|
|
logits = output.logits |
|
|
|
|
|
|
|
|
probs = logits.to(torch.float32).softmax(dim=-1) |
|
|
pred_label_ids = torch.argmax(probs, dim=-1) |
|
|
classification_head_probs = probs |
|
|
|
|
|
if not return_dict: |
|
|
return (pred_label_ids, classification_head_probs) |
|
|
|
|
|
return GigaCheckOutput( |
|
|
pred_label_ids=pred_label_ids, |
|
|
classification_head_probs=classification_head_probs, |
|
|
) |
|
|
|
|
|
|
|
|
def to_absolute(pred_spans: torch.Tensor, text_len: int) -> torch.Tensor: |
|
|
spans = span_cxw_to_xx(pred_spans) * text_len |
|
|
return torch.clamp(spans, 0, text_len) |
|
|
|