c-ho's picture
Update app.py
8e89350 verified
# =========================================================
# app.py (STABLE VERSION - FIXED TOKEN OVERFLOW)
# =========================================================
import gradio as gr
from transformers import pipeline
from pypdf import PdfReader
from pdf2image import convert_from_path
import pytesseract
import tempfile
# =========================================================
# Models
# =========================================================
MODELS = {
"English model (ubffm/academic_text_classifier_en)": "ubffm/academic_text_classifier_en",
"German model (ubffm/academic_text_classifier_de)": "ubffm/academic_text_classifier_de",
}
DEFAULT_MODEL = "English model (ubffm/academic_text_classifier_en)"
# =========================================================
# Labels
# =========================================================
LABELS = ["OUT OF SCOPE", "MAIN TEXT", "EXAMPLE", "REFERENCE"]
DEFAULT_NOISE = ["OUT OF SCOPE", "REFERENCE"]
# =========================================================
# Pipeline cache
# =========================================================
PIPELINES = {}
def get_classifier(model_display_name):
model_name = MODELS[model_display_name]
if model_name not in PIPELINES:
PIPELINES[model_name] = pipeline(
"text-classification",
model=model_name,
tokenizer=model_name,
return_all_scores=True
)
return PIPELINES[model_name]
# =========================================================
# Safe prediction (IMPORTANT FIX)
# =========================================================
def get_best_prediction(classifier, text):
# HARD SAFETY: prevents tokenizer overflow in pipeline
text = text[:2000]
result = classifier(text)
if isinstance(result, list) and isinstance(result[0], list):
result = result[0]
return max(result, key=lambda x: x["score"]), result
# =========================================================
# SAFE CHUNKING (ROBUST FIX)
# =========================================================
def safe_chunk_text(text, tokenizer, max_tokens=480):
"""
True safe chunking for XLM-R (leaves room for special tokens)
"""
sentences = text.split("\n")
chunks = []
current = []
current_len = 0
for sent in sentences:
sent_tokens = tokenizer.encode(sent, add_special_tokens=False)
sent_len = len(sent_tokens)
# Case 1: single sentence too large → hard split
if sent_len > max_tokens:
for i in range(0, sent_len, max_tokens):
part = tokenizer.decode(sent_tokens[i:i + max_tokens])
chunks.append(part)
continue
# Case 2: overflow chunk
if current_len + sent_len > max_tokens:
chunks.append("\n".join(current))
current = [sent]
current_len = sent_len
else:
current.append(sent)
current_len += sent_len
if current:
chunks.append("\n".join(current))
return chunks
# =========================================================
# Empty line cleanup
# =========================================================
def normalize_empty_lines(lines):
cleaned = []
prev_empty = False
for l in lines:
empty = not l.strip()
if empty and prev_empty:
continue
cleaned.append(l)
prev_empty = empty
return cleaned
# =========================================================
# CORE PIPELINE (FIXED)
# =========================================================
def process_text_input(text, noise_labels, selected_model):
if not text.strip():
return "", "", "", None
classifier = get_classifier(selected_model)
tokenizer = classifier.tokenizer
chunks = safe_chunk_text(text, tokenizer)
logs = []
kept = []
removed = []
line_counter = 0
for c_id, chunk in enumerate(chunks):
lines = chunk.splitlines()
for line in lines:
line_counter += 1
if not line.strip():
kept.append("")
continue
pred, _ = get_best_prediction(classifier, line)
logs.append(
f"[Chunk {c_id}] Line {line_counter} | "
f"{pred['label']} ({pred['score']:.4f})\n{line}\n"
)
if pred["label"] in noise_labels:
removed.append(line)
else:
kept.append(line)
kept = normalize_empty_lines(kept)
filtered = "\n".join(kept)
tmp = tempfile.NamedTemporaryFile(
delete=False,
suffix=".txt",
mode="w",
encoding="utf-8"
)
tmp.write(filtered)
tmp.close()
stats = (
f"Chunks: {len(chunks)}\n"
f"Total lines: {line_counter}\n"
f"Removed: {len(removed)}\n"
f"Remaining: {len(kept)}"
)
return "\n".join(logs), filtered, stats, tmp.name
# =========================================================
# TXT FILE
# =========================================================
def process_document_file(file, noise_labels, selected_model):
if file is None:
return "", "", "", None
with open(file.name, "r", encoding="utf-8") as f:
text = f.read()
return process_text_input(text, noise_labels, selected_model)
# =========================================================
# PDF EXTRACTION (DIGITAL + OCR)
# =========================================================
def extract_text_from_pdf(pdf_file):
text_parts = []
# 1. try digital extraction
try:
reader = PdfReader(pdf_file.name)
for page in reader.pages:
t = page.extract_text()
if t:
text_parts.append(t)
except:
pass
text = "\n".join(text_parts).strip()
# 2. OCR fallback
if not text:
pages = convert_from_path(pdf_file.name, dpi=300)
ocr = []
for page in pages:
ocr.append(pytesseract.image_to_string(page))
text = "\n".join(ocr)
return text
# =========================================================
# PDF PIPELINE
# =========================================================
def process_pdf_file(file, noise_labels, selected_model):
if file is None:
return "", "", "", None
text = extract_text_from_pdf(file)
return process_text_input(text, noise_labels, selected_model)
# =========================================================
# UI
# =========================================================
with gr.Blocks(title="Stable Academic Text Filter") as demo:
gr.Markdown("""
# Academic Text Filter (FIXED VERSION)
✔ No tokenizer crashes
✔ OCR + PDF support
✔ Safe chunking (XLM-R compatible)
✔ Robust long-document handling
""")
with gr.Tab("Text"):
m = gr.Dropdown(list(MODELS.keys()), value=DEFAULT_MODEL)
t = gr.Textbox(lines=20)
n = gr.CheckboxGroup(LABELS, value=DEFAULT_NOISE)
btn = gr.Button("Process")
o1 = gr.Textbox(lines=15)
o2 = gr.Textbox(lines=15)
o3 = gr.Textbox()
o4 = gr.File()
btn.click(process_text_input, [t, n, m], [o1, o2, o3, o4])
with gr.Tab("TXT"):
m = gr.Dropdown(list(MODELS.keys()), value=DEFAULT_MODEL)
f = gr.File(file_types=[".txt"])
n = gr.CheckboxGroup(LABELS, value=DEFAULT_NOISE)
btn = gr.Button("Process")
o1 = gr.Textbox(lines=15)
o2 = gr.Textbox(lines=15)
o3 = gr.Textbox()
o4 = gr.File()
btn.click(process_document_file, [f, n, m], [o1, o2, o3, o4])
with gr.Tab("PDF"):
m = gr.Dropdown(list(MODELS.keys()), value=DEFAULT_MODEL)
f = gr.File(file_types=[".pdf"])
n = gr.CheckboxGroup(LABELS, value=DEFAULT_NOISE)
btn = gr.Button("Process PDF")
o1 = gr.Textbox(lines=15)
o2 = gr.Textbox(lines=15)
o3 = gr.Textbox()
o4 = gr.File()
btn.click(process_pdf_file, [f, n, m], [o1, o2, o3, o4])
if __name__ == "__main__":
demo.launch()