File size: 8,987 Bytes
a7e89c0
 
 
 
28c4d49
 
 
 
 
 
a7e89c0
 
 
 
 
28c4d49
 
a7e89c0
 
 
 
28c4d49
a7e89c0
 
28c4d49
 
 
 
 
 
a7e89c0
 
 
28c4d49
 
a7e89c0
 
 
 
28c4d49
a7e89c0
 
28c4d49
 
a7e89c0
 
 
 
 
28c4d49
a7e89c0
 
 
28c4d49
 
a7e89c0
 
28c4d49
 
 
 
a7e89c0
 
28c4d49
 
 
 
a7e89c0
 
 
 
 
28c4d49
 
a7e89c0
28c4d49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7e89c0
 
 
 
 
 
 
 
28c4d49
a7e89c0
 
28c4d49
a7e89c0
 
 
 
 
28c4d49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7e89c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28c4d49
 
a7e89c0
 
 
 
 
 
 
 
 
 
 
28c4d49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7e89c0
 
28c4d49
 
a7e89c0
 
 
28c4d49
 
a7e89c0
28c4d49
a7e89c0
 
 
 
 
28c4d49
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
"""
text_detector_inference.py
==========================
Inference wrapper for HybridAITextDetector.

Strategy
--------
1. If  ``best_text_detector.pt``  exists  β†’  load the custom trained model.
2. Otherwise                               β†’  fall back to a pretrained
   HuggingFace AI-text detector so the Space keeps working immediately.

Usage
-----
from text_detector_inference import TextDetectorInference

detector = TextDetectorInference()          # auto-detects checkpoint
result   = detector.predict("Some text…")
"""

import os
import torch
from transformers import AutoTokenizer, pipeline as hf_pipeline
from text_detector_model import HybridAITextDetector, MODEL_NAME, MAX_LENGTH

# ─── Fallback model ───────────────────────────────────────────────────────────
# Used when best_text_detector.pt is not present in the Space.
# "Hello-SimpleAI/chatgpt-detector-roberta" is a publicly available,
# well-validated AI-text detector (RoBERTa fine-tuned on ChatGPT outputs).
FALLBACK_MODEL_ID = "Hello-SimpleAI/chatgpt-detector-roberta"


class TextDetectorInference:
    """
    Thin wrapper around HybridAITextDetector (or a fallback pretrained model)
    for single-text prediction.

    Parameters
    ----------
    checkpoint : str
        Path to the .pt state-dict file for the custom model.
    threshold  : float
        Decision boundary for the sigmoid probability (default 0.5).
        Set to the optimal F1 threshold found during your training run.
    device     : torch.device | None
        Auto-detects CUDA if None.
    """

    def __init__(
        self,
        checkpoint: str   = "best_text_detector.pt",
        threshold:  float = 0.5,
        device: torch.device = None,
    ):
        self.threshold   = threshold
        self.device      = device or torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )
        self._use_custom = False
        self._fallback   = None
        self.model       = None
        self.tokenizer   = None

        if os.path.exists(checkpoint):
            # ── Load custom trained HybridAITextDetector ──────────────────────
            print(f"[TextDetector] βœ… Checkpoint found: {checkpoint}")
            print(f"[TextDetector] Loading tokenizer from {MODEL_NAME} …")
            self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
            self.model = HybridAITextDetector()
            self.model.load_state_dict(
                torch.load(checkpoint, map_location=self.device)
            )
            self.model.eval().to(self.device)
            self._use_custom = True
            print("[TextDetector] βœ… Custom model ready")
        else:
            # ── Fall back to pretrained HuggingFace model ─────────────────────
            print(
                f"[TextDetector] ⚠️  '{checkpoint}' not found.\n"
                f"[TextDetector] Loading pretrained fallback: {FALLBACK_MODEL_ID}"
            )
            try:
                self._fallback = hf_pipeline(
                    "text-classification",
                    model=FALLBACK_MODEL_ID,
                    device=0 if torch.cuda.is_available() else -1,
                    truncation=True,
                    max_length=512,
                )
                print(f"[TextDetector] βœ… Fallback model ready ({FALLBACK_MODEL_ID})")
            except Exception as e:
                print(f"[TextDetector] ❌ Fallback model failed to load: {e}")
                self._fallback = None

    # ──────────────────────────────────────────────────────────────────────────
    def predict(self, text: str) -> dict:
        """
        Classify a single text string.

        Returns
        -------
        dict with keys:
            label      : "AI-Generated" or "Human-Written"
            confidence : probability of the predicted class (0–1)
            ai_prob    : raw P(AI-generated)
            human_prob : 1 - ai_prob
            source     : "custom_model" | "pretrained_fallback"
        """
        text = text.strip()
        if not text:
            return {"error": "Input text is empty."}

        if self._use_custom:
            return self._predict_custom(text)
        elif self._fallback is not None:
            return self._predict_fallback(text)
        else:
            return {
                "error": (
                    "No model available. Upload 'best_text_detector.pt' to the "
                    "Space, or check your internet connection so the fallback "
                    "model can be downloaded."
                )
            }

    # ──────────────────────────────────────────────────────────────────────────
    def _predict_custom(self, text: str) -> dict:
        """Run inference with the custom HybridAITextDetector checkpoint."""
        enc = self.tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=MAX_LENGTH,
            return_tensors="pt",
        )
        input_ids      = enc["input_ids"].to(self.device)
        attention_mask = enc["attention_mask"].to(self.device)
        token_type_ids = enc.get(
            "token_type_ids",
            torch.zeros_like(enc["input_ids"]),
        ).to(self.device)

        with torch.no_grad():
            logit   = self.model(input_ids, attention_mask, token_type_ids)
            ai_prob = torch.sigmoid(logit).item()

        human_prob = 1.0 - ai_prob
        is_ai      = ai_prob >= self.threshold
        label      = "AI-Generated" if is_ai else "Human-Written"
        confidence = ai_prob if is_ai else human_prob

        return {
            "label":      label,
            "confidence": round(confidence, 4),
            "ai_prob":    round(ai_prob, 4),
            "human_prob": round(human_prob, 4),
            "source":     "custom_model",
        }

    # ──────────────────────────────────────────────────────────────────────────
    def _predict_fallback(self, text: str) -> dict:
        """
        Run inference with the pretrained HuggingFace fallback model.

        Hello-SimpleAI/chatgpt-detector-roberta returns:
            {"label": "ChatGPT" | "Human", "score": float}
        We normalise this to the same dict shape as _predict_custom.
        """
        try:
            raw = self._fallback(text)[0]          # {"label": ..., "score": ...}
        except Exception as e:
            return {"error": f"Fallback inference failed: {e}"}

        hf_label = raw["label"].strip().lower()    # "chatgpt" or "human"
        score    = float(raw["score"])             # confidence of the returned label

        if hf_label in ("chatgpt", "ai", "fake", "generated"):
            ai_prob    = score
            human_prob = 1.0 - score
            label      = "AI-Generated"
        else:
            human_prob = score
            ai_prob    = 1.0 - score
            label      = "Human-Written"

        is_ai      = ai_prob >= self.threshold
        label      = "AI-Generated" if is_ai else "Human-Written"
        confidence = ai_prob if is_ai else human_prob

        return {
            "label":      label,
            "confidence": round(confidence, 4),
            "ai_prob":    round(ai_prob, 4),
            "human_prob": round(human_prob, 4),
            "source":     "pretrained_fallback",
        }

    # ──────────────────────────────────────────────────────────────────────────
    def predict_batch(self, texts: list) -> list:
        """Run predict() on a list of texts. Returns list of result dicts."""
        return [self.predict(t) for t in texts]

    # ──────────────────────────────────────────────────────────────────────────
    def format_for_gradio(self, text: str) -> tuple:
        """
        Convenience wrapper returning Gradio-friendly values:
            (label_string, confidence_float, full_result_dict)
        """
        result = self.predict(text)
        if "error" in result:
            return result["error"], 0.0, result
        return result["label"], result["confidence"], result