tomerz14's picture
Update app.py
4cf9509 verified
raw
history blame
9.64 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Gradio App โ€” AI vs Human Document Classifier (Chunked Inference)
----------------------------------------------------------------
Features:
- Upload a document (TXT/MD/HTML/PDF), chunk if needed, classify each chunk, aggregate to document.
- Shows:
1) Probability bars with raw numbers (AI generated / Human written)
2) Confidence badge ("Likely AI" / "Likely Human") with traffic-light color
3) Tabs for Basic / Advanced controls
4) Chunk details accordion with per-chunk probabilities
"""
import os
import io
import re
from typing import Dict, Any, List, Tuple
import numpy as np
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# -----------------------------
# Config
# -----------------------------
MODEL_ID = os.getenv("MODEL_ID", "bert-base-uncased") # e.g., "username/bert-binclass"
MAX_LENGTH = int(os.getenv("MAX_LENGTH", "512"))
STRIDE = int(os.getenv("STRIDE", "128"))
# Device
device = torch.device("cuda" if torch.cuda.is_available() else
"mps" if torch.backends.mps.is_available() else "cpu")
if device.type == "mps":
try:
torch.set_float32_matmul_precision("high")
except Exception:
pass
# Load model & tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, torch_dtype=torch.float32).to(device)
model.eval()
# -----------------------------
# Utilities
# -----------------------------
TEXT_EXTS = {".txt", ".md", ".rtf", ".html", ".htm"}
PDF_EXTS = {".pdf"}
def read_text_from_file(file_obj) -> str:
"""
Read text content from an uploaded file.
Supports: .txt, .md, .rtf, .html, .htm, .pdf (via pypdf).
"""
name = getattr(file_obj, "name", "") or ""
ext = os.path.splitext(name)[-1].lower()
if ext in TEXT_EXTS:
data = file_obj.read()
if isinstance(data, bytes):
data = data.decode("utf-8", errors="ignore")
if ext in {".html", ".htm"}:
data = re.sub(r"<[^>]+>", " ", data)
data = re.sub(r"\s+", " ", data).strip()
return data
if ext in PDF_EXTS:
try:
from pypdf import PdfReader
reader = PdfReader(io.BytesIO(file_obj.read()))
pages = []
for p in reader.pages:
try:
pages.append(p.extract_text() or "")
except Exception:
pages.append("")
text = "\n".join(pages)
text = re.sub(r"\s+", " ", text).strip()
return text
except Exception as e:
return f"[PDF parse error] {e}"
# Fallback: try as text
data = file_obj.read()
if isinstance(data, bytes):
data = data.decode("utf-8", errors="ignore")
return data
def chunked_predict(text: str, max_length: int = 512, stride: int = 128, agg: str = "mean") -> Dict[str, Any]:
"""
Chunk the document using tokenizer overflow, run classifier on each chunk,
aggregate probabilities, and return both doc-level and chunk-level results.
"""
if not text or not text.strip():
return {"error": "Empty document."}
with torch.no_grad():
enc = tokenizer(
text,
truncation=True,
max_length=max_length,
return_overflowing_tokens=True,
stride=stride,
padding=True,
return_tensors="pt",
)
allowed = {"input_ids", "attention_mask", "token_type_ids"}
inputs = {k: v.to(model.device) for k, v in enc.items() if k in allowed}
logits_list = []
for i in range(inputs["input_ids"].size(0)):
batch = {k: v[i:i+1] for k, v in inputs.items()}
out = model(**batch)
logits_list.append(out.logits)
logits = torch.cat(logits_list, dim=0) # [num_chunks, num_labels]
probs = torch.softmax(logits, dim=-1).cpu().numpy()
num_chunks = int(probs.shape[0])
# Aggregate
if agg == "max":
doc_probs = probs.max(axis=0)
else:
doc_probs = probs.mean(axis=0)
# By convention: 0 -> Human, 1 -> AI
prob_human = float(doc_probs[0])
prob_ai = float(doc_probs[1])
# Per-chunk table rows
chunk_rows = []
for i, p in enumerate(probs):
chunk_rows.append([i + 1, float(p[1]), float(p[0])]) # [chunk, AI, Human]
return {
"ai_prob": prob_ai,
"human_prob": prob_human,
"num_chunks": num_chunks,
"chunk_rows": chunk_rows, # list of [chunk, AI, Human]
"max_length": max_length,
"stride": stride,
}
def predict_from_upload(file, aggregation, max_length, stride):
if file is None:
return {"error": "Please upload a file."}
# Work around gradio temp file behavior
if hasattr(file, "name") and isinstance(file.name, str):
with open(file.name, "rb") as f:
raw = io.BytesIO(f.read())
raw.name = os.path.basename(file.name)
text = read_text_from_file(raw)
else:
text = read_text_from_file(file)
return chunked_predict(text, max_length=int(max_length), stride=int(stride), agg=aggregation)
# -----------------------------
# UI Helpers (HTML formatting)
# -----------------------------
def probability_bar_html(label: str, prob: float) -> str:
"""Return an HTML row with label, percent, and a bar."""
pct = prob * 100.0
return f"""
<div class="prob-row"><div class="prob-label"><b>{label}</b></div>
<div class="prob-value">{pct:.2f}%</div>
<div class="prob-bar">
<div class="prob-fill" style="width:{pct:.2f}%"></div>
</div>
</div>
"""
def verdict_badge_html(prob_ai: float, threshold: float = 0.5) -> str:
label = "Likely AI" if prob_ai >= threshold else "Likely Human"
color = "#ef4444" if prob_ai >= threshold else "#10b981" # red / green
return f"<span class='pill' style='background:{color}22;color:{color}'>{label}</span>"
def format_outputs(result: Dict[str, Any], threshold: float = 0.5):
"""Produce (verdict_html, probs_html, chunk_table_data, details_md)."""
if "error" in result:
return f"<span style='color:#ef4444'>{result['error']}</span>", "", [], ""
ai, human = result["ai_prob"], result["human_prob"]
verdict_html = verdict_badge_html(ai, threshold=threshold)
probs_html = ""
probs_html += probability_bar_html("AI generated", ai)
probs_html += probability_bar_html("Human written", human)
# Chunk table rows
table_data = result["chunk_rows"]
details_md = (
f"**Chunks:** `{result['num_chunks']}` \n"
f"**Tokens per chunk:** `{result['max_length']}` \n"
f"**Stride:** `{result['stride']}`"
)
return verdict_html, probs_html, table_data, details_md
# -----------------------------
# Gradio Interface
# -----------------------------
CSS = """
.pill {padding:6px 12px; border-radius:999px; display:inline-block; margin: 6px 0; font-weight:600;}
.prob-row {display:flex; align-items:center; gap:10px; margin:6px 0;}
.prob-label {min-width:140px;}
.prob-value {min-width:80px; text-align:right; font-variant-numeric: tabular-nums;}
.prob-bar {flex:1; background:#e5e7eb; height:12px; border-radius:6px; overflow:hidden;}
.prob-fill {height:12px; background:#6366f1;}
.small-note {font-size:0.9rem; color:#6b7280;}
"""
DESCRIPTION = """
### ๐Ÿ”Ž AI vs Human โ€” Document Classifier
Upload a file to get **document-level probabilities**.
Long inputs are **chunked** into overlapping windows; chunk predictions are **aggregated**.
"""
with gr.Blocks(
title="AI vs Human Document Classifier",
theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate"),
css=CSS
) as demo:
gr.Markdown(DESCRIPTION)
with gr.Tabs():
with gr.Tab("Predict"):
file_in = gr.File(label="Upload a document", file_types=[".txt", ".md", ".rtf", ".html", ".htm", ".pdf"])
agg_in = gr.Radio(choices=["mean", "max"], value="mean", label="Aggregation over chunks")
btn = gr.Button("Predict", variant="primary")
verdict_html = gr.HTML(label="Verdict")
probs_html = gr.HTML(label="Probabilities")
with gr.Accordion("Chunk details", open=False):
chunk_table = gr.Dataframe(
headers=["Chunk", "AI generated", "Human written"],
datatype=["number", "number", "number"],
label="Per-chunk probabilities",
wrap=True,
interactive=False,
height=240
)
details_md = gr.Markdown("", elem_classes=["small-note"])
with gr.Tab("Advanced"):
gr.Markdown("Adjust chunking parameters below.")
max_len_in = gr.Slider(128, 1024, value=MAX_LENGTH, step=32, label="Tokens per chunk (max_length)")
stride_in = gr.Slider(0, 512, value=STRIDE, step=16, label="Stride / overlap")
gr.Markdown("You can also set `MODEL_ID`, `MAX_LENGTH`, and `STRIDE` via Space Variables.")
def predict_and_prettify(file, aggregation, max_length=MAX_LENGTH, stride=STRIDE):
res = predict_from_upload(file, aggregation, max_length, stride)
return format_outputs(res)
btn.click(
fn=predict_and_prettify,
inputs=[file_in, agg_in, max_len_in, stride_in],
outputs=[verdict_html, probs_html, chunk_table, details_md],
)
if __name__ == "__main__":
demo.launch()