#!/usr/bin/env python # -*- coding: utf-8 -*- """ Gradio App — Binary Text Classifier (Chunked Inference) ------------------------------------------------------- - Users upload a document file (txt, md, html, pdf*), we read the text, chunk if needed, and return a prediction with probability. - Designed for Hugging Face Spaces. * For PDFs, this app uses a simple text extraction via pypdf. For production-quality extraction, consider using `pymupdf` (fitz) or `pdfminer.six`. """ import os import io import re from typing import Dict, Any import numpy as np import torch import gradio as gr from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, ) # ----------------------------- # Config # ----------------------------- MODEL_ID = os.getenv("MODEL_ID", "bert-base-uncased") # e.g., "tomerz14/human-vs-AI_bert-classifier" MAX_LENGTH = int(os.getenv("MAX_LENGTH", "512")) STRIDE = int(os.getenv("STRIDE", "128")) # Device selection (CPU by default on Spaces) device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") if device.type == "mps": try: torch.set_float32_matmul_precision("high") except Exception: pass # Load model & tokenizer at startup tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, torch_dtype=torch.float32).to(device) model.eval() # ----------------------------- # Utilities # ----------------------------- TEXT_EXTS = {".txt", ".md", ".rtf", ".html", ".htm"} PDF_EXTS = {".pdf"} def read_text_from_file(file_obj) -> str: """ Read text content from an uploaded file. Supports: .txt, .md, .rtf, .html, .htm, .pdf (via pypdf). """ name = getattr(file_obj, "name", "") or "" ext = os.path.splitext(name)[-1].lower() if ext in TEXT_EXTS: data = file_obj.read() if isinstance(data, bytes): data = data.decode("utf-8", errors="ignore") if ext in {".html", ".htm"}: data = re.sub(r"<[^>]+>", " ", data) data = re.sub(r"\s+", " ", data).strip() return data if ext in PDF_EXTS: try: from pypdf import PdfReader reader = PdfReader(io.BytesIO(file_obj.read())) pages = [] for p in reader.pages: try: pages.append(p.extract_text() or "") except Exception: pages.append("") text = "\n".join(pages) text = re.sub(r"\s+", " ", text).strip() return text except Exception as e: return f"[PDF parse error] {e}" # Fallback: try to treat as text data = file_obj.read() if isinstance(data, bytes): data = data.decode("utf-8", errors="ignore") return data def chunked_predict(text: str, max_length: int = 512, stride: int = 128, agg: str = "mean") -> Dict[str, Any]: """ Chunk the document using tokenizer overflow, run the classifier on each chunk, and aggregate probabilities (mean or max). """ if not text or not text.strip(): return {"error": "Empty document."} with torch.no_grad(): enc = tokenizer( text, truncation=True, max_length=max_length, return_overflowing_tokens=True, stride=stride, padding=True, return_tensors="pt", ) allowed = {"input_ids", "attention_mask", "token_type_ids"} inputs = {k: v.to(model.device) for k, v in enc.items() if k in allowed} logits_list = [] for i in range(inputs["input_ids"].size(0)): batch = {k: v[i:i+1] for k, v in inputs.items()} out = model(**batch) logits_list.append(out.logits) logits = torch.cat(logits_list, dim=0) # [num_chunks, num_labels] probs = torch.softmax(logits, dim=-1).cpu().numpy() num_chunks = int(probs.shape[0]) doc_probs = probs.mean(axis=0) if agg == "mean" else probs.max(axis=0) pred_id = int(np.argmax(doc_probs)) id2label = getattr(model.config, "id2label", {0: "LABEL_0", 1: "LABEL_1"}) label = id2label.get(pred_id, str(pred_id)) score = float(doc_probs[pred_id]) all_scores = {id2label.get(i, str(i)): float(doc_probs[i]) for i in range(len(doc_probs))} return { "label": label, "score": round(score, 6), "all_scores": all_scores, "num_chunks": num_chunks, "tokens_per_chunk": max_length, "stride": stride, "model": MODEL_ID, "device": str(device), } def predict_from_upload(file, aggregation, max_length, stride): if file is None: return {"error": "Please upload a file."} # Work around gradio temp file behavior if hasattr(file, "name") and isinstance(file.name, str): with open(file.name, "rb") as f: raw_bytes = f.read() mem = io.BytesIO(raw_bytes) mem.name = os.path.basename(file.name) text = read_text_from_file(mem) else: text = read_text_from_file(file) return chunked_predict(text, max_length=int(max_length), stride=int(stride), agg=aggregation) # ----------------------------- # Gradio UI # ----------------------------- DESCRIPTION = """ ## Binary Document Classifier (Chunked) Upload a document (TXT/MD/HTML/PDF) and get a **document-level prediction**. Long files are **split into overlapping 512-token chunks**, each chunk is classified, and probabilities are **aggregated** (mean or max). **Tip:** This Space expects a binary classifier with two labels in the loaded checkpoint. """ with gr.Blocks(title="Binary Document Classifier") as demo: gr.Markdown(DESCRIPTION) file_in = gr.File(label="Upload a document", file_types=[".txt", ".md", ".rtf", ".html", ".htm", ".pdf"]) aggregation = gr.Radio(choices=["mean", "max"], value="mean", label="Aggregation over chunks") with gr.Accordion("Advanced", open=False): max_len_in = gr.Slider(128, 1024, value=MAX_LENGTH, step=32, label="Tokens per chunk (max_length)") stride_in = gr.Slider(0, 512, value=STRIDE, step=16, label="Stride / overlap") btn = gr.Button("Predict") out_json = gr.JSON(label="Prediction") btn.click( fn=predict_from_upload, inputs=[file_in, aggregation, max_len_in, stride_in], outputs=[out_json], api_name="predict", ) if __name__ == "__main__": demo.launch()