tomerz14 commited on
Commit
4cf9509
·
verified ·
1 Parent(s): 4c97017

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -62
app.py CHANGED
@@ -1,48 +1,44 @@
1
  #!/usr/bin/env python
2
  # -*- coding: utf-8 -*-
3
  """
4
- Gradio App — Binary Text Classifier (Chunked Inference)
5
- -------------------------------------------------------
6
- - Users upload a document file (txt, md, html, pdf*), we read the text, chunk if needed,
7
- and return a prediction with probability.
8
- - Designed for Hugging Face Spaces.
9
-
10
- * For PDFs, this app uses a simple text extraction via pypdf. For production-quality
11
- extraction, consider using `pymupdf` (fitz) or `pdfminer.six`.
 
12
  """
13
 
14
  import os
15
  import io
16
  import re
17
- from typing import Dict, Any
18
 
19
  import numpy as np
20
  import torch
21
  import gradio as gr
22
-
23
- from transformers import (
24
- AutoTokenizer,
25
- AutoModelForSequenceClassification,
26
- )
27
 
28
  # -----------------------------
29
  # Config
30
  # -----------------------------
31
- MODEL_ID = os.getenv("MODEL_ID", "bert-base-uncased") # e.g., "tomerz14/human-vs-AI_bert-classifier"
32
  MAX_LENGTH = int(os.getenv("MAX_LENGTH", "512"))
33
  STRIDE = int(os.getenv("STRIDE", "128"))
34
 
35
- # Device selection (CPU by default on Spaces)
36
  device = torch.device("cuda" if torch.cuda.is_available() else
37
  "mps" if torch.backends.mps.is_available() else "cpu")
38
-
39
  if device.type == "mps":
40
  try:
41
  torch.set_float32_matmul_precision("high")
42
  except Exception:
43
  pass
44
 
45
- # Load model & tokenizer at startup
46
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
47
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, torch_dtype=torch.float32).to(device)
48
  model.eval()
@@ -50,7 +46,6 @@ model.eval()
50
  # -----------------------------
51
  # Utilities
52
  # -----------------------------
53
-
54
  TEXT_EXTS = {".txt", ".md", ".rtf", ".html", ".htm"}
55
  PDF_EXTS = {".pdf"}
56
 
@@ -87,7 +82,7 @@ def read_text_from_file(file_obj) -> str:
87
  except Exception as e:
88
  return f"[PDF parse error] {e}"
89
 
90
- # Fallback: try to treat as text
91
  data = file_obj.read()
92
  if isinstance(data, bytes):
93
  data = data.decode("utf-8", errors="ignore")
@@ -96,8 +91,8 @@ def read_text_from_file(file_obj) -> str:
96
 
97
  def chunked_predict(text: str, max_length: int = 512, stride: int = 128, agg: str = "mean") -> Dict[str, Any]:
98
  """
99
- Chunk the document using tokenizer overflow, run the classifier on each chunk,
100
- and aggregate probabilities (mean or max).
101
  """
102
  if not text or not text.strip():
103
  return {"error": "Empty document."}
@@ -122,27 +117,32 @@ def chunked_predict(text: str, max_length: int = 512, stride: int = 128, agg: st
122
  out = model(**batch)
123
  logits_list.append(out.logits)
124
 
125
- logits = torch.cat(logits_list, dim=0) # [num_chunks, num_labels]
126
- probs = torch.softmax(logits, dim=-1).cpu().numpy()
127
  num_chunks = int(probs.shape[0])
128
 
129
- doc_probs = probs.mean(axis=0) if agg == "mean" else probs.max(axis=0)
 
 
 
 
 
 
 
 
130
 
131
- pred_id = int(np.argmax(doc_probs))
132
- id2label = getattr(model.config, "id2label", {0: "LABEL_0", 1: "LABEL_1"})
133
- label = id2label.get(pred_id, str(pred_id))
134
- score = float(doc_probs[pred_id])
135
- all_scores = {id2label.get(i, str(i)): float(doc_probs[i]) for i in range(len(doc_probs))}
136
 
137
  return {
138
- "label": label,
139
- "score": round(score, 6),
140
- "all_scores": all_scores,
141
  "num_chunks": num_chunks,
142
- "tokens_per_chunk": max_length,
 
143
  "stride": stride,
144
- "model": MODEL_ID,
145
- "device": str(device),
146
  }
147
 
148
 
@@ -153,10 +153,9 @@ def predict_from_upload(file, aggregation, max_length, stride):
153
  # Work around gradio temp file behavior
154
  if hasattr(file, "name") and isinstance(file.name, str):
155
  with open(file.name, "rb") as f:
156
- raw_bytes = f.read()
157
- mem = io.BytesIO(raw_bytes)
158
- mem.name = os.path.basename(file.name)
159
- text = read_text_from_file(mem)
160
  else:
161
  text = read_text_from_file(file)
162
 
@@ -164,36 +163,109 @@ def predict_from_upload(file, aggregation, max_length, stride):
164
 
165
 
166
  # -----------------------------
167
- # Gradio UI
168
  # -----------------------------
169
- DESCRIPTION = """
170
- ## Binary Document Classifier (Chunked)
171
- Upload a document (TXT/MD/HTML/PDF) and get a **document-level prediction**.
172
- Long files are **split into overlapping 512-token chunks**, each chunk is classified,
173
- and probabilities are **aggregated** (mean or max).
 
 
 
 
 
 
174
 
175
- **Tip:** This Space expects a binary classifier with two labels in the loaded checkpoint.
176
- """
 
 
177
 
178
- with gr.Blocks(title="Binary Document Classifier") as demo:
179
- gr.Markdown(DESCRIPTION)
 
 
180
 
181
- file_in = gr.File(label="Upload a document", file_types=[".txt", ".md", ".rtf", ".html", ".htm", ".pdf"])
182
- aggregation = gr.Radio(choices=["mean", "max"], value="mean", label="Aggregation over chunks")
183
 
184
- with gr.Accordion("Advanced", open=False):
185
- max_len_in = gr.Slider(128, 1024, value=MAX_LENGTH, step=32, label="Tokens per chunk (max_length)")
186
- stride_in = gr.Slider(0, 512, value=STRIDE, step=16, label="Stride / overlap")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
- btn = gr.Button("Predict")
189
- out_json = gr.JSON(label="Prediction")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  btn.click(
192
- fn=predict_from_upload,
193
- inputs=[file_in, aggregation, max_len_in, stride_in],
194
- outputs=[out_json],
195
- api_name="predict",
196
  )
197
 
198
  if __name__ == "__main__":
199
- demo.launch()
 
1
  #!/usr/bin/env python
2
  # -*- coding: utf-8 -*-
3
  """
4
+ Gradio App — AI vs Human Document Classifier (Chunked Inference)
5
+ ----------------------------------------------------------------
6
+ Features:
7
+ - Upload a document (TXT/MD/HTML/PDF), chunk if needed, classify each chunk, aggregate to document.
8
+ - Shows:
9
+ 1) Probability bars with raw numbers (AI generated / Human written)
10
+ 2) Confidence badge ("Likely AI" / "Likely Human") with traffic-light color
11
+ 3) Tabs for Basic / Advanced controls
12
+ 4) Chunk details accordion with per-chunk probabilities
13
  """
14
 
15
  import os
16
  import io
17
  import re
18
+ from typing import Dict, Any, List, Tuple
19
 
20
  import numpy as np
21
  import torch
22
  import gradio as gr
23
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
 
 
 
24
 
25
  # -----------------------------
26
  # Config
27
  # -----------------------------
28
+ MODEL_ID = os.getenv("MODEL_ID", "bert-base-uncased") # e.g., "username/bert-binclass"
29
  MAX_LENGTH = int(os.getenv("MAX_LENGTH", "512"))
30
  STRIDE = int(os.getenv("STRIDE", "128"))
31
 
32
+ # Device
33
  device = torch.device("cuda" if torch.cuda.is_available() else
34
  "mps" if torch.backends.mps.is_available() else "cpu")
 
35
  if device.type == "mps":
36
  try:
37
  torch.set_float32_matmul_precision("high")
38
  except Exception:
39
  pass
40
 
41
+ # Load model & tokenizer
42
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
43
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, torch_dtype=torch.float32).to(device)
44
  model.eval()
 
46
  # -----------------------------
47
  # Utilities
48
  # -----------------------------
 
49
  TEXT_EXTS = {".txt", ".md", ".rtf", ".html", ".htm"}
50
  PDF_EXTS = {".pdf"}
51
 
 
82
  except Exception as e:
83
  return f"[PDF parse error] {e}"
84
 
85
+ # Fallback: try as text
86
  data = file_obj.read()
87
  if isinstance(data, bytes):
88
  data = data.decode("utf-8", errors="ignore")
 
91
 
92
  def chunked_predict(text: str, max_length: int = 512, stride: int = 128, agg: str = "mean") -> Dict[str, Any]:
93
  """
94
+ Chunk the document using tokenizer overflow, run classifier on each chunk,
95
+ aggregate probabilities, and return both doc-level and chunk-level results.
96
  """
97
  if not text or not text.strip():
98
  return {"error": "Empty document."}
 
117
  out = model(**batch)
118
  logits_list.append(out.logits)
119
 
120
+ logits = torch.cat(logits_list, dim=0) # [num_chunks, num_labels]
121
+ probs = torch.softmax(logits, dim=-1).cpu().numpy()
122
  num_chunks = int(probs.shape[0])
123
 
124
+ # Aggregate
125
+ if agg == "max":
126
+ doc_probs = probs.max(axis=0)
127
+ else:
128
+ doc_probs = probs.mean(axis=0)
129
+
130
+ # By convention: 0 -> Human, 1 -> AI
131
+ prob_human = float(doc_probs[0])
132
+ prob_ai = float(doc_probs[1])
133
 
134
+ # Per-chunk table rows
135
+ chunk_rows = []
136
+ for i, p in enumerate(probs):
137
+ chunk_rows.append([i + 1, float(p[1]), float(p[0])]) # [chunk, AI, Human]
 
138
 
139
  return {
140
+ "ai_prob": prob_ai,
141
+ "human_prob": prob_human,
 
142
  "num_chunks": num_chunks,
143
+ "chunk_rows": chunk_rows, # list of [chunk, AI, Human]
144
+ "max_length": max_length,
145
  "stride": stride,
 
 
146
  }
147
 
148
 
 
153
  # Work around gradio temp file behavior
154
  if hasattr(file, "name") and isinstance(file.name, str):
155
  with open(file.name, "rb") as f:
156
+ raw = io.BytesIO(f.read())
157
+ raw.name = os.path.basename(file.name)
158
+ text = read_text_from_file(raw)
 
159
  else:
160
  text = read_text_from_file(file)
161
 
 
163
 
164
 
165
  # -----------------------------
166
+ # UI Helpers (HTML formatting)
167
  # -----------------------------
168
+ def probability_bar_html(label: str, prob: float) -> str:
169
+ """Return an HTML row with label, percent, and a bar."""
170
+ pct = prob * 100.0
171
+ return f"""
172
+ <div class="prob-row"><div class="prob-label"><b>{label}</b></div>
173
+ <div class="prob-value">{pct:.2f}%</div>
174
+ <div class="prob-bar">
175
+ <div class="prob-fill" style="width:{pct:.2f}%"></div>
176
+ </div>
177
+ </div>
178
+ """
179
 
180
+ def verdict_badge_html(prob_ai: float, threshold: float = 0.5) -> str:
181
+ label = "Likely AI" if prob_ai >= threshold else "Likely Human"
182
+ color = "#ef4444" if prob_ai >= threshold else "#10b981" # red / green
183
+ return f"<span class='pill' style='background:{color}22;color:{color}'>{label}</span>"
184
 
185
+ def format_outputs(result: Dict[str, Any], threshold: float = 0.5):
186
+ """Produce (verdict_html, probs_html, chunk_table_data, details_md)."""
187
+ if "error" in result:
188
+ return f"<span style='color:#ef4444'>{result['error']}</span>", "", [], ""
189
 
190
+ ai, human = result["ai_prob"], result["human_prob"]
191
+ verdict_html = verdict_badge_html(ai, threshold=threshold)
192
 
193
+ probs_html = ""
194
+ probs_html += probability_bar_html("AI generated", ai)
195
+ probs_html += probability_bar_html("Human written", human)
196
+
197
+ # Chunk table rows
198
+ table_data = result["chunk_rows"]
199
+
200
+ details_md = (
201
+ f"**Chunks:** `{result['num_chunks']}` \n"
202
+ f"**Tokens per chunk:** `{result['max_length']}` \n"
203
+ f"**Stride:** `{result['stride']}`"
204
+ )
205
+
206
+ return verdict_html, probs_html, table_data, details_md
207
+
208
+
209
+ # -----------------------------
210
+ # Gradio Interface
211
+ # -----------------------------
212
+ CSS = """
213
+ .pill {padding:6px 12px; border-radius:999px; display:inline-block; margin: 6px 0; font-weight:600;}
214
+ .prob-row {display:flex; align-items:center; gap:10px; margin:6px 0;}
215
+ .prob-label {min-width:140px;}
216
+ .prob-value {min-width:80px; text-align:right; font-variant-numeric: tabular-nums;}
217
+ .prob-bar {flex:1; background:#e5e7eb; height:12px; border-radius:6px; overflow:hidden;}
218
+ .prob-fill {height:12px; background:#6366f1;}
219
+ .small-note {font-size:0.9rem; color:#6b7280;}
220
+ """
221
+
222
+ DESCRIPTION = """
223
+ ### 🔎 AI vs Human — Document Classifier
224
+ Upload a file to get **document-level probabilities**.
225
+ Long inputs are **chunked** into overlapping windows; chunk predictions are **aggregated**.
226
+ """
227
+
228
+ with gr.Blocks(
229
+ title="AI vs Human Document Classifier",
230
+ theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate"),
231
+ css=CSS
232
+ ) as demo:
233
+ gr.Markdown(DESCRIPTION)
234
 
235
+ with gr.Tabs():
236
+ with gr.Tab("Predict"):
237
+ file_in = gr.File(label="Upload a document", file_types=[".txt", ".md", ".rtf", ".html", ".htm", ".pdf"])
238
+ agg_in = gr.Radio(choices=["mean", "max"], value="mean", label="Aggregation over chunks")
239
+ btn = gr.Button("Predict", variant="primary")
240
+ verdict_html = gr.HTML(label="Verdict")
241
+ probs_html = gr.HTML(label="Probabilities")
242
+
243
+ with gr.Accordion("Chunk details", open=False):
244
+ chunk_table = gr.Dataframe(
245
+ headers=["Chunk", "AI generated", "Human written"],
246
+ datatype=["number", "number", "number"],
247
+ label="Per-chunk probabilities",
248
+ wrap=True,
249
+ interactive=False,
250
+ height=240
251
+ )
252
+ details_md = gr.Markdown("", elem_classes=["small-note"])
253
+
254
+ with gr.Tab("Advanced"):
255
+ gr.Markdown("Adjust chunking parameters below.")
256
+ max_len_in = gr.Slider(128, 1024, value=MAX_LENGTH, step=32, label="Tokens per chunk (max_length)")
257
+ stride_in = gr.Slider(0, 512, value=STRIDE, step=16, label="Stride / overlap")
258
+ gr.Markdown("You can also set `MODEL_ID`, `MAX_LENGTH`, and `STRIDE` via Space Variables.")
259
+
260
+ def predict_and_prettify(file, aggregation, max_length=MAX_LENGTH, stride=STRIDE):
261
+ res = predict_from_upload(file, aggregation, max_length, stride)
262
+ return format_outputs(res)
263
 
264
  btn.click(
265
+ fn=predict_and_prettify,
266
+ inputs=[file_in, agg_in, max_len_in, stride_in],
267
+ outputs=[verdict_html, probs_html, chunk_table, details_md],
 
268
  )
269
 
270
  if __name__ == "__main__":
271
+ demo.launch()