Adive01 commited on
Commit
ce2bcea
Β·
verified Β·
1 Parent(s): 30653a0

Upload mlplo/app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. mlplo/app.py +722 -0
mlplo/app.py ADDED
@@ -0,0 +1,722 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import csv
5
+ import logging
6
+ import tempfile
7
+ from pathlib import Path
8
+
9
+ import gradio as gr
10
+ import torch
11
+ from transformers import AutoModelForSeq2SeqLM
12
+
13
+ from .common import (
14
+ DEFAULT_APP_FALLBACK_MODEL,
15
+ DEFAULT_INPUT_MAX_LENGTH,
16
+ default_device,
17
+ ensure_project_dirs,
18
+ existing_default_checkpoint,
19
+ load_json,
20
+ load_tokenizer,
21
+ normalize_text,
22
+ resolve_model_reference,
23
+ )
24
+
25
+ LOGGER = logging.getLogger(__name__)
26
+
27
+ try:
28
+ import PyPDF2
29
+
30
+ HAS_PYPDF2 = True
31
+ except ImportError:
32
+ HAS_PYPDF2 = False
33
+
34
+ # ── Generation Presets ────────────────────────────────────────────────────────
35
+ MODE_PRESETS = {
36
+ "QUICK PULSE": {
37
+ "max_new_tokens": 72,
38
+ "min_new_tokens": 18,
39
+ "num_beams": 4,
40
+ "length_penalty": 1.25,
41
+ },
42
+ "KEY NOTES": {
43
+ "max_new_tokens": 104,
44
+ "min_new_tokens": 24,
45
+ "num_beams": 5,
46
+ "length_penalty": 1.05,
47
+ },
48
+ "DEEP CONTEXT": {
49
+ "max_new_tokens": 152,
50
+ "min_new_tokens": 34,
51
+ "num_beams": 6,
52
+ "length_penalty": 0.92,
53
+ },
54
+ }
55
+
56
+ DEFAULT_MODE = "QUICK PULSE"
57
+
58
+ # ── Wonder Makers-inspired CSS ────────────────────────────────────────────────
59
+ APP_CSS = """
60
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700;800;900&family=JetBrains+Mono:wght@400;500&display=swap');
61
+
62
+ :root {
63
+ --black: #000000;
64
+ --white: #FFFFFF;
65
+ --lime: #D4FF00;
66
+ --lime-dim: rgba(212, 255, 0, 0.15);
67
+ --lime-glow: rgba(212, 255, 0, 0.08);
68
+ --grey-100: #F5F5F5;
69
+ --grey-400: #9CA3AF;
70
+ --grey-600: #52525B;
71
+ --grey-800: #27272A;
72
+ --grey-900: #18181B;
73
+ --border: rgba(255, 255, 255, 0.06);
74
+ --border-hover: rgba(255, 255, 255, 0.12);
75
+ --fn: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif;
76
+ --mono: 'JetBrains Mono', monospace;
77
+ --ease: cubic-bezier(0.16, 1, 0.3, 1);
78
+ }
79
+
80
+ /* ─── Global Reset ─── */
81
+ *, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
82
+
83
+ body {
84
+ background: var(--black) !important;
85
+ color: var(--white) !important;
86
+ font-family: var(--fn) !important;
87
+ -webkit-font-smoothing: antialiased;
88
+ -moz-osx-font-smoothing: grayscale;
89
+ overflow-x: hidden;
90
+ }
91
+
92
+ /* Ambient glow β€” subtle purple/blue vignette like Wonder Makers */
93
+ body::before {
94
+ content: '';
95
+ position: fixed;
96
+ inset: 0;
97
+ background:
98
+ radial-gradient(ellipse 50% 50% at 0% 0%, rgba(120, 80, 255, 0.06), transparent 70%),
99
+ radial-gradient(ellipse 40% 40% at 100% 100%, rgba(212, 255, 0, 0.03), transparent 60%);
100
+ pointer-events: none;
101
+ z-index: -1;
102
+ }
103
+
104
+ /* ─── Gradio Container Overrides ─── */
105
+ .gradio-container {
106
+ max-width: 1100px !important;
107
+ margin: 0 auto !important;
108
+ padding: 0 !important;
109
+ background: transparent !important;
110
+ }
111
+
112
+ footer { display: none !important; }
113
+
114
+ /* Kill ALL default Gradio backgrounds */
115
+ .gradio-container, .gradio-container *,
116
+ .gr-box, .gr-panel, .gr-form, .gr-block,
117
+ [class*="block"], [class*="form"], [class*="panel"],
118
+ [class*="accordion"], [class*="markdown"] {
119
+ background: transparent !important;
120
+ color: var(--white) !important;
121
+ }
122
+
123
+ /* ─── HERO HEADER ─── */
124
+ .wm-hero {
125
+ text-align: center;
126
+ padding: 64px 24px 48px;
127
+ position: relative;
128
+ }
129
+ .wm-hero h1 {
130
+ font-family: var(--fn) !important;
131
+ font-size: 3.2rem !important;
132
+ font-weight: 900 !important;
133
+ letter-spacing: -0.04em !important;
134
+ text-transform: uppercase !important;
135
+ line-height: 1.05 !important;
136
+ margin: 0 0 16px 0 !important;
137
+ background: linear-gradient(135deg, var(--white) 60%, var(--grey-400));
138
+ -webkit-background-clip: text;
139
+ -webkit-text-fill-color: transparent;
140
+ background-clip: text;
141
+ }
142
+ .wm-hero .wm-sub {
143
+ font-size: 0.95rem;
144
+ color: var(--grey-400);
145
+ font-weight: 400;
146
+ letter-spacing: 0.08em;
147
+ text-transform: uppercase;
148
+ margin-bottom: 0;
149
+ }
150
+ .wm-hero .wm-accent {
151
+ display: inline-block;
152
+ background: var(--lime);
153
+ color: var(--black);
154
+ font-weight: 700;
155
+ font-size: 0.7rem;
156
+ letter-spacing: 0.15em;
157
+ text-transform: uppercase;
158
+ padding: 6px 18px;
159
+ border-radius: 100px;
160
+ margin-top: 20px;
161
+ }
162
+
163
+ /* ─── DIVIDER LINE ─── */
164
+ .wm-divider {
165
+ height: 1px;
166
+ background: var(--border);
167
+ margin: 0 32px;
168
+ }
169
+
170
+ /* ─── WORKSPACE ─── */
171
+ .wm-workspace {
172
+ display: grid !important;
173
+ grid-template-columns: 1fr 1fr;
174
+ gap: 2px;
175
+ padding: 0 !important;
176
+ margin: 0 !important;
177
+ }
178
+
179
+ .wm-pane {
180
+ padding: 40px 36px !important;
181
+ min-height: 480px;
182
+ display: flex;
183
+ flex-direction: column;
184
+ background: transparent !important;
185
+ border: none !important;
186
+ border-radius: 0 !important;
187
+ position: relative;
188
+ }
189
+
190
+ /* Vertical separator between panes */
191
+ .wm-pane:first-child {
192
+ border-right: 1px solid var(--border) !important;
193
+ }
194
+
195
+ .wm-pane-label {
196
+ font-size: 0.65rem !important;
197
+ font-weight: 600 !important;
198
+ letter-spacing: 0.2em !important;
199
+ text-transform: uppercase !important;
200
+ color: var(--grey-600) !important;
201
+ margin-bottom: 24px !important;
202
+ display: flex;
203
+ align-items: center;
204
+ gap: 10px;
205
+ }
206
+ .wm-pane-label .wm-dot {
207
+ width: 6px;
208
+ height: 6px;
209
+ border-radius: 50%;
210
+ background: var(--lime);
211
+ box-shadow: 0 0 8px var(--lime);
212
+ }
213
+ .wm-pane-label .wm-dot-cyan {
214
+ background: #06b6d4;
215
+ box-shadow: 0 0 8px rgba(6, 182, 212, 0.6);
216
+ }
217
+
218
+ /* ─── TEXT AREAS ─── */
219
+ .wm-input textarea, .wm-output textarea {
220
+ background: rgba(255, 255, 255, 0.02) !important;
221
+ border: 1px solid var(--border) !important;
222
+ border-radius: 12px !important;
223
+ color: var(--white) !important;
224
+ font-family: var(--fn) !important;
225
+ font-size: 0.95rem !important;
226
+ line-height: 1.8 !important;
227
+ padding: 20px 24px !important;
228
+ resize: none !important;
229
+ transition: border-color 0.4s var(--ease), box-shadow 0.4s var(--ease) !important;
230
+ }
231
+ .wm-input textarea:focus {
232
+ border-color: rgba(212, 255, 0, 0.3) !important;
233
+ box-shadow: 0 0 0 4px var(--lime-glow), inset 0 1px 4px rgba(0,0,0,0.3) !important;
234
+ outline: none !important;
235
+ }
236
+ .wm-input textarea::placeholder {
237
+ color: var(--grey-600) !important;
238
+ font-style: italic;
239
+ }
240
+
241
+ /* ─── BUTTONS ─── */
242
+ .wm-btn-primary {
243
+ background: var(--lime) !important;
244
+ color: var(--black) !important;
245
+ font-family: var(--fn) !important;
246
+ font-weight: 700 !important;
247
+ font-size: 0.75rem !important;
248
+ letter-spacing: 0.12em !important;
249
+ text-transform: uppercase !important;
250
+ border: none !important;
251
+ border-radius: 100px !important;
252
+ padding: 16px 40px !important;
253
+ cursor: pointer !important;
254
+ transition: transform 0.3s var(--ease), box-shadow 0.3s var(--ease), background 0.3s !important;
255
+ }
256
+ .wm-btn-primary:hover {
257
+ transform: translateY(-2px) !important;
258
+ box-shadow: 0 8px 32px rgba(212, 255, 0, 0.25) !important;
259
+ background: #e0ff33 !important;
260
+ }
261
+ .wm-btn-primary:active {
262
+ transform: translateY(0) !important;
263
+ }
264
+
265
+ .wm-btn-ghost {
266
+ background: transparent !important;
267
+ color: var(--grey-400) !important;
268
+ font-family: var(--fn) !important;
269
+ font-weight: 500 !important;
270
+ font-size: 0.75rem !important;
271
+ letter-spacing: 0.1em !important;
272
+ text-transform: uppercase !important;
273
+ border: 1px solid var(--border) !important;
274
+ border-radius: 100px !important;
275
+ padding: 14px 28px !important;
276
+ cursor: pointer !important;
277
+ transition: all 0.3s var(--ease) !important;
278
+ }
279
+ .wm-btn-ghost:hover {
280
+ border-color: var(--grey-400) !important;
281
+ color: var(--white) !important;
282
+ }
283
+
284
+ /* ─── ACTION ROW ─── */
285
+ .wm-actions {
286
+ display: flex;
287
+ gap: 12px;
288
+ margin-top: 20px;
289
+ align-items: center;
290
+ }
291
+
292
+ /* ─── TOKEN COUNTER ─── */
293
+ .wm-tokens {
294
+ font-family: var(--mono) !important;
295
+ font-size: 0.7rem !important;
296
+ letter-spacing: 0.05em;
297
+ margin-top: 12px;
298
+ }
299
+ .wm-tokens-normal { color: var(--grey-600) !important; }
300
+ .wm-tokens-warning {
301
+ color: #FF6B6B !important;
302
+ text-shadow: 0 0 12px rgba(255, 107, 107, 0.3);
303
+ }
304
+
305
+ /* ─── SIDEBAR ─── */
306
+ .wm-sidebar {
307
+ background: rgba(0, 0, 0, 0.95) !important;
308
+ border-right: 1px solid var(--border) !important;
309
+ padding: 32px 24px !important;
310
+ }
311
+ .wm-sidebar h3, .wm-sidebar h4 {
312
+ font-size: 0.6rem !important;
313
+ font-weight: 600 !important;
314
+ letter-spacing: 0.2em !important;
315
+ text-transform: uppercase !important;
316
+ color: var(--grey-600) !important;
317
+ margin-bottom: 16px !important;
318
+ }
319
+
320
+ /* ─── FILE UPLOAD ─── */
321
+ .wm-upload [data-testid="dropzone"] {
322
+ border: 1px dashed var(--border) !important;
323
+ border-radius: 12px !important;
324
+ background: transparent !important;
325
+ padding: 24px !important;
326
+ transition: border-color 0.3s var(--ease) !important;
327
+ }
328
+ .wm-upload [data-testid="dropzone"]:hover {
329
+ border-color: rgba(212, 255, 0, 0.3) !important;
330
+ }
331
+
332
+ /* ─── TABS ─── */
333
+ .tabs { border: none !important; }
334
+ button.tab-nav {
335
+ font-family: var(--fn) !important;
336
+ font-size: 0.65rem !important;
337
+ font-weight: 600 !important;
338
+ letter-spacing: 0.18em !important;
339
+ text-transform: uppercase !important;
340
+ color: var(--grey-600) !important;
341
+ border: none !important;
342
+ background: transparent !important;
343
+ padding: 12px 24px !important;
344
+ transition: color 0.3s !important;
345
+ }
346
+ button.tab-nav.selected {
347
+ color: var(--white) !important;
348
+ border-bottom: 2px solid var(--lime) !important;
349
+ }
350
+ button.tab-nav:hover { color: var(--white) !important; }
351
+
352
+ /* ─── ACCORDION ─── */
353
+ .wm-accordion button {
354
+ font-family: var(--fn) !important;
355
+ font-size: 0.65rem !important;
356
+ letter-spacing: 0.15em !important;
357
+ text-transform: uppercase !important;
358
+ color: var(--grey-400) !important;
359
+ background: transparent !important;
360
+ border: 1px solid var(--border) !important;
361
+ border-radius: 8px !important;
362
+ }
363
+
364
+ /* ─── MODEL INFO ─── */
365
+ .wm-model-info {
366
+ padding: 20px 0;
367
+ border-top: 1px solid var(--border);
368
+ margin-top: 24px;
369
+ }
370
+ .wm-model-info p, .wm-model-info li {
371
+ font-size: 0.8rem !important;
372
+ color: var(--grey-400) !important;
373
+ line-height: 1.7 !important;
374
+ }
375
+ .wm-model-info strong {
376
+ color: var(--white) !important;
377
+ }
378
+
379
+ /* ─── BATCH TAB ─── */
380
+ .wm-batch-info {
381
+ background: rgba(212, 255, 0, 0.04);
382
+ border: 1px solid rgba(212, 255, 0, 0.1);
383
+ border-radius: 12px;
384
+ padding: 20px 24px;
385
+ font-family: var(--mono);
386
+ font-size: 0.8rem;
387
+ line-height: 1.8;
388
+ color: var(--grey-400);
389
+ margin: 16px 0 24px;
390
+ }
391
+ .wm-batch-info strong {
392
+ color: var(--lime);
393
+ font-weight: 600;
394
+ }
395
+
396
+ /* ─── SLIDERS ─── */
397
+ input[type="range"] {
398
+ accent-color: var(--lime) !important;
399
+ }
400
+
401
+ /* ─── RESPONSIVE ─── */
402
+ @media (max-width: 768px) {
403
+ .wm-workspace { grid-template-columns: 1fr !important; }
404
+ .wm-pane:first-child {
405
+ border-right: none !important;
406
+ border-bottom: 1px solid var(--border) !important;
407
+ }
408
+ .wm-hero h1 { font-size: 2rem !important; }
409
+ }
410
+ """
411
+
412
+
413
+ # ── CLI ───────────────────────────────────────────────────────────────────────
414
+ def parse_args() -> argparse.Namespace:
415
+ parser = argparse.ArgumentParser(description="Launch the ML summarization UI.")
416
+ parser.add_argument("--model-path", default=existing_default_checkpoint())
417
+ parser.add_argument("--fallback-model", default=DEFAULT_APP_FALLBACK_MODEL)
418
+ parser.add_argument("--max-input-length", type=int, default=DEFAULT_INPUT_MAX_LENGTH)
419
+ parser.add_argument("--server-name", default="127.0.0.1")
420
+ parser.add_argument("--server-port", type=int, default=7860)
421
+ parser.add_argument("--share", action="store_true")
422
+ return parser.parse_args()
423
+
424
+
425
+ def load_model_info(model_path: str) -> str:
426
+ path = Path(model_path)
427
+ if not path.exists():
428
+ return f"**Hub Model** β€” `{model_path}`"
429
+ info = f"**Checkpoint** β€” `{path.name}`\n"
430
+ metrics_path = path / "metrics" / "test_metrics.json"
431
+ if metrics_path.exists():
432
+ try:
433
+ m = load_json(metrics_path)
434
+ r1 = m.get("test_rouge1", 0)
435
+ rl = m.get("test_rougeL", 0)
436
+ info += f"- ROUGE-1: **{r1:.4f}**\n- ROUGE-L: **{rl:.4f}**\n"
437
+ except Exception:
438
+ pass
439
+ return info
440
+
441
+
442
+ def read_file_content(file_obj) -> str:
443
+ if file_obj is None:
444
+ return ""
445
+ file_path = Path(file_obj.name)
446
+ if file_path.suffix.lower() == ".pdf":
447
+ if not HAS_PYPDF2:
448
+ raise gr.Error("PyPDF2 is not installed. Run `pip install pypdf2` for PDF support.")
449
+ try:
450
+ with open(file_path, "rb") as f:
451
+ reader = PyPDF2.PdfReader(f)
452
+ return "\n".join(page.extract_text() for page in reader.pages)
453
+ except Exception as e:
454
+ raise gr.Error(f"Failed to read PDF: {e}")
455
+ else:
456
+ try:
457
+ return file_path.read_text(encoding="utf-8")
458
+ except Exception as e:
459
+ raise gr.Error(f"Failed to read file: {e}")
460
+
461
+
462
+ # ── Build the UI ──────────────────────────────────────────────────────────────
463
+ def build_demo(
464
+ model, tokenizer, model_reference: str, max_input_length: int, device: torch.device
465
+ ) -> gr.Blocks:
466
+ default_preset = MODE_PRESETS[DEFAULT_MODE]
467
+
468
+ def count_tokens(text: str) -> str:
469
+ cleaned = normalize_text(text)
470
+ if not cleaned:
471
+ return f"<span class='wm-tokens-normal'>{0:03d} / {max_input_length} TOKENS</span>"
472
+ tokens = tokenizer(cleaned, truncation=False)["input_ids"]
473
+ count = len(tokens)
474
+ if count > max_input_length:
475
+ return (
476
+ f"<span class='wm-tokens-warning'>⚠ {count:,} / {max_input_length} TOKENS "
477
+ f"β€” INPUT WILL BE TRUNCATED</span>"
478
+ )
479
+ return f"<span class='wm-tokens-normal'>{count:,} / {max_input_length} TOKENS</span>"
480
+
481
+ @torch.inference_mode()
482
+ def summarize(text, max_new_tokens, min_new_tokens, num_beams, length_penalty):
483
+ cleaned_text = normalize_text(text)
484
+ if not cleaned_text:
485
+ raise gr.Error("Please enter a document to summarize.")
486
+
487
+ tokenized = tokenizer(
488
+ cleaned_text, return_tensors="pt", truncation=True, max_length=max_input_length
489
+ ).to(device)
490
+
491
+ try:
492
+ generated = model.generate(
493
+ **tokenized,
494
+ max_new_tokens=max_new_tokens,
495
+ min_length=min_new_tokens,
496
+ num_beams=num_beams,
497
+ length_penalty=length_penalty,
498
+ no_repeat_ngram_size=3,
499
+ early_stopping=True,
500
+ max_time=45.0,
501
+ )
502
+ except torch.cuda.OutOfMemoryError:
503
+ raise gr.Error(
504
+ "CUDA Out of Memory. Reduce input length or beam count."
505
+ )
506
+ except Exception as e:
507
+ raise gr.Error(f"Generation failed: {e}")
508
+
509
+ return tokenizer.decode(generated[0], skip_special_tokens=True).strip()
510
+
511
+ def batch_summarize(file_obj, max_new_tokens, min_new_tokens, num_beams, length_penalty):
512
+ if file_obj is None:
513
+ raise gr.Error("Upload a .txt file with one document per line.")
514
+ try:
515
+ lines = Path(file_obj.name).read_text(encoding="utf-8").splitlines()
516
+ except Exception as e:
517
+ raise gr.Error(f"Failed to read file: {e}")
518
+
519
+ results = []
520
+ for line in lines:
521
+ if not line.strip():
522
+ continue
523
+ summary = summarize(line, max_new_tokens, min_new_tokens, num_beams, length_penalty)
524
+ results.append({"source": line.strip(), "summary": summary})
525
+
526
+ out_path = Path(tempfile.gettempdir()) / "batch_results.csv"
527
+ with open(out_path, "w", newline="", encoding="utf-8") as f:
528
+ writer = csv.DictWriter(f, fieldnames=["source", "summary"])
529
+ writer.writeheader()
530
+ writer.writerows(results)
531
+ return str(out_path)
532
+
533
+ # ── Theme ─────────────────────────────────────────────────────────────────
534
+ theme = gr.themes.Base(
535
+ primary_hue=gr.themes.colors.lime,
536
+ secondary_hue=gr.themes.colors.cyan,
537
+ neutral_hue=gr.themes.colors.zinc,
538
+ ).set(
539
+ body_background_fill="#000000",
540
+ block_background_fill="transparent",
541
+ input_background_fill="rgba(255,255,255,0.02)",
542
+ body_text_color="#FFFFFF",
543
+ block_label_text_color="#52525B",
544
+ )
545
+
546
+ with gr.Blocks(title="Prism Studio", theme=theme) as demo:
547
+
548
+ # Inject CSS via HTML since Gradio 6 moved css= to launch()
549
+ gr.HTML(f"<style>{APP_CSS}</style>")
550
+
551
+ # ── Hero Header ──────────────────────────────────────────────────────
552
+ gr.HTML("""
553
+ <div class="wm-hero">
554
+ <h1>PRISM<br>STUDIO.</h1>
555
+ <p class="wm-sub">Neural Text Summarization Β· Engineered</p>
556
+ <span class="wm-accent">BART Fine-Tuned on XSum</span>
557
+ </div>
558
+ <div class="wm-divider"></div>
559
+ """)
560
+
561
+ # ── Sidebar ──────────────────────────────────────────────────────────
562
+ with gr.Sidebar(elem_classes=["wm-sidebar"]):
563
+ gr.HTML("<h3>Control Panel</h3>")
564
+ mode_selector = gr.Dropdown(
565
+ choices=list(MODE_PRESETS.keys()),
566
+ value=DEFAULT_MODE,
567
+ label="Generation Preset",
568
+ )
569
+
570
+ with gr.Accordion("Advanced Tuning", open=False, elem_classes=["wm-accordion"]):
571
+ max_new_tokens = gr.Slider(
572
+ 32, 256, value=default_preset["max_new_tokens"], step=8, label="Max tokens"
573
+ )
574
+ min_new_tokens = gr.Slider(
575
+ 8, 96, value=default_preset["min_new_tokens"], step=4, label="Min tokens"
576
+ )
577
+ num_beams = gr.Slider(
578
+ 1, 8, value=default_preset["num_beams"], step=1, label="Beams"
579
+ )
580
+ length_penalty = gr.Slider(
581
+ 0.6, 2.0, value=default_preset["length_penalty"], step=0.05, label="Length penalty"
582
+ )
583
+
584
+ gr.HTML("<div class='wm-model-info'></div>")
585
+ gr.HTML("<h4>Active Model</h4>")
586
+ gr.Markdown(load_model_info(model_reference))
587
+
588
+ # ── Tabs ─────────────────────────────────────────────────────────────
589
+ with gr.Tabs():
590
+ # ── STUDIO TAB ───────────────────────────────────────────────────
591
+ with gr.Tab("STUDIO"):
592
+ with gr.Row(elem_classes=["wm-workspace"]):
593
+ # Left β€” Source
594
+ with gr.Column(elem_classes=["wm-pane"]):
595
+ gr.HTML("""
596
+ <div class="wm-pane-label">
597
+ <span class="wm-dot"></span> SOURCE DOCUMENT
598
+ </div>
599
+ """)
600
+ file_upload = gr.File(
601
+ label="Upload .txt or .pdf",
602
+ file_types=[".txt", ".pdf"],
603
+ elem_classes=["wm-upload"],
604
+ )
605
+ input_text = gr.Textbox(
606
+ show_label=False,
607
+ placeholder="Paste your document here...",
608
+ lines=16,
609
+ elem_classes=["wm-input"],
610
+ )
611
+ token_display = gr.HTML(
612
+ f"<div class='wm-tokens'>"
613
+ f"<span class='wm-tokens-normal'>000 / {max_input_length} TOKENS</span>"
614
+ f"</div>"
615
+ )
616
+ with gr.Row(elem_classes=["wm-actions"]):
617
+ clear_btn = gr.Button("CLEAR", elem_classes=["wm-btn-ghost"])
618
+ summarize_btn = gr.Button("SUMMARIZE β†’", elem_classes=["wm-btn-primary"])
619
+
620
+ # Right β€” Output
621
+ with gr.Column(elem_classes=["wm-pane"]):
622
+ gr.HTML("""
623
+ <div class="wm-pane-label">
624
+ <span class="wm-dot wm-dot-cyan"></span> GENERATED OUTPUT
625
+ </div>
626
+ """)
627
+ output_text = gr.Textbox(
628
+ show_label=False,
629
+ interactive=False,
630
+ lines=20,
631
+ elem_classes=["wm-output"],
632
+ )
633
+
634
+ # ── BATCH TAB ────────────────────────────────────────────────────
635
+ with gr.Tab("BATCH"):
636
+ gr.HTML("""
637
+ <div class="wm-pane-label" style="padding: 32px 0 8px;">
638
+ <span class="wm-dot"></span> BULK INFERENCE
639
+ </div>
640
+ """)
641
+ gr.HTML("""
642
+ <div class="wm-batch-info">
643
+ <strong>TEMPLATE FORMAT</strong><br>
644
+ Line 1: First document to summarize.<br>
645
+ Line 2: Second document to summarize.<br>
646
+ Line 3: Third document to summarize.
647
+ </div>
648
+ """)
649
+ batch_upload = gr.File(
650
+ label="Upload batch .txt",
651
+ file_types=[".txt"],
652
+ elem_classes=["wm-upload"],
653
+ )
654
+ batch_btn = gr.Button("RUN BATCH β†’", elem_classes=["wm-btn-primary"])
655
+ batch_download = gr.File(label="Download CSV Results", interactive=False)
656
+
657
+ # ── Event Wiring ─────────────────────────────────────────────────────
658
+ def update_params(mode):
659
+ p = MODE_PRESETS[mode]
660
+ return p["max_new_tokens"], p["min_new_tokens"], p["num_beams"], p["length_penalty"]
661
+
662
+ mode_selector.change(
663
+ update_params,
664
+ inputs=[mode_selector],
665
+ outputs=[max_new_tokens, min_new_tokens, num_beams, length_penalty],
666
+ )
667
+ file_upload.change(read_file_content, inputs=[file_upload], outputs=[input_text])
668
+ input_text.change(count_tokens, inputs=[input_text], outputs=[token_display])
669
+ summarize_btn.click(
670
+ summarize,
671
+ inputs=[input_text, max_new_tokens, min_new_tokens, num_beams, length_penalty],
672
+ outputs=[output_text],
673
+ )
674
+ clear_btn.click(
675
+ lambda: (
676
+ None,
677
+ "",
678
+ f"<div class='wm-tokens'><span class='wm-tokens-normal'>000 / {max_input_length} TOKENS</span></div>",
679
+ "",
680
+ ),
681
+ inputs=None,
682
+ outputs=[file_upload, input_text, token_display, output_text],
683
+ )
684
+ batch_btn.click(
685
+ batch_summarize,
686
+ inputs=[batch_upload, max_new_tokens, min_new_tokens, num_beams, length_penalty],
687
+ outputs=[batch_download],
688
+ )
689
+
690
+ return demo
691
+
692
+
693
+ # ── Entrypoint ────────────────────────────────────────────────────────────────
694
+ def main() -> None:
695
+ logging.basicConfig(
696
+ level=logging.INFO,
697
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
698
+ )
699
+ args = parse_args()
700
+ ensure_project_dirs()
701
+
702
+ model_reference = resolve_model_reference(args.model_path, fallback=args.fallback_model)
703
+ device = default_device()
704
+
705
+ LOGGER.info("Loading model from %s", model_reference)
706
+ tokenizer = load_tokenizer(model_reference)
707
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_reference)
708
+ if getattr(model.generation_config, "max_length", None) == 20:
709
+ model.generation_config.max_length = None
710
+ model.to(device)
711
+ model.eval()
712
+
713
+ demo = build_demo(model, tokenizer, model_reference, args.max_input_length, device)
714
+ demo.queue().launch(
715
+ server_name=args.server_name,
716
+ server_port=args.server_port,
717
+ share=args.share,
718
+ )
719
+
720
+
721
+ if __name__ == "__main__":
722
+ main()