tomerz14 commited on
Commit
c6f65f2
·
verified ·
1 Parent(s): 220b4b1

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +27 -0
  2. app.py +199 -0
  3. requirements.txt +6 -0
README.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Binary Document Classifier — Gradio Space
2
+
3
+ This Space hosts a Gradio app for **binary text classification** on uploaded documents.
4
+ It supports long documents by **chunking** (512-token windows with overlap) and aggregates
5
+ chunk probabilities into a **document-level** prediction.
6
+
7
+ ## Configure
8
+
9
+ Set the environment variable `MODEL_ID` in your Space to point to your trained model,
10
+ e.g. `your-username/bert-binclass`. You can also set:
11
+
12
+ - `MAX_LENGTH` — tokens per chunk (default: 512)
13
+ - `STRIDE` — overlap tokens between chunks (default: 128)
14
+
15
+ ## Run locally
16
+
17
+ ```bash
18
+ pip install -r requirements.txt
19
+ python app.py
20
+ ```
21
+
22
+ Then open the printed Gradio URL.
23
+
24
+ ## Notes
25
+
26
+ - PDF extraction uses `pypdf` for simplicity. For higher-quality results or OCR,
27
+ consider `pymupdf` (fitz) or `unstructured`.
app.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Gradio App — Binary Text Classifier (Chunked Inference)
5
+ -------------------------------------------------------
6
+ - Users upload a document file (txt, md, html, pdf*), we read the text, chunk if needed,
7
+ and return a prediction with probability.
8
+ - Designed for Hugging Face Spaces.
9
+
10
+ * For PDFs, this app uses a simple text extraction via pypdf. For production-quality
11
+ extraction, consider using `pymupdf` (fitz) or `pdfminer.six`.
12
+ """
13
+
14
+ import os
15
+ import io
16
+ import re
17
+ from typing import Dict, Any
18
+
19
+ import numpy as np
20
+ import torch
21
+ import gradio as gr
22
+
23
+ from transformers import (
24
+ AutoTokenizer,
25
+ AutoModelForSequenceClassification,
26
+ )
27
+
28
+ # -----------------------------
29
+ # Config
30
+ # -----------------------------
31
+ MODEL_ID = os.getenv("MODEL_ID", "bert-base-uncased") # e.g., "tomerz14/human-vs-AI_bert-classifier"
32
+ MAX_LENGTH = int(os.getenv("MAX_LENGTH", "512"))
33
+ STRIDE = int(os.getenv("STRIDE", "128"))
34
+
35
+ # Device selection (CPU by default on Spaces)
36
+ device = torch.device("cuda" if torch.cuda.is_available() else
37
+ "mps" if torch.backends.mps.is_available() else "cpu")
38
+
39
+ if device.type == "mps":
40
+ try:
41
+ torch.set_float32_matmul_precision("high")
42
+ except Exception:
43
+ pass
44
+
45
+ # Load model & tokenizer at startup
46
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
47
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, torch_dtype=torch.float32).to(device)
48
+ model.eval()
49
+
50
+ # -----------------------------
51
+ # Utilities
52
+ # -----------------------------
53
+
54
+ TEXT_EXTS = {".txt", ".md", ".rtf", ".html", ".htm"}
55
+ PDF_EXTS = {".pdf"}
56
+
57
+ def read_text_from_file(file_obj) -> str:
58
+ """
59
+ Read text content from an uploaded file.
60
+ Supports: .txt, .md, .rtf, .html, .htm, .pdf (via pypdf).
61
+ """
62
+ name = getattr(file_obj, "name", "") or ""
63
+ ext = os.path.splitext(name)[-1].lower()
64
+
65
+ if ext in TEXT_EXTS:
66
+ data = file_obj.read()
67
+ if isinstance(data, bytes):
68
+ data = data.decode("utf-8", errors="ignore")
69
+ if ext in {".html", ".htm"}:
70
+ data = re.sub(r"<[^>]+>", " ", data)
71
+ data = re.sub(r"\s+", " ", data).strip()
72
+ return data
73
+
74
+ if ext in PDF_EXTS:
75
+ try:
76
+ from pypdf import PdfReader
77
+ reader = PdfReader(io.BytesIO(file_obj.read()))
78
+ pages = []
79
+ for p in reader.pages:
80
+ try:
81
+ pages.append(p.extract_text() or "")
82
+ except Exception:
83
+ pages.append("")
84
+ text = "\n".join(pages)
85
+ text = re.sub(r"\s+", " ", text).strip()
86
+ return text
87
+ except Exception as e:
88
+ return f"[PDF parse error] {e}"
89
+
90
+ # Fallback: try to treat as text
91
+ data = file_obj.read()
92
+ if isinstance(data, bytes):
93
+ data = data.decode("utf-8", errors="ignore")
94
+ return data
95
+
96
+
97
+ def chunked_predict(text: str, max_length: int = 512, stride: int = 128, agg: str = "mean") -> Dict[str, Any]:
98
+ """
99
+ Chunk the document using tokenizer overflow, run the classifier on each chunk,
100
+ and aggregate probabilities (mean or max).
101
+ """
102
+ if not text or not text.strip():
103
+ return {"error": "Empty document."}
104
+
105
+ with torch.no_grad():
106
+ enc = tokenizer(
107
+ text,
108
+ truncation=True,
109
+ max_length=max_length,
110
+ return_overflowing_tokens=True,
111
+ stride=stride,
112
+ padding=True,
113
+ return_tensors="pt",
114
+ )
115
+
116
+ allowed = {"input_ids", "attention_mask", "token_type_ids"}
117
+ inputs = {k: v.to(model.device) for k, v in enc.items() if k in allowed}
118
+
119
+ logits_list = []
120
+ for i in range(inputs["input_ids"].size(0)):
121
+ batch = {k: v[i:i+1] for k, v in inputs.items()}
122
+ out = model(**batch)
123
+ logits_list.append(out.logits)
124
+
125
+ logits = torch.cat(logits_list, dim=0) # [num_chunks, num_labels]
126
+ probs = torch.softmax(logits, dim=-1).cpu().numpy()
127
+ num_chunks = int(probs.shape[0])
128
+
129
+ doc_probs = probs.mean(axis=0) if agg == "mean" else probs.max(axis=0)
130
+
131
+ pred_id = int(np.argmax(doc_probs))
132
+ id2label = getattr(model.config, "id2label", {0: "LABEL_0", 1: "LABEL_1"})
133
+ label = id2label.get(pred_id, str(pred_id))
134
+ score = float(doc_probs[pred_id])
135
+ all_scores = {id2label.get(i, str(i)): float(doc_probs[i]) for i in range(len(doc_probs))}
136
+
137
+ return {
138
+ "label": label,
139
+ "score": round(score, 6),
140
+ "all_scores": all_scores,
141
+ "num_chunks": num_chunks,
142
+ "tokens_per_chunk": max_length,
143
+ "stride": stride,
144
+ "model": MODEL_ID,
145
+ "device": str(device),
146
+ }
147
+
148
+
149
+ def predict_from_upload(file, aggregation, max_length, stride):
150
+ if file is None:
151
+ return {"error": "Please upload a file."}
152
+
153
+ # Work around gradio temp file behavior
154
+ if hasattr(file, "name") and isinstance(file.name, str):
155
+ with open(file.name, "rb") as f:
156
+ raw_bytes = f.read()
157
+ mem = io.BytesIO(raw_bytes)
158
+ mem.name = os.path.basename(file.name)
159
+ text = read_text_from_file(mem)
160
+ else:
161
+ text = read_text_from_file(file)
162
+
163
+ return chunked_predict(text, max_length=int(max_length), stride=int(stride), agg=aggregation)
164
+
165
+
166
+ # -----------------------------
167
+ # Gradio UI
168
+ # -----------------------------
169
+ DESCRIPTION = """
170
+ ## Binary Document Classifier (Chunked)
171
+ Upload a document (TXT/MD/HTML/PDF) and get a **document-level prediction**.
172
+ Long files are **split into overlapping 512-token chunks**, each chunk is classified,
173
+ and probabilities are **aggregated** (mean or max).
174
+
175
+ **Tip:** This Space expects a binary classifier with two labels in the loaded checkpoint.
176
+ """
177
+
178
+ with gr.Blocks(title="Binary Document Classifier") as demo:
179
+ gr.Markdown(DESCRIPTION)
180
+
181
+ file_in = gr.File(label="Upload a document", file_types=[".txt", ".md", ".rtf", ".html", ".htm", ".pdf"])
182
+ aggregation = gr.Radio(choices=["mean", "max"], value="mean", label="Aggregation over chunks")
183
+
184
+ with gr.Accordion("Advanced", open=False):
185
+ max_len_in = gr.Slider(128, 1024, value=MAX_LENGTH, step=32, label="Tokens per chunk (max_length)")
186
+ stride_in = gr.Slider(0, 512, value=STRIDE, step=16, label="Stride / overlap")
187
+
188
+ btn = gr.Button("Predict")
189
+ out_json = gr.JSON(label="Prediction")
190
+
191
+ btn.click(
192
+ fn=predict_from_upload,
193
+ inputs=[file_in, aggregation, max_len_in, stride_in],
194
+ outputs=[out_json],
195
+ api_name="predict",
196
+ )
197
+
198
+ if __name__ == "__main__":
199
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ transformers>=4.44
2
+ torch
3
+ evaluate>=0.4.0
4
+ datasets>=2.20
5
+ gradio>=4.0
6
+ pypdf>=4.0