#!/usr/bin/env python # -*- coding: utf-8 -*- """ Gradio App — AI vs Human Document Classifier (Chunked Inference) ---------------------------------------------------------------- Features: - Upload a document (TXT/MD/HTML/PDF), chunk if needed, classify each chunk, aggregate to document. - UI includes: 1) Probability bars with raw numbers (AI generated / Human written) 2) Confidence badge ("Likely AI" / "Likely Human") with traffic-light color 3) Tabs for Basic / Advanced controls 4) Chunk details accordion with per-chunk probabilities 5) NEW: Per-chunk **snippet** extracted using tokenizer offset_mapping """ import os import io import re from typing import Dict, Any, List, Tuple import numpy as np import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForSequenceClassification # ----------------------------- # Config # ----------------------------- MODEL_ID = os.getenv("MODEL_ID", "bert-base-uncased") MAX_LENGTH = int(os.getenv("MAX_LENGTH", "512")) STRIDE = int(os.getenv("STRIDE", "128")) # Device device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") if device.type == "mps": try: torch.set_float32_matmul_precision("high") except Exception: pass # Load model & tokenizer tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, torch_dtype=torch.float32).to(device) model.eval() # ----------------------------- # Utilities # ----------------------------- TEXT_EXTS = {".txt", ".md", ".rtf", ".html", ".htm"} PDF_EXTS = {".pdf"} def read_text_from_file(file_obj) -> str: """ Read text content from an uploaded file. Supports: .txt, .md, .rtf, .html, .htm, .pdf (via pypdf). """ name = getattr(file_obj, "name", "") or "" ext = os.path.splitext(name)[-1].lower() if ext in TEXT_EXTS: data = file_obj.read() if isinstance(data, bytes): data = data.decode("utf-8", errors="ignore") if ext in {".html", ".htm"}: data = re.sub(r"<[^>]+>", " ", data) data = re.sub(r"\s+", " ", data).strip() return data if ext in PDF_EXTS: try: from pypdf import PdfReader reader = PdfReader(io.BytesIO(file_obj.read())) pages = [] for p in reader.pages: try: pages.append(p.extract_text() or "") except Exception: pages.append("") text = "\n".join(pages) text = re.sub(r"\s+", " ", text).strip() return text except Exception as e: return f"[PDF parse error] {e}" # Fallback: try as text data = file_obj.read() if isinstance(data, bytes): data = data.decode("utf-8", errors="ignore") return data def chunked_predict(text: str, max_length: int = 512, stride: int = 128, agg: str = "mean") -> Dict[str, Any]: """ Chunk the document using tokenizer overflow, run classifier on each chunk, aggregate probabilities, and return both doc-level and chunk-level results, including a short snippet per chunk derived from offset_mapping. """ if not text or not text.strip(): return {"error": "Empty document."} with torch.no_grad(): enc = tokenizer( text, truncation=True, max_length=max_length, return_overflowing_tokens=True, stride=stride, padding=True, return_offsets_mapping=True, # NEW: get character offsets per token return_tensors="pt", ) allowed = {"input_ids", "attention_mask", "token_type_ids"} inputs = {k: v.to(model.device) for k, v in enc.items() if k in allowed} logits_list = [] for i in range(inputs["input_ids"].size(0)): batch = {k: v[i:i+1] for k, v in inputs.items()} out = model(**batch) logits_list.append(out.logits) logits = torch.cat(logits_list, dim=0) # [num_chunks, num_labels] probs = torch.softmax(logits, dim=-1).cpu().numpy() num_chunks = int(probs.shape[0]) # Aggregate if agg == "max": doc_probs = probs.max(axis=0) else: doc_probs = probs.mean(axis=0) # By convention: 0 -> Human, 1 -> AI prob_human = float(doc_probs[0]) prob_ai = float(doc_probs[1]) # --- Build snippets per chunk from offset mapping --- offsets = enc["offset_mapping"] # tensor of pairs attn = enc["attention_mask"] # [num_chunks, seq_len] snippets: List[str] = [] PREVIEW = 120 for i in range(offsets.shape[0]): offs = offsets[i].tolist() mask = attn[i].tolist() spans = [(s, e) for (s, e), m in zip(offs, mask) if m == 1 and not (s == 0 and e == 0)] if spans: s0 = min(s for s, _ in spans) e0 = max(e for _, e in spans) raw = text[s0:e0].strip() raw = " ".join(raw.split()) if len(raw) > PREVIEW: raw = raw[:PREVIEW].rstrip() + "…" snippets.append(raw) else: snippets.append("") # Per-chunk rows: [chunk#, AI prob, Human prob, Snippet] chunk_rows: List[List[Any]] = [] for i, p in enumerate(probs): ai_p = float(p[1]) hu_p = float(p[0]) chunk_rows.append([i + 1, ai_p, hu_p, snippets[i]]) return { "ai_prob": prob_ai, "human_prob": prob_human, "num_chunks": num_chunks, "chunk_rows": chunk_rows, # list of [chunk, AI, Human, Snippet] "max_length": max_length, "stride": stride, } def predict_from_upload(file, aggregation, max_length, stride): if file is None: return {"error": "Please upload a file."} # Work around gradio temp file behavior if hasattr(file, "name") and isinstance(file.name, str): with open(file.name, "rb") as f: raw = io.BytesIO(f.read()) raw.name = os.path.basename(file.name) text = read_text_from_file(raw) else: text = read_text_from_file(file) return chunked_predict(text, max_length=int(max_length), stride=int(stride), agg=aggregation) # ----------------------------- # UI Helpers (HTML formatting) # ----------------------------- def probability_bar_html(label: str, prob: float) -> str: """Return an HTML row with label, percent, and a bar.""" pct = prob * 100.0 return f"""