File size: 11,123 Bytes
c6f65f2 4cf9509 cfaed4d 4cf9509 cfaed4d c6f65f2 4cf9509 c6f65f2 4cf9509 c6f65f2 cfaed4d c6f65f2 4cf9509 c6f65f2 4cf9509 c6f65f2 4cf9509 c6f65f2 4cf9509 cfaed4d c6f65f2 cfaed4d c6f65f2 4cf9509 c6f65f2 4cf9509 c6f65f2 cfaed4d 4cf9509 cfaed4d c6f65f2 4cf9509 c6f65f2 cfaed4d 4cf9509 c6f65f2 4cf9509 c6f65f2 4cf9509 c6f65f2 4cf9509 c6f65f2 4cf9509 c6f65f2 4cf9509 c6f65f2 4cf9509 c6f65f2 4cf9509 cfaed4d 4cf9509 cfaed4d 4d68a43 4cf9509 2d4eea4 4cf9509 c6f65f2 4cf9509 cfaed4d 4d68a43 cfaed4d 4d68a43 cfaed4d 4cf9509 c6f65f2 4cf9509 c6f65f2 4cf9509 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Gradio App — AI vs Human Document Classifier (Chunked Inference)
----------------------------------------------------------------
Features:
- Upload a document (TXT/MD/HTML/PDF), chunk if needed, classify each chunk, aggregate to document.
- UI includes:
1) Probability bars with raw numbers (AI generated / Human written)
2) Confidence badge ("Likely AI" / "Likely Human") with traffic-light color
3) Tabs for Basic / Advanced controls
4) Chunk details accordion with per-chunk probabilities
5) NEW: Per-chunk **snippet** extracted using tokenizer offset_mapping
"""
import os
import io
import re
from typing import Dict, Any, List, Tuple
import numpy as np
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# -----------------------------
# Config
# -----------------------------
MODEL_ID = os.getenv("MODEL_ID", "bert-base-uncased")
MAX_LENGTH = int(os.getenv("MAX_LENGTH", "512"))
STRIDE = int(os.getenv("STRIDE", "128"))
# Device
device = torch.device("cuda" if torch.cuda.is_available() else
"mps" if torch.backends.mps.is_available() else "cpu")
if device.type == "mps":
try:
torch.set_float32_matmul_precision("high")
except Exception:
pass
# Load model & tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, torch_dtype=torch.float32).to(device)
model.eval()
# -----------------------------
# Utilities
# -----------------------------
TEXT_EXTS = {".txt", ".md", ".rtf", ".html", ".htm"}
PDF_EXTS = {".pdf"}
def read_text_from_file(file_obj) -> str:
"""
Read text content from an uploaded file.
Supports: .txt, .md, .rtf, .html, .htm, .pdf (via pypdf).
"""
name = getattr(file_obj, "name", "") or ""
ext = os.path.splitext(name)[-1].lower()
if ext in TEXT_EXTS:
data = file_obj.read()
if isinstance(data, bytes):
data = data.decode("utf-8", errors="ignore")
if ext in {".html", ".htm"}:
data = re.sub(r"<[^>]+>", " ", data)
data = re.sub(r"\s+", " ", data).strip()
return data
if ext in PDF_EXTS:
try:
from pypdf import PdfReader
reader = PdfReader(io.BytesIO(file_obj.read()))
pages = []
for p in reader.pages:
try:
pages.append(p.extract_text() or "")
except Exception:
pages.append("")
text = "\n".join(pages)
text = re.sub(r"\s+", " ", text).strip()
return text
except Exception as e:
return f"[PDF parse error] {e}"
# Fallback: try as text
data = file_obj.read()
if isinstance(data, bytes):
data = data.decode("utf-8", errors="ignore")
return data
def chunked_predict(text: str, max_length: int = 512, stride: int = 128, agg: str = "mean") -> Dict[str, Any]:
"""
Chunk the document using tokenizer overflow, run classifier on each chunk,
aggregate probabilities, and return both doc-level and chunk-level results,
including a short snippet per chunk derived from offset_mapping.
"""
if not text or not text.strip():
return {"error": "Empty document."}
with torch.no_grad():
enc = tokenizer(
text,
truncation=True,
max_length=max_length,
return_overflowing_tokens=True,
stride=stride,
padding=True,
return_offsets_mapping=True, # NEW: get character offsets per token
return_tensors="pt",
)
allowed = {"input_ids", "attention_mask", "token_type_ids"}
inputs = {k: v.to(model.device) for k, v in enc.items() if k in allowed}
logits_list = []
for i in range(inputs["input_ids"].size(0)):
batch = {k: v[i:i+1] for k, v in inputs.items()}
out = model(**batch)
logits_list.append(out.logits)
logits = torch.cat(logits_list, dim=0) # [num_chunks, num_labels]
probs = torch.softmax(logits, dim=-1).cpu().numpy()
num_chunks = int(probs.shape[0])
# Aggregate
if agg == "max":
doc_probs = probs.max(axis=0)
else:
doc_probs = probs.mean(axis=0)
# By convention: 0 -> Human, 1 -> AI
prob_human = float(doc_probs[0])
prob_ai = float(doc_probs[1])
# --- Build snippets per chunk from offset mapping ---
offsets = enc["offset_mapping"] # tensor of pairs
attn = enc["attention_mask"] # [num_chunks, seq_len]
snippets: List[str] = []
PREVIEW = 120
for i in range(offsets.shape[0]):
offs = offsets[i].tolist()
mask = attn[i].tolist()
spans = [(s, e) for (s, e), m in zip(offs, mask) if m == 1 and not (s == 0 and e == 0)]
if spans:
s0 = min(s for s, _ in spans)
e0 = max(e for _, e in spans)
raw = text[s0:e0].strip()
raw = " ".join(raw.split())
if len(raw) > PREVIEW:
raw = raw[:PREVIEW].rstrip() + "…"
snippets.append(raw)
else:
snippets.append("")
# Per-chunk rows: [chunk#, AI prob, Human prob, Snippet]
chunk_rows: List[List[Any]] = []
for i, p in enumerate(probs):
ai_p = float(p[1])
hu_p = float(p[0])
chunk_rows.append([i + 1, ai_p, hu_p, snippets[i]])
return {
"ai_prob": prob_ai,
"human_prob": prob_human,
"num_chunks": num_chunks,
"chunk_rows": chunk_rows, # list of [chunk, AI, Human, Snippet]
"max_length": max_length,
"stride": stride,
}
def predict_from_upload(file, aggregation, max_length, stride):
if file is None:
return {"error": "Please upload a file."}
# Work around gradio temp file behavior
if hasattr(file, "name") and isinstance(file.name, str):
with open(file.name, "rb") as f:
raw = io.BytesIO(f.read())
raw.name = os.path.basename(file.name)
text = read_text_from_file(raw)
else:
text = read_text_from_file(file)
return chunked_predict(text, max_length=int(max_length), stride=int(stride), agg=aggregation)
# -----------------------------
# UI Helpers (HTML formatting)
# -----------------------------
def probability_bar_html(label: str, prob: float) -> str:
"""Return an HTML row with label, percent, and a bar."""
pct = prob * 100.0
return f"""
<div class="prob-row"><div class="prob-label"><b>{label}</b></div>
<div class="prob-value">{pct:.2f}%</div>
<div class="prob-bar">
<div class="prob-fill" style="width:{pct:.2f}%"></div>
</div>
</div>
"""
def verdict_badge_html(prob_ai: float, threshold: float = 0.5) -> str:
label = "Likely AI" if prob_ai >= threshold else "Likely Human"
color = "#ef4444" if prob_ai >= threshold else "#10b981" # red / green
return f"<span class='pill' style='background:{color}22;color:{color}'>{label}</span>"
def format_outputs(result: Dict[str, Any], threshold: float = 0.5):
"""Produce (verdict_html, probs_html, chunk_table_data, details_md)."""
if "error" in result:
return f"<span style='color:#ef4444'>{result['error']}</span>", "", [], ""
ai, human = result["ai_prob"], result["human_prob"]
verdict_html = verdict_badge_html(ai, threshold=threshold)
probs_html = ""
probs_html += probability_bar_html("AI generated", ai)
probs_html += probability_bar_html("Human written", human)
# Chunk table rows (already built server-side)
table_data = result["chunk_rows"]
details_md = (
f"**Chunks:** `{result['num_chunks']}` \n"
f"**Tokens per chunk:** `{result['max_length']}` \n"
f"**Stride:** `{result['stride']}`"
)
return verdict_html, probs_html, table_data, details_md
# -----------------------------
# Gradio Interface
# -----------------------------
CSS = """
.pill {padding:6px 12px; border-radius:999px; display:inline-block; margin: 6px 0; font-weight:600;}
.prob-row {display:flex; align-items:center; gap:10px; margin:6px 0;}
.prob-label {min-width:140px;}
.prob-value {min-width:80px; text-align:right; font-variant-numeric: tabular-nums;}
.prob-bar {flex:1; background:#e5e7eb; height:12px; border-radius:6px; overflow:hidden;}
.prob-fill {height:12px; background:#6366f1;}
.small-note {font-size:0.9rem; color:#6b7280;}
/* Wrap long snippet text within the DataFrame cells */
.gr-dataframe table td { white-space: normal; }
/* Scrollable chunk table container */
#chunkgroup { max-height: 260px; overflow: auto; }
#details_note { font-size: 0.9rem; color: #6b7280; }
"""
DESCRIPTION = """
### 🔎 AI vs Human — Document Classifier
Upload a file to get **document-level probabilities**.
Long inputs are **chunked** into overlapping windows; chunk predictions are **aggregated**.
"""
with gr.Blocks(
title="AI vs Human Document Classifier",
theme='Nymbo/rounded-gradient',
css=CSS
) as demo:
gr.Markdown(DESCRIPTION)
with gr.Tabs():
with gr.Tab("Predict"):
file_in = gr.File(label="Upload a document", file_types=[".txt", ".md", ".rtf", ".html", ".htm", ".pdf"])
agg_in = gr.Radio(choices=["mean", "max"], value="mean", label="Aggregation over chunks")
btn = gr.Button("Predict", variant="primary")
verdict_html = gr.HTML(label="Verdict")
probs_html = gr.HTML(label="Probabilities")
with gr.Accordion("Chunk details", open=False):
with gr.Group(elem_id="chunkgroup"):
chunk_table = gr.Dataframe(
headers=["Chunk", "AI generated", "Human written", "Snippet"],
datatype=["number", "number", "number", "str"],
label="Per-chunk probabilities",
wrap=True,
interactive=False,
row_count=(0, "dynamic"),
col_count=(4, "fixed"),
)
details_md = gr.Markdown("", elem_id="details_note")
with gr.Tab("Advanced"):
gr.Markdown("Adjust chunking parameters below.")
max_len_in = gr.Slider(128, 1024, value=MAX_LENGTH, step=32, label="Tokens per chunk (max_length)")
stride_in = gr.Slider(0, 512, value=STRIDE, step=16, label="Stride / overlap")
gr.Markdown("You can also set `MODEL_ID`, `MAX_LENGTH`, and `STRIDE` via Space Variables.")
def predict_and_prettify(file, aggregation, max_length=MAX_LENGTH, stride=STRIDE):
res = predict_from_upload(file, aggregation, max_length, stride)
return format_outputs(res)
btn.click(
fn=predict_and_prettify,
inputs=[file_in, agg_in, max_len_in, stride_in],
outputs=[verdict_html, probs_html, chunk_table, details_md],
)
if __name__ == "__main__":
demo.launch() |