|
|
|
|
|
|
|
|
""" |
|
|
Gradio App โ AI vs Human Document Classifier (Chunked Inference) |
|
|
---------------------------------------------------------------- |
|
|
Features: |
|
|
- Upload a document (TXT/MD/HTML/PDF), chunk if needed, classify each chunk, aggregate to document. |
|
|
- UI includes: |
|
|
1) Probability bars with raw numbers (AI generated / Human written) |
|
|
2) Confidence badge ("Likely AI" / "Likely Human") with traffic-light color |
|
|
3) Tabs for Basic / Advanced controls |
|
|
4) Chunk details accordion with per-chunk probabilities |
|
|
5) NEW: Per-chunk **snippet** extracted using tokenizer offset_mapping |
|
|
""" |
|
|
|
|
|
import os |
|
|
import io |
|
|
import re |
|
|
from typing import Dict, Any, List, Tuple |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import gradio as gr |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_ID = os.getenv("MODEL_ID", "bert-base-uncased") |
|
|
MAX_LENGTH = int(os.getenv("MAX_LENGTH", "512")) |
|
|
STRIDE = int(os.getenv("STRIDE", "128")) |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else |
|
|
"mps" if torch.backends.mps.is_available() else "cpu") |
|
|
if device.type == "mps": |
|
|
try: |
|
|
torch.set_float32_matmul_precision("high") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) |
|
|
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, torch_dtype=torch.float32).to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TEXT_EXTS = {".txt", ".md", ".rtf", ".html", ".htm"} |
|
|
PDF_EXTS = {".pdf"} |
|
|
|
|
|
def read_text_from_file(file_obj) -> str: |
|
|
""" |
|
|
Read text content from an uploaded file. |
|
|
Supports: .txt, .md, .rtf, .html, .htm, .pdf (via pypdf). |
|
|
""" |
|
|
name = getattr(file_obj, "name", "") or "" |
|
|
ext = os.path.splitext(name)[-1].lower() |
|
|
|
|
|
if ext in TEXT_EXTS: |
|
|
data = file_obj.read() |
|
|
if isinstance(data, bytes): |
|
|
data = data.decode("utf-8", errors="ignore") |
|
|
if ext in {".html", ".htm"}: |
|
|
data = re.sub(r"<[^>]+>", " ", data) |
|
|
data = re.sub(r"\s+", " ", data).strip() |
|
|
return data |
|
|
|
|
|
if ext in PDF_EXTS: |
|
|
try: |
|
|
from pypdf import PdfReader |
|
|
reader = PdfReader(io.BytesIO(file_obj.read())) |
|
|
pages = [] |
|
|
for p in reader.pages: |
|
|
try: |
|
|
pages.append(p.extract_text() or "") |
|
|
except Exception: |
|
|
pages.append("") |
|
|
text = "\n".join(pages) |
|
|
text = re.sub(r"\s+", " ", text).strip() |
|
|
return text |
|
|
except Exception as e: |
|
|
return f"[PDF parse error] {e}" |
|
|
|
|
|
|
|
|
data = file_obj.read() |
|
|
if isinstance(data, bytes): |
|
|
data = data.decode("utf-8", errors="ignore") |
|
|
return data |
|
|
|
|
|
|
|
|
def chunked_predict(text: str, max_length: int = 512, stride: int = 128, agg: str = "mean") -> Dict[str, Any]: |
|
|
""" |
|
|
Chunk the document using tokenizer overflow, run classifier on each chunk, |
|
|
aggregate probabilities, and return both doc-level and chunk-level results, |
|
|
including a short snippet per chunk derived from offset_mapping. |
|
|
""" |
|
|
if not text or not text.strip(): |
|
|
return {"error": "Empty document."} |
|
|
|
|
|
with torch.no_grad(): |
|
|
enc = tokenizer( |
|
|
text, |
|
|
truncation=True, |
|
|
max_length=max_length, |
|
|
return_overflowing_tokens=True, |
|
|
stride=stride, |
|
|
padding=True, |
|
|
return_offsets_mapping=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
allowed = {"input_ids", "attention_mask", "token_type_ids"} |
|
|
inputs = {k: v.to(model.device) for k, v in enc.items() if k in allowed} |
|
|
|
|
|
logits_list = [] |
|
|
for i in range(inputs["input_ids"].size(0)): |
|
|
batch = {k: v[i:i+1] for k, v in inputs.items()} |
|
|
out = model(**batch) |
|
|
logits_list.append(out.logits) |
|
|
|
|
|
logits = torch.cat(logits_list, dim=0) |
|
|
probs = torch.softmax(logits, dim=-1).cpu().numpy() |
|
|
num_chunks = int(probs.shape[0]) |
|
|
|
|
|
|
|
|
if agg == "max": |
|
|
doc_probs = probs.max(axis=0) |
|
|
else: |
|
|
doc_probs = probs.mean(axis=0) |
|
|
|
|
|
|
|
|
prob_human = float(doc_probs[0]) |
|
|
prob_ai = float(doc_probs[1]) |
|
|
|
|
|
|
|
|
offsets = enc["offset_mapping"] |
|
|
attn = enc["attention_mask"] |
|
|
snippets: List[str] = [] |
|
|
PREVIEW = 120 |
|
|
|
|
|
for i in range(offsets.shape[0]): |
|
|
offs = offsets[i].tolist() |
|
|
mask = attn[i].tolist() |
|
|
spans = [(s, e) for (s, e), m in zip(offs, mask) if m == 1 and not (s == 0 and e == 0)] |
|
|
if spans: |
|
|
s0 = min(s for s, _ in spans) |
|
|
e0 = max(e for _, e in spans) |
|
|
raw = text[s0:e0].strip() |
|
|
raw = " ".join(raw.split()) |
|
|
if len(raw) > PREVIEW: |
|
|
raw = raw[:PREVIEW].rstrip() + "โฆ" |
|
|
snippets.append(raw) |
|
|
else: |
|
|
snippets.append("") |
|
|
|
|
|
|
|
|
chunk_rows: List[List[Any]] = [] |
|
|
for i, p in enumerate(probs): |
|
|
ai_p = float(p[1]) |
|
|
hu_p = float(p[0]) |
|
|
chunk_rows.append([i + 1, ai_p, hu_p, snippets[i]]) |
|
|
|
|
|
return { |
|
|
"ai_prob": prob_ai, |
|
|
"human_prob": prob_human, |
|
|
"num_chunks": num_chunks, |
|
|
"chunk_rows": chunk_rows, |
|
|
"max_length": max_length, |
|
|
"stride": stride, |
|
|
} |
|
|
|
|
|
|
|
|
def predict_from_upload(file, aggregation, max_length, stride): |
|
|
if file is None: |
|
|
return {"error": "Please upload a file."} |
|
|
|
|
|
|
|
|
if hasattr(file, "name") and isinstance(file.name, str): |
|
|
with open(file.name, "rb") as f: |
|
|
raw = io.BytesIO(f.read()) |
|
|
raw.name = os.path.basename(file.name) |
|
|
text = read_text_from_file(raw) |
|
|
else: |
|
|
text = read_text_from_file(file) |
|
|
|
|
|
return chunked_predict(text, max_length=int(max_length), stride=int(stride), agg=aggregation) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def probability_bar_html(label: str, prob: float) -> str: |
|
|
"""Return an HTML row with label, percent, and a bar.""" |
|
|
pct = prob * 100.0 |
|
|
return f""" |
|
|
<div class="prob-row"><div class="prob-label"><b>{label}</b></div> |
|
|
<div class="prob-value">{pct:.2f}%</div> |
|
|
<div class="prob-bar"> |
|
|
<div class="prob-fill" style="width:{pct:.2f}%"></div> |
|
|
</div> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
def verdict_badge_html(prob_ai: float, threshold: float = 0.5) -> str: |
|
|
label = "Likely AI" if prob_ai >= threshold else "Likely Human" |
|
|
color = "#ef4444" if prob_ai >= threshold else "#10b981" |
|
|
return f"<span class='pill' style='background:{color}22;color:{color}'>{label}</span>" |
|
|
|
|
|
def format_outputs(result: Dict[str, Any], threshold: float = 0.5): |
|
|
"""Produce (verdict_html, probs_html, chunk_table_data, details_md).""" |
|
|
if "error" in result: |
|
|
return f"<span style='color:#ef4444'>{result['error']}</span>", "", [], "" |
|
|
|
|
|
ai, human = result["ai_prob"], result["human_prob"] |
|
|
verdict_html = verdict_badge_html(ai, threshold=threshold) |
|
|
|
|
|
probs_html = "" |
|
|
probs_html += probability_bar_html("AI generated", ai) |
|
|
probs_html += probability_bar_html("Human written", human) |
|
|
|
|
|
|
|
|
table_data = result["chunk_rows"] |
|
|
|
|
|
details_md = ( |
|
|
f"**Chunks:** `{result['num_chunks']}` \n" |
|
|
f"**Tokens per chunk:** `{result['max_length']}` \n" |
|
|
f"**Stride:** `{result['stride']}`" |
|
|
) |
|
|
|
|
|
return verdict_html, probs_html, table_data, details_md |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CSS = """ |
|
|
.pill {padding:6px 12px; border-radius:999px; display:inline-block; margin: 6px 0; font-weight:600;} |
|
|
.prob-row {display:flex; align-items:center; gap:10px; margin:6px 0;} |
|
|
.prob-label {min-width:140px;} |
|
|
.prob-value {min-width:80px; text-align:right; font-variant-numeric: tabular-nums;} |
|
|
.prob-bar {flex:1; background:#e5e7eb; height:12px; border-radius:6px; overflow:hidden;} |
|
|
.prob-fill {height:12px; background:#6366f1;} |
|
|
.small-note {font-size:0.9rem; color:#6b7280;} |
|
|
/* Wrap long snippet text within the DataFrame cells */ |
|
|
.gr-dataframe table td { white-space: normal; } |
|
|
/* Scrollable chunk table container */ |
|
|
#chunkgroup { max-height: 260px; overflow: auto; } |
|
|
#details_note { font-size: 0.9rem; color: #6b7280; } |
|
|
""" |
|
|
|
|
|
DESCRIPTION = """ |
|
|
### ๐ AI vs Human โ Document Classifier |
|
|
Upload a file to get **document-level probabilities**. |
|
|
Long inputs are **chunked** into overlapping windows; chunk predictions are **aggregated**. |
|
|
""" |
|
|
|
|
|
with gr.Blocks( |
|
|
title="AI vs Human Document Classifier", |
|
|
theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate"), |
|
|
css=CSS |
|
|
) as demo: |
|
|
gr.Markdown(DESCRIPTION) |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.Tab("Predict"): |
|
|
file_in = gr.File(label="Upload a document", file_types=[".txt", ".md", ".rtf", ".html", ".htm", ".pdf"]) |
|
|
agg_in = gr.Radio(choices=["mean", "max"], value="mean", label="Aggregation over chunks") |
|
|
btn = gr.Button("Predict", variant="primary") |
|
|
verdict_html = gr.HTML(label="Verdict") |
|
|
probs_html = gr.HTML(label="Probabilities") |
|
|
|
|
|
with gr.Accordion("Chunk details", open=False): |
|
|
with gr.Group(elem_id="chunkgroup"): |
|
|
chunk_table = gr.Dataframe( |
|
|
headers=["Chunk", "AI generated", "Human written", "Snippet"], |
|
|
datatype=["number", "number", "number", "str"], |
|
|
label="Per-chunk probabilities", |
|
|
wrap=True, |
|
|
interactive=False, |
|
|
row_count=(0, "dynamic"), |
|
|
col_count=(4, "fixed"), |
|
|
) |
|
|
details_md = gr.Markdown("", elem_id="details_note") |
|
|
|
|
|
with gr.Tab("Advanced"): |
|
|
gr.Markdown("Adjust chunking parameters below.") |
|
|
max_len_in = gr.Slider(128, 1024, value=MAX_LENGTH, step=32, label="Tokens per chunk (max_length)") |
|
|
stride_in = gr.Slider(0, 512, value=STRIDE, step=16, label="Stride / overlap") |
|
|
gr.Markdown("You can also set `MODEL_ID`, `MAX_LENGTH`, and `STRIDE` via Space Variables.") |
|
|
|
|
|
def predict_and_prettify(file, aggregation, max_length=MAX_LENGTH, stride=STRIDE): |
|
|
res = predict_from_upload(file, aggregation, max_length, stride) |
|
|
return format_outputs(res) |
|
|
|
|
|
btn.click( |
|
|
fn=predict_and_prettify, |
|
|
inputs=[file_in, agg_in, max_len_in, stride_in], |
|
|
outputs=[verdict_html, probs_html, chunk_table, details_md], |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |