VictorM-Coder commited on
Commit
3b0d005
Β·
verified Β·
1 Parent(s): 344cbaa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -84
app.py CHANGED
@@ -1,44 +1,28 @@
1
- import os, shutil, glob
2
-
3
- # Put HF caches somewhere "fresh" (avoid reusing an old corrupt cache)
4
- os.environ["HF_HOME"] = "/tmp/hf"
5
- os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/hf/hub"
6
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf/transformers"
7
- os.environ["HF_HUB_DISABLE_XET"] = "1" # also avoids xet-related partial downloads
8
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
9
-
10
- def wipe_model_cache(model_id: str):
11
- safe = model_id.replace("/", "--")
12
- paths = [
13
- f"/tmp/hf/hub/models--{safe}",
14
- f"/tmp/hf/transformers/models--{safe}",
15
- # also wipe common defaults in case something else wrote there
16
- os.path.expanduser(f"~/.cache/huggingface/hub/models--{safe}"),
17
- os.path.expanduser(f"~/.cache/huggingface/transformers/models--{safe}"),
18
- ]
19
- for p in paths:
20
- if os.path.exists(p):
21
- shutil.rmtree(p, ignore_errors=True)
22
-
23
- # wipe the specific model cache on startup
24
- wipe_model_cache("desklib/ai-text-detector-v1.01")
25
-
26
  import re
27
  import shutil
 
28
  import torch
29
  import torch.nn.functional as F
30
  import pandas as pd
31
  import gradio as gr
32
-
33
- from huggingface_hub import snapshot_download
34
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
35
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  # -----------------------------
38
  # MODEL INITIALIZATION
39
  # -----------------------------
40
  MODEL_NAME = "desklib/ai-text-detector-v1.01"
41
- LOCAL_MODEL_DIR = "/tmp/desklib_ai_text_detector_v1_01" # local snapshot dir
42
  tokenizer = None
43
  model = None
44
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -46,63 +30,74 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
  THRESHOLD = 0.59
47
 
48
 
49
- def _rm_dir(path: str) -> None:
50
- if os.path.exists(path):
51
- shutil.rmtree(path, ignore_errors=True)
52
-
53
-
54
- def download_model_snapshot() -> str:
55
  """
56
- Download the HF repo snapshot to a local folder, forcing a clean download.
57
- Disabling Xet via env var helps avoid corrupted large-file downloads on some envs.
58
  """
59
- # wipe local dir to ensure truly clean download
60
- _rm_dir(LOCAL_MODEL_DIR)
61
-
62
- print(f"⬇️ Downloading snapshot for: {MODEL_NAME}")
63
- local_dir = snapshot_download(
64
- repo_id=MODEL_NAME,
65
- local_dir=LOCAL_MODEL_DIR,
66
- local_dir_use_symlinks=False,
67
- force_download=True,
68
- resume_download=False,
69
- )
70
-
71
- # Basic integrity sanity check: make sure model.safetensors looks real
72
- st_path = os.path.join(local_dir, "model.safetensors")
73
- if not os.path.exists(st_path):
74
- raise RuntimeError(f"model.safetensors not found in snapshot at: {st_path}")
75
-
76
- size_gb = os.path.getsize(st_path) / (1024**3)
77
- print(f"βœ… model.safetensors size: {size_gb:.2f} GB")
78
-
79
- # The HF repo shows ~1.74GB for model.safetensors. :contentReference[oaicite:3]{index=3}
80
- # If the file is drastically smaller, it's likely truncated.
81
- if size_gb < 1.0:
82
- raise RuntimeError(
83
- f"Downloaded model.safetensors looks too small ({size_gb:.2f} GB). "
84
- "Likely truncated download."
85
- )
86
 
87
- return local_dir
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
 
90
- def get_model():
 
 
 
 
91
  global tokenizer, model
92
- if model is not None and tokenizer is not None:
 
93
  return tokenizer, model
94
 
95
- print(f"πŸš€ Loading Model: {MODEL_NAME} on {device}")
 
 
 
 
 
96
 
97
- local_dir = download_model_snapshot()
98
 
99
- tokenizer = AutoTokenizer.from_pretrained(local_dir)
 
 
 
 
100
 
 
101
  model = AutoModelForSequenceClassification.from_pretrained(
102
- local_dir,
103
  use_safetensors=True,
104
  ignore_mismatched_sizes=True,
105
  low_cpu_mem_usage=True,
 
106
  ).to(device).eval()
107
 
108
  return tokenizer, model
@@ -158,7 +153,7 @@ def split_preserving_structure(text):
158
  def analyze(text):
159
  text = (text or "").strip()
160
  if not text:
161
- return "β€”", "β€”", "<em>Please enter text...</em>", None
162
 
163
  word_count = len(text.split())
164
  if word_count < 250:
@@ -166,24 +161,19 @@ def analyze(text):
166
  f"⚠️ <b>Insufficient Text:</b> Your input has {word_count} words. "
167
  f"Please enter at least 250 words for accurate results."
168
  )
169
- return (
170
- "Too Short",
171
- "N/A",
172
- f"<div style='color:#b80d0d; padding:20px; border:1px solid #b80d0d; border-radius:8px;'>{warning_msg}</div>",
173
- None,
174
- )
175
 
176
  try:
177
- tok, mod = get_model()
178
  except Exception as e:
179
- return "ERROR", "0%", f"Failed to load model: {str(e)}", None
180
 
181
  blocks = split_preserving_structure(text)
182
  pure_sents_indices = [i for i, b in enumerate(blocks) if b.strip() and not b.startswith("\n")]
183
  pure_sents = [blocks[i] for i in pure_sents_indices]
184
 
185
  if not pure_sents:
186
- return "β€”", "β€”", "<em>No sentences detected.</em>", None
187
 
188
  windows = []
189
  for i in range(len(pure_sents)):
@@ -240,7 +230,29 @@ def analyze(text):
240
  display_score = f"{weighted_avg:.2%}"
241
  df = pd.DataFrame({"Sentence": pure_sents, "AI Confidence": [f"{p:.2%}" for p in probs]})
242
 
243
- return label, display_score, highlighted_html, df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
 
246
  # -----------------------------
@@ -256,23 +268,27 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI Detector Pro") as demo:
256
  with gr.Row():
257
  clear_btn = gr.Button("Clear")
258
  run_btn = gr.Button("Analyze Text", variant="primary")
 
259
 
260
  with gr.Column(scale=1):
261
  verdict_out = gr.Label(label="Global Verdict")
262
  score_out = gr.Label(label="Weighted Probability")
263
 
 
 
264
  with gr.Tabs():
265
  with gr.TabItem("Visual Heatmap"):
266
  html_out = gr.HTML()
267
  with gr.TabItem("Data Breakdown"):
268
  table_out = gr.Dataframe(headers=["Sentence", "AI Confidence"], wrap=True)
269
 
270
- run_btn.click(analyze, inputs=text_input, outputs=[verdict_out, score_out, html_out, table_out])
271
 
272
  def _clear():
273
- return "", "β€”", "β€”", "<em>Please enter text...</em>", None
274
 
275
- clear_btn.click(_clear, outputs=[text_input, verdict_out, score_out, html_out, table_out])
 
276
 
277
  if __name__ == "__main__":
278
  demo.launch()
 
1
+ import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import re
3
  import shutil
4
+
5
  import torch
6
  import torch.nn.functional as F
7
  import pandas as pd
8
  import gradio as gr
 
 
9
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
10
 
11
+ # ============================================================
12
+ # ENV (set BEFORE loading models)
13
+ # ============================================================
14
+ # Use a predictable cache location (helps avoid reusing a corrupt home cache)
15
+ os.environ.setdefault("HF_HOME", "/tmp/hf")
16
+ os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/tmp/hf/hub")
17
+ os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/hf/transformers")
18
+ # Disable Xet (helps avoid partial/corrupt downloads in some environments)
19
+ os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
20
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
21
 
22
  # -----------------------------
23
  # MODEL INITIALIZATION
24
  # -----------------------------
25
  MODEL_NAME = "desklib/ai-text-detector-v1.01"
 
26
  tokenizer = None
27
  model = None
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
30
  THRESHOLD = 0.59
31
 
32
 
33
+ def wipe_model_cache(model_id: str) -> int:
 
 
 
 
 
34
  """
35
+ Delete cached files for this model from common HF cache locations.
36
+ Returns number of cache directories removed.
37
  """
38
+ safe = model_id.replace("/", "--")
39
+ candidates = [
40
+ # our /tmp cache (recommended)
41
+ f"/tmp/hf/hub/models--{safe}",
42
+ f"/tmp/hf/transformers/models--{safe}",
43
+ # default home cache (in case something wrote there)
44
+ os.path.expanduser(f"~/.cache/huggingface/hub/models--{safe}"),
45
+ os.path.expanduser(f"~/.cache/huggingface/transformers/models--{safe}"),
46
+ os.path.expanduser(f"~/.cache/huggingface/modules/models--{safe}"),
47
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ removed = 0
50
+ for path in candidates:
51
+ if os.path.exists(path):
52
+ try:
53
+ shutil.rmtree(path, ignore_errors=True)
54
+ removed += 1
55
+ except Exception:
56
+ # ignore deletion errors (permissions etc.)
57
+ pass
58
+ return removed
59
+
60
+
61
+ def _build_error_card(msg: str) -> str:
62
+ return (
63
+ "<div style='color:#b80d0d; padding:14px; border:1px solid #b80d0d; "
64
+ "border-radius:10px; background:rgba(184,13,13,0.06);'>"
65
+ f"{msg}</div>"
66
+ )
67
 
68
 
69
+ def get_model(force_redownload: bool = False):
70
+ """
71
+ Normal load uses cache (fast).
72
+ If force_redownload=True (from the Nuke button), we wipe cache + re-download.
73
+ """
74
  global tokenizer, model
75
+
76
+ if (not force_redownload) and model is not None and tokenizer is not None:
77
  return tokenizer, model
78
 
79
+ if force_redownload:
80
+ print("πŸ’£ NUKE requested: wiping cache + forcing re-download...")
81
+ removed = wipe_model_cache(MODEL_NAME)
82
+ print(f"🧹 Cache dirs removed: {removed}")
83
+ tokenizer = None
84
+ model = None
85
 
86
+ print(f"πŸš€ Loading Model: {MODEL_NAME} on {device}")
87
 
88
+ # Tokenizer
89
+ tokenizer = AutoTokenizer.from_pretrained(
90
+ MODEL_NAME,
91
+ force_download=force_redownload,
92
+ )
93
 
94
+ # Model (prefer safetensors)
95
  model = AutoModelForSequenceClassification.from_pretrained(
96
+ MODEL_NAME,
97
  use_safetensors=True,
98
  ignore_mismatched_sizes=True,
99
  low_cpu_mem_usage=True,
100
+ force_download=force_redownload,
101
  ).to(device).eval()
102
 
103
  return tokenizer, model
 
153
  def analyze(text):
154
  text = (text or "").strip()
155
  if not text:
156
+ return "β€”", "β€”", "<em>Please enter text...</em>", None, ""
157
 
158
  word_count = len(text.split())
159
  if word_count < 250:
 
161
  f"⚠️ <b>Insufficient Text:</b> Your input has {word_count} words. "
162
  f"Please enter at least 250 words for accurate results."
163
  )
164
+ return "Too Short", "N/A", _build_error_card(warning_msg), None, ""
 
 
 
 
 
165
 
166
  try:
167
+ tok, mod = get_model(force_redownload=False)
168
  except Exception as e:
169
+ return "ERROR", "0%", _build_error_card(f"<b>Failed to load model:</b><br>{str(e)}"), None, ""
170
 
171
  blocks = split_preserving_structure(text)
172
  pure_sents_indices = [i for i, b in enumerate(blocks) if b.strip() and not b.startswith("\n")]
173
  pure_sents = [blocks[i] for i in pure_sents_indices]
174
 
175
  if not pure_sents:
176
+ return "β€”", "β€”", "<em>No sentences detected.</em>", None, ""
177
 
178
  windows = []
179
  for i in range(len(pure_sents)):
 
230
  display_score = f"{weighted_avg:.2%}"
231
  df = pd.DataFrame({"Sentence": pure_sents, "AI Confidence": [f"{p:.2%}" for p in probs]})
232
 
233
+ return label, display_score, highlighted_html, df, ""
234
+
235
+
236
+ def nuke_and_reload():
237
+ """
238
+ UI button: wipe cache + force re-download + try to load.
239
+ Returns a status message.
240
+ """
241
+ try:
242
+ get_model(force_redownload=True)
243
+ return (
244
+ "βœ… **Nuked cache and reloaded model successfully.**\n\n"
245
+ "- Cache wiped\n"
246
+ "- Fresh download forced\n"
247
+ "- Model ready βœ…"
248
+ )
249
+ except Exception as e:
250
+ return (
251
+ "❌ **Nuke attempted but model still failed to load.**\n\n"
252
+ f"**Error:** `{str(e)}`\n\n"
253
+ "If this keeps happening, it usually means the downloaded weights are getting truncated "
254
+ "(network/storage) or the runtime stack (Python/Torch) is incompatible."
255
+ )
256
 
257
 
258
  # -----------------------------
 
268
  with gr.Row():
269
  clear_btn = gr.Button("Clear")
270
  run_btn = gr.Button("Analyze Text", variant="primary")
271
+ nuke_btn = gr.Button("πŸ’£ Nuke Model Cache", variant="stop")
272
 
273
  with gr.Column(scale=1):
274
  verdict_out = gr.Label(label="Global Verdict")
275
  score_out = gr.Label(label="Weighted Probability")
276
 
277
+ status_out = gr.Markdown()
278
+
279
  with gr.Tabs():
280
  with gr.TabItem("Visual Heatmap"):
281
  html_out = gr.HTML()
282
  with gr.TabItem("Data Breakdown"):
283
  table_out = gr.Dataframe(headers=["Sentence", "AI Confidence"], wrap=True)
284
 
285
+ run_btn.click(analyze, inputs=text_input, outputs=[verdict_out, score_out, html_out, table_out, status_out])
286
 
287
  def _clear():
288
+ return "", "β€”", "β€”", "<em>Please enter text...</em>", None, ""
289
 
290
+ clear_btn.click(_clear, outputs=[text_input, verdict_out, score_out, html_out, table_out, status_out])
291
+ nuke_btn.click(nuke_and_reload, outputs=status_out)
292
 
293
  if __name__ == "__main__":
294
  demo.launch()