File size: 6,111 Bytes
88b272e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
from transformers import AutoTokenizer
from transformers.modeling_outputs import ModelOutput
from typing import List, Dict, Optional, Union, Tuple
from dataclasses import dataclass
import torch
from gigacheck.model.mistral_ai_detector import MistralAIDetectorForSequenceClassification
from gigacheck.model.src.interval_detector.span_utils import span_cxw_to_xx
from .configuration_gigacheck import GigaCheckConfig
@dataclass
class GigaCheckOutput(ModelOutput):
"""
Output type for GigaCheck model.
Args:
pred_label_ids (torch.Tensor): [Batch] Indices of the predicted classes (Human/AI/Mixed).
classification_head_probs (torch.Tensor): [Batch, Num_Classes] Softmax probabilities.
"""
pred_label_ids: Optional[torch.Tensor] = None
classification_head_probs: Optional[torch.Tensor] = None
class GigaCheckForSequenceClassification(MistralAIDetectorForSequenceClassification):
config_class = GigaCheckConfig
def __init__(self, config: GigaCheckConfig):
super().__init__(
config,
with_detr = False,
detr_config = None,
ce_weights = None,
freeze_backbone = False,
id2label = config.id2label,
)
self.trained_classification_head = True
self._max_len = self.config.max_length
self.tokenizer = None
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs): # type: ignore
"""Loads a pretrained GigaCheck model from a local path or the Hugging Face Hub.
Args:
pretrained_model_name_or_path (str): The name or path of the pretrained model.
model_args: Additional positional arguments passed to parent class.
kwargs: Additional keyword arguments passed to parent class.
Returns:
GigaCheckForSequenceClassification: The initialized model with loaded weights and initialized tokenizer.
"""
# set model weights
model = super().from_pretrained(
pretrained_model_name_or_path,
*model_args,
**kwargs,
)
if model.config.to_dict().get("trained_classification_head", True) is False:
# when only detr was trained
model.trained_classification_head = False
model.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
# Ensure pad token exists
model.config.pad_token_id = model.tokenizer.pad_token_id \
if model.tokenizer.pad_token_id is not None else model.tokenizer.unk_token_id
if model.tokenizer.pad_token_id is None:
model.tokenizer.pad_token_id = model.tokenizer.unk_token_id
model.config.bos_token_id = model.tokenizer.bos_token_id
model.config.eos_token_id = model.tokenizer.eos_token_id
model.config.unk_token_id = model.tokenizer.unk_token_id
return model
def _get_inputs(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
"""
Tokenizes a batch of texts handling specific truncation logic to preserve exact text length mapping.
"""
assert self._max_len is not None and self.tokenizer is not None, "Model must be initialized"
# 1. Tokenize all texts without special tokens/padding first
raw_encodings = self.tokenizer(texts, add_special_tokens=False)
batch_features = [] # List of dicts for tokenizer.pad
text_lens = []
content_max_len = self._max_len - 2
bos_id = self.tokenizer.bos_token_id
eos_id = self.tokenizer.eos_token_id
for i, tokens in enumerate(raw_encodings.input_ids):
if len(tokens) > content_max_len:
tokens = tokens[:content_max_len]
# Convert back to string to get the exact character length of the truncated part
cur_text = self.tokenizer.decode(tokens, skip_special_tokens=True)
text_len = len(cur_text)
else:
# If no truncation, use the original text length
text_len = len(texts[i])
# Construct final token sequence: [BOS] + tokens + [EOS]
final_tokens = [bos_id] + tokens + [eos_id]
# Append as dictionary for tokenizer.pad
batch_features.append({"input_ids": final_tokens})
text_lens.append(text_len)
# 2. Pad using tokenizer.pad
padded_output = self.tokenizer.pad(
batch_features,
padding=True,
return_tensors="pt"
)
input_ids = padded_output["input_ids"].to(self.device)
attention_mask = padded_output["attention_mask"].to(self.device)
return input_ids, attention_mask, text_lens
def forward(
self,
text: Union[str, List[str]],
return_dict: Optional[bool] = None,
) -> Union[Tuple, GigaCheckOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if isinstance(text, str):
text = [text]
input_ids, attention_mask, text_lens = self._get_inputs(text)
output = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
return_detr_output=self.config.with_detr,
)
# 1. Classification Head Processing
logits = output.logits
# logits: [Batch, NumClasses]
probs = logits.to(torch.float32).softmax(dim=-1)
pred_label_ids = torch.argmax(probs, dim=-1) # [Batch]
classification_head_probs = probs # [Batch, NumClasses]
if not return_dict:
return (pred_label_ids, classification_head_probs)
return GigaCheckOutput(
pred_label_ids=pred_label_ids,
classification_head_probs=classification_head_probs,
)
def to_absolute(pred_spans: torch.Tensor, text_len: int) -> torch.Tensor:
spans = span_cxw_to_xx(pred_spans) * text_len
return torch.clamp(spans, 0, text_len)
|