Spaces:
Running
Running
| from __future__ import annotations | |
| import argparse | |
| import csv | |
| import logging | |
| import tempfile | |
| from pathlib import Path | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForSeq2SeqLM | |
| from .common import ( | |
| DEFAULT_APP_FALLBACK_MODEL, | |
| DEFAULT_INPUT_MAX_LENGTH, | |
| default_device, | |
| ensure_project_dirs, | |
| existing_default_checkpoint, | |
| load_json, | |
| load_tokenizer, | |
| normalize_text, | |
| resolve_model_reference, | |
| ) | |
| LOGGER = logging.getLogger(__name__) | |
| try: | |
| import PyPDF2 | |
| HAS_PYPDF2 = True | |
| except ImportError: | |
| HAS_PYPDF2 = False | |
| # ββ Generation Presets ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODE_PRESETS = { | |
| "QUICK PULSE": { | |
| "max_new_tokens": 72, | |
| "min_new_tokens": 18, | |
| "num_beams": 4, | |
| "length_penalty": 1.25, | |
| }, | |
| "KEY NOTES": { | |
| "max_new_tokens": 104, | |
| "min_new_tokens": 24, | |
| "num_beams": 5, | |
| "length_penalty": 1.05, | |
| }, | |
| "DEEP CONTEXT": { | |
| "max_new_tokens": 152, | |
| "min_new_tokens": 34, | |
| "num_beams": 6, | |
| "length_penalty": 0.92, | |
| }, | |
| } | |
| DEFAULT_MODE = "QUICK PULSE" | |
| # ββ Wonder Makers-inspired CSS ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| APP_CSS = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700;800;900&family=JetBrains+Mono:wght@400;500&display=swap'); | |
| :root { | |
| --black: #000000; | |
| --white: #FFFFFF; | |
| --lime: #D4FF00; | |
| --lime-dim: rgba(212, 255, 0, 0.15); | |
| --lime-glow: rgba(212, 255, 0, 0.08); | |
| --grey-100: #F5F5F5; | |
| --grey-400: #9CA3AF; | |
| --grey-600: #52525B; | |
| --grey-800: #27272A; | |
| --grey-900: #18181B; | |
| --border: rgba(255, 255, 255, 0.06); | |
| --border-hover: rgba(255, 255, 255, 0.12); | |
| --fn: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif; | |
| --mono: 'JetBrains Mono', monospace; | |
| --ease: cubic-bezier(0.16, 1, 0.3, 1); | |
| } | |
| /* βββ Global Reset βββ */ | |
| *, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; } | |
| body { | |
| background: var(--black) !important; | |
| color: var(--white) !important; | |
| font-family: var(--fn) !important; | |
| -webkit-font-smoothing: antialiased; | |
| -moz-osx-font-smoothing: grayscale; | |
| overflow-x: hidden; | |
| } | |
| /* Ambient glow β subtle purple/blue vignette like Wonder Makers */ | |
| body::before { | |
| content: ''; | |
| position: fixed; | |
| inset: 0; | |
| background: | |
| radial-gradient(ellipse 50% 50% at 0% 0%, rgba(120, 80, 255, 0.06), transparent 70%), | |
| radial-gradient(ellipse 40% 40% at 100% 100%, rgba(212, 255, 0, 0.03), transparent 60%); | |
| pointer-events: none; | |
| z-index: -1; | |
| } | |
| /* βββ Gradio Container Overrides βββ */ | |
| .gradio-container { | |
| max-width: 1100px !important; | |
| margin: 0 auto !important; | |
| padding: 0 !important; | |
| background: transparent !important; | |
| } | |
| footer { display: none !important; } | |
| /* Kill ALL default Gradio backgrounds */ | |
| .gradio-container, .gradio-container *, | |
| .gr-box, .gr-panel, .gr-form, .gr-block, | |
| [class*="block"], [class*="form"], [class*="panel"], | |
| [class*="accordion"], [class*="markdown"] { | |
| background: transparent !important; | |
| color: var(--white) !important; | |
| } | |
| /* βββ HERO HEADER βββ */ | |
| .wm-hero { | |
| text-align: center; | |
| padding: 64px 24px 48px; | |
| position: relative; | |
| } | |
| .wm-hero h1 { | |
| font-family: var(--fn) !important; | |
| font-size: 3.2rem !important; | |
| font-weight: 900 !important; | |
| letter-spacing: -0.04em !important; | |
| text-transform: uppercase !important; | |
| line-height: 1.05 !important; | |
| margin: 0 0 16px 0 !important; | |
| background: linear-gradient(135deg, var(--white) 60%, var(--grey-400)); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| background-clip: text; | |
| } | |
| .wm-hero .wm-sub { | |
| font-size: 0.95rem; | |
| color: var(--grey-400); | |
| font-weight: 400; | |
| letter-spacing: 0.08em; | |
| text-transform: uppercase; | |
| margin-bottom: 0; | |
| } | |
| .wm-hero .wm-accent { | |
| display: inline-block; | |
| background: var(--lime); | |
| color: var(--black); | |
| font-weight: 700; | |
| font-size: 0.7rem; | |
| letter-spacing: 0.15em; | |
| text-transform: uppercase; | |
| padding: 6px 18px; | |
| border-radius: 100px; | |
| margin-top: 20px; | |
| } | |
| /* βββ DIVIDER LINE βββ */ | |
| .wm-divider { | |
| height: 1px; | |
| background: var(--border); | |
| margin: 0 32px; | |
| } | |
| /* βββ WORKSPACE βββ */ | |
| .wm-workspace { | |
| display: grid !important; | |
| grid-template-columns: 1fr 1fr; | |
| gap: 2px; | |
| padding: 0 !important; | |
| margin: 0 !important; | |
| } | |
| .wm-pane { | |
| padding: 40px 36px !important; | |
| min-height: 480px; | |
| display: flex; | |
| flex-direction: column; | |
| background: transparent !important; | |
| border: none !important; | |
| border-radius: 0 !important; | |
| position: relative; | |
| } | |
| /* Vertical separator between panes */ | |
| .wm-pane:first-child { | |
| border-right: 1px solid var(--border) !important; | |
| } | |
| .wm-pane-label { | |
| font-size: 0.65rem !important; | |
| font-weight: 600 !important; | |
| letter-spacing: 0.2em !important; | |
| text-transform: uppercase !important; | |
| color: var(--grey-600) !important; | |
| margin-bottom: 24px !important; | |
| display: flex; | |
| align-items: center; | |
| gap: 10px; | |
| } | |
| .wm-pane-label .wm-dot { | |
| width: 6px; | |
| height: 6px; | |
| border-radius: 50%; | |
| background: var(--lime); | |
| box-shadow: 0 0 8px var(--lime); | |
| } | |
| .wm-pane-label .wm-dot-cyan { | |
| background: #06b6d4; | |
| box-shadow: 0 0 8px rgba(6, 182, 212, 0.6); | |
| } | |
| /* βββ TEXT AREAS βββ */ | |
| .wm-input textarea, .wm-output textarea { | |
| background: rgba(255, 255, 255, 0.02) !important; | |
| border: 1px solid var(--border) !important; | |
| border-radius: 12px !important; | |
| color: var(--white) !important; | |
| font-family: var(--fn) !important; | |
| font-size: 0.95rem !important; | |
| line-height: 1.8 !important; | |
| padding: 20px 24px !important; | |
| resize: none !important; | |
| transition: border-color 0.4s var(--ease), box-shadow 0.4s var(--ease) !important; | |
| } | |
| .wm-input textarea:focus { | |
| border-color: rgba(212, 255, 0, 0.3) !important; | |
| box-shadow: 0 0 0 4px var(--lime-glow), inset 0 1px 4px rgba(0,0,0,0.3) !important; | |
| outline: none !important; | |
| } | |
| .wm-input textarea::placeholder { | |
| color: var(--grey-600) !important; | |
| font-style: italic; | |
| } | |
| /* βββ BUTTONS βββ */ | |
| .wm-btn-primary { | |
| background: var(--lime) !important; | |
| color: var(--black) !important; | |
| font-family: var(--fn) !important; | |
| font-weight: 700 !important; | |
| font-size: 0.75rem !important; | |
| letter-spacing: 0.12em !important; | |
| text-transform: uppercase !important; | |
| border: none !important; | |
| border-radius: 100px !important; | |
| padding: 16px 40px !important; | |
| cursor: pointer !important; | |
| transition: transform 0.3s var(--ease), box-shadow 0.3s var(--ease), background 0.3s !important; | |
| } | |
| .wm-btn-primary:hover { | |
| transform: translateY(-2px) !important; | |
| box-shadow: 0 8px 32px rgba(212, 255, 0, 0.25) !important; | |
| background: #e0ff33 !important; | |
| } | |
| .wm-btn-primary:active { | |
| transform: translateY(0) !important; | |
| } | |
| .wm-btn-ghost { | |
| background: transparent !important; | |
| color: var(--grey-400) !important; | |
| font-family: var(--fn) !important; | |
| font-weight: 500 !important; | |
| font-size: 0.75rem !important; | |
| letter-spacing: 0.1em !important; | |
| text-transform: uppercase !important; | |
| border: 1px solid var(--border) !important; | |
| border-radius: 100px !important; | |
| padding: 14px 28px !important; | |
| cursor: pointer !important; | |
| transition: all 0.3s var(--ease) !important; | |
| } | |
| .wm-btn-ghost:hover { | |
| border-color: var(--grey-400) !important; | |
| color: var(--white) !important; | |
| } | |
| /* βββ ACTION ROW βββ */ | |
| .wm-actions { | |
| display: flex; | |
| gap: 12px; | |
| margin-top: 20px; | |
| align-items: center; | |
| } | |
| /* βββ TOKEN COUNTER βββ */ | |
| .wm-tokens { | |
| font-family: var(--mono) !important; | |
| font-size: 0.7rem !important; | |
| letter-spacing: 0.05em; | |
| margin-top: 12px; | |
| } | |
| .wm-tokens-normal { color: var(--grey-600) !important; } | |
| .wm-tokens-warning { | |
| color: #FF6B6B !important; | |
| text-shadow: 0 0 12px rgba(255, 107, 107, 0.3); | |
| } | |
| /* βββ SIDEBAR βββ */ | |
| .wm-sidebar { | |
| background: rgba(0, 0, 0, 0.95) !important; | |
| border-right: 1px solid var(--border) !important; | |
| padding: 32px 24px !important; | |
| } | |
| .wm-sidebar h3, .wm-sidebar h4 { | |
| font-size: 0.6rem !important; | |
| font-weight: 600 !important; | |
| letter-spacing: 0.2em !important; | |
| text-transform: uppercase !important; | |
| color: var(--grey-600) !important; | |
| margin-bottom: 16px !important; | |
| } | |
| /* βββ FILE UPLOAD βββ */ | |
| .wm-upload [data-testid="dropzone"] { | |
| border: 1px dashed var(--border) !important; | |
| border-radius: 12px !important; | |
| background: transparent !important; | |
| padding: 24px !important; | |
| transition: border-color 0.3s var(--ease) !important; | |
| } | |
| .wm-upload [data-testid="dropzone"]:hover { | |
| border-color: rgba(212, 255, 0, 0.3) !important; | |
| } | |
| /* βββ TABS βββ */ | |
| .tabs { border: none !important; } | |
| button.tab-nav { | |
| font-family: var(--fn) !important; | |
| font-size: 0.65rem !important; | |
| font-weight: 600 !important; | |
| letter-spacing: 0.18em !important; | |
| text-transform: uppercase !important; | |
| color: var(--grey-600) !important; | |
| border: none !important; | |
| background: transparent !important; | |
| padding: 12px 24px !important; | |
| transition: color 0.3s !important; | |
| } | |
| button.tab-nav.selected { | |
| color: var(--white) !important; | |
| border-bottom: 2px solid var(--lime) !important; | |
| } | |
| button.tab-nav:hover { color: var(--white) !important; } | |
| /* βββ ACCORDION βββ */ | |
| .wm-accordion button { | |
| font-family: var(--fn) !important; | |
| font-size: 0.65rem !important; | |
| letter-spacing: 0.15em !important; | |
| text-transform: uppercase !important; | |
| color: var(--grey-400) !important; | |
| background: transparent !important; | |
| border: 1px solid var(--border) !important; | |
| border-radius: 8px !important; | |
| } | |
| /* βββ MODEL INFO βββ */ | |
| .wm-model-info { | |
| padding: 20px 0; | |
| border-top: 1px solid var(--border); | |
| margin-top: 24px; | |
| } | |
| .wm-model-info p, .wm-model-info li { | |
| font-size: 0.8rem !important; | |
| color: var(--grey-400) !important; | |
| line-height: 1.7 !important; | |
| } | |
| .wm-model-info strong { | |
| color: var(--white) !important; | |
| } | |
| /* βββ BATCH TAB βββ */ | |
| .wm-batch-info { | |
| background: rgba(212, 255, 0, 0.04); | |
| border: 1px solid rgba(212, 255, 0, 0.1); | |
| border-radius: 12px; | |
| padding: 20px 24px; | |
| font-family: var(--mono); | |
| font-size: 0.8rem; | |
| line-height: 1.8; | |
| color: var(--grey-400); | |
| margin: 16px 0 24px; | |
| } | |
| .wm-batch-info strong { | |
| color: var(--lime); | |
| font-weight: 600; | |
| } | |
| /* βββ SLIDERS βββ */ | |
| input[type="range"] { | |
| accent-color: var(--lime) !important; | |
| } | |
| /* βββ RESPONSIVE βββ */ | |
| @media (max-width: 768px) { | |
| .wm-workspace { grid-template-columns: 1fr !important; } | |
| .wm-pane:first-child { | |
| border-right: none !important; | |
| border-bottom: 1px solid var(--border) !important; | |
| } | |
| .wm-hero h1 { font-size: 2rem !important; } | |
| } | |
| """ | |
| # ββ CLI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Launch the ML summarization UI.") | |
| parser.add_argument("--model-path", default=existing_default_checkpoint()) | |
| parser.add_argument("--fallback-model", default=DEFAULT_APP_FALLBACK_MODEL) | |
| parser.add_argument("--max-input-length", type=int, default=DEFAULT_INPUT_MAX_LENGTH) | |
| parser.add_argument("--server-name", default="127.0.0.1") | |
| parser.add_argument("--server-port", type=int, default=7860) | |
| parser.add_argument("--share", action="store_true") | |
| return parser.parse_args() | |
| def load_model_info(model_path: str) -> str: | |
| path = Path(model_path) | |
| if not path.exists(): | |
| return f"**Hub Model** β `{model_path}`" | |
| info = f"**Checkpoint** β `{path.name}`\n" | |
| metrics_path = path / "metrics" / "test_metrics.json" | |
| if metrics_path.exists(): | |
| try: | |
| m = load_json(metrics_path) | |
| r1 = m.get("test_rouge1", 0) | |
| rl = m.get("test_rougeL", 0) | |
| info += f"- ROUGE-1: **{r1:.4f}**\n- ROUGE-L: **{rl:.4f}**\n" | |
| except Exception: | |
| pass | |
| return info | |
| def read_file_content(file_obj) -> str: | |
| if file_obj is None: | |
| return "" | |
| file_path = Path(file_obj.name) | |
| if file_path.suffix.lower() == ".pdf": | |
| if not HAS_PYPDF2: | |
| raise gr.Error("PyPDF2 is not installed. Run `pip install pypdf2` for PDF support.") | |
| try: | |
| with open(file_path, "rb") as f: | |
| reader = PyPDF2.PdfReader(f) | |
| return "\n".join(page.extract_text() for page in reader.pages) | |
| except Exception as e: | |
| raise gr.Error(f"Failed to read PDF: {e}") | |
| else: | |
| try: | |
| return file_path.read_text(encoding="utf-8") | |
| except Exception as e: | |
| raise gr.Error(f"Failed to read file: {e}") | |
| # ββ Build the UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_demo( | |
| model, tokenizer, model_reference: str, max_input_length: int, device: torch.device | |
| ) -> gr.Blocks: | |
| default_preset = MODE_PRESETS[DEFAULT_MODE] | |
| def count_tokens(text: str) -> str: | |
| cleaned = normalize_text(text) | |
| if not cleaned: | |
| return f"<span class='wm-tokens-normal'>{0:03d} / {max_input_length} TOKENS</span>" | |
| tokens = tokenizer(cleaned, truncation=False)["input_ids"] | |
| count = len(tokens) | |
| if count > max_input_length: | |
| return ( | |
| f"<span class='wm-tokens-warning'>β {count:,} / {max_input_length} TOKENS " | |
| f"β INPUT WILL BE TRUNCATED</span>" | |
| ) | |
| return f"<span class='wm-tokens-normal'>{count:,} / {max_input_length} TOKENS</span>" | |
| def summarize(text, max_new_tokens, min_new_tokens, num_beams, length_penalty): | |
| cleaned_text = normalize_text(text) | |
| if not cleaned_text: | |
| raise gr.Error("Please enter a document to summarize.") | |
| tokenized = tokenizer( | |
| cleaned_text, return_tensors="pt", truncation=True, max_length=max_input_length | |
| ).to(device) | |
| try: | |
| generated = model.generate( | |
| **tokenized, | |
| max_new_tokens=max_new_tokens, | |
| min_length=min_new_tokens, | |
| num_beams=num_beams, | |
| length_penalty=length_penalty, | |
| no_repeat_ngram_size=3, | |
| early_stopping=True, | |
| max_time=45.0, | |
| ) | |
| except torch.cuda.OutOfMemoryError: | |
| raise gr.Error( | |
| "CUDA Out of Memory. Reduce input length or beam count." | |
| ) | |
| except Exception as e: | |
| raise gr.Error(f"Generation failed: {e}") | |
| return tokenizer.decode(generated[0], skip_special_tokens=True).strip() | |
| def batch_summarize(file_obj, max_new_tokens, min_new_tokens, num_beams, length_penalty): | |
| if file_obj is None: | |
| raise gr.Error("Upload a .txt file with one document per line.") | |
| try: | |
| lines = Path(file_obj.name).read_text(encoding="utf-8").splitlines() | |
| except Exception as e: | |
| raise gr.Error(f"Failed to read file: {e}") | |
| results = [] | |
| for line in lines: | |
| if not line.strip(): | |
| continue | |
| summary = summarize(line, max_new_tokens, min_new_tokens, num_beams, length_penalty) | |
| results.append({"source": line.strip(), "summary": summary}) | |
| out_path = Path(tempfile.gettempdir()) / "batch_results.csv" | |
| with open(out_path, "w", newline="", encoding="utf-8") as f: | |
| writer = csv.DictWriter(f, fieldnames=["source", "summary"]) | |
| writer.writeheader() | |
| writer.writerows(results) | |
| return str(out_path) | |
| # ββ Theme βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| theme = gr.themes.Base( | |
| primary_hue=gr.themes.colors.lime, | |
| secondary_hue=gr.themes.colors.cyan, | |
| neutral_hue=gr.themes.colors.zinc, | |
| ).set( | |
| body_background_fill="#000000", | |
| block_background_fill="transparent", | |
| input_background_fill="rgba(255,255,255,0.02)", | |
| body_text_color="#FFFFFF", | |
| block_label_text_color="#52525B", | |
| ) | |
| with gr.Blocks(title="Prism Studio", theme=theme) as demo: | |
| # Inject CSS via HTML since Gradio 6 moved css= to launch() | |
| gr.HTML(f"<style>{APP_CSS}</style>") | |
| # ββ Hero Header ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| gr.HTML(""" | |
| <div class="wm-hero"> | |
| <h1>PRISM<br>STUDIO.</h1> | |
| <p class="wm-sub">Neural Text Summarization Β· Engineered</p> | |
| <span class="wm-accent">BART Fine-Tuned on XSum</span> | |
| </div> | |
| <div class="wm-divider"></div> | |
| """) | |
| # ββ Sidebar ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Sidebar(elem_classes=["wm-sidebar"]): | |
| gr.HTML("<h3>Control Panel</h3>") | |
| mode_selector = gr.Dropdown( | |
| choices=list(MODE_PRESETS.keys()), | |
| value=DEFAULT_MODE, | |
| label="Generation Preset", | |
| ) | |
| with gr.Accordion("Advanced Tuning", open=False, elem_classes=["wm-accordion"]): | |
| max_new_tokens = gr.Slider( | |
| 32, 256, value=default_preset["max_new_tokens"], step=8, label="Max tokens" | |
| ) | |
| min_new_tokens = gr.Slider( | |
| 8, 96, value=default_preset["min_new_tokens"], step=4, label="Min tokens" | |
| ) | |
| num_beams = gr.Slider( | |
| 1, 8, value=default_preset["num_beams"], step=1, label="Beams" | |
| ) | |
| length_penalty = gr.Slider( | |
| 0.6, 2.0, value=default_preset["length_penalty"], step=0.05, label="Length penalty" | |
| ) | |
| gr.HTML("<div class='wm-model-info'></div>") | |
| gr.HTML("<h4>Active Model</h4>") | |
| gr.Markdown(load_model_info(model_reference)) | |
| # ββ Tabs βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tabs(): | |
| # ββ STUDIO TAB βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("STUDIO"): | |
| with gr.Row(elem_classes=["wm-workspace"]): | |
| # Left β Source | |
| with gr.Column(elem_classes=["wm-pane"]): | |
| gr.HTML(""" | |
| <div class="wm-pane-label"> | |
| <span class="wm-dot"></span> SOURCE DOCUMENT | |
| </div> | |
| """) | |
| file_upload = gr.File( | |
| label="Upload .txt or .pdf", | |
| file_types=[".txt", ".pdf"], | |
| elem_classes=["wm-upload"], | |
| ) | |
| input_text = gr.Textbox( | |
| show_label=False, | |
| placeholder="Paste your document here...", | |
| lines=16, | |
| elem_classes=["wm-input"], | |
| ) | |
| token_display = gr.HTML( | |
| f"<div class='wm-tokens'>" | |
| f"<span class='wm-tokens-normal'>000 / {max_input_length} TOKENS</span>" | |
| f"</div>" | |
| ) | |
| with gr.Row(elem_classes=["wm-actions"]): | |
| clear_btn = gr.Button("CLEAR", elem_classes=["wm-btn-ghost"]) | |
| summarize_btn = gr.Button("SUMMARIZE β", elem_classes=["wm-btn-primary"]) | |
| # Right β Output | |
| with gr.Column(elem_classes=["wm-pane"]): | |
| gr.HTML(""" | |
| <div class="wm-pane-label"> | |
| <span class="wm-dot wm-dot-cyan"></span> GENERATED OUTPUT | |
| </div> | |
| """) | |
| output_text = gr.Textbox( | |
| show_label=False, | |
| interactive=False, | |
| lines=20, | |
| elem_classes=["wm-output"], | |
| ) | |
| # ββ BATCH TAB ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("BATCH"): | |
| gr.HTML(""" | |
| <div class="wm-pane-label" style="padding: 32px 0 8px;"> | |
| <span class="wm-dot"></span> BULK INFERENCE | |
| </div> | |
| """) | |
| gr.HTML(""" | |
| <div class="wm-batch-info"> | |
| <strong>TEMPLATE FORMAT</strong><br> | |
| Line 1: First document to summarize.<br> | |
| Line 2: Second document to summarize.<br> | |
| Line 3: Third document to summarize. | |
| </div> | |
| """) | |
| batch_upload = gr.File( | |
| label="Upload batch .txt", | |
| file_types=[".txt"], | |
| elem_classes=["wm-upload"], | |
| ) | |
| batch_btn = gr.Button("RUN BATCH β", elem_classes=["wm-btn-primary"]) | |
| batch_download = gr.File(label="Download CSV Results", interactive=False) | |
| # ββ Event Wiring βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def update_params(mode): | |
| p = MODE_PRESETS[mode] | |
| return p["max_new_tokens"], p["min_new_tokens"], p["num_beams"], p["length_penalty"] | |
| mode_selector.change( | |
| update_params, | |
| inputs=[mode_selector], | |
| outputs=[max_new_tokens, min_new_tokens, num_beams, length_penalty], | |
| ) | |
| file_upload.change(read_file_content, inputs=[file_upload], outputs=[input_text]) | |
| input_text.change(count_tokens, inputs=[input_text], outputs=[token_display]) | |
| summarize_btn.click( | |
| summarize, | |
| inputs=[input_text, max_new_tokens, min_new_tokens, num_beams, length_penalty], | |
| outputs=[output_text], | |
| ) | |
| clear_btn.click( | |
| lambda: ( | |
| None, | |
| "", | |
| f"<div class='wm-tokens'><span class='wm-tokens-normal'>000 / {max_input_length} TOKENS</span></div>", | |
| "", | |
| ), | |
| inputs=None, | |
| outputs=[file_upload, input_text, token_display, output_text], | |
| ) | |
| batch_btn.click( | |
| batch_summarize, | |
| inputs=[batch_upload, max_new_tokens, min_new_tokens, num_beams, length_penalty], | |
| outputs=[batch_download], | |
| ) | |
| return demo | |
| # ββ Entrypoint ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main() -> None: | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
| ) | |
| args = parse_args() | |
| ensure_project_dirs() | |
| model_reference = resolve_model_reference(args.model_path, fallback=args.fallback_model) | |
| device = default_device() | |
| LOGGER.info("Loading model from %s", model_reference) | |
| tokenizer = load_tokenizer(model_reference) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_reference) | |
| if getattr(model.generation_config, "max_length", None) == 20: | |
| model.generation_config.max_length = None | |
| model.to(device) | |
| model.eval() | |
| demo = build_demo(model, tokenizer, model_reference, args.max_input_length, device) | |
| demo.queue().launch( | |
| server_name=args.server_name, | |
| server_port=args.server_port, | |
| share=args.share, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |