""" text_detector_inference.py ========================== Inference wrapper for HybridAITextDetector. Strategy -------- 1. If ``best_text_detector.pt`` exists → load the custom trained model. 2. Otherwise → fall back to a pretrained HuggingFace AI-text detector so the Space keeps working immediately. Usage ----- from text_detector_inference import TextDetectorInference detector = TextDetectorInference() # auto-detects checkpoint result = detector.predict("Some text…") """ import os import torch from transformers import AutoTokenizer, pipeline as hf_pipeline from text_detector_model import HybridAITextDetector, MODEL_NAME, MAX_LENGTH # ─── Fallback model ─────────────────────────────────────────────────────────── # Used when best_text_detector.pt is not present in the Space. # "Hello-SimpleAI/chatgpt-detector-roberta" is a publicly available, # well-validated AI-text detector (RoBERTa fine-tuned on ChatGPT outputs). FALLBACK_MODEL_ID = "Hello-SimpleAI/chatgpt-detector-roberta" class TextDetectorInference: """ Thin wrapper around HybridAITextDetector (or a fallback pretrained model) for single-text prediction. Parameters ---------- checkpoint : str Path to the .pt state-dict file for the custom model. threshold : float Decision boundary for the sigmoid probability (default 0.5). Set to the optimal F1 threshold found during your training run. device : torch.device | None Auto-detects CUDA if None. """ def __init__( self, checkpoint: str = "best_text_detector.pt", threshold: float = 0.5, device: torch.device = None, ): self.threshold = threshold self.device = device or torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) self._use_custom = False self._fallback = None self.model = None self.tokenizer = None if os.path.exists(checkpoint): # ── Load custom trained HybridAITextDetector ────────────────────── print(f"[TextDetector] ✅ Checkpoint found: {checkpoint}") print(f"[TextDetector] Loading tokenizer from {MODEL_NAME} …") self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) self.model = HybridAITextDetector() self.model.load_state_dict( torch.load(checkpoint, map_location=self.device) ) self.model.eval().to(self.device) self._use_custom = True print("[TextDetector] ✅ Custom model ready") else: # ── Fall back to pretrained HuggingFace model ───────────────────── print( f"[TextDetector] ⚠️ '{checkpoint}' not found.\n" f"[TextDetector] Loading pretrained fallback: {FALLBACK_MODEL_ID}" ) try: self._fallback = hf_pipeline( "text-classification", model=FALLBACK_MODEL_ID, device=0 if torch.cuda.is_available() else -1, truncation=True, max_length=512, ) print(f"[TextDetector] ✅ Fallback model ready ({FALLBACK_MODEL_ID})") except Exception as e: print(f"[TextDetector] ❌ Fallback model failed to load: {e}") self._fallback = None # ────────────────────────────────────────────────────────────────────────── def predict(self, text: str) -> dict: """ Classify a single text string. Returns ------- dict with keys: label : "AI-Generated" or "Human-Written" confidence : probability of the predicted class (0–1) ai_prob : raw P(AI-generated) human_prob : 1 - ai_prob source : "custom_model" | "pretrained_fallback" """ text = text.strip() if not text: return {"error": "Input text is empty."} if self._use_custom: return self._predict_custom(text) elif self._fallback is not None: return self._predict_fallback(text) else: return { "error": ( "No model available. Upload 'best_text_detector.pt' to the " "Space, or check your internet connection so the fallback " "model can be downloaded." ) } # ────────────────────────────────────────────────────────────────────────── def _predict_custom(self, text: str) -> dict: """Run inference with the custom HybridAITextDetector checkpoint.""" enc = self.tokenizer( text, truncation=True, padding="max_length", max_length=MAX_LENGTH, return_tensors="pt", ) input_ids = enc["input_ids"].to(self.device) attention_mask = enc["attention_mask"].to(self.device) token_type_ids = enc.get( "token_type_ids", torch.zeros_like(enc["input_ids"]), ).to(self.device) with torch.no_grad(): logit = self.model(input_ids, attention_mask, token_type_ids) ai_prob = torch.sigmoid(logit).item() human_prob = 1.0 - ai_prob is_ai = ai_prob >= self.threshold label = "AI-Generated" if is_ai else "Human-Written" confidence = ai_prob if is_ai else human_prob return { "label": label, "confidence": round(confidence, 4), "ai_prob": round(ai_prob, 4), "human_prob": round(human_prob, 4), "source": "custom_model", } # ────────────────────────────────────────────────────────────────────────── def _predict_fallback(self, text: str) -> dict: """ Run inference with the pretrained HuggingFace fallback model. Hello-SimpleAI/chatgpt-detector-roberta returns: {"label": "ChatGPT" | "Human", "score": float} We normalise this to the same dict shape as _predict_custom. """ try: raw = self._fallback(text)[0] # {"label": ..., "score": ...} except Exception as e: return {"error": f"Fallback inference failed: {e}"} hf_label = raw["label"].strip().lower() # "chatgpt" or "human" score = float(raw["score"]) # confidence of the returned label if hf_label in ("chatgpt", "ai", "fake", "generated"): ai_prob = score human_prob = 1.0 - score label = "AI-Generated" else: human_prob = score ai_prob = 1.0 - score label = "Human-Written" is_ai = ai_prob >= self.threshold label = "AI-Generated" if is_ai else "Human-Written" confidence = ai_prob if is_ai else human_prob return { "label": label, "confidence": round(confidence, 4), "ai_prob": round(ai_prob, 4), "human_prob": round(human_prob, 4), "source": "pretrained_fallback", } # ────────────────────────────────────────────────────────────────────────── def predict_batch(self, texts: list) -> list: """Run predict() on a list of texts. Returns list of result dicts.""" return [self.predict(t) for t in texts] # ────────────────────────────────────────────────────────────────────────── def format_for_gradio(self, text: str) -> tuple: """ Convenience wrapper returning Gradio-friendly values: (label_string, confidence_float, full_result_dict) """ result = self.predict(text) if "error" in result: return result["error"], 0.0, result return result["label"], result["confidence"], result