| """ |
| text_detector_inference.py |
| ========================== |
| Inference wrapper for HybridAITextDetector. |
| |
| Strategy |
| -------- |
| 1. If ``best_text_detector.pt`` exists β load the custom trained model. |
| 2. Otherwise β fall back to a pretrained |
| HuggingFace AI-text detector so the Space keeps working immediately. |
| |
| Usage |
| ----- |
| from text_detector_inference import TextDetectorInference |
| |
| detector = TextDetectorInference() # auto-detects checkpoint |
| result = detector.predict("Some textβ¦") |
| """ |
|
|
| import os |
| import torch |
| from transformers import AutoTokenizer, pipeline as hf_pipeline |
| from text_detector_model import HybridAITextDetector, MODEL_NAME, MAX_LENGTH |
|
|
| |
| |
| |
| |
| FALLBACK_MODEL_ID = "Hello-SimpleAI/chatgpt-detector-roberta" |
|
|
|
|
| class TextDetectorInference: |
| """ |
| Thin wrapper around HybridAITextDetector (or a fallback pretrained model) |
| for single-text prediction. |
| |
| Parameters |
| ---------- |
| checkpoint : str |
| Path to the .pt state-dict file for the custom model. |
| threshold : float |
| Decision boundary for the sigmoid probability (default 0.5). |
| Set to the optimal F1 threshold found during your training run. |
| device : torch.device | None |
| Auto-detects CUDA if None. |
| """ |
|
|
| def __init__( |
| self, |
| checkpoint: str = "best_text_detector.pt", |
| threshold: float = 0.5, |
| device: torch.device = None, |
| ): |
| self.threshold = threshold |
| self.device = device or torch.device( |
| "cuda" if torch.cuda.is_available() else "cpu" |
| ) |
| self._use_custom = False |
| self._fallback = None |
| self.model = None |
| self.tokenizer = None |
|
|
| if os.path.exists(checkpoint): |
| |
| print(f"[TextDetector] β
Checkpoint found: {checkpoint}") |
| print(f"[TextDetector] Loading tokenizer from {MODEL_NAME} β¦") |
| self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| self.model = HybridAITextDetector() |
| self.model.load_state_dict( |
| torch.load(checkpoint, map_location=self.device) |
| ) |
| self.model.eval().to(self.device) |
| self._use_custom = True |
| print("[TextDetector] β
Custom model ready") |
| else: |
| |
| print( |
| f"[TextDetector] β οΈ '{checkpoint}' not found.\n" |
| f"[TextDetector] Loading pretrained fallback: {FALLBACK_MODEL_ID}" |
| ) |
| try: |
| self._fallback = hf_pipeline( |
| "text-classification", |
| model=FALLBACK_MODEL_ID, |
| device=0 if torch.cuda.is_available() else -1, |
| truncation=True, |
| max_length=512, |
| ) |
| print(f"[TextDetector] β
Fallback model ready ({FALLBACK_MODEL_ID})") |
| except Exception as e: |
| print(f"[TextDetector] β Fallback model failed to load: {e}") |
| self._fallback = None |
|
|
| |
| def predict(self, text: str) -> dict: |
| """ |
| Classify a single text string. |
| |
| Returns |
| ------- |
| dict with keys: |
| label : "AI-Generated" or "Human-Written" |
| confidence : probability of the predicted class (0β1) |
| ai_prob : raw P(AI-generated) |
| human_prob : 1 - ai_prob |
| source : "custom_model" | "pretrained_fallback" |
| """ |
| text = text.strip() |
| if not text: |
| return {"error": "Input text is empty."} |
|
|
| if self._use_custom: |
| return self._predict_custom(text) |
| elif self._fallback is not None: |
| return self._predict_fallback(text) |
| else: |
| return { |
| "error": ( |
| "No model available. Upload 'best_text_detector.pt' to the " |
| "Space, or check your internet connection so the fallback " |
| "model can be downloaded." |
| ) |
| } |
|
|
| |
| def _predict_custom(self, text: str) -> dict: |
| """Run inference with the custom HybridAITextDetector checkpoint.""" |
| enc = self.tokenizer( |
| text, |
| truncation=True, |
| padding="max_length", |
| max_length=MAX_LENGTH, |
| return_tensors="pt", |
| ) |
| input_ids = enc["input_ids"].to(self.device) |
| attention_mask = enc["attention_mask"].to(self.device) |
| token_type_ids = enc.get( |
| "token_type_ids", |
| torch.zeros_like(enc["input_ids"]), |
| ).to(self.device) |
|
|
| with torch.no_grad(): |
| logit = self.model(input_ids, attention_mask, token_type_ids) |
| ai_prob = torch.sigmoid(logit).item() |
|
|
| human_prob = 1.0 - ai_prob |
| is_ai = ai_prob >= self.threshold |
| label = "AI-Generated" if is_ai else "Human-Written" |
| confidence = ai_prob if is_ai else human_prob |
|
|
| return { |
| "label": label, |
| "confidence": round(confidence, 4), |
| "ai_prob": round(ai_prob, 4), |
| "human_prob": round(human_prob, 4), |
| "source": "custom_model", |
| } |
|
|
| |
| def _predict_fallback(self, text: str) -> dict: |
| """ |
| Run inference with the pretrained HuggingFace fallback model. |
| |
| Hello-SimpleAI/chatgpt-detector-roberta returns: |
| {"label": "ChatGPT" | "Human", "score": float} |
| We normalise this to the same dict shape as _predict_custom. |
| """ |
| try: |
| raw = self._fallback(text)[0] |
| except Exception as e: |
| return {"error": f"Fallback inference failed: {e}"} |
|
|
| hf_label = raw["label"].strip().lower() |
| score = float(raw["score"]) |
|
|
| if hf_label in ("chatgpt", "ai", "fake", "generated"): |
| ai_prob = score |
| human_prob = 1.0 - score |
| label = "AI-Generated" |
| else: |
| human_prob = score |
| ai_prob = 1.0 - score |
| label = "Human-Written" |
|
|
| is_ai = ai_prob >= self.threshold |
| label = "AI-Generated" if is_ai else "Human-Written" |
| confidence = ai_prob if is_ai else human_prob |
|
|
| return { |
| "label": label, |
| "confidence": round(confidence, 4), |
| "ai_prob": round(ai_prob, 4), |
| "human_prob": round(human_prob, 4), |
| "source": "pretrained_fallback", |
| } |
|
|
| |
| def predict_batch(self, texts: list) -> list: |
| """Run predict() on a list of texts. Returns list of result dicts.""" |
| return [self.predict(t) for t in texts] |
|
|
| |
| def format_for_gradio(self, text: str) -> tuple: |
| """ |
| Convenience wrapper returning Gradio-friendly values: |
| (label_string, confidence_float, full_result_dict) |
| """ |
| result = self.predict(text) |
| if "error" in result: |
| return result["error"], 0.0, result |
| return result["label"], result["confidence"], result |