VictorM-Coder commited on
Commit
cec130d
·
verified ·
1 Parent(s): 5c9e54f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +318 -0
app.py CHANGED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import shutil
4
+
5
+ # ============================================================
6
+ # ENV (set BEFORE transformers/hub usage)
7
+ # ============================================================
8
+ os.environ.setdefault("HF_HOME", "/tmp/hf")
9
+ os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/tmp/hf/hub")
10
+ os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/hf/transformers")
11
+ os.environ.setdefault("HF_HUB_DISABLE_XET", "1") # disable hf-xet if present
12
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import pandas as pd
18
+ import gradio as gr
19
+
20
+ from huggingface_hub import hf_hub_download
21
+ from transformers import AutoConfig, AutoTokenizer, AutoModel
22
+ from safetensors.torch import load_file
23
+
24
+
25
+ # -----------------------------
26
+ # MODEL INITIALIZATION
27
+ # -----------------------------
28
+ MODEL_NAME = "desklib/ai-text-detector-v1.01"
29
+ tokenizer = None
30
+ model = None
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ THRESHOLD = 0.59
33
+
34
+
35
+ def _build_error_card(msg: str) -> str:
36
+ return (
37
+ "<div style='color:#b80d0d; padding:14px; border:1px solid #b80d0d; "
38
+ "border-radius:10px; background:rgba(184,13,13,0.06);'>"
39
+ f"{msg}</div>"
40
+ )
41
+
42
+
43
+ def wipe_model_cache(model_id: str) -> int:
44
+ """
45
+ Delete cached files for this model from common HF cache locations.
46
+ Returns number of cache directories removed.
47
+ """
48
+ safe = model_id.replace("/", "--")
49
+ candidates = [
50
+ # our /tmp cache (recommended)
51
+ f"/tmp/hf/hub/models--{safe}",
52
+ f"/tmp/hf/transformers/models--{safe}",
53
+ # default home cache (in case something wrote there)
54
+ os.path.expanduser(f"~/.cache/huggingface/hub/models--{safe}"),
55
+ os.path.expanduser(f"~/.cache/huggingface/transformers/models--{safe}"),
56
+ os.path.expanduser(f"~/.cache/huggingface/modules/models--{safe}"),
57
+ ]
58
+
59
+ removed = 0
60
+ for path in candidates:
61
+ if os.path.exists(path):
62
+ shutil.rmtree(path, ignore_errors=True)
63
+ removed += 1
64
+ return removed
65
+
66
+
67
+ class DesklibAIDetectionModel(nn.Module):
68
+ """
69
+ Matches the architecture described by desklib:
70
+ base transformer + mean pooling + linear classifier to 1 logit.
71
+ The repo config lists "architectures": ["DesklibAIDetectionModel"]. :contentReference[oaicite:1]{index=1}
72
+ """
73
+ def __init__(self, config):
74
+ super().__init__()
75
+ self.backbone = AutoModel.from_config(config)
76
+ self.classifier = nn.Linear(config.hidden_size, 1)
77
+
78
+ def forward(self, input_ids, attention_mask=None):
79
+ outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
80
+ last_hidden = outputs.last_hidden_state # (B, T, H)
81
+
82
+ if attention_mask is None:
83
+ pooled = last_hidden.mean(dim=1)
84
+ else:
85
+ mask = attention_mask.unsqueeze(-1).expand(last_hidden.size()).float()
86
+ summed = torch.sum(last_hidden * mask, dim=1)
87
+ denom = torch.clamp(mask.sum(dim=1), min=1e-9)
88
+ pooled = summed / denom
89
+
90
+ logits = self.classifier(pooled) # (B, 1)
91
+ return logits
92
+
93
+
94
+ def load_desklib_model(force_redownload: bool = False):
95
+ """
96
+ Robust loader:
97
+ - downloads config/tokenizer normally
98
+ - downloads model.safetensors explicitly
99
+ - loads safetensors via safetensors.torch.load_file
100
+ - loads into our matching PyTorch module with strict=False
101
+ """
102
+ global tokenizer, model
103
+
104
+ if (not force_redownload) and tokenizer is not None and model is not None:
105
+ return tokenizer, model
106
+
107
+ if force_redownload:
108
+ print("💣 NUKE requested: wiping cache + forcing fresh downloads...")
109
+ removed = wipe_model_cache(MODEL_NAME)
110
+ print(f"🧹 Cache dirs removed: {removed}")
111
+ tokenizer = None
112
+ model = None
113
+
114
+ print(f"🚀 Loading tokenizer/config: {MODEL_NAME}")
115
+ config = AutoConfig.from_pretrained(MODEL_NAME, force_download=force_redownload)
116
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, force_download=force_redownload)
117
+
118
+ print("⬇️ Downloading model.safetensors explicitly...")
119
+ weights_path = hf_hub_download(
120
+ repo_id=MODEL_NAME,
121
+ filename="model.safetensors",
122
+ force_download=force_redownload,
123
+ )
124
+
125
+ size_gb = os.path.getsize(weights_path) / (1024**3)
126
+ print(f"✅ model.safetensors path: {weights_path}")
127
+ print(f"✅ model.safetensors size: {size_gb:.2f} GB")
128
+
129
+ # Build model + load weights
130
+ print("🧠 Building DesklibAIDetectionModel + loading weights...")
131
+ m = DesklibAIDetectionModel(config)
132
+ state = load_file(weights_path) # this will throw if file is truly corrupt
133
+ missing, unexpected = m.load_state_dict(state, strict=False)
134
+
135
+ # Helpful debug (won't crash)
136
+ if missing:
137
+ print(f"⚠️ Missing keys (first 20): {missing[:20]}")
138
+ if unexpected:
139
+ print(f"⚠️ Unexpected keys (first 20): {unexpected[:20]}")
140
+
141
+ model = m.to(device).eval()
142
+ return tokenizer, model
143
+
144
+
145
+ # -----------------------------
146
+ # UTILITIES
147
+ # -----------------------------
148
+ ABBR = ["e.g", "i.e", "mr", "mrs", "ms", "dr", "prof", "vs", "etc", "fig", "al", "jr", "sr", "st", "inc", "ltd", "u.s", "u.k"]
149
+ ABBR_REGEX = re.compile(r"\b(" + "|".join(map(re.escape, ABBR)) + r")\.", re.IGNORECASE)
150
+
151
+ def _protect(text):
152
+ text = text.replace("...", "⟨ELLIPSIS⟩")
153
+ text = re.sub(r"(?<=\d)\.(?=\d)", "⟨DECIMAL⟩", text)
154
+ text = ABBR_REGEX.sub(r"\1⟨ABBRDOT⟩", text)
155
+ return text
156
+
157
+ def _restore(text):
158
+ return text.replace("⟨ABBRDOT⟩", ".").replace("⟨DECIMAL⟩", ".").replace("⟨ELLIPSIS⟩", "...")
159
+
160
+ def split_preserving_structure(text):
161
+ blocks = re.split(r"(\n+)", text)
162
+ final_blocks = []
163
+ for block in blocks:
164
+ if not block:
165
+ continue
166
+ if block.startswith("\n"):
167
+ final_blocks.append(block)
168
+ else:
169
+ protected = _protect(block)
170
+ parts = re.split(r"([.?!])(\s+)", protected)
171
+ for i in range(0, len(parts), 3):
172
+ sentence = parts[i]
173
+ punct = parts[i + 1] if i + 1 < len(parts) else ""
174
+ space = parts[i + 2] if i + 2 < len(parts) else ""
175
+ if sentence.strip():
176
+ final_blocks.append(_restore(sentence + punct))
177
+ if space:
178
+ final_blocks.append(space)
179
+ return final_blocks
180
+
181
+
182
+ # -----------------------------
183
+ # ANALYSIS
184
+ # -----------------------------
185
+ @torch.inference_mode()
186
+ def analyze(text):
187
+ text = (text or "").strip()
188
+ if not text:
189
+ return "—", "—", "<em>Please enter text...</em>", None, ""
190
+
191
+ word_count = len(text.split())
192
+ if word_count < 250:
193
+ warning_msg = f"⚠️ <b>Insufficient Text:</b> Your input has {word_count} words. Please enter at least 250 words for accurate results."
194
+ return "Too Short", "N/A", _build_error_card(warning_msg), None, ""
195
+
196
+ try:
197
+ tok, mod = load_desklib_model(force_redownload=False)
198
+ except Exception as e:
199
+ return "ERROR", "0%", _build_error_card(f"<b>Failed to load model:</b><br>{str(e)}"), None, ""
200
+
201
+ blocks = split_preserving_structure(text)
202
+ pure_sents_indices = [i for i, b in enumerate(blocks) if b.strip() and not b.startswith("\n")]
203
+ pure_sents = [blocks[i] for i in pure_sents_indices]
204
+
205
+ if not pure_sents:
206
+ return "—", "—", "<em>No sentences detected.</em>", None, ""
207
+
208
+ windows = []
209
+ for i in range(len(pure_sents)):
210
+ start = max(0, i - 1)
211
+ end = min(len(pure_sents), i + 2)
212
+ windows.append(" ".join(pure_sents[start:end]))
213
+
214
+ batch_size = 8
215
+ probs = []
216
+ for i in range(0, len(windows), batch_size):
217
+ batch = windows[i: i + batch_size]
218
+ inputs = tok(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
219
+ logits = mod(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask"))
220
+ batch_probs = torch.sigmoid(logits).detach().cpu().numpy().flatten().tolist()
221
+ probs.extend(batch_probs)
222
+
223
+ lengths = [len(s.split()) for s in pure_sents]
224
+ total_words = sum(lengths)
225
+ weighted_avg = sum(p * l for p, l in zip(probs, lengths)) / total_words if total_words > 0 else 0
226
+
227
+ # HTML Heatmap
228
+ highlighted_html = "<div style='font-family: sans-serif; line-height: 1.8;'>"
229
+ prob_map = {idx: probs[i] for i, idx in enumerate(pure_sents_indices)}
230
+
231
+ for i, block in enumerate(blocks):
232
+ if block.startswith("\n") or block.isspace():
233
+ highlighted_html += block.replace("\n", "<br>")
234
+ continue
235
+
236
+ if i in prob_map:
237
+ score = prob_map[i]
238
+ if score >= THRESHOLD:
239
+ color, bg = "#d32f2f", "rgba(211, 47, 47, 0.12)"
240
+ border = "2px solid #d32f2f"
241
+ else:
242
+ color, bg = "#2e7d32", "rgba(46, 125, 50, 0.08)"
243
+ border = "1px solid transparent"
244
+
245
+ highlighted_html += (
246
+ f"<span style='background:{bg}; padding:1px 2px; border-radius:3px; border-bottom: {border}; cursor: help;' "
247
+ f"title='AI Confidence: {score:.2%}'>"
248
+ f"<span style='color:{color}; font-weight: bold; font-size: 0.75em; vertical-align: super; margin-right: 2px;'>{score:.0%}</span>"
249
+ f"{block}</span>"
250
+ )
251
+ else:
252
+ highlighted_html += block
253
+
254
+ highlighted_html += "</div>"
255
+
256
+ label = f"{weighted_avg:.1%} AI Written"
257
+ display_score = f"{weighted_avg:.2%}"
258
+ df = pd.DataFrame({"Sentence": pure_sents, "AI Confidence": [f"{p:.2%}" for p in probs]})
259
+
260
+ return label, display_score, highlighted_html, df, ""
261
+
262
+
263
+ def nuke_and_reload():
264
+ try:
265
+ load_desklib_model(force_redownload=True)
266
+ return (
267
+ "✅ **Nuked cache and reloaded model successfully.**\n\n"
268
+ "- Cache wiped\n"
269
+ "- Fresh download forced\n"
270
+ "- Custom loader used (DesklibAIDetectionModel)\n"
271
+ "- Model ready ✅"
272
+ )
273
+ except Exception as e:
274
+ return (
275
+ "❌ **Nuke attempted but model still failed to load.**\n\n"
276
+ f"**Error:** `{str(e)}`\n\n"
277
+ "If this error happens inside `load_file(model.safetensors)`, the file is truly corrupted/truncated.\n"
278
+ "If it happens after that, it’s likely key mismatches (shown in logs as missing/unexpected keys)."
279
+ )
280
+
281
+
282
+ # -----------------------------
283
+ # INTERFACE
284
+ # -----------------------------
285
+ with gr.Blocks(theme=gr.themes.Soft(), title="AI Detector Pro") as demo:
286
+ gr.Markdown("# 🕵️ AI Detector Pro")
287
+ gr.Markdown(f"Model: **{MODEL_NAME}** | Highlight Threshold: **{THRESHOLD*100:.0f}%**")
288
+
289
+ with gr.Row():
290
+ with gr.Column(scale=3):
291
+ text_input = gr.Textbox(label="Input Text", lines=15, placeholder="Enter at least 250 words...")
292
+ with gr.Row():
293
+ clear_btn = gr.Button("Clear")
294
+ run_btn = gr.Button("Analyze Text", variant="primary")
295
+ nuke_btn = gr.Button("💣 Nuke Model Cache", variant="stop")
296
+
297
+ with gr.Column(scale=1):
298
+ verdict_out = gr.Label(label="Global Verdict")
299
+ score_out = gr.Label(label="Weighted Probability")
300
+
301
+ status_out = gr.Markdown()
302
+
303
+ with gr.Tabs():
304
+ with gr.TabItem("Visual Heatmap"):
305
+ html_out = gr.HTML()
306
+ with gr.TabItem("Data Breakdown"):
307
+ table_out = gr.Dataframe(headers=["Sentence", "AI Confidence"], wrap=True)
308
+
309
+ run_btn.click(analyze, inputs=text_input, outputs=[verdict_out, score_out, html_out, table_out, status_out])
310
+
311
+ def _clear():
312
+ return "", "—", "—", "<em>Please enter text...</em>", None, ""
313
+
314
+ clear_btn.click(_clear, outputs=[text_input, verdict_out, score_out, html_out, table_out, status_out])
315
+ nuke_btn.click(nuke_and_reload, outputs=status_out)
316
+
317
+ if __name__ == "__main__":
318
+ demo.launch()