tomerz14's picture
Upload 3 files
c6f65f2 verified
raw
history blame
6.6 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Gradio App — Binary Text Classifier (Chunked Inference)
-------------------------------------------------------
- Users upload a document file (txt, md, html, pdf*), we read the text, chunk if needed,
and return a prediction with probability.
- Designed for Hugging Face Spaces.
* For PDFs, this app uses a simple text extraction via pypdf. For production-quality
extraction, consider using `pymupdf` (fitz) or `pdfminer.six`.
"""
import os
import io
import re
from typing import Dict, Any
import numpy as np
import torch
import gradio as gr
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
)
# -----------------------------
# Config
# -----------------------------
MODEL_ID = os.getenv("MODEL_ID", "bert-base-uncased") # e.g., "tomerz14/human-vs-AI_bert-classifier"
MAX_LENGTH = int(os.getenv("MAX_LENGTH", "512"))
STRIDE = int(os.getenv("STRIDE", "128"))
# Device selection (CPU by default on Spaces)
device = torch.device("cuda" if torch.cuda.is_available() else
"mps" if torch.backends.mps.is_available() else "cpu")
if device.type == "mps":
try:
torch.set_float32_matmul_precision("high")
except Exception:
pass
# Load model & tokenizer at startup
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, torch_dtype=torch.float32).to(device)
model.eval()
# -----------------------------
# Utilities
# -----------------------------
TEXT_EXTS = {".txt", ".md", ".rtf", ".html", ".htm"}
PDF_EXTS = {".pdf"}
def read_text_from_file(file_obj) -> str:
"""
Read text content from an uploaded file.
Supports: .txt, .md, .rtf, .html, .htm, .pdf (via pypdf).
"""
name = getattr(file_obj, "name", "") or ""
ext = os.path.splitext(name)[-1].lower()
if ext in TEXT_EXTS:
data = file_obj.read()
if isinstance(data, bytes):
data = data.decode("utf-8", errors="ignore")
if ext in {".html", ".htm"}:
data = re.sub(r"<[^>]+>", " ", data)
data = re.sub(r"\s+", " ", data).strip()
return data
if ext in PDF_EXTS:
try:
from pypdf import PdfReader
reader = PdfReader(io.BytesIO(file_obj.read()))
pages = []
for p in reader.pages:
try:
pages.append(p.extract_text() or "")
except Exception:
pages.append("")
text = "\n".join(pages)
text = re.sub(r"\s+", " ", text).strip()
return text
except Exception as e:
return f"[PDF parse error] {e}"
# Fallback: try to treat as text
data = file_obj.read()
if isinstance(data, bytes):
data = data.decode("utf-8", errors="ignore")
return data
def chunked_predict(text: str, max_length: int = 512, stride: int = 128, agg: str = "mean") -> Dict[str, Any]:
"""
Chunk the document using tokenizer overflow, run the classifier on each chunk,
and aggregate probabilities (mean or max).
"""
if not text or not text.strip():
return {"error": "Empty document."}
with torch.no_grad():
enc = tokenizer(
text,
truncation=True,
max_length=max_length,
return_overflowing_tokens=True,
stride=stride,
padding=True,
return_tensors="pt",
)
allowed = {"input_ids", "attention_mask", "token_type_ids"}
inputs = {k: v.to(model.device) for k, v in enc.items() if k in allowed}
logits_list = []
for i in range(inputs["input_ids"].size(0)):
batch = {k: v[i:i+1] for k, v in inputs.items()}
out = model(**batch)
logits_list.append(out.logits)
logits = torch.cat(logits_list, dim=0) # [num_chunks, num_labels]
probs = torch.softmax(logits, dim=-1).cpu().numpy()
num_chunks = int(probs.shape[0])
doc_probs = probs.mean(axis=0) if agg == "mean" else probs.max(axis=0)
pred_id = int(np.argmax(doc_probs))
id2label = getattr(model.config, "id2label", {0: "LABEL_0", 1: "LABEL_1"})
label = id2label.get(pred_id, str(pred_id))
score = float(doc_probs[pred_id])
all_scores = {id2label.get(i, str(i)): float(doc_probs[i]) for i in range(len(doc_probs))}
return {
"label": label,
"score": round(score, 6),
"all_scores": all_scores,
"num_chunks": num_chunks,
"tokens_per_chunk": max_length,
"stride": stride,
"model": MODEL_ID,
"device": str(device),
}
def predict_from_upload(file, aggregation, max_length, stride):
if file is None:
return {"error": "Please upload a file."}
# Work around gradio temp file behavior
if hasattr(file, "name") and isinstance(file.name, str):
with open(file.name, "rb") as f:
raw_bytes = f.read()
mem = io.BytesIO(raw_bytes)
mem.name = os.path.basename(file.name)
text = read_text_from_file(mem)
else:
text = read_text_from_file(file)
return chunked_predict(text, max_length=int(max_length), stride=int(stride), agg=aggregation)
# -----------------------------
# Gradio UI
# -----------------------------
DESCRIPTION = """
## Binary Document Classifier (Chunked)
Upload a document (TXT/MD/HTML/PDF) and get a **document-level prediction**.
Long files are **split into overlapping 512-token chunks**, each chunk is classified,
and probabilities are **aggregated** (mean or max).
**Tip:** This Space expects a binary classifier with two labels in the loaded checkpoint.
"""
with gr.Blocks(title="Binary Document Classifier") as demo:
gr.Markdown(DESCRIPTION)
file_in = gr.File(label="Upload a document", file_types=[".txt", ".md", ".rtf", ".html", ".htm", ".pdf"])
aggregation = gr.Radio(choices=["mean", "max"], value="mean", label="Aggregation over chunks")
with gr.Accordion("Advanced", open=False):
max_len_in = gr.Slider(128, 1024, value=MAX_LENGTH, step=32, label="Tokens per chunk (max_length)")
stride_in = gr.Slider(0, 512, value=STRIDE, step=16, label="Stride / overlap")
btn = gr.Button("Predict")
out_json = gr.JSON(label="Prediction")
btn.click(
fn=predict_from_upload,
inputs=[file_in, aggregation, max_len_in, stride_in],
outputs=[out_json],
api_name="predict",
)
if __name__ == "__main__":
demo.launch()