VictorM-Coder commited on
Commit
785a5d2
·
verified ·
1 Parent(s): 3b0d005

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -68
app.py CHANGED
@@ -2,23 +2,26 @@ import os
2
  import re
3
  import shutil
4
 
5
- import torch
6
- import torch.nn.functional as F
7
- import pandas as pd
8
- import gradio as gr
9
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
10
-
11
  # ============================================================
12
- # ENV (set BEFORE loading models)
13
  # ============================================================
14
- # Use a predictable cache location (helps avoid reusing a corrupt home cache)
15
  os.environ.setdefault("HF_HOME", "/tmp/hf")
16
  os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/tmp/hf/hub")
17
  os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/hf/transformers")
18
- # Disable Xet (helps avoid partial/corrupt downloads in some environments)
19
- os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
20
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
21
 
 
 
 
 
 
 
 
 
 
 
 
22
  # -----------------------------
23
  # MODEL INITIALIZATION
24
  # -----------------------------
@@ -26,10 +29,17 @@ MODEL_NAME = "desklib/ai-text-detector-v1.01"
26
  tokenizer = None
27
  model = None
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
-
30
  THRESHOLD = 0.59
31
 
32
 
 
 
 
 
 
 
 
 
33
  def wipe_model_cache(model_id: str) -> int:
34
  """
35
  Delete cached files for this model from common HF cache locations.
@@ -49,81 +59,104 @@ def wipe_model_cache(model_id: str) -> int:
49
  removed = 0
50
  for path in candidates:
51
  if os.path.exists(path):
52
- try:
53
- shutil.rmtree(path, ignore_errors=True)
54
- removed += 1
55
- except Exception:
56
- # ignore deletion errors (permissions etc.)
57
- pass
58
  return removed
59
 
60
 
61
- def _build_error_card(msg: str) -> str:
62
- return (
63
- "<div style='color:#b80d0d; padding:14px; border:1px solid #b80d0d; "
64
- "border-radius:10px; background:rgba(184,13,13,0.06);'>"
65
- f"{msg}</div>"
66
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
 
69
- def get_model(force_redownload: bool = False):
70
  """
71
- Normal load uses cache (fast).
72
- If force_redownload=True (from the Nuke button), we wipe cache + re-download.
 
 
 
73
  """
74
  global tokenizer, model
75
 
76
- if (not force_redownload) and model is not None and tokenizer is not None:
77
  return tokenizer, model
78
 
79
  if force_redownload:
80
- print("💣 NUKE requested: wiping cache + forcing re-download...")
81
  removed = wipe_model_cache(MODEL_NAME)
82
  print(f"🧹 Cache dirs removed: {removed}")
83
  tokenizer = None
84
  model = None
85
 
86
- print(f"🚀 Loading Model: {MODEL_NAME} on {device}")
 
 
87
 
88
- # Tokenizer
89
- tokenizer = AutoTokenizer.from_pretrained(
90
- MODEL_NAME,
 
91
  force_download=force_redownload,
92
  )
93
 
94
- # Model (prefer safetensors)
95
- model = AutoModelForSequenceClassification.from_pretrained(
96
- MODEL_NAME,
97
- use_safetensors=True,
98
- ignore_mismatched_sizes=True,
99
- low_cpu_mem_usage=True,
100
- force_download=force_redownload,
101
- ).to(device).eval()
 
 
 
 
 
 
 
102
 
 
103
  return tokenizer, model
104
 
105
 
106
  # -----------------------------
107
  # UTILITIES
108
  # -----------------------------
109
- ABBR = [
110
- "e.g", "i.e", "mr", "mrs", "ms", "dr", "prof", "vs", "etc",
111
- "fig", "al", "jr", "sr", "st", "inc", "ltd", "u.s", "u.k"
112
- ]
113
  ABBR_REGEX = re.compile(r"\b(" + "|".join(map(re.escape, ABBR)) + r")\.", re.IGNORECASE)
114
 
115
-
116
  def _protect(text):
117
  text = text.replace("...", "⟨ELLIPSIS⟩")
118
  text = re.sub(r"(?<=\d)\.(?=\d)", "⟨DECIMAL⟩", text)
119
  text = ABBR_REGEX.sub(r"\1⟨ABBRDOT⟩", text)
120
  return text
121
 
122
-
123
  def _restore(text):
124
  return text.replace("⟨ABBRDOT⟩", ".").replace("⟨DECIMAL⟩", ".").replace("⟨ELLIPSIS⟩", "...")
125
 
126
-
127
  def split_preserving_structure(text):
128
  blocks = re.split(r"(\n+)", text)
129
  final_blocks = []
@@ -157,14 +190,11 @@ def analyze(text):
157
 
158
  word_count = len(text.split())
159
  if word_count < 250:
160
- warning_msg = (
161
- f"⚠️ <b>Insufficient Text:</b> Your input has {word_count} words. "
162
- f"Please enter at least 250 words for accurate results."
163
- )
164
  return "Too Short", "N/A", _build_error_card(warning_msg), None, ""
165
 
166
  try:
167
- tok, mod = get_model(force_redownload=False)
168
  except Exception as e:
169
  return "ERROR", "0%", _build_error_card(f"<b>Failed to load model:</b><br>{str(e)}"), None, ""
170
 
@@ -186,19 +216,16 @@ def analyze(text):
186
  for i in range(0, len(windows), batch_size):
187
  batch = windows[i: i + batch_size]
188
  inputs = tok(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
189
- output = mod(**inputs)
190
-
191
- if output.logits.shape[1] > 1:
192
- batch_probs = F.softmax(output.logits, dim=-1)[:, 1].detach().cpu().numpy().tolist()
193
- else:
194
- batch_probs = torch.sigmoid(output.logits).detach().cpu().numpy().flatten().tolist()
195
  probs.extend(batch_probs)
196
 
197
  lengths = [len(s.split()) for s in pure_sents]
198
  total_words = sum(lengths)
199
  weighted_avg = sum(p * l for p, l in zip(probs, lengths)) / total_words if total_words > 0 else 0
200
 
201
- highlighted_html = "<div style='font-family:sans-serif; line-height:1.8;'>"
 
202
  prob_map = {idx: probs[i] for i, idx in enumerate(pure_sents_indices)}
203
 
204
  for i, block in enumerate(blocks):
@@ -216,9 +243,9 @@ def analyze(text):
216
  border = "1px solid transparent"
217
 
218
  highlighted_html += (
219
- f"<span style='background:{bg}; padding:1px 2px; border-radius:3px; border-bottom:{border}; cursor:help;' "
220
  f"title='AI Confidence: {score:.2%}'>"
221
- f"<span style='color:{color}; font-weight:bold; font-size:0.75em; vertical-align:super; margin-right:2px;'>{score:.0%}</span>"
222
  f"{block}</span>"
223
  )
224
  else:
@@ -234,24 +261,21 @@ def analyze(text):
234
 
235
 
236
  def nuke_and_reload():
237
- """
238
- UI button: wipe cache + force re-download + try to load.
239
- Returns a status message.
240
- """
241
  try:
242
- get_model(force_redownload=True)
243
  return (
244
  "✅ **Nuked cache and reloaded model successfully.**\n\n"
245
  "- Cache wiped\n"
246
  "- Fresh download forced\n"
 
247
  "- Model ready ✅"
248
  )
249
  except Exception as e:
250
  return (
251
  "❌ **Nuke attempted but model still failed to load.**\n\n"
252
  f"**Error:** `{str(e)}`\n\n"
253
- "If this keeps happening, it usually means the downloaded weights are getting truncated "
254
- "(network/storage) or the runtime stack (Python/Torch) is incompatible."
255
  )
256
 
257
 
 
2
  import re
3
  import shutil
4
 
 
 
 
 
 
 
5
  # ============================================================
6
+ # ENV (set BEFORE transformers/hub usage)
7
  # ============================================================
 
8
  os.environ.setdefault("HF_HOME", "/tmp/hf")
9
  os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/tmp/hf/hub")
10
  os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/hf/transformers")
11
+ os.environ.setdefault("HF_HUB_DISABLE_XET", "1") # disable hf-xet if present
 
12
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
13
 
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import pandas as pd
18
+ import gradio as gr
19
+
20
+ from huggingface_hub import hf_hub_download
21
+ from transformers import AutoConfig, AutoTokenizer, AutoModel
22
+ from safetensors.torch import load_file
23
+
24
+
25
  # -----------------------------
26
  # MODEL INITIALIZATION
27
  # -----------------------------
 
29
  tokenizer = None
30
  model = None
31
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
32
  THRESHOLD = 0.59
33
 
34
 
35
+ def _build_error_card(msg: str) -> str:
36
+ return (
37
+ "<div style='color:#b80d0d; padding:14px; border:1px solid #b80d0d; "
38
+ "border-radius:10px; background:rgba(184,13,13,0.06);'>"
39
+ f"{msg}</div>"
40
+ )
41
+
42
+
43
  def wipe_model_cache(model_id: str) -> int:
44
  """
45
  Delete cached files for this model from common HF cache locations.
 
59
  removed = 0
60
  for path in candidates:
61
  if os.path.exists(path):
62
+ shutil.rmtree(path, ignore_errors=True)
63
+ removed += 1
 
 
 
 
64
  return removed
65
 
66
 
67
+ class DesklibAIDetectionModel(nn.Module):
68
+ """
69
+ Matches the architecture described by desklib:
70
+ base transformer + mean pooling + linear classifier to 1 logit.
71
+ The repo config lists "architectures": ["DesklibAIDetectionModel"]. :contentReference[oaicite:1]{index=1}
72
+ """
73
+ def __init__(self, config):
74
+ super().__init__()
75
+ self.backbone = AutoModel.from_config(config)
76
+ self.classifier = nn.Linear(config.hidden_size, 1)
77
+
78
+ def forward(self, input_ids, attention_mask=None):
79
+ outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
80
+ last_hidden = outputs.last_hidden_state # (B, T, H)
81
+
82
+ if attention_mask is None:
83
+ pooled = last_hidden.mean(dim=1)
84
+ else:
85
+ mask = attention_mask.unsqueeze(-1).expand(last_hidden.size()).float()
86
+ summed = torch.sum(last_hidden * mask, dim=1)
87
+ denom = torch.clamp(mask.sum(dim=1), min=1e-9)
88
+ pooled = summed / denom
89
+
90
+ logits = self.classifier(pooled) # (B, 1)
91
+ return logits
92
 
93
 
94
+ def load_desklib_model(force_redownload: bool = False):
95
  """
96
+ Robust loader:
97
+ - downloads config/tokenizer normally
98
+ - downloads model.safetensors explicitly
99
+ - loads safetensors via safetensors.torch.load_file
100
+ - loads into our matching PyTorch module with strict=False
101
  """
102
  global tokenizer, model
103
 
104
+ if (not force_redownload) and tokenizer is not None and model is not None:
105
  return tokenizer, model
106
 
107
  if force_redownload:
108
+ print("💣 NUKE requested: wiping cache + forcing fresh downloads...")
109
  removed = wipe_model_cache(MODEL_NAME)
110
  print(f"🧹 Cache dirs removed: {removed}")
111
  tokenizer = None
112
  model = None
113
 
114
+ print(f"🚀 Loading tokenizer/config: {MODEL_NAME}")
115
+ config = AutoConfig.from_pretrained(MODEL_NAME, force_download=force_redownload)
116
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, force_download=force_redownload)
117
 
118
+ print("⬇️ Downloading model.safetensors explicitly...")
119
+ weights_path = hf_hub_download(
120
+ repo_id=MODEL_NAME,
121
+ filename="model.safetensors",
122
  force_download=force_redownload,
123
  )
124
 
125
+ size_gb = os.path.getsize(weights_path) / (1024**3)
126
+ print(f"✅ model.safetensors path: {weights_path}")
127
+ print(f"✅ model.safetensors size: {size_gb:.2f} GB")
128
+
129
+ # Build model + load weights
130
+ print("🧠 Building DesklibAIDetectionModel + loading weights...")
131
+ m = DesklibAIDetectionModel(config)
132
+ state = load_file(weights_path) # this will throw if file is truly corrupt
133
+ missing, unexpected = m.load_state_dict(state, strict=False)
134
+
135
+ # Helpful debug (won't crash)
136
+ if missing:
137
+ print(f"⚠️ Missing keys (first 20): {missing[:20]}")
138
+ if unexpected:
139
+ print(f"⚠️ Unexpected keys (first 20): {unexpected[:20]}")
140
 
141
+ model = m.to(device).eval()
142
  return tokenizer, model
143
 
144
 
145
  # -----------------------------
146
  # UTILITIES
147
  # -----------------------------
148
+ ABBR = ["e.g", "i.e", "mr", "mrs", "ms", "dr", "prof", "vs", "etc", "fig", "al", "jr", "sr", "st", "inc", "ltd", "u.s", "u.k"]
 
 
 
149
  ABBR_REGEX = re.compile(r"\b(" + "|".join(map(re.escape, ABBR)) + r")\.", re.IGNORECASE)
150
 
 
151
  def _protect(text):
152
  text = text.replace("...", "⟨ELLIPSIS⟩")
153
  text = re.sub(r"(?<=\d)\.(?=\d)", "⟨DECIMAL⟩", text)
154
  text = ABBR_REGEX.sub(r"\1⟨ABBRDOT⟩", text)
155
  return text
156
 
 
157
  def _restore(text):
158
  return text.replace("⟨ABBRDOT⟩", ".").replace("⟨DECIMAL⟩", ".").replace("⟨ELLIPSIS⟩", "...")
159
 
 
160
  def split_preserving_structure(text):
161
  blocks = re.split(r"(\n+)", text)
162
  final_blocks = []
 
190
 
191
  word_count = len(text.split())
192
  if word_count < 250:
193
+ warning_msg = f"⚠️ <b>Insufficient Text:</b> Your input has {word_count} words. Please enter at least 250 words for accurate results."
 
 
 
194
  return "Too Short", "N/A", _build_error_card(warning_msg), None, ""
195
 
196
  try:
197
+ tok, mod = load_desklib_model(force_redownload=False)
198
  except Exception as e:
199
  return "ERROR", "0%", _build_error_card(f"<b>Failed to load model:</b><br>{str(e)}"), None, ""
200
 
 
216
  for i in range(0, len(windows), batch_size):
217
  batch = windows[i: i + batch_size]
218
  inputs = tok(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
219
+ logits = mod(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask"))
220
+ batch_probs = torch.sigmoid(logits).detach().cpu().numpy().flatten().tolist()
 
 
 
 
221
  probs.extend(batch_probs)
222
 
223
  lengths = [len(s.split()) for s in pure_sents]
224
  total_words = sum(lengths)
225
  weighted_avg = sum(p * l for p, l in zip(probs, lengths)) / total_words if total_words > 0 else 0
226
 
227
+ # HTML Heatmap
228
+ highlighted_html = "<div style='font-family: sans-serif; line-height: 1.8;'>"
229
  prob_map = {idx: probs[i] for i, idx in enumerate(pure_sents_indices)}
230
 
231
  for i, block in enumerate(blocks):
 
243
  border = "1px solid transparent"
244
 
245
  highlighted_html += (
246
+ f"<span style='background:{bg}; padding:1px 2px; border-radius:3px; border-bottom: {border}; cursor: help;' "
247
  f"title='AI Confidence: {score:.2%}'>"
248
+ f"<span style='color:{color}; font-weight: bold; font-size: 0.75em; vertical-align: super; margin-right: 2px;'>{score:.0%}</span>"
249
  f"{block}</span>"
250
  )
251
  else:
 
261
 
262
 
263
  def nuke_and_reload():
 
 
 
 
264
  try:
265
+ load_desklib_model(force_redownload=True)
266
  return (
267
  "✅ **Nuked cache and reloaded model successfully.**\n\n"
268
  "- Cache wiped\n"
269
  "- Fresh download forced\n"
270
+ "- Custom loader used (DesklibAIDetectionModel)\n"
271
  "- Model ready ✅"
272
  )
273
  except Exception as e:
274
  return (
275
  "❌ **Nuke attempted but model still failed to load.**\n\n"
276
  f"**Error:** `{str(e)}`\n\n"
277
+ "If this error happens inside `load_file(model.safetensors)`, the file is truly corrupted/truncated.\n"
278
+ "If it happens after that, it’s likely key mismatches (shown in logs as missing/unexpected keys)."
279
  )
280
 
281