Multi_Modal_Deepfake_Detection / text_detector_inference.py
pavankumarvk's picture
Update text_detector_inference.py
28c4d49 verified
"""
text_detector_inference.py
==========================
Inference wrapper for HybridAITextDetector.
Strategy
--------
1. If ``best_text_detector.pt`` exists β†’ load the custom trained model.
2. Otherwise β†’ fall back to a pretrained
HuggingFace AI-text detector so the Space keeps working immediately.
Usage
-----
from text_detector_inference import TextDetectorInference
detector = TextDetectorInference() # auto-detects checkpoint
result = detector.predict("Some text…")
"""
import os
import torch
from transformers import AutoTokenizer, pipeline as hf_pipeline
from text_detector_model import HybridAITextDetector, MODEL_NAME, MAX_LENGTH
# ─── Fallback model ───────────────────────────────────────────────────────────
# Used when best_text_detector.pt is not present in the Space.
# "Hello-SimpleAI/chatgpt-detector-roberta" is a publicly available,
# well-validated AI-text detector (RoBERTa fine-tuned on ChatGPT outputs).
FALLBACK_MODEL_ID = "Hello-SimpleAI/chatgpt-detector-roberta"
class TextDetectorInference:
"""
Thin wrapper around HybridAITextDetector (or a fallback pretrained model)
for single-text prediction.
Parameters
----------
checkpoint : str
Path to the .pt state-dict file for the custom model.
threshold : float
Decision boundary for the sigmoid probability (default 0.5).
Set to the optimal F1 threshold found during your training run.
device : torch.device | None
Auto-detects CUDA if None.
"""
def __init__(
self,
checkpoint: str = "best_text_detector.pt",
threshold: float = 0.5,
device: torch.device = None,
):
self.threshold = threshold
self.device = device or torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
self._use_custom = False
self._fallback = None
self.model = None
self.tokenizer = None
if os.path.exists(checkpoint):
# ── Load custom trained HybridAITextDetector ──────────────────────
print(f"[TextDetector] βœ… Checkpoint found: {checkpoint}")
print(f"[TextDetector] Loading tokenizer from {MODEL_NAME} …")
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
self.model = HybridAITextDetector()
self.model.load_state_dict(
torch.load(checkpoint, map_location=self.device)
)
self.model.eval().to(self.device)
self._use_custom = True
print("[TextDetector] βœ… Custom model ready")
else:
# ── Fall back to pretrained HuggingFace model ─────────────────────
print(
f"[TextDetector] ⚠️ '{checkpoint}' not found.\n"
f"[TextDetector] Loading pretrained fallback: {FALLBACK_MODEL_ID}"
)
try:
self._fallback = hf_pipeline(
"text-classification",
model=FALLBACK_MODEL_ID,
device=0 if torch.cuda.is_available() else -1,
truncation=True,
max_length=512,
)
print(f"[TextDetector] βœ… Fallback model ready ({FALLBACK_MODEL_ID})")
except Exception as e:
print(f"[TextDetector] ❌ Fallback model failed to load: {e}")
self._fallback = None
# ──────────────────────────────────────────────────────────────────────────
def predict(self, text: str) -> dict:
"""
Classify a single text string.
Returns
-------
dict with keys:
label : "AI-Generated" or "Human-Written"
confidence : probability of the predicted class (0–1)
ai_prob : raw P(AI-generated)
human_prob : 1 - ai_prob
source : "custom_model" | "pretrained_fallback"
"""
text = text.strip()
if not text:
return {"error": "Input text is empty."}
if self._use_custom:
return self._predict_custom(text)
elif self._fallback is not None:
return self._predict_fallback(text)
else:
return {
"error": (
"No model available. Upload 'best_text_detector.pt' to the "
"Space, or check your internet connection so the fallback "
"model can be downloaded."
)
}
# ──────────────────────────────────────────────────────────────────────────
def _predict_custom(self, text: str) -> dict:
"""Run inference with the custom HybridAITextDetector checkpoint."""
enc = self.tokenizer(
text,
truncation=True,
padding="max_length",
max_length=MAX_LENGTH,
return_tensors="pt",
)
input_ids = enc["input_ids"].to(self.device)
attention_mask = enc["attention_mask"].to(self.device)
token_type_ids = enc.get(
"token_type_ids",
torch.zeros_like(enc["input_ids"]),
).to(self.device)
with torch.no_grad():
logit = self.model(input_ids, attention_mask, token_type_ids)
ai_prob = torch.sigmoid(logit).item()
human_prob = 1.0 - ai_prob
is_ai = ai_prob >= self.threshold
label = "AI-Generated" if is_ai else "Human-Written"
confidence = ai_prob if is_ai else human_prob
return {
"label": label,
"confidence": round(confidence, 4),
"ai_prob": round(ai_prob, 4),
"human_prob": round(human_prob, 4),
"source": "custom_model",
}
# ──────────────────────────────────────────────────────────────────────────
def _predict_fallback(self, text: str) -> dict:
"""
Run inference with the pretrained HuggingFace fallback model.
Hello-SimpleAI/chatgpt-detector-roberta returns:
{"label": "ChatGPT" | "Human", "score": float}
We normalise this to the same dict shape as _predict_custom.
"""
try:
raw = self._fallback(text)[0] # {"label": ..., "score": ...}
except Exception as e:
return {"error": f"Fallback inference failed: {e}"}
hf_label = raw["label"].strip().lower() # "chatgpt" or "human"
score = float(raw["score"]) # confidence of the returned label
if hf_label in ("chatgpt", "ai", "fake", "generated"):
ai_prob = score
human_prob = 1.0 - score
label = "AI-Generated"
else:
human_prob = score
ai_prob = 1.0 - score
label = "Human-Written"
is_ai = ai_prob >= self.threshold
label = "AI-Generated" if is_ai else "Human-Written"
confidence = ai_prob if is_ai else human_prob
return {
"label": label,
"confidence": round(confidence, 4),
"ai_prob": round(ai_prob, 4),
"human_prob": round(human_prob, 4),
"source": "pretrained_fallback",
}
# ──────────────────────────────────────────────────────────────────────────
def predict_batch(self, texts: list) -> list:
"""Run predict() on a list of texts. Returns list of result dicts."""
return [self.predict(t) for t in texts]
# ──────────────────────────────────────────────────────────────────────────
def format_for_gradio(self, text: str) -> tuple:
"""
Convenience wrapper returning Gradio-friendly values:
(label_string, confidence_float, full_result_dict)
"""
result = self.predict(text)
if "error" in result:
return result["error"], 0.0, result
return result["label"], result["confidence"], result