|
|
from transformers import AutoTokenizer |
|
|
from transformers.modeling_outputs import ModelOutput |
|
|
from typing import List, Dict, Optional, Union, Tuple |
|
|
from dataclasses import dataclass |
|
|
import torch |
|
|
|
|
|
from gigacheck.model.mistral_ai_detector import MistralAIDetectorForSequenceClassification |
|
|
from gigacheck.model.src.interval_detector.span_utils import span_cxw_to_xx |
|
|
|
|
|
from .configuration_gigacheck import GigaCheckConfig |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class GigaCheckOutput(ModelOutput): |
|
|
""" |
|
|
Output type for GigaCheck model. |
|
|
|
|
|
Args: |
|
|
pred_label_ids (torch.Tensor): [Batch] Indices of the predicted classes (Human/AI/Mixed). |
|
|
classification_head_probs (torch.Tensor): [Batch, Num_Classes] Softmax probabilities. |
|
|
ai_intervals (List[torch.Tensor]): List of length Batch. Each element is a tensor of shape [Num_Intervals, 3] |
|
|
containing (start, end, score) for detected AI-generated spans. |
|
|
""" |
|
|
pred_label_ids: Optional[torch.Tensor] = None |
|
|
classification_head_probs: Optional[torch.Tensor] = None |
|
|
ai_intervals: Optional[List[torch.Tensor]] = None |
|
|
|
|
|
|
|
|
class GigaCheckForDetection(MistralAIDetectorForSequenceClassification): |
|
|
config_class = GigaCheckConfig |
|
|
|
|
|
def __init__(self, config: GigaCheckConfig): |
|
|
super().__init__( |
|
|
config, |
|
|
with_detr = config.with_detr, |
|
|
detr_config = config.detr_config, |
|
|
ce_weights = None, |
|
|
freeze_backbone = False, |
|
|
id2label = config.id2label, |
|
|
) |
|
|
self.trained_classification_head = True |
|
|
self._max_len = self.config.max_length |
|
|
self.tokenizer = None |
|
|
self.conf_interval_thresh = config.conf_interval_thresh |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs): |
|
|
"""Loads a pretrained GigaCheck model from a local path or the Hugging Face Hub. |
|
|
|
|
|
Args: |
|
|
pretrained_model_name_or_path (str): The name or path of the pretrained model. |
|
|
model_args: Additional positional arguments passed to parent class. |
|
|
kwargs: Additional keyword arguments passed to parent class. |
|
|
|
|
|
Returns: |
|
|
GigaCheckForSequenceClassification: The initialized model with loaded weights and initialized tokenizer. |
|
|
""" |
|
|
|
|
|
model = super().from_pretrained( |
|
|
pretrained_model_name_or_path, |
|
|
*model_args, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
if model.config.with_detr: |
|
|
extractor_dtype = getattr(torch, model.config.detr_config["extractor_dtype"]) |
|
|
print(f"Using dtype={extractor_dtype} for {type(model.model)}") |
|
|
if extractor_dtype == torch.bfloat16: |
|
|
model.model.to(torch.bfloat16) |
|
|
model.classification_head.to(torch.bfloat16) |
|
|
|
|
|
if model.config.to_dict().get("trained_classification_head", True) is False: |
|
|
|
|
|
model.trained_classification_head = False |
|
|
|
|
|
model.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) |
|
|
|
|
|
|
|
|
model.config.pad_token_id = model.tokenizer.pad_token_id \ |
|
|
if model.tokenizer.pad_token_id is not None else model.tokenizer.unk_token_id |
|
|
if model.tokenizer.pad_token_id is None: |
|
|
model.tokenizer.pad_token_id = model.tokenizer.unk_token_id |
|
|
|
|
|
model.config.bos_token_id = model.tokenizer.bos_token_id |
|
|
model.config.eos_token_id = model.tokenizer.eos_token_id |
|
|
model.config.unk_token_id = model.tokenizer.unk_token_id |
|
|
|
|
|
return model |
|
|
|
|
|
def _get_inputs(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: |
|
|
""" |
|
|
Tokenizes a batch of texts handling specific truncation logic to preserve exact text length mapping. |
|
|
""" |
|
|
assert self._max_len is not None and self.tokenizer is not None, "Model must be initialized" |
|
|
|
|
|
|
|
|
raw_encodings = self.tokenizer(texts, add_special_tokens=False) |
|
|
|
|
|
batch_features = [] |
|
|
text_lens = [] |
|
|
|
|
|
content_max_len = self._max_len - 2 |
|
|
bos_id = self.tokenizer.bos_token_id |
|
|
eos_id = self.tokenizer.eos_token_id |
|
|
|
|
|
for i, tokens in enumerate(raw_encodings.input_ids): |
|
|
if len(tokens) > content_max_len: |
|
|
tokens = tokens[:content_max_len] |
|
|
|
|
|
cur_text = self.tokenizer.decode(tokens, skip_special_tokens=True) |
|
|
text_len = len(cur_text) |
|
|
else: |
|
|
|
|
|
text_len = len(texts[i]) |
|
|
|
|
|
|
|
|
final_tokens = [bos_id] + tokens + [eos_id] |
|
|
|
|
|
|
|
|
batch_features.append({"input_ids": final_tokens}) |
|
|
text_lens.append(text_len) |
|
|
|
|
|
|
|
|
padded_output = self.tokenizer.pad( |
|
|
batch_features, |
|
|
padding=True, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
input_ids = padded_output["input_ids"].to(self.device) |
|
|
attention_mask = padded_output["attention_mask"].to(self.device) |
|
|
|
|
|
return input_ids, attention_mask, text_lens |
|
|
|
|
|
@staticmethod |
|
|
def _get_ai_intervals(detr_out: Dict[str, torch.Tensor], text_lens: List[int], conf_interval_thresh: float) -> List[torch.Tensor]: |
|
|
""" |
|
|
Converts DETR outputs to absolute text intervals. |
|
|
""" |
|
|
pred_spans = detr_out["pred_spans"] |
|
|
src_logits = detr_out["pred_logits"] |
|
|
assert len(text_lens) == pred_spans.shape[0] |
|
|
|
|
|
|
|
|
pred_probs = torch.softmax(src_logits, dim=-1)[:, :, 0:1] |
|
|
|
|
|
final_preds_batch = [] |
|
|
|
|
|
for i, length in enumerate(text_lens): |
|
|
|
|
|
|
|
|
spans_abs = to_absolute(pred_spans[i], length) |
|
|
|
|
|
|
|
|
scores = pred_probs[i] |
|
|
preds_i = torch.cat([spans_abs, scores], dim=1) |
|
|
|
|
|
|
|
|
mask = preds_i[:, 2] > conf_interval_thresh |
|
|
filtered_preds = preds_i[mask] |
|
|
|
|
|
final_preds_batch.append(filtered_preds) |
|
|
|
|
|
return final_preds_batch |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
text: Union[str, List[str]], |
|
|
return_dict: Optional[bool] = None, |
|
|
conf_interval_thresh: float = None, |
|
|
) -> Union[Tuple, GigaCheckOutput]: |
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
conf_interval_thresh = conf_interval_thresh if conf_interval_thresh is not None else self.config.conf_interval_thresh |
|
|
|
|
|
if isinstance(text, str): |
|
|
text = [text] |
|
|
|
|
|
input_ids, attention_mask, text_lens = self._get_inputs(text) |
|
|
|
|
|
output = super().forward( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
return_dict=True, |
|
|
return_detr_output=self.config.with_detr, |
|
|
) |
|
|
|
|
|
pred_label_ids = None |
|
|
classification_head_probs = None |
|
|
ai_intervals = None |
|
|
|
|
|
|
|
|
if not self.config.with_detr: |
|
|
logits = output.logits |
|
|
elif self.trained_classification_head: |
|
|
logits, _ = output.logits |
|
|
else: |
|
|
logits = None |
|
|
|
|
|
if logits is not None: |
|
|
|
|
|
probs = logits.to(torch.float32).softmax(dim=-1) |
|
|
pred_label_ids = torch.argmax(probs, dim=-1) |
|
|
classification_head_probs = probs |
|
|
|
|
|
|
|
|
if self.config.with_detr: |
|
|
_, detr_out = output.logits |
|
|
ai_intervals = self._get_ai_intervals(detr_out, text_lens, conf_interval_thresh) |
|
|
|
|
|
if not return_dict: |
|
|
return (pred_label_ids, classification_head_probs, ai_intervals) |
|
|
|
|
|
return GigaCheckOutput( |
|
|
pred_label_ids=pred_label_ids, |
|
|
classification_head_probs=classification_head_probs, |
|
|
ai_intervals=ai_intervals, |
|
|
) |
|
|
|
|
|
|
|
|
def to_absolute(pred_spans: torch.Tensor, text_len: int) -> torch.Tensor: |
|
|
spans = span_cxw_to_xx(pred_spans) * text_len |
|
|
return torch.clamp(spans, 0, text_len) |
|
|
|