Mrkomiljon commited on
Commit
d874350
·
verified ·
1 Parent(s): a7ca176

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -15
app.py CHANGED
@@ -9,10 +9,16 @@ import numpy as np
9
  import pandas as pd
10
  import warnings
11
  import nltk
 
 
12
  from sentence_transformers import SentenceTransformer
13
  from huggingface_hub import hf_hub_download
14
- from sklearn.calibration import CalibratedClassifierCV
15
- from sklearn.model_selection import train_test_split
 
 
 
 
16
 
17
  warnings.filterwarnings("ignore")
18
 
@@ -35,7 +41,7 @@ FORCED_DIM = 768
35
  def ensure_nltk():
36
  resources = {
37
  "punkt": "tokenizers/punkt",
38
- "punkt_tab": "tokenizers/punkt_tab/english",
39
  }
40
  for pkg, path in resources.items():
41
  try:
@@ -61,7 +67,7 @@ def preprocess_text(text: str, max_chars: int = 100000) -> str:
61
  t = t[:max_chars]
62
  return t
63
 
64
- def chunk_by_words(text: str, words_per_chunk: int = 350):
65
  words = text.split()
66
  if not words:
67
  return []
@@ -73,9 +79,16 @@ def chunk_by_words(text: str, words_per_chunk: int = 350):
73
  return chunks
74
 
75
  # -------------------------------------------------
76
- # Load classifier + embedder (forced 768-dim)
77
  # -------------------------------------------------
 
78
  def load_embedding_model():
 
 
 
 
 
 
79
  path = hf_hub_download(
80
  repo_id=REPO_ID,
81
  filename=FILENAME,
@@ -96,7 +109,7 @@ def load_embedding_model():
96
  if actual_dim != FORCED_DIM:
97
  raise RuntimeError(f"Loaded embedder dim={actual_dim}, expected {FORCED_DIM}")
98
 
99
- # Classifier sanity check
100
  clf_dim = getattr(clf, "n_features_in_", None)
101
  if clf_dim and clf_dim != FORCED_DIM:
102
  raise RuntimeError(
@@ -104,15 +117,131 @@ def load_embedding_model():
104
  )
105
 
106
  # -------------------------------------------------
107
- # Ensure calibration
108
  # -------------------------------------------------
109
  if not hasattr(clf, "predict_proba") or "CalibratedClassifierCV" not in str(type(clf)):
110
- print("⚠️ Classifier not calibrated — applying Platt scaling (logistic calibration)")
111
- # Create dummy calibration using embeddings stored in joblib (if available)
112
- X_train = data.get("X_val")
113
- y_train = data.get("y_val")
114
- if X_train is not None and y_train is not None:
115
- clf = CalibratedClassifierCV(base_estimator=clf, method="sigmoid", cv="prefit")
116
- clf.fit(X_train, y_train)
 
 
 
 
 
 
 
117
  else:
118
- print("⚠️ No calibration data found probabilities may still be uncalibrated")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  import pandas as pd
10
  import warnings
11
  import nltk
12
+ from functools import lru_cache
13
+ from typing import List, Dict, Any
14
  from sentence_transformers import SentenceTransformer
15
  from huggingface_hub import hf_hub_download
16
+
17
+ # Optional: sklearn imports (some envs have slightly different names)
18
+ try:
19
+ from sklearn.calibration import CalibratedClassifierCV
20
+ except Exception:
21
+ CalibratedClassifierCV = None
22
 
23
  warnings.filterwarnings("ignore")
24
 
 
41
  def ensure_nltk():
42
  resources = {
43
  "punkt": "tokenizers/punkt",
44
+ "punkt_tab": "tokenizers/punkt_tab/english", # may not exist on older NLTK; safe to ignore
45
  }
46
  for pkg, path in resources.items():
47
  try:
 
67
  t = t[:max_chars]
68
  return t
69
 
70
+ def chunk_by_words(text: str, words_per_chunk: int = 350) -> List[str]:
71
  words = text.split()
72
  if not words:
73
  return []
 
79
  return chunks
80
 
81
  # -------------------------------------------------
82
+ # Model loader (cached)
83
  # -------------------------------------------------
84
+ @lru_cache(maxsize=1)
85
  def load_embedding_model():
86
+ """
87
+ Returns:
88
+ clf: a classifier with predict / predict_proba
89
+ embedding_model: SentenceTransformer (768-dim)
90
+ meta: dict with any metadata found in the joblib file
91
+ """
92
  path = hf_hub_download(
93
  repo_id=REPO_ID,
94
  filename=FILENAME,
 
109
  if actual_dim != FORCED_DIM:
110
  raise RuntimeError(f"Loaded embedder dim={actual_dim}, expected {FORCED_DIM}")
111
 
112
+ # Ensure classifier feature dimension matches
113
  clf_dim = getattr(clf, "n_features_in_", None)
114
  if clf_dim and clf_dim != FORCED_DIM:
115
  raise RuntimeError(
 
117
  )
118
 
119
  # -------------------------------------------------
120
+ # Ensure calibration (Platt scaling if needed)
121
  # -------------------------------------------------
122
  if not hasattr(clf, "predict_proba") or "CalibratedClassifierCV" not in str(type(clf)):
123
+ print("⚠️ Classifier not calibrated — applying Platt scaling (logistic calibration).")
124
+ X_val = data.get("X_val")
125
+ y_val = data.get("y_val")
126
+ if X_val is not None and y_val is not None and CalibratedClassifierCV is not None:
127
+ try:
128
+ # Newer sklearn uses 'estimator'; older uses 'base_estimator'
129
+ try:
130
+ clf = CalibratedClassifierCV(estimator=clf, method="sigmoid", cv="prefit")
131
+ except TypeError:
132
+ clf = CalibratedClassifierCV(base_estimator=clf, method="sigmoid", cv="prefit")
133
+ clf.fit(X_val, y_val)
134
+ print("✅ Calibration complete using provided validation split.")
135
+ except Exception as e:
136
+ print(f"⚠️ Calibration failed: {e}. Continuing with uncalibrated probabilities.")
137
  else:
138
+ print("⚠️ No calibration data found or CalibratedClassifierCV unavailable.")
139
+
140
+ meta = {k: v for k, v in data.items() if k not in {"model"}}
141
+ return clf, embedding_model, meta
142
+
143
+ # -------------------------------------------------
144
+ # Inference
145
+ # -------------------------------------------------
146
+ def embed_texts(embedding_model: SentenceTransformer, texts: List[str]) -> np.ndarray:
147
+ with torch.no_grad():
148
+ embs = embedding_model.encode(
149
+ texts,
150
+ batch_size=32,
151
+ show_progress_bar=False,
152
+ convert_to_numpy=True,
153
+ normalize_embeddings=False, # keep raw for classifier trained that way
154
+ )
155
+ return embs.astype(np.float32)
156
+
157
+ def aggregate_probs(chunk_probs: List[float], mode: str = "mean") -> float:
158
+ if not chunk_probs:
159
+ return 0.0
160
+ arr = np.array(chunk_probs, dtype=np.float32)
161
+ if mode == "max":
162
+ return float(np.max(arr))
163
+ if mode == "median":
164
+ return float(np.median(arr))
165
+ return float(np.mean(arr)) # default mean
166
+
167
+ def predict_single(text: str,
168
+ words_per_chunk: int = 350,
169
+ agg_mode: str = "mean") -> Dict[str, Any]:
170
+ clf, embedding_model, _ = load_embedding_model()
171
+ clean = preprocess_text(text)
172
+ chunks = chunk_by_words(clean, words_per_chunk=words_per_chunk) or [clean]
173
+
174
+ embs = embed_texts(embedding_model, chunks)
175
+ # Binary classifier: assume class 1 = AI, class 0 = Human (common convention)
176
+ proba = clf.predict_proba(embs)[:, 1] if hasattr(clf, "predict_proba") else clf.decision_function(embs)
177
+ # If decision_function, convert to [0,1] via sigmoid as a fallback
178
+ if proba.ndim == 1 and (proba.min() < 0 or proba.max() > 1):
179
+ proba = 1 / (1 + np.exp(-proba))
180
+
181
+ chunk_outputs = [{"chunk_index": i, "proba_ai": float(p), "text_preview": chunks[i][:120]} for i, p in enumerate(proba)]
182
+ doc_proba = aggregate_probs([co["proba_ai"] for co in chunk_outputs], mode=agg_mode)
183
+
184
+ return {
185
+ "doc_proba_ai": float(doc_proba),
186
+ "agg_mode": agg_mode,
187
+ "words_per_chunk": words_per_chunk,
188
+ "num_chunks": len(chunks),
189
+ "chunks": chunk_outputs,
190
+ }
191
+
192
+ def classify_text(text: str,
193
+ decision_threshold: float = 0.5,
194
+ words_per_chunk: int = 350,
195
+ agg_mode: str = "mean") -> Dict[str, Any]:
196
+ """
197
+ decision_threshold: label 'AI' if doc_proba_ai >= threshold
198
+ """
199
+ result = predict_single(text, words_per_chunk=words_per_chunk, agg_mode=agg_mode)
200
+ label = "AI" if result["doc_proba_ai"] >= float(decision_threshold) else "Human"
201
+ result["decision_threshold"] = float(decision_threshold)
202
+ result["label"] = label
203
+ return result
204
+
205
+ # -------------------------------------------------
206
+ # Gradio UI
207
+ # -------------------------------------------------
208
+ with gr.Blocks(title="Text AI Detector") as demo:
209
+ gr.Markdown("## 🔎 Text AI Detector (MPNet 768-dim + Calibrated Classifier)")
210
+
211
+ with gr.Row():
212
+ inp = gr.Textbox(label="Input text", lines=12, placeholder="Paste or type English text here...")
213
+ with gr.Row():
214
+ thr = gr.Slider(0.0, 1.0, value=0.50, step=0.01, label="Decision threshold (AI if ≥ threshold)")
215
+ with gr.Row():
216
+ wpc = gr.Slider(100, 800, value=350, step=50, label="Words per chunk")
217
+ agg = gr.Dropdown(choices=["mean", "median", "max"], value="mean", label="Chunk aggregation")
218
+ with gr.Row():
219
+ btn = gr.Button("Classify")
220
+ with gr.Row():
221
+ out_label = gr.Textbox(label="Label", interactive=False)
222
+ out_proba = gr.Number(label="Document AI probability", interactive=False, precision=4)
223
+ with gr.Row():
224
+ out_json = gr.JSON(label="Details (per-chunk probabilities)")
225
+
226
+ def _on_click(text, threshold, words_per_chunk, agg_mode):
227
+ if not text or not text.strip():
228
+ return ("", 0.0, {"error": "Empty text"})
229
+ try:
230
+ res = classify_text(text, float(threshold), int(words_per_chunk), agg_mode)
231
+ return (res["label"], res["doc_proba_ai"], res)
232
+ except Exception as e:
233
+ return ("", 0.0, {"error": str(e)})
234
+
235
+ btn.click(_on_click, inputs=[inp, thr, wpc, agg], outputs=[out_label, out_proba, out_json])
236
+
237
+ # Provide ASGI `app` for hosts that expect it (e.g., Uvicorn/Gunicorn)
238
+ try:
239
+ from fastapi import FastAPI
240
+ fastapi_app = FastAPI()
241
+ app = gr.mount_gradio_app(fastapi_app, demo, path="/")
242
+ except Exception:
243
+ app = None
244
+
245
+ if __name__ == "__main__":
246
+ # Local run
247
+ demo.queue(concurrency_count=2).launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))