GigaCheck-Classifier-Multi / modeling_gigacheck.py
iitolstykh's picture
Upload 2 files
88b272e verified
from transformers import AutoTokenizer
from transformers.modeling_outputs import ModelOutput
from typing import List, Dict, Optional, Union, Tuple
from dataclasses import dataclass
import torch
from gigacheck.model.mistral_ai_detector import MistralAIDetectorForSequenceClassification
from gigacheck.model.src.interval_detector.span_utils import span_cxw_to_xx
from .configuration_gigacheck import GigaCheckConfig
@dataclass
class GigaCheckOutput(ModelOutput):
"""
Output type for GigaCheck model.
Args:
pred_label_ids (torch.Tensor): [Batch] Indices of the predicted classes (Human/AI/Mixed).
classification_head_probs (torch.Tensor): [Batch, Num_Classes] Softmax probabilities.
"""
pred_label_ids: Optional[torch.Tensor] = None
classification_head_probs: Optional[torch.Tensor] = None
class GigaCheckForSequenceClassification(MistralAIDetectorForSequenceClassification):
config_class = GigaCheckConfig
def __init__(self, config: GigaCheckConfig):
super().__init__(
config,
with_detr = False,
detr_config = None,
ce_weights = None,
freeze_backbone = False,
id2label = config.id2label,
)
self.trained_classification_head = True
self._max_len = self.config.max_length
self.tokenizer = None
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs): # type: ignore
"""Loads a pretrained GigaCheck model from a local path or the Hugging Face Hub.
Args:
pretrained_model_name_or_path (str): The name or path of the pretrained model.
model_args: Additional positional arguments passed to parent class.
kwargs: Additional keyword arguments passed to parent class.
Returns:
GigaCheckForSequenceClassification: The initialized model with loaded weights and initialized tokenizer.
"""
# set model weights
model = super().from_pretrained(
pretrained_model_name_or_path,
*model_args,
**kwargs,
)
if model.config.to_dict().get("trained_classification_head", True) is False:
# when only detr was trained
model.trained_classification_head = False
model.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
# Ensure pad token exists
model.config.pad_token_id = model.tokenizer.pad_token_id \
if model.tokenizer.pad_token_id is not None else model.tokenizer.unk_token_id
if model.tokenizer.pad_token_id is None:
model.tokenizer.pad_token_id = model.tokenizer.unk_token_id
model.config.bos_token_id = model.tokenizer.bos_token_id
model.config.eos_token_id = model.tokenizer.eos_token_id
model.config.unk_token_id = model.tokenizer.unk_token_id
return model
def _get_inputs(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
"""
Tokenizes a batch of texts handling specific truncation logic to preserve exact text length mapping.
"""
assert self._max_len is not None and self.tokenizer is not None, "Model must be initialized"
# 1. Tokenize all texts without special tokens/padding first
raw_encodings = self.tokenizer(texts, add_special_tokens=False)
batch_features = [] # List of dicts for tokenizer.pad
text_lens = []
content_max_len = self._max_len - 2
bos_id = self.tokenizer.bos_token_id
eos_id = self.tokenizer.eos_token_id
for i, tokens in enumerate(raw_encodings.input_ids):
if len(tokens) > content_max_len:
tokens = tokens[:content_max_len]
# Convert back to string to get the exact character length of the truncated part
cur_text = self.tokenizer.decode(tokens, skip_special_tokens=True)
text_len = len(cur_text)
else:
# If no truncation, use the original text length
text_len = len(texts[i])
# Construct final token sequence: [BOS] + tokens + [EOS]
final_tokens = [bos_id] + tokens + [eos_id]
# Append as dictionary for tokenizer.pad
batch_features.append({"input_ids": final_tokens})
text_lens.append(text_len)
# 2. Pad using tokenizer.pad
padded_output = self.tokenizer.pad(
batch_features,
padding=True,
return_tensors="pt"
)
input_ids = padded_output["input_ids"].to(self.device)
attention_mask = padded_output["attention_mask"].to(self.device)
return input_ids, attention_mask, text_lens
def forward(
self,
text: Union[str, List[str]],
return_dict: Optional[bool] = None,
) -> Union[Tuple, GigaCheckOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if isinstance(text, str):
text = [text]
input_ids, attention_mask, text_lens = self._get_inputs(text)
output = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
return_detr_output=self.config.with_detr,
)
# 1. Classification Head Processing
logits = output.logits
# logits: [Batch, NumClasses]
probs = logits.to(torch.float32).softmax(dim=-1)
pred_label_ids = torch.argmax(probs, dim=-1) # [Batch]
classification_head_probs = probs # [Batch, NumClasses]
if not return_dict:
return (pred_label_ids, classification_head_probs)
return GigaCheckOutput(
pred_label_ids=pred_label_ids,
classification_head_probs=classification_head_probs,
)
def to_absolute(pred_spans: torch.Tensor, text_len: int) -> torch.Tensor:
spans = span_cxw_to_xx(pred_spans) * text_len
return torch.clamp(spans, 0, text_len)