File size: 8,835 Bytes
ef31f0e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 | from transformers import AutoTokenizer
from transformers.modeling_outputs import ModelOutput
from typing import List, Dict, Optional, Union, Tuple
from dataclasses import dataclass
import torch
from gigacheck.model.mistral_ai_detector import MistralAIDetectorForSequenceClassification
from gigacheck.model.src.interval_detector.span_utils import span_cxw_to_xx
from .configuration_gigacheck import GigaCheckConfig
@dataclass
class GigaCheckOutput(ModelOutput):
"""
Output type for GigaCheck model.
Args:
pred_label_ids (torch.Tensor): [Batch] Indices of the predicted classes (Human/AI/Mixed).
classification_head_probs (torch.Tensor): [Batch, Num_Classes] Softmax probabilities.
ai_intervals (List[torch.Tensor]): List of length Batch. Each element is a tensor of shape [Num_Intervals, 3]
containing (start, end, score) for detected AI-generated spans.
"""
pred_label_ids: Optional[torch.Tensor] = None
classification_head_probs: Optional[torch.Tensor] = None
ai_intervals: Optional[List[torch.Tensor]] = None
class GigaCheckForDetection(MistralAIDetectorForSequenceClassification):
config_class = GigaCheckConfig
def __init__(self, config: GigaCheckConfig):
super().__init__(
config,
with_detr = config.with_detr,
detr_config = config.detr_config,
ce_weights = None,
freeze_backbone = False,
id2label = config.id2label,
)
self.trained_classification_head = True
self._max_len = self.config.max_length
self.tokenizer = None
self.conf_interval_thresh = config.conf_interval_thresh
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs): # type: ignore
"""Loads a pretrained GigaCheck model from a local path or the Hugging Face Hub.
Args:
pretrained_model_name_or_path (str): The name or path of the pretrained model.
model_args: Additional positional arguments passed to parent class.
kwargs: Additional keyword arguments passed to parent class.
Returns:
GigaCheckForSequenceClassification: The initialized model with loaded weights and initialized tokenizer.
"""
# set model weights
model = super().from_pretrained(
pretrained_model_name_or_path,
*model_args,
**kwargs,
)
if model.config.with_detr:
extractor_dtype = getattr(torch, model.config.detr_config["extractor_dtype"])
print(f"Using dtype={extractor_dtype} for {type(model.model)}")
if extractor_dtype == torch.bfloat16:
model.model.to(torch.bfloat16)
model.classification_head.to(torch.bfloat16)
if model.config.to_dict().get("trained_classification_head", True) is False:
# when only detr was trained
model.trained_classification_head = False
model.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
# Ensure pad token exists
model.config.pad_token_id = model.tokenizer.pad_token_id \
if model.tokenizer.pad_token_id is not None else model.tokenizer.unk_token_id
if model.tokenizer.pad_token_id is None:
model.tokenizer.pad_token_id = model.tokenizer.unk_token_id
model.config.bos_token_id = model.tokenizer.bos_token_id
model.config.eos_token_id = model.tokenizer.eos_token_id
model.config.unk_token_id = model.tokenizer.unk_token_id
return model
def _get_inputs(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
"""
Tokenizes a batch of texts handling specific truncation logic to preserve exact text length mapping.
"""
assert self._max_len is not None and self.tokenizer is not None, "Model must be initialized"
# 1. Tokenize all texts without special tokens/padding first
raw_encodings = self.tokenizer(texts, add_special_tokens=False)
batch_features = [] # List of dicts for tokenizer.pad
text_lens = []
content_max_len = self._max_len - 2
bos_id = self.tokenizer.bos_token_id
eos_id = self.tokenizer.eos_token_id
for i, tokens in enumerate(raw_encodings.input_ids):
if len(tokens) > content_max_len:
tokens = tokens[:content_max_len]
# Convert back to string to get the exact character length of the truncated part
cur_text = self.tokenizer.decode(tokens, skip_special_tokens=True)
text_len = len(cur_text)
else:
# If no truncation, use the original text length
text_len = len(texts[i])
# Construct final token sequence: [BOS] + tokens + [EOS]
final_tokens = [bos_id] + tokens + [eos_id]
# Append as dictionary for tokenizer.pad
batch_features.append({"input_ids": final_tokens})
text_lens.append(text_len)
# 2. Pad using tokenizer.pad
padded_output = self.tokenizer.pad(
batch_features,
padding=True,
return_tensors="pt"
)
input_ids = padded_output["input_ids"].to(self.device)
attention_mask = padded_output["attention_mask"].to(self.device)
return input_ids, attention_mask, text_lens
@staticmethod
def _get_ai_intervals(detr_out: Dict[str, torch.Tensor], text_lens: List[int], conf_interval_thresh: float) -> List[torch.Tensor]:
"""
Converts DETR outputs to absolute text intervals.
"""
pred_spans = detr_out["pred_spans"] # (batch_size, #queries, 2)
src_logits = detr_out["pred_logits"] # (batch_size, #queries, #classes=2)
assert len(text_lens) == pred_spans.shape[0]
# Take probs for foreground objects only (ind = 0)
pred_probs = torch.softmax(src_logits, dim=-1)[:, :, 0:1] # [Batch, Queries, 1]
final_preds_batch = []
for i, length in enumerate(text_lens):
# Convert center-width [0,1] to [0, length] absolute start-end
# pred_spans[i]: [Queries, 2]
spans_abs = to_absolute(pred_spans[i], length)
# Concat spans and scores: [Queries, 3] -> (start, end, score)
scores = pred_probs[i]
preds_i = torch.cat([spans_abs, scores], dim=1)
# Filter by confidence threshold
mask = preds_i[:, 2] > conf_interval_thresh
filtered_preds = preds_i[mask]
final_preds_batch.append(filtered_preds)
return final_preds_batch
def forward(
self,
text: Union[str, List[str]],
return_dict: Optional[bool] = None,
conf_interval_thresh: float = None,
) -> Union[Tuple, GigaCheckOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
conf_interval_thresh = conf_interval_thresh if conf_interval_thresh is not None else self.config.conf_interval_thresh
if isinstance(text, str):
text = [text]
input_ids, attention_mask, text_lens = self._get_inputs(text)
output = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
return_detr_output=self.config.with_detr,
)
pred_label_ids = None
classification_head_probs = None
ai_intervals = None
# 1. Classification Head Processing
if not self.config.with_detr:
logits = output.logits
elif self.trained_classification_head:
logits, _ = output.logits
else:
logits = None
if logits is not None:
# logits: [Batch, NumClasses]
probs = logits.to(torch.float32).softmax(dim=-1)
pred_label_ids = torch.argmax(probs, dim=-1) # [Batch]
classification_head_probs = probs # [Batch, NumClasses]
# 2. Interval Detection (DETR) Processing
if self.config.with_detr:
_, detr_out = output.logits
ai_intervals = self._get_ai_intervals(detr_out, text_lens, conf_interval_thresh)
if not return_dict:
return (pred_label_ids, classification_head_probs, ai_intervals)
return GigaCheckOutput(
pred_label_ids=pred_label_ids,
classification_head_probs=classification_head_probs,
ai_intervals=ai_intervals,
)
def to_absolute(pred_spans: torch.Tensor, text_len: int) -> torch.Tensor:
spans = span_cxw_to_xx(pred_spans) * text_len
return torch.clamp(spans, 0, text_len)
|