from transformers import AutoTokenizer from transformers.modeling_outputs import ModelOutput from typing import List, Dict, Optional, Union, Tuple from dataclasses import dataclass import torch from gigacheck.model.mistral_ai_detector import MistralAIDetectorForSequenceClassification from gigacheck.model.src.interval_detector.span_utils import span_cxw_to_xx from .configuration_gigacheck import GigaCheckConfig @dataclass class GigaCheckOutput(ModelOutput): """ Output type for GigaCheck model. Args: pred_label_ids (torch.Tensor): [Batch] Indices of the predicted classes (Human/AI/Mixed). classification_head_probs (torch.Tensor): [Batch, Num_Classes] Softmax probabilities. ai_intervals (List[torch.Tensor]): List of length Batch. Each element is a tensor of shape [Num_Intervals, 3] containing (start, end, score) for detected AI-generated spans. """ pred_label_ids: Optional[torch.Tensor] = None classification_head_probs: Optional[torch.Tensor] = None ai_intervals: Optional[List[torch.Tensor]] = None class GigaCheckForDetection(MistralAIDetectorForSequenceClassification): config_class = GigaCheckConfig def __init__(self, config: GigaCheckConfig): super().__init__( config, with_detr = config.with_detr, detr_config = config.detr_config, ce_weights = None, freeze_backbone = False, id2label = config.id2label, ) self.trained_classification_head = True self._max_len = self.config.max_length self.tokenizer = None self.conf_interval_thresh = config.conf_interval_thresh @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs): # type: ignore """Loads a pretrained GigaCheck model from a local path or the Hugging Face Hub. Args: pretrained_model_name_or_path (str): The name or path of the pretrained model. model_args: Additional positional arguments passed to parent class. kwargs: Additional keyword arguments passed to parent class. Returns: GigaCheckForSequenceClassification: The initialized model with loaded weights and initialized tokenizer. """ # set model weights model = super().from_pretrained( pretrained_model_name_or_path, *model_args, **kwargs, ) if model.config.with_detr: extractor_dtype = getattr(torch, model.config.detr_config["extractor_dtype"]) print(f"Using dtype={extractor_dtype} for {type(model.model)}") if extractor_dtype == torch.bfloat16: model.model.to(torch.bfloat16) model.classification_head.to(torch.bfloat16) if model.config.to_dict().get("trained_classification_head", True) is False: # when only detr was trained model.trained_classification_head = False model.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) # Ensure pad token exists model.config.pad_token_id = model.tokenizer.pad_token_id \ if model.tokenizer.pad_token_id is not None else model.tokenizer.unk_token_id if model.tokenizer.pad_token_id is None: model.tokenizer.pad_token_id = model.tokenizer.unk_token_id model.config.bos_token_id = model.tokenizer.bos_token_id model.config.eos_token_id = model.tokenizer.eos_token_id model.config.unk_token_id = model.tokenizer.unk_token_id return model def _get_inputs(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: """ Tokenizes a batch of texts handling specific truncation logic to preserve exact text length mapping. """ assert self._max_len is not None and self.tokenizer is not None, "Model must be initialized" # 1. Tokenize all texts without special tokens/padding first raw_encodings = self.tokenizer(texts, add_special_tokens=False) batch_features = [] # List of dicts for tokenizer.pad text_lens = [] content_max_len = self._max_len - 2 bos_id = self.tokenizer.bos_token_id eos_id = self.tokenizer.eos_token_id for i, tokens in enumerate(raw_encodings.input_ids): if len(tokens) > content_max_len: tokens = tokens[:content_max_len] # Convert back to string to get the exact character length of the truncated part cur_text = self.tokenizer.decode(tokens, skip_special_tokens=True) text_len = len(cur_text) else: # If no truncation, use the original text length text_len = len(texts[i]) # Construct final token sequence: [BOS] + tokens + [EOS] final_tokens = [bos_id] + tokens + [eos_id] # Append as dictionary for tokenizer.pad batch_features.append({"input_ids": final_tokens}) text_lens.append(text_len) # 2. Pad using tokenizer.pad padded_output = self.tokenizer.pad( batch_features, padding=True, return_tensors="pt" ) input_ids = padded_output["input_ids"].to(self.device) attention_mask = padded_output["attention_mask"].to(self.device) return input_ids, attention_mask, text_lens @staticmethod def _get_ai_intervals(detr_out: Dict[str, torch.Tensor], text_lens: List[int], conf_interval_thresh: float) -> List[torch.Tensor]: """ Converts DETR outputs to absolute text intervals. """ pred_spans = detr_out["pred_spans"] # (batch_size, #queries, 2) src_logits = detr_out["pred_logits"] # (batch_size, #queries, #classes=2) assert len(text_lens) == pred_spans.shape[0] # Take probs for foreground objects only (ind = 0) pred_probs = torch.softmax(src_logits, dim=-1)[:, :, 0:1] # [Batch, Queries, 1] final_preds_batch = [] for i, length in enumerate(text_lens): # Convert center-width [0,1] to [0, length] absolute start-end # pred_spans[i]: [Queries, 2] spans_abs = to_absolute(pred_spans[i], length) # Concat spans and scores: [Queries, 3] -> (start, end, score) scores = pred_probs[i] preds_i = torch.cat([spans_abs, scores], dim=1) # Filter by confidence threshold mask = preds_i[:, 2] > conf_interval_thresh filtered_preds = preds_i[mask] final_preds_batch.append(filtered_preds) return final_preds_batch def forward( self, text: Union[str, List[str]], return_dict: Optional[bool] = None, conf_interval_thresh: float = None, ) -> Union[Tuple, GigaCheckOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict conf_interval_thresh = conf_interval_thresh if conf_interval_thresh is not None else self.config.conf_interval_thresh if isinstance(text, str): text = [text] input_ids, attention_mask, text_lens = self._get_inputs(text) output = super().forward( input_ids=input_ids, attention_mask=attention_mask, return_dict=True, return_detr_output=self.config.with_detr, ) pred_label_ids = None classification_head_probs = None ai_intervals = None # 1. Classification Head Processing if not self.config.with_detr: logits = output.logits elif self.trained_classification_head: logits, _ = output.logits else: logits = None if logits is not None: # logits: [Batch, NumClasses] probs = logits.to(torch.float32).softmax(dim=-1) pred_label_ids = torch.argmax(probs, dim=-1) # [Batch] classification_head_probs = probs # [Batch, NumClasses] # 2. Interval Detection (DETR) Processing if self.config.with_detr: _, detr_out = output.logits ai_intervals = self._get_ai_intervals(detr_out, text_lens, conf_interval_thresh) if not return_dict: return (pred_label_ids, classification_head_probs, ai_intervals) return GigaCheckOutput( pred_label_ids=pred_label_ids, classification_head_probs=classification_head_probs, ai_intervals=ai_intervals, ) def to_absolute(pred_spans: torch.Tensor, text_len: int) -> torch.Tensor: spans = span_cxw_to_xx(pred_spans) * text_len return torch.clamp(spans, 0, text_len)