GigaCheck-Detector-Multi / modeling_gigacheck.py
iitolstykh's picture
Upload 2 files
ef31f0e verified
from transformers import AutoTokenizer
from transformers.modeling_outputs import ModelOutput
from typing import List, Dict, Optional, Union, Tuple
from dataclasses import dataclass
import torch
from gigacheck.model.mistral_ai_detector import MistralAIDetectorForSequenceClassification
from gigacheck.model.src.interval_detector.span_utils import span_cxw_to_xx
from .configuration_gigacheck import GigaCheckConfig
@dataclass
class GigaCheckOutput(ModelOutput):
"""
Output type for GigaCheck model.
Args:
pred_label_ids (torch.Tensor): [Batch] Indices of the predicted classes (Human/AI/Mixed).
classification_head_probs (torch.Tensor): [Batch, Num_Classes] Softmax probabilities.
ai_intervals (List[torch.Tensor]): List of length Batch. Each element is a tensor of shape [Num_Intervals, 3]
containing (start, end, score) for detected AI-generated spans.
"""
pred_label_ids: Optional[torch.Tensor] = None
classification_head_probs: Optional[torch.Tensor] = None
ai_intervals: Optional[List[torch.Tensor]] = None
class GigaCheckForDetection(MistralAIDetectorForSequenceClassification):
config_class = GigaCheckConfig
def __init__(self, config: GigaCheckConfig):
super().__init__(
config,
with_detr = config.with_detr,
detr_config = config.detr_config,
ce_weights = None,
freeze_backbone = False,
id2label = config.id2label,
)
self.trained_classification_head = True
self._max_len = self.config.max_length
self.tokenizer = None
self.conf_interval_thresh = config.conf_interval_thresh
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs): # type: ignore
"""Loads a pretrained GigaCheck model from a local path or the Hugging Face Hub.
Args:
pretrained_model_name_or_path (str): The name or path of the pretrained model.
model_args: Additional positional arguments passed to parent class.
kwargs: Additional keyword arguments passed to parent class.
Returns:
GigaCheckForSequenceClassification: The initialized model with loaded weights and initialized tokenizer.
"""
# set model weights
model = super().from_pretrained(
pretrained_model_name_or_path,
*model_args,
**kwargs,
)
if model.config.with_detr:
extractor_dtype = getattr(torch, model.config.detr_config["extractor_dtype"])
print(f"Using dtype={extractor_dtype} for {type(model.model)}")
if extractor_dtype == torch.bfloat16:
model.model.to(torch.bfloat16)
model.classification_head.to(torch.bfloat16)
if model.config.to_dict().get("trained_classification_head", True) is False:
# when only detr was trained
model.trained_classification_head = False
model.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
# Ensure pad token exists
model.config.pad_token_id = model.tokenizer.pad_token_id \
if model.tokenizer.pad_token_id is not None else model.tokenizer.unk_token_id
if model.tokenizer.pad_token_id is None:
model.tokenizer.pad_token_id = model.tokenizer.unk_token_id
model.config.bos_token_id = model.tokenizer.bos_token_id
model.config.eos_token_id = model.tokenizer.eos_token_id
model.config.unk_token_id = model.tokenizer.unk_token_id
return model
def _get_inputs(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
"""
Tokenizes a batch of texts handling specific truncation logic to preserve exact text length mapping.
"""
assert self._max_len is not None and self.tokenizer is not None, "Model must be initialized"
# 1. Tokenize all texts without special tokens/padding first
raw_encodings = self.tokenizer(texts, add_special_tokens=False)
batch_features = [] # List of dicts for tokenizer.pad
text_lens = []
content_max_len = self._max_len - 2
bos_id = self.tokenizer.bos_token_id
eos_id = self.tokenizer.eos_token_id
for i, tokens in enumerate(raw_encodings.input_ids):
if len(tokens) > content_max_len:
tokens = tokens[:content_max_len]
# Convert back to string to get the exact character length of the truncated part
cur_text = self.tokenizer.decode(tokens, skip_special_tokens=True)
text_len = len(cur_text)
else:
# If no truncation, use the original text length
text_len = len(texts[i])
# Construct final token sequence: [BOS] + tokens + [EOS]
final_tokens = [bos_id] + tokens + [eos_id]
# Append as dictionary for tokenizer.pad
batch_features.append({"input_ids": final_tokens})
text_lens.append(text_len)
# 2. Pad using tokenizer.pad
padded_output = self.tokenizer.pad(
batch_features,
padding=True,
return_tensors="pt"
)
input_ids = padded_output["input_ids"].to(self.device)
attention_mask = padded_output["attention_mask"].to(self.device)
return input_ids, attention_mask, text_lens
@staticmethod
def _get_ai_intervals(detr_out: Dict[str, torch.Tensor], text_lens: List[int], conf_interval_thresh: float) -> List[torch.Tensor]:
"""
Converts DETR outputs to absolute text intervals.
"""
pred_spans = detr_out["pred_spans"] # (batch_size, #queries, 2)
src_logits = detr_out["pred_logits"] # (batch_size, #queries, #classes=2)
assert len(text_lens) == pred_spans.shape[0]
# Take probs for foreground objects only (ind = 0)
pred_probs = torch.softmax(src_logits, dim=-1)[:, :, 0:1] # [Batch, Queries, 1]
final_preds_batch = []
for i, length in enumerate(text_lens):
# Convert center-width [0,1] to [0, length] absolute start-end
# pred_spans[i]: [Queries, 2]
spans_abs = to_absolute(pred_spans[i], length)
# Concat spans and scores: [Queries, 3] -> (start, end, score)
scores = pred_probs[i]
preds_i = torch.cat([spans_abs, scores], dim=1)
# Filter by confidence threshold
mask = preds_i[:, 2] > conf_interval_thresh
filtered_preds = preds_i[mask]
final_preds_batch.append(filtered_preds)
return final_preds_batch
def forward(
self,
text: Union[str, List[str]],
return_dict: Optional[bool] = None,
conf_interval_thresh: float = None,
) -> Union[Tuple, GigaCheckOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
conf_interval_thresh = conf_interval_thresh if conf_interval_thresh is not None else self.config.conf_interval_thresh
if isinstance(text, str):
text = [text]
input_ids, attention_mask, text_lens = self._get_inputs(text)
output = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
return_detr_output=self.config.with_detr,
)
pred_label_ids = None
classification_head_probs = None
ai_intervals = None
# 1. Classification Head Processing
if not self.config.with_detr:
logits = output.logits
elif self.trained_classification_head:
logits, _ = output.logits
else:
logits = None
if logits is not None:
# logits: [Batch, NumClasses]
probs = logits.to(torch.float32).softmax(dim=-1)
pred_label_ids = torch.argmax(probs, dim=-1) # [Batch]
classification_head_probs = probs # [Batch, NumClasses]
# 2. Interval Detection (DETR) Processing
if self.config.with_detr:
_, detr_out = output.logits
ai_intervals = self._get_ai_intervals(detr_out, text_lens, conf_interval_thresh)
if not return_dict:
return (pred_label_ids, classification_head_probs, ai_intervals)
return GigaCheckOutput(
pred_label_ids=pred_label_ids,
classification_head_probs=classification_head_probs,
ai_intervals=ai_intervals,
)
def to_absolute(pred_spans: torch.Tensor, text_len: int) -> torch.Tensor:
spans = span_cxw_to_xx(pred_spans) * text_len
return torch.clamp(spans, 0, text_len)