Spaces:
Sleeping
Sleeping
| # ========================================================= | |
| # app.py (STABLE VERSION - FIXED TOKEN OVERFLOW) | |
| # ========================================================= | |
| import gradio as gr | |
| from transformers import pipeline | |
| from pypdf import PdfReader | |
| from pdf2image import convert_from_path | |
| import pytesseract | |
| import tempfile | |
| # ========================================================= | |
| # Models | |
| # ========================================================= | |
| MODELS = { | |
| "English model (ubffm/academic_text_classifier_en)": "ubffm/academic_text_classifier_en", | |
| "German model (ubffm/academic_text_classifier_de)": "ubffm/academic_text_classifier_de", | |
| } | |
| DEFAULT_MODEL = "English model (ubffm/academic_text_classifier_en)" | |
| # ========================================================= | |
| # Labels | |
| # ========================================================= | |
| LABELS = ["OUT OF SCOPE", "MAIN TEXT", "EXAMPLE", "REFERENCE"] | |
| DEFAULT_NOISE = ["OUT OF SCOPE", "REFERENCE"] | |
| # ========================================================= | |
| # Pipeline cache | |
| # ========================================================= | |
| PIPELINES = {} | |
| def get_classifier(model_display_name): | |
| model_name = MODELS[model_display_name] | |
| if model_name not in PIPELINES: | |
| PIPELINES[model_name] = pipeline( | |
| "text-classification", | |
| model=model_name, | |
| tokenizer=model_name, | |
| return_all_scores=True | |
| ) | |
| return PIPELINES[model_name] | |
| # ========================================================= | |
| # Safe prediction (IMPORTANT FIX) | |
| # ========================================================= | |
| def get_best_prediction(classifier, text): | |
| # HARD SAFETY: prevents tokenizer overflow in pipeline | |
| text = text[:2000] | |
| result = classifier(text) | |
| if isinstance(result, list) and isinstance(result[0], list): | |
| result = result[0] | |
| return max(result, key=lambda x: x["score"]), result | |
| # ========================================================= | |
| # SAFE CHUNKING (ROBUST FIX) | |
| # ========================================================= | |
| def safe_chunk_text(text, tokenizer, max_tokens=480): | |
| """ | |
| True safe chunking for XLM-R (leaves room for special tokens) | |
| """ | |
| sentences = text.split("\n") | |
| chunks = [] | |
| current = [] | |
| current_len = 0 | |
| for sent in sentences: | |
| sent_tokens = tokenizer.encode(sent, add_special_tokens=False) | |
| sent_len = len(sent_tokens) | |
| # Case 1: single sentence too large → hard split | |
| if sent_len > max_tokens: | |
| for i in range(0, sent_len, max_tokens): | |
| part = tokenizer.decode(sent_tokens[i:i + max_tokens]) | |
| chunks.append(part) | |
| continue | |
| # Case 2: overflow chunk | |
| if current_len + sent_len > max_tokens: | |
| chunks.append("\n".join(current)) | |
| current = [sent] | |
| current_len = sent_len | |
| else: | |
| current.append(sent) | |
| current_len += sent_len | |
| if current: | |
| chunks.append("\n".join(current)) | |
| return chunks | |
| # ========================================================= | |
| # Empty line cleanup | |
| # ========================================================= | |
| def normalize_empty_lines(lines): | |
| cleaned = [] | |
| prev_empty = False | |
| for l in lines: | |
| empty = not l.strip() | |
| if empty and prev_empty: | |
| continue | |
| cleaned.append(l) | |
| prev_empty = empty | |
| return cleaned | |
| # ========================================================= | |
| # CORE PIPELINE (FIXED) | |
| # ========================================================= | |
| def process_text_input(text, noise_labels, selected_model): | |
| if not text.strip(): | |
| return "", "", "", None | |
| classifier = get_classifier(selected_model) | |
| tokenizer = classifier.tokenizer | |
| chunks = safe_chunk_text(text, tokenizer) | |
| logs = [] | |
| kept = [] | |
| removed = [] | |
| line_counter = 0 | |
| for c_id, chunk in enumerate(chunks): | |
| lines = chunk.splitlines() | |
| for line in lines: | |
| line_counter += 1 | |
| if not line.strip(): | |
| kept.append("") | |
| continue | |
| pred, _ = get_best_prediction(classifier, line) | |
| logs.append( | |
| f"[Chunk {c_id}] Line {line_counter} | " | |
| f"{pred['label']} ({pred['score']:.4f})\n{line}\n" | |
| ) | |
| if pred["label"] in noise_labels: | |
| removed.append(line) | |
| else: | |
| kept.append(line) | |
| kept = normalize_empty_lines(kept) | |
| filtered = "\n".join(kept) | |
| tmp = tempfile.NamedTemporaryFile( | |
| delete=False, | |
| suffix=".txt", | |
| mode="w", | |
| encoding="utf-8" | |
| ) | |
| tmp.write(filtered) | |
| tmp.close() | |
| stats = ( | |
| f"Chunks: {len(chunks)}\n" | |
| f"Total lines: {line_counter}\n" | |
| f"Removed: {len(removed)}\n" | |
| f"Remaining: {len(kept)}" | |
| ) | |
| return "\n".join(logs), filtered, stats, tmp.name | |
| # ========================================================= | |
| # TXT FILE | |
| # ========================================================= | |
| def process_document_file(file, noise_labels, selected_model): | |
| if file is None: | |
| return "", "", "", None | |
| with open(file.name, "r", encoding="utf-8") as f: | |
| text = f.read() | |
| return process_text_input(text, noise_labels, selected_model) | |
| # ========================================================= | |
| # PDF EXTRACTION (DIGITAL + OCR) | |
| # ========================================================= | |
| def extract_text_from_pdf(pdf_file): | |
| text_parts = [] | |
| # 1. try digital extraction | |
| try: | |
| reader = PdfReader(pdf_file.name) | |
| for page in reader.pages: | |
| t = page.extract_text() | |
| if t: | |
| text_parts.append(t) | |
| except: | |
| pass | |
| text = "\n".join(text_parts).strip() | |
| # 2. OCR fallback | |
| if not text: | |
| pages = convert_from_path(pdf_file.name, dpi=300) | |
| ocr = [] | |
| for page in pages: | |
| ocr.append(pytesseract.image_to_string(page)) | |
| text = "\n".join(ocr) | |
| return text | |
| # ========================================================= | |
| # PDF PIPELINE | |
| # ========================================================= | |
| def process_pdf_file(file, noise_labels, selected_model): | |
| if file is None: | |
| return "", "", "", None | |
| text = extract_text_from_pdf(file) | |
| return process_text_input(text, noise_labels, selected_model) | |
| # ========================================================= | |
| # UI | |
| # ========================================================= | |
| with gr.Blocks(title="Stable Academic Text Filter") as demo: | |
| gr.Markdown(""" | |
| # Academic Text Filter (FIXED VERSION) | |
| ✔ No tokenizer crashes | |
| ✔ OCR + PDF support | |
| ✔ Safe chunking (XLM-R compatible) | |
| ✔ Robust long-document handling | |
| """) | |
| with gr.Tab("Text"): | |
| m = gr.Dropdown(list(MODELS.keys()), value=DEFAULT_MODEL) | |
| t = gr.Textbox(lines=20) | |
| n = gr.CheckboxGroup(LABELS, value=DEFAULT_NOISE) | |
| btn = gr.Button("Process") | |
| o1 = gr.Textbox(lines=15) | |
| o2 = gr.Textbox(lines=15) | |
| o3 = gr.Textbox() | |
| o4 = gr.File() | |
| btn.click(process_text_input, [t, n, m], [o1, o2, o3, o4]) | |
| with gr.Tab("TXT"): | |
| m = gr.Dropdown(list(MODELS.keys()), value=DEFAULT_MODEL) | |
| f = gr.File(file_types=[".txt"]) | |
| n = gr.CheckboxGroup(LABELS, value=DEFAULT_NOISE) | |
| btn = gr.Button("Process") | |
| o1 = gr.Textbox(lines=15) | |
| o2 = gr.Textbox(lines=15) | |
| o3 = gr.Textbox() | |
| o4 = gr.File() | |
| btn.click(process_document_file, [f, n, m], [o1, o2, o3, o4]) | |
| with gr.Tab("PDF"): | |
| m = gr.Dropdown(list(MODELS.keys()), value=DEFAULT_MODEL) | |
| f = gr.File(file_types=[".pdf"]) | |
| n = gr.CheckboxGroup(LABELS, value=DEFAULT_NOISE) | |
| btn = gr.Button("Process PDF") | |
| o1 = gr.Textbox(lines=15) | |
| o2 = gr.Textbox(lines=15) | |
| o3 = gr.Textbox() | |
| o4 = gr.File() | |
| btn.click(process_pdf_file, [f, n, m], [o1, o2, o3, o4]) | |
| if __name__ == "__main__": | |
| demo.launch() |