#!/usr/bin/env python # -*- coding: utf-8 -*- """ Gradio App — AI vs Human Document Classifier (Chunked Inference) ---------------------------------------------------------------- Features: - Upload a document (TXT/MD/HTML/PDF), chunk if needed, classify each chunk, aggregate to document. - Shows: 1) Probability bars with raw numbers (AI generated / Human written) 2) Confidence badge ("Likely AI" / "Likely Human") with traffic-light color 3) Tabs for Basic / Advanced controls 4) Chunk details accordion with per-chunk probabilities """ import os import io import re from typing import Dict, Any, List, Tuple import numpy as np import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForSequenceClassification # ----------------------------- # Config # ----------------------------- MODEL_ID = os.getenv("MODEL_ID", "bert-base-uncased") # e.g., "username/bert-binclass" MAX_LENGTH = int(os.getenv("MAX_LENGTH", "512")) STRIDE = int(os.getenv("STRIDE", "128")) # Device device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") if device.type == "mps": try: torch.set_float32_matmul_precision("high") except Exception: pass # Load model & tokenizer tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, torch_dtype=torch.float32).to(device) model.eval() # ----------------------------- # Utilities # ----------------------------- TEXT_EXTS = {".txt", ".md", ".rtf", ".html", ".htm"} PDF_EXTS = {".pdf"} def read_text_from_file(file_obj) -> str: """ Read text content from an uploaded file. Supports: .txt, .md, .rtf, .html, .htm, .pdf (via pypdf). """ name = getattr(file_obj, "name", "") or "" ext = os.path.splitext(name)[-1].lower() if ext in TEXT_EXTS: data = file_obj.read() if isinstance(data, bytes): data = data.decode("utf-8", errors="ignore") if ext in {".html", ".htm"}: data = re.sub(r"<[^>]+>", " ", data) data = re.sub(r"\s+", " ", data).strip() return data if ext in PDF_EXTS: try: from pypdf import PdfReader reader = PdfReader(io.BytesIO(file_obj.read())) pages = [] for p in reader.pages: try: pages.append(p.extract_text() or "") except Exception: pages.append("") text = "\n".join(pages) text = re.sub(r"\s+", " ", text).strip() return text except Exception as e: return f"[PDF parse error] {e}" # Fallback: try as text data = file_obj.read() if isinstance(data, bytes): data = data.decode("utf-8", errors="ignore") return data def chunked_predict(text: str, max_length: int = 512, stride: int = 128, agg: str = "mean") -> Dict[str, Any]: """ Chunk the document using tokenizer overflow, run classifier on each chunk, aggregate probabilities, and return both doc-level and chunk-level results. """ if not text or not text.strip(): return {"error": "Empty document."} with torch.no_grad(): enc = tokenizer( text, truncation=True, max_length=max_length, return_overflowing_tokens=True, stride=stride, padding=True, return_tensors="pt", ) allowed = {"input_ids", "attention_mask", "token_type_ids"} inputs = {k: v.to(model.device) for k, v in enc.items() if k in allowed} logits_list = [] for i in range(inputs["input_ids"].size(0)): batch = {k: v[i:i+1] for k, v in inputs.items()} out = model(**batch) logits_list.append(out.logits) logits = torch.cat(logits_list, dim=0) # [num_chunks, num_labels] probs = torch.softmax(logits, dim=-1).cpu().numpy() num_chunks = int(probs.shape[0]) # Aggregate if agg == "max": doc_probs = probs.max(axis=0) else: doc_probs = probs.mean(axis=0) # By convention: 0 -> Human, 1 -> AI prob_human = float(doc_probs[0]) prob_ai = float(doc_probs[1]) # Per-chunk table rows chunk_rows = [] for i, p in enumerate(probs): chunk_rows.append([i + 1, float(p[1]), float(p[0])]) # [chunk, AI, Human] return { "ai_prob": prob_ai, "human_prob": prob_human, "num_chunks": num_chunks, "chunk_rows": chunk_rows, # list of [chunk, AI, Human] "max_length": max_length, "stride": stride, } def predict_from_upload(file, aggregation, max_length, stride): if file is None: return {"error": "Please upload a file."} # Work around gradio temp file behavior if hasattr(file, "name") and isinstance(file.name, str): with open(file.name, "rb") as f: raw = io.BytesIO(f.read()) raw.name = os.path.basename(file.name) text = read_text_from_file(raw) else: text = read_text_from_file(file) return chunked_predict(text, max_length=int(max_length), stride=int(stride), agg=aggregation) # ----------------------------- # UI Helpers (HTML formatting) # ----------------------------- def probability_bar_html(label: str, prob: float) -> str: """Return an HTML row with label, percent, and a bar.""" pct = prob * 100.0 return f"""