VictorM-Coder commited on
Commit
3b317ce
·
verified ·
1 Parent(s): 62bd0d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -57
app.py CHANGED
@@ -1,10 +1,12 @@
 
 
 
1
  import torch
2
  import torch.nn.functional as F
3
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
- import re
5
  import pandas as pd
6
  import gradio as gr
7
- import os
 
8
 
9
  # -----------------------------
10
  # MODEL INITIALIZATION
@@ -14,51 +16,109 @@ tokenizer = None
14
  model = None
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def get_model():
 
 
 
 
18
  global tokenizer, model
19
- if model is None:
20
- try:
21
- print(f"🚀 Loading Model: {MODEL_NAME} on {device}")
22
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
23
- # added use_safetensors=True and low_cpu_mem_usage for stability
24
- model = AutoModelForSequenceClassification.from_pretrained(
25
- MODEL_NAME,
26
- ignore_mismatched_sizes=True,
27
- low_cpu_mem_usage=True
28
- ).to(device).eval()
29
- except Exception as e:
30
- print(f"⚠️ Initial load failed: {e}. Attempting recovery...")
31
- # If it fails, it might be a corrupted cache.
32
- # This forces a clean redownload of the weights.
33
- model = AutoModelForSequenceClassification.from_pretrained(
34
- MODEL_NAME,
35
- ignore_mismatched_sizes=True,
36
- force_download=True
37
- ).to(device).eval()
38
- return tokenizer, model
39
-
40
- THRESHOLD = 0.59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  # -----------------------------
43
  # UTILITIES
44
  # -----------------------------
45
- ABBR = ["e.g", "i.e", "mr", "mrs", "ms", "dr", "prof", "vs", "etc", "fig", "al", "jr", "sr", "st", "inc", "ltd", "u.s", "u.k"]
 
 
 
46
  ABBR_REGEX = re.compile(r"\b(" + "|".join(map(re.escape, ABBR)) + r")\.", re.IGNORECASE)
47
 
48
- def _protect(text):
 
49
  text = text.replace("...", "⟨ELLIPSIS⟩")
50
  text = re.sub(r"(?<=\d)\.(?=\d)", "⟨DECIMAL⟩", text)
51
  text = ABBR_REGEX.sub(r"\1⟨ABBRDOT⟩", text)
52
  return text
53
 
54
- def _restore(text):
55
- return text.replace("⟨ABBRDOT⟩", ".").replace("⟨DECIMAL⟩", ".").replace("⟨ELLIPSIS⟩", "...")
56
 
57
- def split_preserving_structure(text):
 
 
 
 
 
 
 
 
58
  blocks = re.split(r"(\n+)", text)
59
  final_blocks = []
60
  for block in blocks:
61
- if not block: continue
 
62
  if block.startswith("\n"):
63
  final_blocks.append(block)
64
  else:
@@ -66,27 +126,36 @@ def split_preserving_structure(text):
66
  parts = re.split(r"([.?!])(\s+)", protected)
67
  for i in range(0, len(parts), 3):
68
  sentence = parts[i]
69
- punct = parts[i+1] if i+1 < len(parts) else ""
70
- space = parts[i+2] if i+2 < len(parts) else ""
71
  if sentence.strip():
72
  final_blocks.append(_restore(sentence + punct))
73
  if space:
74
  final_blocks.append(space)
75
  return final_blocks
76
 
 
77
  # -----------------------------
78
  # ANALYSIS
79
  # -----------------------------
80
  @torch.inference_mode()
81
  def analyze(text):
82
- text = text.strip()
83
  if not text:
84
  return "—", "—", "<em>Please enter text...</em>", None
85
-
86
  word_count = len(text.split())
87
  if word_count < 250:
88
- warning_msg = f"⚠️ <b>Insufficient Text:</b> Your input has {word_count} words. Please enter at least 250 words for accurate results."
89
- return "Too Short", "N/A", f"<div style='color: #b80d0d; padding: 20px; border: 1px solid #b80d0d; border-radius: 8px;'>{warning_msg}</div>", None
 
 
 
 
 
 
 
 
90
 
91
  try:
92
  tok, mod = get_model()
@@ -96,7 +165,7 @@ def analyze(text):
96
  blocks = split_preserving_structure(text)
97
  pure_sents_indices = [i for i, b in enumerate(blocks) if b.strip() and not b.startswith("\n")]
98
  pure_sents = [blocks[i] for i in pure_sents_indices]
99
-
100
  if not pure_sents:
101
  return "—", "—", "<em>No sentences detected.</em>", None
102
 
@@ -108,30 +177,43 @@ def analyze(text):
108
 
109
  batch_size = 8
110
  probs = []
 
111
  for i in range(0, len(windows), batch_size):
112
- batch = windows[i : i + batch_size]
113
- inputs = tok(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
 
 
 
 
 
 
 
114
  output = mod(**inputs)
115
-
116
  if output.logits.shape[1] > 1:
117
- batch_probs = F.softmax(output.logits, dim=-1)[:, 1].cpu().numpy().tolist()
118
  else:
119
- batch_probs = torch.sigmoid(output.logits).cpu().numpy().flatten().tolist()
 
120
  probs.extend(batch_probs)
121
 
122
  lengths = [len(s.split()) for s in pure_sents]
123
  total_words = sum(lengths)
124
- weighted_avg = sum(p * l for p, l in zip(probs, lengths)) / total_words if total_words > 0 else 0
 
 
 
 
125
 
126
  # HTML Heatmap
127
- highlighted_html = "<div style='font-family: sans-serif; line-height: 1.8;'>"
128
  prob_map = {idx: probs[i] for i, idx in enumerate(pure_sents_indices)}
129
-
130
  for i, block in enumerate(blocks):
131
  if block.startswith("\n") or block.isspace():
132
  highlighted_html += block.replace("\n", "<br>")
133
  continue
134
-
135
  if i in prob_map:
136
  score = prob_map[i]
137
  if score >= THRESHOLD:
@@ -140,49 +222,59 @@ def analyze(text):
140
  else:
141
  color, bg = "#2e7d32", "rgba(46, 125, 50, 0.08)"
142
  border = "1px solid transparent"
143
-
144
  highlighted_html += (
145
- f"<span style='background:{bg}; padding:1px 2px; border-radius:3px; border-bottom: {border}; cursor: help;' "
146
  f"title='AI Confidence: {score:.2%}'>"
147
- f"<span style='color:{color}; font-weight: bold; font-size: 0.75em; vertical-align: super; margin-right: 2px;'>{score:.0%}</span>"
148
  f"{block}</span>"
149
  )
150
  else:
151
  highlighted_html += block
 
152
  highlighted_html += "</div>"
153
 
154
  label = f"{weighted_avg:.1%} AI Written"
155
  display_score = f"{weighted_avg:.2%}"
156
  df = pd.DataFrame({"Sentence": pure_sents, "AI Confidence": [f"{p:.2%}" for p in probs]})
157
-
158
  return label, display_score, highlighted_html, df
159
 
 
160
  # -----------------------------
161
  # INTERFACE
162
  # -----------------------------
163
  with gr.Blocks(theme=gr.themes.Soft(), title="AI Detector Pro") as demo:
164
  gr.Markdown("# 🕵️ AI Detector Pro")
165
  gr.Markdown(f"Model: **{MODEL_NAME}** | Highlight Threshold: **{THRESHOLD*100:.0f}%**")
166
-
167
  with gr.Row():
168
  with gr.Column(scale=3):
169
- text_input = gr.Textbox(label="Input Text", lines=15, placeholder="Enter at least 250 words...")
 
 
 
 
170
  with gr.Row():
171
  clear_btn = gr.Button("Clear")
172
  run_btn = gr.Button("Analyze Text", variant="primary")
173
-
174
  with gr.Column(scale=1):
175
  verdict_out = gr.Label(label="Global Verdict")
176
  score_out = gr.Label(label="Weighted Probability")
177
-
178
  with gr.Tabs():
179
  with gr.TabItem("Visual Heatmap"):
180
  html_out = gr.HTML()
181
  with gr.TabItem("Data Breakdown"):
182
  table_out = gr.Dataframe(headers=["Sentence", "AI Confidence"], wrap=True)
183
-
184
  run_btn.click(analyze, inputs=text_input, outputs=[verdict_out, score_out, html_out, table_out])
185
- clear_btn.click(lambda: ["", "", "", "", None], outputs=[text_input, verdict_out, score_out, html_out, table_out])
 
 
 
 
186
 
187
  if __name__ == "__main__":
188
  demo.launch()
 
1
+ import os
2
+ import re
3
+ import shutil
4
  import torch
5
  import torch.nn.functional as F
 
 
6
  import pandas as pd
7
  import gradio as gr
8
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
9
+
10
 
11
  # -----------------------------
12
  # MODEL INITIALIZATION
 
16
  model = None
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
 
19
+
20
+ def purge_model_cache(model_id: str) -> None:
21
+ """
22
+ Remove cached weights/tokenizer for this model from common HF cache locations.
23
+ This fixes the 'state dictionary ... corrupted' error caused by partial downloads.
24
+ """
25
+ safe = model_id.replace("/", "--")
26
+
27
+ candidates = [
28
+ os.path.expanduser(f"~/.cache/huggingface/hub/models--{safe}"),
29
+ os.path.expanduser(f"~/.cache/huggingface/transformers/models--{safe}"),
30
+ os.path.expanduser(f"~/.cache/huggingface/modules/models--{safe}"),
31
+ ]
32
+
33
+ for path in candidates:
34
+ if os.path.exists(path):
35
+ try:
36
+ shutil.rmtree(path, ignore_errors=True)
37
+ print(f"🧹 Removed cache: {path}")
38
+ except Exception as e:
39
+ print(f"⚠️ Failed to remove cache at {path}: {e}")
40
+
41
+
42
  def get_model():
43
+ """
44
+ Loads tokenizer + model with safetensors preferred.
45
+ If load fails (often due to corrupted HF cache), purge cache + force download.
46
+ """
47
  global tokenizer, model
48
+
49
+ if model is not None and tokenizer is not None:
50
+ return tokenizer, model
51
+
52
+ print(f"🚀 Loading Model: {MODEL_NAME} on {device}")
53
+
54
+ try:
55
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
56
+
57
+ model = AutoModelForSequenceClassification.from_pretrained(
58
+ MODEL_NAME,
59
+ use_safetensors=True, # prefer safetensors
60
+ ignore_mismatched_sizes=True,
61
+ low_cpu_mem_usage=True,
62
+ ).to(device).eval()
63
+
64
+ return tokenizer, model
65
+
66
+ except Exception as e:
67
+ print(f"⚠️ Initial load failed: {e}")
68
+ print("🔁 Attempting recovery: purge cache + force re-download...")
69
+
70
+ purge_model_cache(MODEL_NAME)
71
+
72
+ # Redownload everything cleanly
73
+ tokenizer = AutoTokenizer.from_pretrained(
74
+ MODEL_NAME,
75
+ force_download=True,
76
+ )
77
+
78
+ model = AutoModelForSequenceClassification.from_pretrained(
79
+ MODEL_NAME,
80
+ use_safetensors=True, # ✅ keep safetensors on recovery too
81
+ ignore_mismatched_sizes=True,
82
+ force_download=True,
83
+ ).to(device).eval()
84
+
85
+ return tokenizer, model
86
+
87
+
88
+ THRESHOLD = 0.59
89
+
90
 
91
  # -----------------------------
92
  # UTILITIES
93
  # -----------------------------
94
+ ABBR = [
95
+ "e.g", "i.e", "mr", "mrs", "ms", "dr", "prof", "vs", "etc",
96
+ "fig", "al", "jr", "sr", "st", "inc", "ltd", "u.s", "u.k"
97
+ ]
98
  ABBR_REGEX = re.compile(r"\b(" + "|".join(map(re.escape, ABBR)) + r")\.", re.IGNORECASE)
99
 
100
+
101
+ def _protect(text: str) -> str:
102
  text = text.replace("...", "⟨ELLIPSIS⟩")
103
  text = re.sub(r"(?<=\d)\.(?=\d)", "⟨DECIMAL⟩", text)
104
  text = ABBR_REGEX.sub(r"\1⟨ABBRDOT⟩", text)
105
  return text
106
 
 
 
107
 
108
+ def _restore(text: str) -> str:
109
+ return (
110
+ text.replace("⟨ABBRDOT⟩", ".")
111
+ .replace("⟨DECIMAL⟩", ".")
112
+ .replace("⟨ELLIPSIS⟩", "...")
113
+ )
114
+
115
+
116
+ def split_preserving_structure(text: str):
117
  blocks = re.split(r"(\n+)", text)
118
  final_blocks = []
119
  for block in blocks:
120
+ if not block:
121
+ continue
122
  if block.startswith("\n"):
123
  final_blocks.append(block)
124
  else:
 
126
  parts = re.split(r"([.?!])(\s+)", protected)
127
  for i in range(0, len(parts), 3):
128
  sentence = parts[i]
129
+ punct = parts[i + 1] if i + 1 < len(parts) else ""
130
+ space = parts[i + 2] if i + 2 < len(parts) else ""
131
  if sentence.strip():
132
  final_blocks.append(_restore(sentence + punct))
133
  if space:
134
  final_blocks.append(space)
135
  return final_blocks
136
 
137
+
138
  # -----------------------------
139
  # ANALYSIS
140
  # -----------------------------
141
  @torch.inference_mode()
142
  def analyze(text):
143
+ text = (text or "").strip()
144
  if not text:
145
  return "—", "—", "<em>Please enter text...</em>", None
146
+
147
  word_count = len(text.split())
148
  if word_count < 250:
149
+ warning_msg = (
150
+ f"⚠️ <b>Insufficient Text:</b> Your input has {word_count} words. "
151
+ f"Please enter at least 250 words for accurate results."
152
+ )
153
+ return (
154
+ "Too Short",
155
+ "N/A",
156
+ f"<div style='color:#b80d0d; padding:20px; border:1px solid #b80d0d; border-radius:8px;'>{warning_msg}</div>",
157
+ None,
158
+ )
159
 
160
  try:
161
  tok, mod = get_model()
 
165
  blocks = split_preserving_structure(text)
166
  pure_sents_indices = [i for i, b in enumerate(blocks) if b.strip() and not b.startswith("\n")]
167
  pure_sents = [blocks[i] for i in pure_sents_indices]
168
+
169
  if not pure_sents:
170
  return "—", "—", "<em>No sentences detected.</em>", None
171
 
 
177
 
178
  batch_size = 8
179
  probs = []
180
+
181
  for i in range(0, len(windows), batch_size):
182
+ batch = windows[i: i + batch_size]
183
+ inputs = tok(
184
+ batch,
185
+ return_tensors="pt",
186
+ padding=True,
187
+ truncation=True,
188
+ max_length=512,
189
+ ).to(device)
190
+
191
  output = mod(**inputs)
192
+
193
  if output.logits.shape[1] > 1:
194
+ batch_probs = F.softmax(output.logits, dim=-1)[:, 1].detach().cpu().numpy().tolist()
195
  else:
196
+ batch_probs = torch.sigmoid(output.logits).detach().cpu().numpy().flatten().tolist()
197
+
198
  probs.extend(batch_probs)
199
 
200
  lengths = [len(s.split()) for s in pure_sents]
201
  total_words = sum(lengths)
202
+ weighted_avg = (
203
+ sum(p * l for p, l in zip(probs, lengths)) / total_words
204
+ if total_words > 0
205
+ else 0
206
+ )
207
 
208
  # HTML Heatmap
209
+ highlighted_html = "<div style='font-family:sans-serif; line-height:1.8;'>"
210
  prob_map = {idx: probs[i] for i, idx in enumerate(pure_sents_indices)}
211
+
212
  for i, block in enumerate(blocks):
213
  if block.startswith("\n") or block.isspace():
214
  highlighted_html += block.replace("\n", "<br>")
215
  continue
216
+
217
  if i in prob_map:
218
  score = prob_map[i]
219
  if score >= THRESHOLD:
 
222
  else:
223
  color, bg = "#2e7d32", "rgba(46, 125, 50, 0.08)"
224
  border = "1px solid transparent"
225
+
226
  highlighted_html += (
227
+ f"<span style='background:{bg}; padding:1px 2px; border-radius:3px; border-bottom:{border}; cursor:help;' "
228
  f"title='AI Confidence: {score:.2%}'>"
229
+ f"<span style='color:{color}; font-weight:bold; font-size:0.75em; vertical-align:super; margin-right:2px;'>{score:.0%}</span>"
230
  f"{block}</span>"
231
  )
232
  else:
233
  highlighted_html += block
234
+
235
  highlighted_html += "</div>"
236
 
237
  label = f"{weighted_avg:.1%} AI Written"
238
  display_score = f"{weighted_avg:.2%}"
239
  df = pd.DataFrame({"Sentence": pure_sents, "AI Confidence": [f"{p:.2%}" for p in probs]})
240
+
241
  return label, display_score, highlighted_html, df
242
 
243
+
244
  # -----------------------------
245
  # INTERFACE
246
  # -----------------------------
247
  with gr.Blocks(theme=gr.themes.Soft(), title="AI Detector Pro") as demo:
248
  gr.Markdown("# 🕵️ AI Detector Pro")
249
  gr.Markdown(f"Model: **{MODEL_NAME}** | Highlight Threshold: **{THRESHOLD*100:.0f}%**")
250
+
251
  with gr.Row():
252
  with gr.Column(scale=3):
253
+ text_input = gr.Textbox(
254
+ label="Input Text",
255
+ lines=15,
256
+ placeholder="Enter at least 250 words..."
257
+ )
258
  with gr.Row():
259
  clear_btn = gr.Button("Clear")
260
  run_btn = gr.Button("Analyze Text", variant="primary")
261
+
262
  with gr.Column(scale=1):
263
  verdict_out = gr.Label(label="Global Verdict")
264
  score_out = gr.Label(label="Weighted Probability")
265
+
266
  with gr.Tabs():
267
  with gr.TabItem("Visual Heatmap"):
268
  html_out = gr.HTML()
269
  with gr.TabItem("Data Breakdown"):
270
  table_out = gr.Dataframe(headers=["Sentence", "AI Confidence"], wrap=True)
271
+
272
  run_btn.click(analyze, inputs=text_input, outputs=[verdict_out, score_out, html_out, table_out])
273
+
274
+ def _clear():
275
+ return "", "—", "—", "<em>Please enter text...</em>", None
276
+
277
+ clear_btn.click(_clear, outputs=[text_input, verdict_out, score_out, html_out, table_out])
278
 
279
  if __name__ == "__main__":
280
  demo.launch()