Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,91 +1,105 @@
|
|
| 1 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import re
|
| 3 |
import shutil
|
| 4 |
import torch
|
| 5 |
import torch.nn.functional as F
|
| 6 |
import pandas as pd
|
| 7 |
import gradio as gr
|
|
|
|
|
|
|
| 8 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 9 |
|
| 10 |
|
| 11 |
# -----------------------------
|
| 12 |
# MODEL INITIALIZATION
|
| 13 |
# -----------------------------
|
| 14 |
-
MODEL_NAME = "desklib/ai-text-detector-v1.
|
|
|
|
| 15 |
tokenizer = None
|
| 16 |
model = None
|
| 17 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 18 |
|
|
|
|
| 19 |
|
| 20 |
-
def purge_model_cache(model_id: str) -> None:
|
| 21 |
-
"""
|
| 22 |
-
Remove cached weights/tokenizer for this model from common HF cache locations.
|
| 23 |
-
This fixes the 'state dictionary ... corrupted' error caused by partial downloads.
|
| 24 |
-
"""
|
| 25 |
-
safe = model_id.replace("/", "--")
|
| 26 |
-
|
| 27 |
-
candidates = [
|
| 28 |
-
os.path.expanduser(f"~/.cache/huggingface/hub/models--{safe}"),
|
| 29 |
-
os.path.expanduser(f"~/.cache/huggingface/transformers/models--{safe}"),
|
| 30 |
-
os.path.expanduser(f"~/.cache/huggingface/modules/models--{safe}"),
|
| 31 |
-
]
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
shutil.rmtree(path, ignore_errors=True)
|
| 37 |
-
print(f"🧹 Removed cache: {path}")
|
| 38 |
-
except Exception as e:
|
| 39 |
-
print(f"⚠️ Failed to remove cache at {path}: {e}")
|
| 40 |
|
| 41 |
|
| 42 |
-
def
|
| 43 |
"""
|
| 44 |
-
|
| 45 |
-
|
| 46 |
"""
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
| 63 |
|
| 64 |
-
|
| 65 |
|
| 66 |
-
except Exception as e:
|
| 67 |
-
print(f"⚠️ Initial load failed: {e}")
|
| 68 |
-
print("🔁 Attempting recovery: purge cache + force re-download...")
|
| 69 |
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
-
|
| 73 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
| 74 |
-
MODEL_NAME,
|
| 75 |
-
force_download=True,
|
| 76 |
-
)
|
| 77 |
|
| 78 |
-
|
| 79 |
-
MODEL_NAME,
|
| 80 |
-
use_safetensors=True, # ✅ keep safetensors on recovery too
|
| 81 |
-
ignore_mismatched_sizes=True,
|
| 82 |
-
force_download=True,
|
| 83 |
-
).to(device).eval()
|
| 84 |
|
| 85 |
-
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
-
|
| 89 |
|
| 90 |
|
| 91 |
# -----------------------------
|
|
@@ -98,22 +112,18 @@ ABBR = [
|
|
| 98 |
ABBR_REGEX = re.compile(r"\b(" + "|".join(map(re.escape, ABBR)) + r")\.", re.IGNORECASE)
|
| 99 |
|
| 100 |
|
| 101 |
-
def _protect(text
|
| 102 |
text = text.replace("...", "⟨ELLIPSIS⟩")
|
| 103 |
text = re.sub(r"(?<=\d)\.(?=\d)", "⟨DECIMAL⟩", text)
|
| 104 |
text = ABBR_REGEX.sub(r"\1⟨ABBRDOT⟩", text)
|
| 105 |
return text
|
| 106 |
|
| 107 |
|
| 108 |
-
def _restore(text
|
| 109 |
-
return (
|
| 110 |
-
text.replace("⟨ABBRDOT⟩", ".")
|
| 111 |
-
.replace("⟨DECIMAL⟩", ".")
|
| 112 |
-
.replace("⟨ELLIPSIS⟩", "...")
|
| 113 |
-
)
|
| 114 |
|
| 115 |
|
| 116 |
-
def split_preserving_structure(text
|
| 117 |
blocks = re.split(r"(\n+)", text)
|
| 118 |
final_blocks = []
|
| 119 |
for block in blocks:
|
|
@@ -177,35 +187,21 @@ def analyze(text):
|
|
| 177 |
|
| 178 |
batch_size = 8
|
| 179 |
probs = []
|
| 180 |
-
|
| 181 |
for i in range(0, len(windows), batch_size):
|
| 182 |
batch = windows[i: i + batch_size]
|
| 183 |
-
inputs = tok(
|
| 184 |
-
batch,
|
| 185 |
-
return_tensors="pt",
|
| 186 |
-
padding=True,
|
| 187 |
-
truncation=True,
|
| 188 |
-
max_length=512,
|
| 189 |
-
).to(device)
|
| 190 |
-
|
| 191 |
output = mod(**inputs)
|
| 192 |
|
| 193 |
if output.logits.shape[1] > 1:
|
| 194 |
batch_probs = F.softmax(output.logits, dim=-1)[:, 1].detach().cpu().numpy().tolist()
|
| 195 |
else:
|
| 196 |
batch_probs = torch.sigmoid(output.logits).detach().cpu().numpy().flatten().tolist()
|
| 197 |
-
|
| 198 |
probs.extend(batch_probs)
|
| 199 |
|
| 200 |
lengths = [len(s.split()) for s in pure_sents]
|
| 201 |
total_words = sum(lengths)
|
| 202 |
-
weighted_avg = (
|
| 203 |
-
sum(p * l for p, l in zip(probs, lengths)) / total_words
|
| 204 |
-
if total_words > 0
|
| 205 |
-
else 0
|
| 206 |
-
)
|
| 207 |
|
| 208 |
-
# HTML Heatmap
|
| 209 |
highlighted_html = "<div style='font-family:sans-serif; line-height:1.8;'>"
|
| 210 |
prob_map = {idx: probs[i] for i, idx in enumerate(pure_sents_indices)}
|
| 211 |
|
|
@@ -250,11 +246,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI Detector Pro") as demo:
|
|
| 250 |
|
| 251 |
with gr.Row():
|
| 252 |
with gr.Column(scale=3):
|
| 253 |
-
text_input = gr.Textbox(
|
| 254 |
-
label="Input Text",
|
| 255 |
-
lines=15,
|
| 256 |
-
placeholder="Enter at least 250 words..."
|
| 257 |
-
)
|
| 258 |
with gr.Row():
|
| 259 |
clear_btn = gr.Button("Clear")
|
| 260 |
run_btn = gr.Button("Analyze Text", variant="primary")
|
|
|
|
| 1 |
import os
|
| 2 |
+
|
| 3 |
+
# ============================================================
|
| 4 |
+
# ✅ FIX FOR "state dict corrupted" ON SPACES (Xet downloads)
|
| 5 |
+
# ============================================================
|
| 6 |
+
# Disable hf-xet usage (forces download via LFS bridge instead).
|
| 7 |
+
# HF docs: HF_HUB_DISABLE_XET disables using hf-xet. :contentReference[oaicite:2]{index=2}
|
| 8 |
+
os.environ["HF_HUB_DISABLE_XET"] = "1"
|
| 9 |
+
|
| 10 |
+
# Optional: place HF cache in a writable/temp location (Spaces friendly)
|
| 11 |
+
# You can comment this out if you prefer default cache locations.
|
| 12 |
+
os.environ.setdefault("HF_HOME", "/tmp/hf")
|
| 13 |
+
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/tmp/hf/hub")
|
| 14 |
+
os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/hf/transformers")
|
| 15 |
+
|
| 16 |
+
# (Optional) reduce parallelism issues
|
| 17 |
+
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
import re
|
| 21 |
import shutil
|
| 22 |
import torch
|
| 23 |
import torch.nn.functional as F
|
| 24 |
import pandas as pd
|
| 25 |
import gradio as gr
|
| 26 |
+
|
| 27 |
+
from huggingface_hub import snapshot_download
|
| 28 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 29 |
|
| 30 |
|
| 31 |
# -----------------------------
|
| 32 |
# MODEL INITIALIZATION
|
| 33 |
# -----------------------------
|
| 34 |
+
MODEL_NAME = "desklib/ai-text-detector-v1.01"
|
| 35 |
+
LOCAL_MODEL_DIR = "/tmp/desklib_ai_text_detector_v1_01" # local snapshot dir
|
| 36 |
tokenizer = None
|
| 37 |
model = None
|
| 38 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 39 |
|
| 40 |
+
THRESHOLD = 0.59
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
+
def _rm_dir(path: str) -> None:
|
| 44 |
+
if os.path.exists(path):
|
| 45 |
+
shutil.rmtree(path, ignore_errors=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
|
| 48 |
+
def download_model_snapshot() -> str:
|
| 49 |
"""
|
| 50 |
+
Download the HF repo snapshot to a local folder, forcing a clean download.
|
| 51 |
+
Disabling Xet via env var helps avoid corrupted large-file downloads on some envs.
|
| 52 |
"""
|
| 53 |
+
# wipe local dir to ensure truly clean download
|
| 54 |
+
_rm_dir(LOCAL_MODEL_DIR)
|
| 55 |
+
|
| 56 |
+
print(f"⬇️ Downloading snapshot for: {MODEL_NAME}")
|
| 57 |
+
local_dir = snapshot_download(
|
| 58 |
+
repo_id=MODEL_NAME,
|
| 59 |
+
local_dir=LOCAL_MODEL_DIR,
|
| 60 |
+
local_dir_use_symlinks=False,
|
| 61 |
+
force_download=True,
|
| 62 |
+
resume_download=False,
|
| 63 |
+
)
|
| 64 |
|
| 65 |
+
# Basic integrity sanity check: make sure model.safetensors looks real
|
| 66 |
+
st_path = os.path.join(local_dir, "model.safetensors")
|
| 67 |
+
if not os.path.exists(st_path):
|
| 68 |
+
raise RuntimeError(f"model.safetensors not found in snapshot at: {st_path}")
|
| 69 |
|
| 70 |
+
size_gb = os.path.getsize(st_path) / (1024**3)
|
| 71 |
+
print(f"✅ model.safetensors size: {size_gb:.2f} GB")
|
| 72 |
|
| 73 |
+
# The HF repo shows ~1.74GB for model.safetensors. :contentReference[oaicite:3]{index=3}
|
| 74 |
+
# If the file is drastically smaller, it's likely truncated.
|
| 75 |
+
if size_gb < 1.0:
|
| 76 |
+
raise RuntimeError(
|
| 77 |
+
f"Downloaded model.safetensors looks too small ({size_gb:.2f} GB). "
|
| 78 |
+
"Likely truncated download."
|
| 79 |
+
)
|
| 80 |
|
| 81 |
+
return local_dir
|
| 82 |
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
+
def get_model():
|
| 85 |
+
global tokenizer, model
|
| 86 |
+
if model is not None and tokenizer is not None:
|
| 87 |
+
return tokenizer, model
|
| 88 |
|
| 89 |
+
print(f"🚀 Loading Model: {MODEL_NAME} on {device}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
+
local_dir = download_model_snapshot()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
+
tokenizer = AutoTokenizer.from_pretrained(local_dir)
|
| 94 |
|
| 95 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 96 |
+
local_dir,
|
| 97 |
+
use_safetensors=True,
|
| 98 |
+
ignore_mismatched_sizes=True,
|
| 99 |
+
low_cpu_mem_usage=True,
|
| 100 |
+
).to(device).eval()
|
| 101 |
|
| 102 |
+
return tokenizer, model
|
| 103 |
|
| 104 |
|
| 105 |
# -----------------------------
|
|
|
|
| 112 |
ABBR_REGEX = re.compile(r"\b(" + "|".join(map(re.escape, ABBR)) + r")\.", re.IGNORECASE)
|
| 113 |
|
| 114 |
|
| 115 |
+
def _protect(text):
|
| 116 |
text = text.replace("...", "⟨ELLIPSIS⟩")
|
| 117 |
text = re.sub(r"(?<=\d)\.(?=\d)", "⟨DECIMAL⟩", text)
|
| 118 |
text = ABBR_REGEX.sub(r"\1⟨ABBRDOT⟩", text)
|
| 119 |
return text
|
| 120 |
|
| 121 |
|
| 122 |
+
def _restore(text):
|
| 123 |
+
return text.replace("⟨ABBRDOT⟩", ".").replace("⟨DECIMAL⟩", ".").replace("⟨ELLIPSIS⟩", "...")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
|
| 126 |
+
def split_preserving_structure(text):
|
| 127 |
blocks = re.split(r"(\n+)", text)
|
| 128 |
final_blocks = []
|
| 129 |
for block in blocks:
|
|
|
|
| 187 |
|
| 188 |
batch_size = 8
|
| 189 |
probs = []
|
|
|
|
| 190 |
for i in range(0, len(windows), batch_size):
|
| 191 |
batch = windows[i: i + batch_size]
|
| 192 |
+
inputs = tok(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
output = mod(**inputs)
|
| 194 |
|
| 195 |
if output.logits.shape[1] > 1:
|
| 196 |
batch_probs = F.softmax(output.logits, dim=-1)[:, 1].detach().cpu().numpy().tolist()
|
| 197 |
else:
|
| 198 |
batch_probs = torch.sigmoid(output.logits).detach().cpu().numpy().flatten().tolist()
|
|
|
|
| 199 |
probs.extend(batch_probs)
|
| 200 |
|
| 201 |
lengths = [len(s.split()) for s in pure_sents]
|
| 202 |
total_words = sum(lengths)
|
| 203 |
+
weighted_avg = sum(p * l for p, l in zip(probs, lengths)) / total_words if total_words > 0 else 0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
|
|
|
|
| 205 |
highlighted_html = "<div style='font-family:sans-serif; line-height:1.8;'>"
|
| 206 |
prob_map = {idx: probs[i] for i, idx in enumerate(pure_sents_indices)}
|
| 207 |
|
|
|
|
| 246 |
|
| 247 |
with gr.Row():
|
| 248 |
with gr.Column(scale=3):
|
| 249 |
+
text_input = gr.Textbox(label="Input Text", lines=15, placeholder="Enter at least 250 words...")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
with gr.Row():
|
| 251 |
clear_btn = gr.Button("Clear")
|
| 252 |
run_btn = gr.Button("Analyze Text", variant="primary")
|