File size: 8,987 Bytes
a7e89c0 28c4d49 a7e89c0 28c4d49 a7e89c0 28c4d49 a7e89c0 28c4d49 a7e89c0 28c4d49 a7e89c0 28c4d49 a7e89c0 28c4d49 a7e89c0 28c4d49 a7e89c0 28c4d49 a7e89c0 28c4d49 a7e89c0 28c4d49 a7e89c0 28c4d49 a7e89c0 28c4d49 a7e89c0 28c4d49 a7e89c0 28c4d49 a7e89c0 28c4d49 a7e89c0 28c4d49 a7e89c0 28c4d49 a7e89c0 28c4d49 a7e89c0 28c4d49 a7e89c0 28c4d49 a7e89c0 28c4d49 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 | """
text_detector_inference.py
==========================
Inference wrapper for HybridAITextDetector.
Strategy
--------
1. If ``best_text_detector.pt`` exists β load the custom trained model.
2. Otherwise β fall back to a pretrained
HuggingFace AI-text detector so the Space keeps working immediately.
Usage
-----
from text_detector_inference import TextDetectorInference
detector = TextDetectorInference() # auto-detects checkpoint
result = detector.predict("Some textβ¦")
"""
import os
import torch
from transformers import AutoTokenizer, pipeline as hf_pipeline
from text_detector_model import HybridAITextDetector, MODEL_NAME, MAX_LENGTH
# βββ Fallback model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Used when best_text_detector.pt is not present in the Space.
# "Hello-SimpleAI/chatgpt-detector-roberta" is a publicly available,
# well-validated AI-text detector (RoBERTa fine-tuned on ChatGPT outputs).
FALLBACK_MODEL_ID = "Hello-SimpleAI/chatgpt-detector-roberta"
class TextDetectorInference:
"""
Thin wrapper around HybridAITextDetector (or a fallback pretrained model)
for single-text prediction.
Parameters
----------
checkpoint : str
Path to the .pt state-dict file for the custom model.
threshold : float
Decision boundary for the sigmoid probability (default 0.5).
Set to the optimal F1 threshold found during your training run.
device : torch.device | None
Auto-detects CUDA if None.
"""
def __init__(
self,
checkpoint: str = "best_text_detector.pt",
threshold: float = 0.5,
device: torch.device = None,
):
self.threshold = threshold
self.device = device or torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
self._use_custom = False
self._fallback = None
self.model = None
self.tokenizer = None
if os.path.exists(checkpoint):
# ββ Load custom trained HybridAITextDetector ββββββββββββββββββββββ
print(f"[TextDetector] β
Checkpoint found: {checkpoint}")
print(f"[TextDetector] Loading tokenizer from {MODEL_NAME} β¦")
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
self.model = HybridAITextDetector()
self.model.load_state_dict(
torch.load(checkpoint, map_location=self.device)
)
self.model.eval().to(self.device)
self._use_custom = True
print("[TextDetector] β
Custom model ready")
else:
# ββ Fall back to pretrained HuggingFace model βββββββββββββββββββββ
print(
f"[TextDetector] β οΈ '{checkpoint}' not found.\n"
f"[TextDetector] Loading pretrained fallback: {FALLBACK_MODEL_ID}"
)
try:
self._fallback = hf_pipeline(
"text-classification",
model=FALLBACK_MODEL_ID,
device=0 if torch.cuda.is_available() else -1,
truncation=True,
max_length=512,
)
print(f"[TextDetector] β
Fallback model ready ({FALLBACK_MODEL_ID})")
except Exception as e:
print(f"[TextDetector] β Fallback model failed to load: {e}")
self._fallback = None
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def predict(self, text: str) -> dict:
"""
Classify a single text string.
Returns
-------
dict with keys:
label : "AI-Generated" or "Human-Written"
confidence : probability of the predicted class (0β1)
ai_prob : raw P(AI-generated)
human_prob : 1 - ai_prob
source : "custom_model" | "pretrained_fallback"
"""
text = text.strip()
if not text:
return {"error": "Input text is empty."}
if self._use_custom:
return self._predict_custom(text)
elif self._fallback is not None:
return self._predict_fallback(text)
else:
return {
"error": (
"No model available. Upload 'best_text_detector.pt' to the "
"Space, or check your internet connection so the fallback "
"model can be downloaded."
)
}
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _predict_custom(self, text: str) -> dict:
"""Run inference with the custom HybridAITextDetector checkpoint."""
enc = self.tokenizer(
text,
truncation=True,
padding="max_length",
max_length=MAX_LENGTH,
return_tensors="pt",
)
input_ids = enc["input_ids"].to(self.device)
attention_mask = enc["attention_mask"].to(self.device)
token_type_ids = enc.get(
"token_type_ids",
torch.zeros_like(enc["input_ids"]),
).to(self.device)
with torch.no_grad():
logit = self.model(input_ids, attention_mask, token_type_ids)
ai_prob = torch.sigmoid(logit).item()
human_prob = 1.0 - ai_prob
is_ai = ai_prob >= self.threshold
label = "AI-Generated" if is_ai else "Human-Written"
confidence = ai_prob if is_ai else human_prob
return {
"label": label,
"confidence": round(confidence, 4),
"ai_prob": round(ai_prob, 4),
"human_prob": round(human_prob, 4),
"source": "custom_model",
}
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _predict_fallback(self, text: str) -> dict:
"""
Run inference with the pretrained HuggingFace fallback model.
Hello-SimpleAI/chatgpt-detector-roberta returns:
{"label": "ChatGPT" | "Human", "score": float}
We normalise this to the same dict shape as _predict_custom.
"""
try:
raw = self._fallback(text)[0] # {"label": ..., "score": ...}
except Exception as e:
return {"error": f"Fallback inference failed: {e}"}
hf_label = raw["label"].strip().lower() # "chatgpt" or "human"
score = float(raw["score"]) # confidence of the returned label
if hf_label in ("chatgpt", "ai", "fake", "generated"):
ai_prob = score
human_prob = 1.0 - score
label = "AI-Generated"
else:
human_prob = score
ai_prob = 1.0 - score
label = "Human-Written"
is_ai = ai_prob >= self.threshold
label = "AI-Generated" if is_ai else "Human-Written"
confidence = ai_prob if is_ai else human_prob
return {
"label": label,
"confidence": round(confidence, 4),
"ai_prob": round(ai_prob, 4),
"human_prob": round(human_prob, 4),
"source": "pretrained_fallback",
}
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def predict_batch(self, texts: list) -> list:
"""Run predict() on a list of texts. Returns list of result dicts."""
return [self.predict(t) for t in texts]
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def format_for_gradio(self, text: str) -> tuple:
"""
Convenience wrapper returning Gradio-friendly values:
(label_string, confidence_float, full_result_dict)
"""
result = self.predict(text)
if "error" in result:
return result["error"], 0.0, result
return result["label"], result["confidence"], result |