VictorM-Coder commited on
Commit
46d3fde
·
verified ·
1 Parent(s): eea664e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -83
app.py CHANGED
@@ -1,91 +1,105 @@
1
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import re
3
  import shutil
4
  import torch
5
  import torch.nn.functional as F
6
  import pandas as pd
7
  import gradio as gr
 
 
8
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
9
 
10
 
11
  # -----------------------------
12
  # MODEL INITIALIZATION
13
  # -----------------------------
14
- MODEL_NAME = "desklib/ai-text-detector-v1.03"
 
15
  tokenizer = None
16
  model = None
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
 
 
19
 
20
- def purge_model_cache(model_id: str) -> None:
21
- """
22
- Remove cached weights/tokenizer for this model from common HF cache locations.
23
- This fixes the 'state dictionary ... corrupted' error caused by partial downloads.
24
- """
25
- safe = model_id.replace("/", "--")
26
-
27
- candidates = [
28
- os.path.expanduser(f"~/.cache/huggingface/hub/models--{safe}"),
29
- os.path.expanduser(f"~/.cache/huggingface/transformers/models--{safe}"),
30
- os.path.expanduser(f"~/.cache/huggingface/modules/models--{safe}"),
31
- ]
32
 
33
- for path in candidates:
34
- if os.path.exists(path):
35
- try:
36
- shutil.rmtree(path, ignore_errors=True)
37
- print(f"🧹 Removed cache: {path}")
38
- except Exception as e:
39
- print(f"⚠️ Failed to remove cache at {path}: {e}")
40
 
41
 
42
- def get_model():
43
  """
44
- Loads tokenizer + model with safetensors preferred.
45
- If load fails (often due to corrupted HF cache), purge cache + force download.
46
  """
47
- global tokenizer, model
48
-
49
- if model is not None and tokenizer is not None:
50
- return tokenizer, model
 
 
 
 
 
 
 
51
 
52
- print(f"🚀 Loading Model: {MODEL_NAME} on {device}")
 
 
 
53
 
54
- try:
55
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
56
 
57
- model = AutoModelForSequenceClassification.from_pretrained(
58
- MODEL_NAME,
59
- use_safetensors=True, # prefer safetensors
60
- ignore_mismatched_sizes=True,
61
- low_cpu_mem_usage=True,
62
- ).to(device).eval()
 
63
 
64
- return tokenizer, model
65
 
66
- except Exception as e:
67
- print(f"⚠️ Initial load failed: {e}")
68
- print("🔁 Attempting recovery: purge cache + force re-download...")
69
 
70
- purge_model_cache(MODEL_NAME)
 
 
 
71
 
72
- # Redownload everything cleanly
73
- tokenizer = AutoTokenizer.from_pretrained(
74
- MODEL_NAME,
75
- force_download=True,
76
- )
77
 
78
- model = AutoModelForSequenceClassification.from_pretrained(
79
- MODEL_NAME,
80
- use_safetensors=True, # ✅ keep safetensors on recovery too
81
- ignore_mismatched_sizes=True,
82
- force_download=True,
83
- ).to(device).eval()
84
 
85
- return tokenizer, model
86
 
 
 
 
 
 
 
87
 
88
- THRESHOLD = 0.59
89
 
90
 
91
  # -----------------------------
@@ -98,22 +112,18 @@ ABBR = [
98
  ABBR_REGEX = re.compile(r"\b(" + "|".join(map(re.escape, ABBR)) + r")\.", re.IGNORECASE)
99
 
100
 
101
- def _protect(text: str) -> str:
102
  text = text.replace("...", "⟨ELLIPSIS⟩")
103
  text = re.sub(r"(?<=\d)\.(?=\d)", "⟨DECIMAL⟩", text)
104
  text = ABBR_REGEX.sub(r"\1⟨ABBRDOT⟩", text)
105
  return text
106
 
107
 
108
- def _restore(text: str) -> str:
109
- return (
110
- text.replace("⟨ABBRDOT⟩", ".")
111
- .replace("⟨DECIMAL⟩", ".")
112
- .replace("⟨ELLIPSIS⟩", "...")
113
- )
114
 
115
 
116
- def split_preserving_structure(text: str):
117
  blocks = re.split(r"(\n+)", text)
118
  final_blocks = []
119
  for block in blocks:
@@ -177,35 +187,21 @@ def analyze(text):
177
 
178
  batch_size = 8
179
  probs = []
180
-
181
  for i in range(0, len(windows), batch_size):
182
  batch = windows[i: i + batch_size]
183
- inputs = tok(
184
- batch,
185
- return_tensors="pt",
186
- padding=True,
187
- truncation=True,
188
- max_length=512,
189
- ).to(device)
190
-
191
  output = mod(**inputs)
192
 
193
  if output.logits.shape[1] > 1:
194
  batch_probs = F.softmax(output.logits, dim=-1)[:, 1].detach().cpu().numpy().tolist()
195
  else:
196
  batch_probs = torch.sigmoid(output.logits).detach().cpu().numpy().flatten().tolist()
197
-
198
  probs.extend(batch_probs)
199
 
200
  lengths = [len(s.split()) for s in pure_sents]
201
  total_words = sum(lengths)
202
- weighted_avg = (
203
- sum(p * l for p, l in zip(probs, lengths)) / total_words
204
- if total_words > 0
205
- else 0
206
- )
207
 
208
- # HTML Heatmap
209
  highlighted_html = "<div style='font-family:sans-serif; line-height:1.8;'>"
210
  prob_map = {idx: probs[i] for i, idx in enumerate(pure_sents_indices)}
211
 
@@ -250,11 +246,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI Detector Pro") as demo:
250
 
251
  with gr.Row():
252
  with gr.Column(scale=3):
253
- text_input = gr.Textbox(
254
- label="Input Text",
255
- lines=15,
256
- placeholder="Enter at least 250 words..."
257
- )
258
  with gr.Row():
259
  clear_btn = gr.Button("Clear")
260
  run_btn = gr.Button("Analyze Text", variant="primary")
 
1
  import os
2
+
3
+ # ============================================================
4
+ # ✅ FIX FOR "state dict corrupted" ON SPACES (Xet downloads)
5
+ # ============================================================
6
+ # Disable hf-xet usage (forces download via LFS bridge instead).
7
+ # HF docs: HF_HUB_DISABLE_XET disables using hf-xet. :contentReference[oaicite:2]{index=2}
8
+ os.environ["HF_HUB_DISABLE_XET"] = "1"
9
+
10
+ # Optional: place HF cache in a writable/temp location (Spaces friendly)
11
+ # You can comment this out if you prefer default cache locations.
12
+ os.environ.setdefault("HF_HOME", "/tmp/hf")
13
+ os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/tmp/hf/hub")
14
+ os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/hf/transformers")
15
+
16
+ # (Optional) reduce parallelism issues
17
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
18
+
19
+
20
  import re
21
  import shutil
22
  import torch
23
  import torch.nn.functional as F
24
  import pandas as pd
25
  import gradio as gr
26
+
27
+ from huggingface_hub import snapshot_download
28
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
29
 
30
 
31
  # -----------------------------
32
  # MODEL INITIALIZATION
33
  # -----------------------------
34
+ MODEL_NAME = "desklib/ai-text-detector-v1.01"
35
+ LOCAL_MODEL_DIR = "/tmp/desklib_ai_text_detector_v1_01" # local snapshot dir
36
  tokenizer = None
37
  model = None
38
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
 
40
+ THRESHOLD = 0.59
41
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ def _rm_dir(path: str) -> None:
44
+ if os.path.exists(path):
45
+ shutil.rmtree(path, ignore_errors=True)
 
 
 
 
46
 
47
 
48
+ def download_model_snapshot() -> str:
49
  """
50
+ Download the HF repo snapshot to a local folder, forcing a clean download.
51
+ Disabling Xet via env var helps avoid corrupted large-file downloads on some envs.
52
  """
53
+ # wipe local dir to ensure truly clean download
54
+ _rm_dir(LOCAL_MODEL_DIR)
55
+
56
+ print(f"⬇️ Downloading snapshot for: {MODEL_NAME}")
57
+ local_dir = snapshot_download(
58
+ repo_id=MODEL_NAME,
59
+ local_dir=LOCAL_MODEL_DIR,
60
+ local_dir_use_symlinks=False,
61
+ force_download=True,
62
+ resume_download=False,
63
+ )
64
 
65
+ # Basic integrity sanity check: make sure model.safetensors looks real
66
+ st_path = os.path.join(local_dir, "model.safetensors")
67
+ if not os.path.exists(st_path):
68
+ raise RuntimeError(f"model.safetensors not found in snapshot at: {st_path}")
69
 
70
+ size_gb = os.path.getsize(st_path) / (1024**3)
71
+ print(f"✅ model.safetensors size: {size_gb:.2f} GB")
72
 
73
+ # The HF repo shows ~1.74GB for model.safetensors. :contentReference[oaicite:3]{index=3}
74
+ # If the file is drastically smaller, it's likely truncated.
75
+ if size_gb < 1.0:
76
+ raise RuntimeError(
77
+ f"Downloaded model.safetensors looks too small ({size_gb:.2f} GB). "
78
+ "Likely truncated download."
79
+ )
80
 
81
+ return local_dir
82
 
 
 
 
83
 
84
+ def get_model():
85
+ global tokenizer, model
86
+ if model is not None and tokenizer is not None:
87
+ return tokenizer, model
88
 
89
+ print(f"🚀 Loading Model: {MODEL_NAME} on {device}")
 
 
 
 
90
 
91
+ local_dir = download_model_snapshot()
 
 
 
 
 
92
 
93
+ tokenizer = AutoTokenizer.from_pretrained(local_dir)
94
 
95
+ model = AutoModelForSequenceClassification.from_pretrained(
96
+ local_dir,
97
+ use_safetensors=True,
98
+ ignore_mismatched_sizes=True,
99
+ low_cpu_mem_usage=True,
100
+ ).to(device).eval()
101
 
102
+ return tokenizer, model
103
 
104
 
105
  # -----------------------------
 
112
  ABBR_REGEX = re.compile(r"\b(" + "|".join(map(re.escape, ABBR)) + r")\.", re.IGNORECASE)
113
 
114
 
115
+ def _protect(text):
116
  text = text.replace("...", "⟨ELLIPSIS⟩")
117
  text = re.sub(r"(?<=\d)\.(?=\d)", "⟨DECIMAL⟩", text)
118
  text = ABBR_REGEX.sub(r"\1⟨ABBRDOT⟩", text)
119
  return text
120
 
121
 
122
+ def _restore(text):
123
+ return text.replace("⟨ABBRDOT⟩", ".").replace("⟨DECIMAL⟩", ".").replace("⟨ELLIPSIS⟩", "...")
 
 
 
 
124
 
125
 
126
+ def split_preserving_structure(text):
127
  blocks = re.split(r"(\n+)", text)
128
  final_blocks = []
129
  for block in blocks:
 
187
 
188
  batch_size = 8
189
  probs = []
 
190
  for i in range(0, len(windows), batch_size):
191
  batch = windows[i: i + batch_size]
192
+ inputs = tok(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
 
 
 
 
 
 
 
193
  output = mod(**inputs)
194
 
195
  if output.logits.shape[1] > 1:
196
  batch_probs = F.softmax(output.logits, dim=-1)[:, 1].detach().cpu().numpy().tolist()
197
  else:
198
  batch_probs = torch.sigmoid(output.logits).detach().cpu().numpy().flatten().tolist()
 
199
  probs.extend(batch_probs)
200
 
201
  lengths = [len(s.split()) for s in pure_sents]
202
  total_words = sum(lengths)
203
+ weighted_avg = sum(p * l for p, l in zip(probs, lengths)) / total_words if total_words > 0 else 0
 
 
 
 
204
 
 
205
  highlighted_html = "<div style='font-family:sans-serif; line-height:1.8;'>"
206
  prob_map = {idx: probs[i] for i, idx in enumerate(pure_sents_indices)}
207
 
 
246
 
247
  with gr.Row():
248
  with gr.Column(scale=3):
249
+ text_input = gr.Textbox(label="Input Text", lines=15, placeholder="Enter at least 250 words...")
 
 
 
 
250
  with gr.Row():
251
  clear_btn = gr.Button("Clear")
252
  run_btn = gr.Button("Analyze Text", variant="primary")