tomerz14's picture
Update app.py
2d4eea4 verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Gradio App โ€” AI vs Human Document Classifier (Chunked Inference)
----------------------------------------------------------------
Features:
- Upload a document (TXT/MD/HTML/PDF), chunk if needed, classify each chunk, aggregate to document.
- UI includes:
1) Probability bars with raw numbers (AI generated / Human written)
2) Confidence badge ("Likely AI" / "Likely Human") with traffic-light color
3) Tabs for Basic / Advanced controls
4) Chunk details accordion with per-chunk probabilities
5) NEW: Per-chunk **snippet** extracted using tokenizer offset_mapping
"""
import os
import io
import re
from typing import Dict, Any, List, Tuple
import numpy as np
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# -----------------------------
# Config
# -----------------------------
MODEL_ID = os.getenv("MODEL_ID", "bert-base-uncased")
MAX_LENGTH = int(os.getenv("MAX_LENGTH", "512"))
STRIDE = int(os.getenv("STRIDE", "128"))
# Device
device = torch.device("cuda" if torch.cuda.is_available() else
"mps" if torch.backends.mps.is_available() else "cpu")
if device.type == "mps":
try:
torch.set_float32_matmul_precision("high")
except Exception:
pass
# Load model & tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, torch_dtype=torch.float32).to(device)
model.eval()
# -----------------------------
# Utilities
# -----------------------------
TEXT_EXTS = {".txt", ".md", ".rtf", ".html", ".htm"}
PDF_EXTS = {".pdf"}
def read_text_from_file(file_obj) -> str:
"""
Read text content from an uploaded file.
Supports: .txt, .md, .rtf, .html, .htm, .pdf (via pypdf).
"""
name = getattr(file_obj, "name", "") or ""
ext = os.path.splitext(name)[-1].lower()
if ext in TEXT_EXTS:
data = file_obj.read()
if isinstance(data, bytes):
data = data.decode("utf-8", errors="ignore")
if ext in {".html", ".htm"}:
data = re.sub(r"<[^>]+>", " ", data)
data = re.sub(r"\s+", " ", data).strip()
return data
if ext in PDF_EXTS:
try:
from pypdf import PdfReader
reader = PdfReader(io.BytesIO(file_obj.read()))
pages = []
for p in reader.pages:
try:
pages.append(p.extract_text() or "")
except Exception:
pages.append("")
text = "\n".join(pages)
text = re.sub(r"\s+", " ", text).strip()
return text
except Exception as e:
return f"[PDF parse error] {e}"
# Fallback: try as text
data = file_obj.read()
if isinstance(data, bytes):
data = data.decode("utf-8", errors="ignore")
return data
def chunked_predict(text: str, max_length: int = 512, stride: int = 128, agg: str = "mean") -> Dict[str, Any]:
"""
Chunk the document using tokenizer overflow, run classifier on each chunk,
aggregate probabilities, and return both doc-level and chunk-level results,
including a short snippet per chunk derived from offset_mapping.
"""
if not text or not text.strip():
return {"error": "Empty document."}
with torch.no_grad():
enc = tokenizer(
text,
truncation=True,
max_length=max_length,
return_overflowing_tokens=True,
stride=stride,
padding=True,
return_offsets_mapping=True, # NEW: get character offsets per token
return_tensors="pt",
)
allowed = {"input_ids", "attention_mask", "token_type_ids"}
inputs = {k: v.to(model.device) for k, v in enc.items() if k in allowed}
logits_list = []
for i in range(inputs["input_ids"].size(0)):
batch = {k: v[i:i+1] for k, v in inputs.items()}
out = model(**batch)
logits_list.append(out.logits)
logits = torch.cat(logits_list, dim=0) # [num_chunks, num_labels]
probs = torch.softmax(logits, dim=-1).cpu().numpy()
num_chunks = int(probs.shape[0])
# Aggregate
if agg == "max":
doc_probs = probs.max(axis=0)
else:
doc_probs = probs.mean(axis=0)
# By convention: 0 -> Human, 1 -> AI
prob_human = float(doc_probs[0])
prob_ai = float(doc_probs[1])
# --- Build snippets per chunk from offset mapping ---
offsets = enc["offset_mapping"] # tensor of pairs
attn = enc["attention_mask"] # [num_chunks, seq_len]
snippets: List[str] = []
PREVIEW = 120
for i in range(offsets.shape[0]):
offs = offsets[i].tolist()
mask = attn[i].tolist()
spans = [(s, e) for (s, e), m in zip(offs, mask) if m == 1 and not (s == 0 and e == 0)]
if spans:
s0 = min(s for s, _ in spans)
e0 = max(e for _, e in spans)
raw = text[s0:e0].strip()
raw = " ".join(raw.split())
if len(raw) > PREVIEW:
raw = raw[:PREVIEW].rstrip() + "โ€ฆ"
snippets.append(raw)
else:
snippets.append("")
# Per-chunk rows: [chunk#, AI prob, Human prob, Snippet]
chunk_rows: List[List[Any]] = []
for i, p in enumerate(probs):
ai_p = float(p[1])
hu_p = float(p[0])
chunk_rows.append([i + 1, ai_p, hu_p, snippets[i]])
return {
"ai_prob": prob_ai,
"human_prob": prob_human,
"num_chunks": num_chunks,
"chunk_rows": chunk_rows, # list of [chunk, AI, Human, Snippet]
"max_length": max_length,
"stride": stride,
}
def predict_from_upload(file, aggregation, max_length, stride):
if file is None:
return {"error": "Please upload a file."}
# Work around gradio temp file behavior
if hasattr(file, "name") and isinstance(file.name, str):
with open(file.name, "rb") as f:
raw = io.BytesIO(f.read())
raw.name = os.path.basename(file.name)
text = read_text_from_file(raw)
else:
text = read_text_from_file(file)
return chunked_predict(text, max_length=int(max_length), stride=int(stride), agg=aggregation)
# -----------------------------
# UI Helpers (HTML formatting)
# -----------------------------
def probability_bar_html(label: str, prob: float) -> str:
"""Return an HTML row with label, percent, and a bar."""
pct = prob * 100.0
return f"""
<div class="prob-row"><div class="prob-label"><b>{label}</b></div>
<div class="prob-value">{pct:.2f}%</div>
<div class="prob-bar">
<div class="prob-fill" style="width:{pct:.2f}%"></div>
</div>
</div>
"""
def verdict_badge_html(prob_ai: float, threshold: float = 0.5) -> str:
label = "Likely AI" if prob_ai >= threshold else "Likely Human"
color = "#ef4444" if prob_ai >= threshold else "#10b981" # red / green
return f"<span class='pill' style='background:{color}22;color:{color}'>{label}</span>"
def format_outputs(result: Dict[str, Any], threshold: float = 0.5):
"""Produce (verdict_html, probs_html, chunk_table_data, details_md)."""
if "error" in result:
return f"<span style='color:#ef4444'>{result['error']}</span>", "", [], ""
ai, human = result["ai_prob"], result["human_prob"]
verdict_html = verdict_badge_html(ai, threshold=threshold)
probs_html = ""
probs_html += probability_bar_html("AI generated", ai)
probs_html += probability_bar_html("Human written", human)
# Chunk table rows (already built server-side)
table_data = result["chunk_rows"]
details_md = (
f"**Chunks:** `{result['num_chunks']}` \n"
f"**Tokens per chunk:** `{result['max_length']}` \n"
f"**Stride:** `{result['stride']}`"
)
return verdict_html, probs_html, table_data, details_md
# -----------------------------
# Gradio Interface
# -----------------------------
CSS = """
.pill {padding:6px 12px; border-radius:999px; display:inline-block; margin: 6px 0; font-weight:600;}
.prob-row {display:flex; align-items:center; gap:10px; margin:6px 0;}
.prob-label {min-width:140px;}
.prob-value {min-width:80px; text-align:right; font-variant-numeric: tabular-nums;}
.prob-bar {flex:1; background:#e5e7eb; height:12px; border-radius:6px; overflow:hidden;}
.prob-fill {height:12px; background:#6366f1;}
.small-note {font-size:0.9rem; color:#6b7280;}
/* Wrap long snippet text within the DataFrame cells */
.gr-dataframe table td { white-space: normal; }
/* Scrollable chunk table container */
#chunkgroup { max-height: 260px; overflow: auto; }
#details_note { font-size: 0.9rem; color: #6b7280; }
"""
DESCRIPTION = """
### ๐Ÿ”Ž AI vs Human โ€” Document Classifier
Upload a file to get **document-level probabilities**.
Long inputs are **chunked** into overlapping windows; chunk predictions are **aggregated**.
"""
with gr.Blocks(
title="AI vs Human Document Classifier",
theme='Nymbo/rounded-gradient',
css=CSS
) as demo:
gr.Markdown(DESCRIPTION)
with gr.Tabs():
with gr.Tab("Predict"):
file_in = gr.File(label="Upload a document", file_types=[".txt", ".md", ".rtf", ".html", ".htm", ".pdf"])
agg_in = gr.Radio(choices=["mean", "max"], value="mean", label="Aggregation over chunks")
btn = gr.Button("Predict", variant="primary")
verdict_html = gr.HTML(label="Verdict")
probs_html = gr.HTML(label="Probabilities")
with gr.Accordion("Chunk details", open=False):
with gr.Group(elem_id="chunkgroup"):
chunk_table = gr.Dataframe(
headers=["Chunk", "AI generated", "Human written", "Snippet"],
datatype=["number", "number", "number", "str"],
label="Per-chunk probabilities",
wrap=True,
interactive=False,
row_count=(0, "dynamic"),
col_count=(4, "fixed"),
)
details_md = gr.Markdown("", elem_id="details_note")
with gr.Tab("Advanced"):
gr.Markdown("Adjust chunking parameters below.")
max_len_in = gr.Slider(128, 1024, value=MAX_LENGTH, step=32, label="Tokens per chunk (max_length)")
stride_in = gr.Slider(0, 512, value=STRIDE, step=16, label="Stride / overlap")
gr.Markdown("You can also set `MODEL_ID`, `MAX_LENGTH`, and `STRIDE` via Space Variables.")
def predict_and_prettify(file, aggregation, max_length=MAX_LENGTH, stride=STRIDE):
res = predict_from_upload(file, aggregation, max_length, stride)
return format_outputs(res)
btn.click(
fn=predict_and_prettify,
inputs=[file_in, agg_in, max_len_in, stride_in],
outputs=[verdict_html, probs_html, chunk_table, details_md],
)
if __name__ == "__main__":
demo.launch()