RishiRP commited on
Commit
41b65ed
·
verified ·
1 Parent(s): dfeaa23

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +295 -925
app.py CHANGED
@@ -1,978 +1,348 @@
1
- # app.py
2
  import os
3
- import re
4
- import io
5
  import json
6
- import time
7
- import zipfile
8
- from pathlib import Path
9
- from typing import List, Dict, Any, Tuple, Optional
10
-
11
- import numpy as np
12
- import pandas as pd
13
  import gradio as gr
14
-
15
  import torch
16
- from transformers import (
17
- AutoTokenizer,
18
- AutoModelForCausalLM,
19
- BitsAndBytesConfig,
20
- GenerationConfig,
21
- )
22
-
23
- # =========================
24
- # Global config
25
- # =========================
26
- SPACE_CACHE = Path.home() / ".cache" / "huggingface"
27
- SPACE_CACHE.mkdir(parents=True, exist_ok=True)
28
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
-
30
- # Fast, deterministic, compact outputs for lower latency
31
- GEN_CONFIG = GenerationConfig(
32
- temperature=0.0,
33
- top_p=1.0,
34
- do_sample=False,
35
- max_new_tokens=128, # increase if your JSON is getting truncated
36
- )
37
-
38
- # Official UBS labels (canonical)
39
- OFFICIAL_LABELS = [
40
- "plan_contact",
41
- "schedule_meeting",
42
- "update_contact_info_non_postal",
43
- "update_contact_info_postal_address",
44
- "update_kyc_activity",
45
- "update_kyc_origin_of_assets",
46
- "update_kyc_purpose_of_businessrelation",
47
- "update_kyc_total_assets",
48
- ]
49
- OFFICIAL_LABELS_TEXT = "\n".join(OFFICIAL_LABELS)
50
-
51
- # =========================
52
- # Editable defaults (shown in UI)
53
- # =========================
54
- DEFAULT_SYSTEM_INSTRUCTIONS = (
55
- "You extract ACTIONABLE TASKS from client–advisor transcripts. "
56
- "The transcript may be in German, French, Italian, or English. "
57
- "Prioritize RECALL: if a label plausibly applies, include it. "
58
- "Use ONLY the canonical labels provided. "
59
- "Return STRICT JSON only with keys 'labels' and 'tasks'. "
60
- "Each task must include 'label', a brief 'explanation', and a short 'evidence' quote from the transcript."
61
- )
62
-
63
- # Very short, language-agnostic semantics to keep prompt small
64
- DEFAULT_LABEL_GLOSSARY = {
65
- "plan_contact": "Commitment to contact later (advisor/client will reach out, follow-up promised).",
66
- "schedule_meeting": "Scheduling or confirming a meeting/call/appointment (time/date/slot/virtual).",
67
- "update_contact_info_non_postal": "Change or confirmation of phone/email (non-postal contact details).",
68
- "update_contact_info_postal_address": "Change or confirmation of postal/residential/mailing address.",
69
- "update_kyc_activity": "Change/confirmation of occupation, employment status, or economic activity.",
70
- "update_kyc_origin_of_assets": "Discussion/confirmation of source of funds / origin of assets.",
71
- "update_kyc_purpose_of_businessrelation": "Purpose of the banking relationship/account usage.",
72
- "update_kyc_total_assets": "Discussion/confirmation of total assets/net worth.",
73
- }
74
-
75
- # Tiny multilingual fallback rules (optional) to guarantee recall if model is empty.
76
- DEFAULT_FALLBACK_CUES = {
77
- "plan_contact": [
78
- # EN
79
- r"\b(get|got|will|we'?ll|i'?ll)\s+back to you\b",
80
- r"\bfollow\s*up\b",
81
- r"\breach out\b",
82
- r"\btouch base\b",
83
- r"\bcontact (you|me|us)\b",
84
- # DE
85
- r"\bin verbindung setzen\b",
86
- r"\brückmeldung\b",
87
- r"\bich\s+melde\b|\bwir\s+melden\b",
88
- r"\bnachfassen\b",
89
- # FR
90
- r"\bje vous recontacte\b|\bnous vous recontacterons\b",
91
- r"\bprendre contact\b|\breprendre contact\b",
92
- # IT
93
- r"\bla ricontatter[oò]\b|\bci metteremo in contatto\b",
94
- r"\btenersi in contatto\b",
95
- ],
96
- "schedule_meeting": [
97
- # EN
98
- r"\b(let'?s\s+)?meet(ing|s)?\b",
99
- r"\bschedule( a)? (call|meeting|appointment)\b",
100
- r"\bbook( a)? (slot|time|meeting)\b",
101
- r"\b(next week|tomorrow|this (afternoon|morning|evening))\b",
102
- r"\bconfirm( the)? (time|meeting|appointment)\b",
103
- # DE
104
- r"\btermin(e|s)?\b|\bvereinbaren\b|\bansetzen\b|\babstimmen\b|\bbesprechung(en)?\b|\bvirtuell(e|en)?\b",
105
- r"\bnächste(n|r)? woche\b|\b(dienstag|montag|mittwoch|donnerstag|freitag)\b|\bnachmittag|vormittag|morgen\b",
106
- # FR
107
- r"\brendez[- ]?vous\b|\bréunion\b|\bfixer\b|\bplanifier\b|\bcalendrier\b|\bse rencontrer\b|\bse voir\b",
108
- r"\bla semaine prochaine\b|\bdemain\b|\bcet (après-midi|apres-midi|après midi|apres midi|matin|soir)\b",
109
- # IT
110
- r"\bappuntamento\b|\briunione\b|\borganizzare\b|\bprogrammare\b|\bincontrarci\b|\bcalendario\b",
111
- r"\bla prossima settimana\b|\bdomani\b|\b(questo|questa)\s*(pomeriggio|mattina|sera)\b",
112
- ],
113
- "update_kyc_origin_of_assets": [
114
- # EN
115
- r"\bsource of funds\b|\borigin of assets\b|\bproof of (funds|assets)\b",
116
- # DE
117
- r"\bvermögensursprung(e|s)?\b|\bherkunft der mittel\b|\bnachweis\b",
118
- # FR
119
- r"\borigine des fonds\b|\borigine du patrimoine\b|\bjustificatif(s)?\b",
120
- # IT
121
- r"\borigine dei fondi\b|\borigine del patrimonio\b|\bprova dei fondi\b|\bgiustificativo\b",
122
- ],
123
- "update_kyc_activity": [
124
- # EN
125
- r"\bemployment status\b|\boccupation\b|\bjob change\b|\bsalary history\b",
126
- # DE
127
- r"\bbeschäftigungsstatus\b|\bberuf\b|\bjobwechsel\b|\bgehaltshistorie\b|\btätigkeit\b",
128
- # FR
129
- r"\bstatut professionnel\b|\bprofession\b|\bchangement d'emploi\b|\bhistorique salarial\b|\bactivité\b",
130
- # IT
131
- r"\bstato occupazionale\b|\bprofessione\b|\bcambio di lavoro\b|\bstoria salariale\b|\battivit[aà]\b",
132
- ],
133
- }
134
 
135
  # =========================
136
- # Prompt templates (minimal multilingual)
137
  # =========================
138
- USER_PROMPT_TEMPLATE = (
139
- "Transcript (may be DE/FR/IT/EN):\n"
140
- "```\n{transcript}\n```\n\n"
141
- "Allowed Labels (canonical; use only these):\n"
142
- "{allowed_labels_list}\n\n"
143
- "Label Glossary (concise semantics):\n"
144
- "{glossary}\n\n"
145
- "Return STRICT JSON ONLY in this exact schema:\n"
146
- '{\n "labels": ["<Label1>", "..."],\n'
147
- ' "tasks": [{"label": "<Label1>", "explanation": "<why>", "evidence": "<quote>"}]\n}\n'
148
- )
149
 
150
- # =========================
151
- # Utilities
152
- # =========================
153
- def _now_ms() -> int:
154
- return int(time.time() * 1000)
155
-
156
- def normalize_labels(labels: List[str]) -> List[str]:
157
- return list(dict.fromkeys([l.strip() for l in labels if isinstance(l, str) and l.strip()]))
158
-
159
- def canonicalize_map(allowed: List[str]) -> Dict[str, str]:
160
- return {lab.lower(): lab for lab in allowed}
161
-
162
- def robust_json_extract(text: str) -> Dict[str, Any]:
163
- if not text:
164
- return {"labels": [], "tasks": []}
165
- start, end = text.find("{"), text.rfind("}")
166
- candidate = text[start:end+1] if (start != -1 and end != -1 and end > start) else text
167
  try:
168
- return json.loads(candidate)
 
169
  except Exception:
170
- candidate = re.sub(r",\s*}", "}", candidate)
171
- candidate = re.sub(r",\s*]", "]", candidate)
172
- try:
173
- return json.loads(candidate)
174
- except Exception:
175
- return {"labels": [], "tasks": []}
176
-
177
- def restrict_to_allowed(pred: Dict[str, Any], allowed: List[str]) -> Dict[str, Any]:
178
- out = {"labels": [], "tasks": []}
179
- allowed_map = canonicalize_map(allowed)
180
- filt_labels = []
181
- for l in pred.get("labels", []) or []:
182
- k = str(l).strip().lower()
183
- if k in allowed_map:
184
- filt_labels.append(allowed_map[k])
185
- filt_labels = normalize_labels(filt_labels)
186
- filt_tasks = []
187
- for t in pred.get("tasks", []) or []:
188
- if not isinstance(t, dict):
189
- continue
190
- k = str(t.get("label", "")).strip().lower()
191
- if k in allowed_map:
192
- new_t = dict(t); new_t["label"] = allowed_map[k]
193
- new_t = {
194
- "label": new_t["label"],
195
- "explanation": str(new_t.get("explanation", ""))[:300],
196
- "evidence": str(new_t.get("evidence", ""))[:300],
197
- }
198
- filt_tasks.append(new_t)
199
- merged = normalize_labels(list(set(filt_labels) | {tt["label"] for tt in filt_tasks}))
200
- out["labels"] = merged
201
- out["tasks"] = filt_tasks
202
- return out
203
 
204
  # =========================
205
- # Pre-processing
206
  # =========================
207
- _DISCLAIMER_PATTERNS = [
208
- r"(?is)^\s*(?:disclaimer|legal notice|confidentiality notice).+?(?:\n{2,}|$)",
209
- r"(?is)^\s*the information contained.+?(?:\n{2,}|$)",
210
- r"(?is)^\s*this message \(including any attachments\).+?(?:\n{2,}|$)",
211
- ]
212
- _FOOTER_PATTERNS = [
213
- r"(?is)\n+kind regards[^\n]*\n.*$", r"(?is)\n+best regards[^\n]*\n.*$",
214
- r"(?is)\n+sent from my.*$", r"(?is)\n+ubs ag.*$",
215
- ]
216
- _TIMESTAMP_SPEAKER = [
217
- r"\[\d{1,2}:\d{2}(:\d{2})?\]", # [00:01] or [00:01:02]
218
- r"^\s*(advisor|client|client advisor)\s*:\s*", # Advisor:, Client:
219
- r"^\s*(speaker\s*\d+)\s*:\s*", # Speaker 1:
220
- ]
221
-
222
- def clean_transcript(text: str) -> str:
223
- if not text:
224
- return text
225
- s = text
226
- lines = []
227
- for ln in s.splitlines():
228
- ln2 = ln
229
- for pat in _TIMESTAMP_SPEAKER:
230
- ln2 = re.sub(pat, "", ln2, flags=re.IGNORECASE)
231
- lines.append(ln2)
232
- s = "\n".join(lines)
233
- for pat in _DISCLAIMER_PATTERNS:
234
- s = re.sub(pat, "", s).strip()
235
- for pat in _FOOTER_PATTERNS:
236
- s = re.sub(pat, "", s)
237
- s = re.sub(r"[ \t]+", " ", s)
238
- s = re.sub(r"\n{3,}", "\n\n", s).strip()
239
- return s
240
-
241
- def read_text_file_any(file_input) -> str:
242
- if not file_input:
243
- return ""
244
- if isinstance(file_input, (str, Path)):
245
- try:
246
- return Path(file_input).read_text(encoding="utf-8", errors="ignore")
247
- except Exception:
248
- return ""
249
- try:
250
- data = file_input.read()
251
- return data.decode("utf-8", errors="ignore")
252
- except Exception:
253
- return ""
254
-
255
- def read_json_file_any(file_input) -> Optional[dict]:
256
- if not file_input:
257
  return None
258
- if isinstance(file_input, (str, Path)):
259
- try:
260
- return json.loads(Path(file_input).read_text(encoding="utf-8", errors="ignore"))
261
- except Exception:
262
- return None
263
  try:
264
- return json.loads(file_input.read().decode("utf-8", errors="ignore"))
 
265
  except Exception:
266
  return None
267
 
268
- def truncate_tokens(tokenizer, text: str, max_tokens: int) -> str:
269
- toks = tokenizer(text, add_special_tokens=False)["input_ids"]
270
- if len(toks) <= max_tokens:
271
- return text
272
- return tokenizer.decode(toks[-max_tokens:], skip_special_tokens=True)
273
-
274
- # =========================
275
- # HF model wrapper (main LLM) – robust against meta tensor errors
276
- # =========================
277
- class ModelWrapper:
278
- def __init__(self, repo_id: str, hf_token: Optional[str], load_in_4bit: bool, use_sdpa: bool):
279
- self.repo_id = repo_id
280
- self.hf_token = hf_token
281
- self.load_in_4bit = load_in_4bit
282
- self.use_sdpa = use_sdpa
283
- self.tokenizer = None
284
- self.model = None
285
- self.load_path = "uninitialized"
286
-
287
- def load(self):
288
- # Build a BitsAndBytes config if needed
289
- qcfg = None
290
- if self.load_in_4bit and DEVICE == "cuda":
291
- qcfg = BitsAndBytesConfig(
292
- load_in_4bit=True,
293
- bnb_4bit_quant_type="nf4",
294
- bnb_4bit_compute_dtype=torch.float16,
295
- bnb_4bit_use_double_quant=True,
296
- )
297
-
298
- # Try a safe load first (no low_cpu_mem_usage, device_map="auto")
299
- errors = []
300
- for attempt in [
301
- # (desc, kwargs)
302
- ("auto_device_no_lowcpu" + ("_sdpa" if self.use_sdpa else ""),
303
- dict(
304
- torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
305
- device_map="auto" if DEVICE == "cuda" else None,
306
- low_cpu_mem_usage=False, # avoid meta init
307
- quantization_config=qcfg,
308
- trust_remote_code=True,
309
- cache_dir=str(SPACE_CACHE),
310
- attn_implementation=("sdpa" if (self.use_sdpa and DEVICE == "cuda") else None),
311
- )),
312
- ("auto_device_no_sdpa",
313
- dict(
314
- torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
315
- device_map="auto" if DEVICE == "cuda" else None,
316
- low_cpu_mem_usage=False,
317
- quantization_config=qcfg,
318
- trust_remote_code=True,
319
- cache_dir=str(SPACE_CACHE),
320
- # no attn_implementation key => let HF choose
321
- )),
322
- ("cpu_then_to_cuda" if DEVICE == "cuda" else "cpu_only",
323
- dict(
324
- torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
325
- device_map=None, # load on CPU
326
- low_cpu_mem_usage=False,
327
- quantization_config=None if DEVICE != "cuda" else qcfg, # if 4bit, keep qcfg
328
- trust_remote_code=True,
329
- cache_dir=str(SPACE_CACHE),
330
- )),
331
- ]:
332
- desc, kwargs = attempt
333
- try:
334
- tok = AutoTokenizer.from_pretrained(
335
- self.repo_id, token=self.hf_token,
336
- cache_dir=str(SPACE_CACHE), trust_remote_code=True, use_fast=True,
337
- )
338
- if tok.pad_token is None and tok.eos_token:
339
- tok.pad_token = tok.eos_token
340
-
341
- mdl = AutoModelForCausalLM.from_pretrained(
342
- self.repo_id, token=self.hf_token, **kwargs
343
- )
344
-
345
- # If we loaded on CPU and have CUDA, move model (non-meta) to CUDA
346
- if desc.startswith("cpu_then_to_cuda") and DEVICE == "cuda":
347
- mdl = mdl.to(torch.device("cuda"))
348
-
349
- self.tokenizer = tok
350
- self.model = mdl
351
- self.load_path = desc
352
- return
353
-
354
- except Exception as e:
355
- errors.append(f"{desc}: {e}")
356
-
357
- raise RuntimeError("All load attempts failed:\n" + "\n".join(errors))
358
-
359
- @torch.inference_mode()
360
- def generate(self, system_prompt: str, user_prompt: str) -> str:
361
- # Build inputs as input_ids=... (avoid **tensor kwargs mixing)
362
- if hasattr(self.tokenizer, "apply_chat_template"):
363
- messages = [
364
- {"role": "system", "content": system_prompt},
365
- {"role": "user", "content": user_prompt},
366
- ]
367
- input_ids = self.tokenizer.apply_chat_template(
368
- messages,
369
- tokenize=True,
370
- add_generation_prompt=True,
371
- return_tensors="pt",
372
- )
373
- input_ids = input_ids.to(self.model.device)
374
- gen_kwargs = dict(
375
- input_ids=input_ids,
376
- generation_config=GEN_CONFIG,
377
- eos_token_id=self.tokenizer.eos_token_id,
378
- pad_token_id=self.tokenizer.pad_token_id,
379
- )
380
- else:
381
- enc = self.tokenizer(
382
- f"<s>[SYSTEM]\n{system_prompt}\n[/SYSTEM]\n[USER]\n{user_prompt}\n[/USER]\n",
383
- return_tensors="pt"
384
- ).to(self.model.device)
385
- gen_kwargs = dict(
386
- **enc,
387
- generation_config=GEN_CONFIG,
388
- eos_token_id=self.tokenizer.eos_token_id,
389
- pad_token_id=self.tokenizer.pad_token_id,
390
- )
391
-
392
- with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
393
- out_ids = self.model.generate(**gen_kwargs)
394
- return self.tokenizer.decode(out_ids[0], skip_special_tokens=True)
395
-
396
- _MODEL_CACHE: Dict[str, ModelWrapper] = {}
397
- def get_model(repo_id: str, hf_token: Optional[str], load_in_4bit: bool, use_sdpa: bool) -> ModelWrapper:
398
- key = f"{repo_id}::{'4bit' if (load_in_4bit and DEVICE=='cuda') else 'full'}::{'sdpa' if use_sdpa else 'nosdpa'}"
399
- if key not in _MODEL_CACHE:
400
- m = ModelWrapper(repo_id, hf_token, load_in_4bit, use_sdpa)
401
- m.load()
402
- _MODEL_CACHE[key] = m
403
- return _MODEL_CACHE[key]
404
-
405
- # =========================
406
- # Evaluation (official weighted score)
407
- # =========================
408
- def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> float:
409
- ALLOWED_LABELS = OFFICIAL_LABELS
410
- LABEL_TO_IDX = {label: idx for idx, label in enumerate(ALLOWED_LABELS)}
411
-
412
- def _process_sample_labels(sample_labels: List[str], sample_name: str) -> List[str]:
413
- if not isinstance(sample_labels, list):
414
- raise ValueError(f"{sample_name} must be a list of strings, got {type(sample_labels)}")
415
- seen, uniq = set(), []
416
- for label in sample_labels:
417
- if not isinstance(label, str):
418
- raise ValueError(f"{sample_name} contains non-string: {label} (type: {type(label)})")
419
- if label in seen:
420
- raise ValueError(f"{sample_name} contains duplicate label: '{label}'")
421
- if label not in ALLOWED_LABELS:
422
- raise ValueError(f"{sample_name} contains invalid label: '{label}'. Allowed: {ALLOWED_LABELS}")
423
- seen.add(label); uniq.append(label)
424
- return uniq
425
-
426
- if len(y_true) != len(y_pred):
427
- raise ValueError(f"y_true and y_pred must have same length. Got {len(y_true)} vs {len(y_pred)}")
428
-
429
- n_samples = len(y_true)
430
- n_labels = len(OFFICIAL_LABELS)
431
- y_true_binary = np.zeros((n_samples, n_labels), dtype=int)
432
- y_pred_binary = np.zeros((n_samples, n_labels), dtype=int)
433
-
434
- for i, sample_labels in enumerate(y_true):
435
- for label in _process_sample_labels(sample_labels, f"y_true[{i}]"):
436
- y_true_binary[i, LABEL_TO_IDX[label]] = 1
437
-
438
- for i, sample_labels in enumerate(y_pred):
439
- for label in _process_sample_labels(sample_labels, f"y_pred[{i}]"):
440
- y_pred_binary[i, LABEL_TO_IDX[label]] = 1
441
-
442
- fn = np.sum((y_true_binary == 1) & (y_pred_binary == 0), axis=1) # penalty 2x
443
- fp = np.sum((y_true_binary == 0) & (y_pred_binary == 1), axis=1) # penalty 1x
444
- weighted = 2.0 * fn + 1.0 * fp
445
- max_err = 2.0 * np.sum(y_true_binary, axis=1) + 1.0 * (n_labels - np.sum(y_true_binary, axis=1))
446
- per_sample = np.where(max_err > 0, 1.0 - (weighted / max_err), 1.0)
447
- return float(max(0.0, min(1.0, np.mean(per_sample))))
448
 
449
- # =========================
450
- # Multilingual fallback (regex on original text)
451
- # =========================
452
- def multilingual_fallback(text: str, allowed: List[str], cues: Dict[str, List[str]]) -> Dict[str, Any]:
453
- low = text.lower()
454
- labels, tasks = [], []
455
- for lab in allowed:
456
- for pat in cues.get(lab, []):
457
- m = re.search(pat, low)
458
- if m:
459
- i = m.start()
460
- start = max(0, i - 60); end = min(len(text), i + len(m.group(0)) + 60)
461
- if lab not in labels:
462
- labels.append(lab)
463
- tasks.append({
464
- "label": lab,
465
- "explanation": "Rule hit (multilingual fallback)",
466
- "evidence": text[start:end].strip()
467
- })
468
- break
469
- return {"labels": normalize_labels(labels), "tasks": tasks}
470
 
471
  # =========================
472
- # Inference helpers
473
  # =========================
474
- def build_glossary_str(glossary: Dict[str, str], allowed: List[str]) -> str:
475
- return "\n".join([f"- {lab}: {glossary.get(lab, '')}" for lab in allowed])
476
-
477
- def warmup_model(model_repo: str, use_4bit: bool, use_sdpa: bool, hf_token: str) -> str:
478
- t0 = _now_ms()
479
- try:
480
- model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit, use_sdpa)
481
- _ = model.generate("Return JSON only.", '{"labels": [], "tasks": []}')
482
- return f"Warm-up complete in {_now_ms() - t0} ms. Load path: {model.load_path}"
483
- except Exception as e:
484
- return f"Warm-up failed: {e}"
485
-
486
- def run_single(
487
- transcript_text: str,
488
- transcript_file,
489
- gt_json_text: str,
490
- gt_json_file,
491
- use_cleaning: bool,
492
- use_fallback: bool,
493
- allowed_labels_text: str,
494
- sys_instructions_text: str,
495
- glossary_json_text: str,
496
- fallback_json_text: str,
497
- model_repo: str,
498
- use_4bit: bool,
499
- use_sdpa: bool,
500
- max_input_tokens: int,
501
- hf_token: str,
502
- ) -> Tuple[str, str, str, str, str, str, str, str, str]:
503
-
504
- t0 = _now_ms()
505
-
506
- # Load transcript
507
- raw_text = ""
508
- if transcript_file:
509
- raw_text = read_text_file_any(transcript_file)
510
- raw_text = (raw_text or transcript_text or "").strip()
511
- if not raw_text:
512
- return "", "", "No transcript provided.", "", "", "", "", "", ""
513
-
514
- text = clean_transcript(raw_text) if use_cleaning else raw_text
515
-
516
- # Allowed labels
517
- user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()]
518
- allowed = normalize_labels(user_allowed or OFFICIAL_LABELS)
519
-
520
- # Editable configs
521
- try:
522
- sys_instructions = (sys_instructions_text or DEFAULT_SYSTEM_INSTRUCTIONS).strip()
523
- if not sys_instructions:
524
- sys_instructions = DEFAULT_SYSTEM_INSTRUCTIONS
525
- except Exception:
526
- sys_instructions = DEFAULT_SYSTEM_INSTRUCTIONS
527
-
528
- try:
529
- label_glossary = json.loads(glossary_json_text) if glossary_json_text else DEFAULT_LABEL_GLOSSARY
530
- except Exception:
531
- label_glossary = DEFAULT_LABEL_GLOSSARY
532
-
533
- try:
534
- fallback_cues = json.loads(fallback_json_text) if fallback_json_text else DEFAULT_FALLBACK_CUES
535
- except Exception:
536
- fallback_cues = DEFAULT_FALLBACK_CUES
537
 
538
- # Model
539
  try:
540
- model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit, use_sdpa)
541
  except Exception as e:
542
- return "", "", f"Model load failed: {e}", "", "", "", "", "", ""
543
-
544
- # Truncate
545
- trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
546
-
547
- # Build prompt
548
- glossary_str = build_glossary_str(label_glossary, allowed)
549
- allowed_list_str = "\n".join(f"- {l}" for l in allowed)
550
- user_prompt = USER_PROMPT_TEMPLATE.format(
551
- transcript=trunc,
552
- allowed_labels_list=allowed_list_str,
553
- glossary=glossary_str,
 
 
 
554
  )
555
 
556
- # Token info + prompt preview
557
- transcript_tokens = len(model.tokenizer(trunc, add_special_tokens=False)["input_ids"])
558
- prompt_tokens = len(model.tokenizer(user_prompt, add_special_tokens=False)["input_ids"])
559
- token_info_text = f"Transcript tokens: {transcript_tokens} | Prompt tokens: {prompt_tokens} | Load path: {model.load_path}"
560
- prompt_preview_text = "```\n" + user_prompt[:4000] + ("\n... (truncated)" if len(user_prompt) > 4000 else "") + "\n```"
561
-
562
- # Generate
563
- t1 = _now_ms()
564
  try:
565
- out = model.generate(sys_instructions, user_prompt)
 
 
 
 
 
 
 
 
 
 
566
  except Exception as e:
567
- return "", "", f"Generation error: {e}", "", "", "", prompt_preview_text, token_info_text, ""
568
- t2 = _now_ms()
569
-
570
- parsed = robust_json_extract(out)
571
- filtered = restrict_to_allowed(parsed, allowed)
572
-
573
- # Fallback (multilingual rules) on original text; merge for recall if enabled
574
- if use_fallback:
575
- fb = multilingual_fallback(trunc, allowed, fallback_cues)
576
- if fb["labels"]:
577
- merged_labels = sorted(list(set(filtered.get("labels", [])) | set(fb["labels"])))
578
- existing = {tt.get("label") for tt in filtered.get("tasks", [])}
579
- merged_tasks = filtered.get("tasks", []) + [t for t in fb["tasks"] if t["label"] not in existing]
580
- filtered = {"labels": merged_labels, "tasks": merged_tasks}
581
-
582
- # Diagnostics
583
- diag = "\n".join([
584
- f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
585
- f"Model: {model_repo}",
586
- f"Input cleaned: {'Yes' if use_cleaning else 'No'}",
587
- f"Fallback rules: {'Yes' if use_fallback else 'No'}",
588
- f"SDPA attention: {'Yes' if use_sdpa else 'No'}",
589
- f"Tokens (input limit): ≤ {max_input_tokens}",
590
- f"Latency: prep {t1-t0} ms, gen {t2-t1} ms, total {t2-t0} ms",
591
- f"Allowed labels: {', '.join(allowed)}",
592
- ])
593
-
594
- # Summaries
595
- labs = filtered.get("labels", [])
596
- tasks = filtered.get("tasks", [])
597
- summary = "Detected labels:\n" + ("\n".join(f"- {l}" for l in labs) if labs else "(none)")
598
- if tasks:
599
- summary += "\n\nTasks:\n" + "\n".join(
600
- f"• [{t['label']}] {t.get('explanation','')} | ev: {t.get('evidence','')[:140]}{'…' if len(t.get('evidence',''))>140 else ''}"
601
- for t in tasks
602
- )
603
- else:
604
- summary += "\n\nTasks: (none)"
605
- json_out = json.dumps(filtered, indent=2, ensure_ascii=False)
606
-
607
- # Single-file scoring if GT provided
608
- metrics = ""
609
- if gt_json_file or (gt_json_text and gt_json_text.strip()):
610
- truth_obj = None
611
- if gt_json_file:
612
- truth_obj = read_json_file_any(gt_json_file)
613
- if (not truth_obj) and gt_json_text:
614
- try:
615
- truth_obj = json.loads(gt_json_text)
616
- except Exception:
617
- pass
618
- if isinstance(truth_obj, dict) and isinstance(truth_obj.get("labels"), list):
619
- true_labels = [x for x in truth_obj["labels"] if x in OFFICIAL_LABELS]
620
- pred_labels = labs
621
- try:
622
- score = evaluate_predictions([true_labels], [pred_labels])
623
- tp = len(set(true_labels) & set(pred_labels))
624
- fp = len(set(pred_labels) - set(true_labels))
625
- fn = len(set(true_labels) - set(pred_labels))
626
- recall = tp / (tp + fn) if (tp + fn) else 1.0
627
- precision = tp / (tp + fp) if (tp + fp) else 1.0
628
- f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 1.0
629
- metrics = (
630
- f"Weighted score: {score:.3f}\n"
631
- f"Recall: {recall:.3f} | Precision: {precision:.3f} | F1: {f1:.3f}\n"
632
- f"TP={tp} FP={fp} FN={fn}\n"
633
- f"Truth: {', '.join(true_labels)}"
634
- )
635
- except Exception as e:
636
- metrics = f"Scoring error: {e}"
637
- else:
638
- metrics = "Ground truth JSON missing or invalid; expected {'labels': [...]}."
639
 
640
- # For UI: show effective context (glossary) and instructions
641
- context_preview = "### Label Glossary (used)\n" + "\n".join(f"- {k}: {v}" for k, v in label_glossary.items() if k in allowed)
642
- instructions_preview = "```\n" + sys_instructions + "\n```"
 
 
643
 
644
- return summary, json_out, diag, out.strip(), context_preview, instructions_preview, metrics, prompt_preview_text, token_info_text
 
 
 
645
 
646
  # =========================
647
- # Batch mode (ZIP with transcripts + truths)
648
  # =========================
649
- def read_zip_from_path(path: str, exdir: Path) -> List[Path]:
650
- exdir.mkdir(parents=True, exist_ok=True)
651
- with open(path, "rb") as f:
652
- data = f.read()
653
- with zipfile.ZipFile(io.BytesIO(data)) as zf:
654
- zf.extractall(exdir)
655
- return [p for p in exdir.rglob("*") if p.is_file()]
656
-
657
- def run_batch(
658
- zip_path,
659
- use_cleaning: bool,
660
- use_fallback: bool,
661
- sys_instructions_text: str,
662
- glossary_json_text: str,
663
- fallback_json_text: str,
664
- model_repo: str,
665
- use_4bit: bool,
666
- use_sdpa: bool,
667
- max_input_tokens: int,
668
- hf_token: str,
669
- limit_files: int,
670
- ) -> Tuple[str, str, pd.DataFrame, str]:
671
-
672
- if not zip_path:
673
- return ("No ZIP provided.", "", pd.DataFrame(), "")
674
-
675
- # Editable configs
676
- try:
677
- sys_instructions = (sys_instructions_text or DEFAULT_SYSTEM_INSTRUCTIONS).strip()
678
- if not sys_instructions:
679
- sys_instructions = DEFAULT_SYSTEM_INSTRUCTIONS
680
- except Exception:
681
- sys_instructions = DEFAULT_SYSTEM_INSTRUCTIONS
682
-
683
- try:
684
- label_glossary = json.loads(glossary_json_text) if glossary_json_text else DEFAULT_LABEL_GLOSSARY
685
- except Exception:
686
- label_glossary = DEFAULT_LABEL_GLOSSARY
687
-
688
- try:
689
- fallback_cues = json.loads(fallback_json_text) if fallback_json_text else DEFAULT_FALLBACK_CUES
690
- except Exception:
691
- fallback_cues = DEFAULT_FALLBACK_CUES
692
-
693
- # Prepare workspace
694
- work = Path("/tmp/batch")
695
- if work.exists():
696
- for p in sorted(work.rglob("*"), reverse=True):
697
- try: p.unlink()
698
- except Exception: pass
699
- try: work.rmdir()
700
- except Exception: pass
701
- work.mkdir(parents=True, exist_ok=True)
702
-
703
- files = read_zip_from_path(zip_path, work)
704
-
705
- txts: Dict[str, Path] = {}
706
- gts: Dict[str, Path] = {}
707
- for p in files:
708
- if p.suffix.lower() == ".txt":
709
- txts[p.stem] = p
710
- elif p.suffix.lower() == ".json":
711
- gts[p.stem] = p
712
-
713
- stems = sorted(txts.keys())
714
- if limit_files > 0:
715
- stems = stems[:limit_files]
716
- if not stems:
717
- return ("No .txt transcripts found in ZIP.", "", pd.DataFrame(), "")
718
-
719
- # Model
720
- try:
721
- model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit, use_sdpa)
722
- except Exception as e:
723
- return (f"Model load failed: {e}", "", pd.DataFrame(), "")
724
-
725
- allowed = OFFICIAL_LABELS[:]
726
- glossary_str = build_glossary_str(label_glossary, allowed)
727
- allowed_list_str = "\n".join(f"- {l}" for l in allowed)
728
-
729
- y_true, y_pred = [], []
730
- rows = []
731
- t_start = _now_ms()
732
-
733
- for stem in stems:
734
- raw = txts[stem].read_text(encoding="utf-8", errors="ignore")
735
- text = clean_transcript(raw) if use_cleaning else raw
736
 
737
- trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
 
 
 
738
 
739
- user_prompt = USER_PROMPT_TEMPLATE.format(
740
- transcript=trunc,
741
- allowed_labels_list=allowed_list_str,
742
- glossary=glossary_str,
743
- )
744
 
745
- t0 = _now_ms()
746
- out = model.generate(sys_instructions, user_prompt)
747
- t1 = _now_ms()
748
-
749
- parsed = robust_json_extract(out)
750
- filtered = restrict_to_allowed(parsed, allowed)
751
-
752
- if use_fallback:
753
- fb = multilingual_fallback(trunc, allowed, fallback_cues)
754
- if fb["labels"]:
755
- merged_labels = sorted(list(set(filtered.get("labels", [])) | set(fb["labels"])))
756
- existing = {tt.get("label") for tt in filtered.get("tasks", [])}
757
- merged_tasks = filtered.get("tasks", []) + [t for t in fb["tasks"] if t["label"] not in existing]
758
- filtered = {"labels": merged_labels, "tasks": merged_tasks}
759
-
760
- pred_labels = filtered.get("labels", [])
761
- y_pred.append(pred_labels)
762
-
763
- gt_labels = []
764
- if stem in gts:
765
- try:
766
- gt_obj = json.loads(gts[stem].read_text(encoding="utf-8", errors="ignore"))
767
- if isinstance(gt_obj, dict) and isinstance(gt_obj.get("labels"), list):
768
- gt_labels = [x for x in gt_obj["labels"] if x in OFFICIAL_LABELS]
769
- except Exception:
770
- pass
771
- y_true.append(gt_labels)
772
-
773
- gt_set, pr_set = set(gt_labels), set(pred_labels)
774
- tp = sorted(gt_set & pr_set)
775
- fp = sorted(pr_set - gt_set)
776
- fn = sorted(gt_set - pr_set)
777
-
778
- rows.append({
779
- "file": stem,
780
- "true_labels": ", ".join(gt_labels),
781
- "pred_labels": ", ".join(pred_labels),
782
- "TP": len(tp), "FP": len(fp), "FN": len(fn),
783
- "gen_ms": t1 - t0
784
- })
785
-
786
- have_truth = any(len(v) > 0 for v in y_true)
787
- score = evaluate_predictions(y_true, y_pred) if have_truth else None
788
-
789
- df = pd.DataFrame(rows).sort_values(["FN", "FP", "file"])
790
- diag = [
791
- f"Processed files: {len(stems)}",
792
- f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
793
- f"Model: {model_repo}",
794
- f"Fallback rules: {'Yes' if use_fallback else 'No'}",
795
- f"SDPA attention: {'Yes' if use_sdpa else 'No'}",
796
- f"Tokens (input limit): ≤ {max_input_tokens}",
797
- f"Batch time: {_now_ms()-t_start} ms",
798
- ]
799
- if have_truth and score is not None:
800
- total_tp = int(df["TP"].sum())
801
- total_fp = int(df["FP"].sum())
802
- total_fn = int(df["FN"].sum())
803
- recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) else 1.0
804
- precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) else 1.0
805
- f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 1.0
806
- diag += [
807
- f"Official weighted score (0–1): {score:.3f}",
808
- f"Recall: {recall:.3f} | Precision: {precision:.3f} | F1: {f1:.3f}",
809
- f"Total TP={total_tp} FP={total_fp} FN={total_fn}",
810
- ]
811
- diag_str = "\n".join(diag)
812
 
813
- out_csv = Path("/tmp/batch_results.csv")
814
- df.to_csv(out_csv, index=False, encoding="utf-8")
815
- return ("Batch done.", diag_str, df, str(out_csv))
 
 
 
 
 
 
816
 
817
- # =========================
818
- # UI
819
- # =========================
820
- MODEL_CHOICES = [
821
- "swiss-ai/Apertus-8B-Instruct-2509", # multilingual
822
- "meta-llama/Meta-Llama-3-8B-Instruct", # strong generalist
823
- "mistralai/Mistral-7B-Instruct-v0.3", # light/fast
824
- ]
825
-
826
- # Light, modern UI (white background, neutral accents)
827
- custom_css = """
828
- :root { --radius: 14px; }
829
- .gradio-container { font-family: Inter, ui-sans-serif, system-ui; background: #ffffff; color: #111827; }
830
- .card { border: 1px solid #e5e7eb; border-radius: var(--radius); padding: 14px 16px; background: #ffffff; box-shadow: 0 1px 2px rgba(0,0,0,.03); }
831
- .header { font-weight: 700; font-size: 22px; margin-bottom: 4px; color: #0f172a; }
832
- .subtle { color: #475569; font-size: 14px; margin-bottom: 12px; }
833
- hr.sep { border: none; border-top: 1px solid #e5e7eb; margin: 10px 0 16px; }
834
- .gr-button { border-radius: 12px !important; }
835
- a, .prose a { color: #0ea5e9; }
836
  """
837
 
838
- with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo:
839
- gr.Markdown("<div class='header'>Talk2Task Multilingual Task Extraction (UBS Challenge)</div>")
840
- gr.Markdown("<div class='subtle'>Single-pass multilingual extraction (DE/FR/IT/EN) with compact prompts. Optional rule fallback ensures recall. Batch evaluation & scoring included.</div>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
841
 
842
- with gr.Tab("Single transcript"):
843
  with gr.Row():
844
- with gr.Column(scale=3):
845
- gr.Markdown("<div class='card'><div class='header'>Transcript</div>")
846
- file = gr.File(
847
- label="Drag & drop transcript (.txt / .md / .json)",
848
- file_types=[".txt", ".md", ".json"],
849
- type="filepath",
 
 
850
  )
851
- text = gr.Textbox(label="Or paste transcript", lines=10, placeholder="Paste transcript in DE/FR/IT/EN…")
852
- gr.Markdown("<hr class='sep'/>")
853
-
854
- gr.Markdown("<div class='header'>Ground truth JSON (optional)</div>")
855
- gt_file = gr.File(
856
- label="Upload ground truth JSON (expects {'labels': [...]})",
857
- file_types=[".json"],
858
- type="filepath",
859
  )
860
- gt_text = gr.Textbox(label="Or paste ground truth JSON", lines=6, placeholder='{\"labels\": [\"schedule_meeting\"]}')
861
- gr.Markdown("</div>") # close card
862
-
863
- gr.Markdown("<div class='card'><div class='header'>Processing options</div>")
864
- use_cleaning = gr.Checkbox(label="Apply default cleaning (remove disclaimers, timestamps, speakers, footers)", value=True)
865
- use_fallback = gr.Checkbox(label="Enable multilingual fallback rule layer", value=True)
866
- gr.Markdown("</div>")
867
-
868
- gr.Markdown("<div class='card'><div class='header'>Allowed labels</div>")
869
- labels_text = gr.Textbox(label="Allowed Labels (one per line)", value=OFFICIAL_LABELS_TEXT, lines=8)
870
- reset_btn = gr.Button("Reset to official labels")
871
- gr.Markdown("</div>")
872
-
873
- gr.Markdown("<div class='card'><div class='header'>Editable instructions & context</div>")
874
- sys_instr_tb = gr.Textbox(label="System Instructions (editable)", value=DEFAULT_SYSTEM_INSTRUCTIONS, lines=5)
875
- glossary_tb = gr.Code(label="Label Glossary (JSON; editable)", value=json.dumps(DEFAULT_LABEL_GLOSSARY, indent=2), language="json")
876
- fallback_tb = gr.Code(label="Fallback Cues (Multilingual, JSON; editable)", value=json.dumps(DEFAULT_FALLBACK_CUES, indent=2), language="json")
877
- gr.Markdown("</div>")
878
-
879
- with gr.Column(scale=2):
880
- gr.Markdown("<div class='card'><div class='header'>Model & run</div>")
881
- repo = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
882
- use_4bit = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
883
- use_sdpa = gr.Checkbox(label="Use SDPA attention (faster on many GPUs)", value=True)
884
- max_tokens = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=2048)
885
- hf_token = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
886
- warm_btn = gr.Button("Warm up model (load & compile kernels)")
887
- run_btn = gr.Button("Run Extraction", variant="primary")
888
- gr.Markdown("</div>")
889
-
890
- gr.Markdown("<div class='card'><div class='header'>Outputs</div>")
891
- summary = gr.Textbox(label="Summary", lines=12)
892
- json_out = gr.Code(label="Strict JSON Output", language="json")
893
- diag = gr.Textbox(label="Diagnostics", lines=10)
894
- raw = gr.Textbox(label="Raw Model Output", lines=8)
895
- prompt_preview = gr.Code(label="Prompt preview (user prompt sent)", language="markdown")
896
- token_info = gr.Textbox(label="Token counts (transcript / prompt / load path)", lines=2)
897
- gr.Markdown("</div>")
898
 
899
- with gr.Row():
900
- with gr.Column():
901
- with gr.Accordion("Instructions used (system prompt)", open=False):
902
- instr_md = gr.Markdown("```\n" + DEFAULT_SYSTEM_INSTRUCTIONS + "\n```")
903
- with gr.Column():
904
- with gr.Accordion("Context used (glossary)", open=True):
905
- context_md = gr.Markdown("")
906
-
907
- # Reset labels to official
908
- def _reset_labels():
909
- return OFFICIAL_LABELS_TEXT
910
- reset_btn.click(fn=_reset_labels, inputs=None, outputs=labels_text)
911
-
912
- # Warm-up
913
- warm_btn.click(fn=warmup_model, inputs=[repo, use_4bit, use_sdpa, hf_token], outputs=diag)
914
-
915
- # For initial context preview
916
- def _pack_context_md(glossary_json, allowed_text):
917
- try:
918
- glossary = json.loads(glossary_json) if glossary_json else DEFAULT_LABEL_GLOSSARY
919
- except Exception:
920
- glossary = DEFAULT_LABEL_GLOSSARY
921
- allowed_list = [ln.strip() for ln in (allowed_text or OFFICIAL_LABELS_TEXT).splitlines() if ln.strip()]
922
- return "### Label Glossary (used)\n" + "\n".join(f"- {k}: {glossary.get(k,'')}" for k in allowed_list)
923
-
924
- context_md.value = _pack_context_md(json.dumps(DEFAULT_LABEL_GLOSSARY), OFFICIAL_LABELS_TEXT)
925
-
926
- # Single run
927
  run_btn.click(
928
- fn=run_single,
929
- inputs=[
930
- text, file, gt_text, gt_file, use_cleaning, use_fallback,
931
- labels_text, sys_instr_tb, glossary_tb, fallback_tb,
932
- repo, use_4bit, use_sdpa, max_tokens, hf_token
933
- ],
934
- outputs=[summary, json_out, diag, raw, context_md, instr_md, gr.Textbox(visible=False), prompt_preview, token_info],
935
  )
936
 
937
- with gr.Tab("Batch evaluation"):
938
- with gr.Row():
939
- with gr.Column(scale=3):
940
- gr.Markdown("<div class='card'><div class='header'>ZIP input</div>")
941
- zip_in = gr.File(label="ZIP with transcripts (.txt) and truths (.json)", file_types=[".zip"], type="filepath")
942
- use_cleaning_b = gr.Checkbox(label="Apply default cleaning", value=True)
943
- use_fallback_b = gr.Checkbox(label="Enable multilingual fallback rule layer", value=True)
944
- gr.Markdown("</div>")
945
- with gr.Column(scale=2):
946
- gr.Markdown("<div class='card'><div class='header'>Model & run</div>")
947
- repo_b = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
948
- use_4bit_b = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
949
- use_sdpa_b = gr.Checkbox(label="Use SDPA attention (faster on many GPUs)", value=True)
950
- max_tokens_b = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=2048)
951
- hf_token_b = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
952
- sys_instr_tb_b = gr.Textbox(label="System Instructions (editable for batch)", value=DEFAULT_SYSTEM_INSTRUCTIONS, lines=4)
953
- glossary_tb_b = gr.Code(label="Label Glossary (JSON; editable for batch)", value=json.dumps(DEFAULT_LABEL_GLOSSARY, indent=2), language="json")
954
- fallback_tb_b = gr.Code(label="Fallback Cues (Multilingual, JSON; editable for batch)", value=json.dumps(DEFAULT_FALLBACK_CUES, indent=2), language="json")
955
- limit_files = gr.Slider(label="Process at most N files (0 = all)", minimum=0, maximum=2000, step=10, value=0)
956
- run_batch_btn = gr.Button("Run Batch", variant="primary")
957
- gr.Markdown("</div>")
958
 
959
- with gr.Row():
960
- gr.Markdown("<div class='card'><div class='header'>Batch outputs</div>")
961
- status = gr.Textbox(label="Status", lines=1)
962
- diag_b = gr.Textbox(label="Batch diagnostics & metrics", lines=12)
963
- df_out = gr.Dataframe(label="Per-file results (TP/FP/FN, latency)", interactive=False)
964
- csv_out = gr.File(label="Download CSV", interactive=False)
965
- gr.Markdown("</div>")
966
-
967
- run_batch_btn.click(
968
- fn=run_batch,
969
- inputs=[
970
- zip_in, use_cleaning_b, use_fallback_b,
971
- sys_instr_tb_b, glossary_tb_b, fallback_tb_b,
972
- repo_b, use_4bit_b, use_sdpa_b, max_tokens_b, hf_token_b, limit_files
973
- ],
974
- outputs=[status, diag_b, df_out, csv_out],
975
- )
976
 
977
  if __name__ == "__main__":
 
978
  demo.launch()
 
 
1
  import os
 
 
2
  import json
 
 
 
 
 
 
 
3
  import gradio as gr
 
4
  import torch
5
+ from typing import Optional, Tuple, Dict, Any
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  # =========================
9
+ # Runtime / Model Defaults
10
  # =========================
11
+ # Small, ungated default to avoid permission/download issues.
12
+ # You can switch at runtime via the dropdown or set MODEL_ID env var.
13
+ DEFAULT_MODEL_ID = os.environ.get("MODEL_ID", "TinyLlama/TinyLlama-1.1B-Chat-v1.0")
 
 
 
 
 
 
 
 
14
 
15
+ def _has_bnb_and_cuda() -> bool:
16
+ if not torch.cuda.is_available():
17
+ return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  try:
19
+ import bitsandbytes as _bnb # noqa: F401
20
+ return True
21
  except Exception:
22
+ return False
23
+
24
+ USE_BNB = _has_bnb_and_cuda()
25
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # =========================
28
+ # Model Load (safe + flexible)
29
  # =========================
30
+ _tokenizer: Optional[AutoTokenizer] = None
31
+ _model: Optional[AutoModelForCausalLM] = None
32
+ _current_model_id: Optional[str] = None
33
+
34
+ def load_model(model_id: str) -> Tuple[AutoTokenizer, AutoModelForCausalLM]:
35
+ """
36
+ Loads (or reuses) a model/tokenizer. Uses bitsandbytes 4-bit only if
37
+ CUDA is available AND bnb is installed. Otherwise plain CPU/GPU.
38
+ """
39
+ global _tokenizer, _model, _current_model_id
40
+
41
+ if _tokenizer is not None and _model is not None and _current_model_id == model_id:
42
+ return _tokenizer, _model
43
+
44
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, trust_remote_code=True)
45
+
46
+ if USE_BNB:
47
+ from transformers import BitsAndBytesConfig
48
+ quant = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
49
+ model = AutoModelForCausalLM.from_pretrained(
50
+ model_id,
51
+ quantization_config=quant,
52
+ device_map="auto",
53
+ trust_remote_code=True,
54
+ )
55
+ else:
56
+ dtype = torch.float16 if DEVICE == "cuda" else torch.float32
57
+ model = AutoModelForCausalLM.from_pretrained(
58
+ model_id,
59
+ torch_dtype=dtype,
60
+ low_cpu_mem_usage=True,
61
+ trust_remote_code=True,
62
+ ).to(DEVICE)
63
+
64
+ _tokenizer, _model, _current_model_id = tokenizer, model, model_id
65
+ return tokenizer, model
66
+
67
+ # ======================================
68
+ # Helpers: Ingest TXT/JSON from Tabs box
69
+ # ======================================
70
+ def read_file(file_obj: Optional[gr.File]) -> Optional[str]:
71
+ if not file_obj:
 
 
 
 
 
 
 
 
72
  return None
 
 
 
 
 
73
  try:
74
+ with open(file_obj.name, "r", encoding="utf-8", errors="ignore") as f:
75
+ return f.read()
76
  except Exception:
77
  return None
78
 
79
+ def normalize_txt_input(paste_txt: str, upload_file: Optional[gr.File]) -> str:
80
+ file_text = read_file(upload_file)
81
+ if paste_txt and paste_txt.strip():
82
+ return paste_txt
83
+ return file_text or ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ def normalize_json_input(paste_json: str, upload_file: Optional[gr.File]) -> str:
86
+ file_text = read_file(upload_file)
87
+ candidate = paste_json.strip() if paste_json else ""
88
+ if not candidate and file_text:
89
+ candidate = file_text
90
+ return candidate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  # =========================
93
+ # Core Extraction (placeholder)
94
  # =========================
95
+ def run_extraction(
96
+ model_choice: str,
97
+ params_checked: list,
98
+ instructions_text: str,
99
+ context_text: str,
100
+ txt_paste: str,
101
+ txt_upload: Optional[gr.File],
102
+ json_paste: str,
103
+ json_upload: Optional[gr.File],
104
+ max_new_tokens: int,
105
+ temperature: float,
106
+ top_p: float,
107
+ ) -> Tuple[str, str, str, str, str]:
108
+ """
109
+ Wire your real extraction here.
110
+ Returns:
111
+ tasks_out, entities_out, cleaned_out, summary_out, diagnostics
112
+ """
113
+ diagnostics_lines = []
114
+
115
+ # Resolve inputs from single-box Tab controls
116
+ input_txt = normalize_txt_input(txt_paste, txt_upload)
117
+ input_json_raw = normalize_json_input(json_paste, json_upload)
118
+
119
+ diagnostics_lines.append(f"Model: {model_choice}")
120
+ diagnostics_lines.append(f"Params: {params_checked}")
121
+ diagnostics_lines.append(f"Instructions length: {len(instructions_text)} chars")
122
+ diagnostics_lines.append(f"Context length: {len(context_text)} chars")
123
+ diagnostics_lines.append(f"TXT length: {len(input_txt)} chars")
124
+
125
+ # Try parse JSON (optional)
126
+ parsed_json: Dict[str, Any] = {}
127
+ if input_json_raw:
128
+ try:
129
+ parsed_json = json.loads(input_json_raw)
130
+ diagnostics_lines.append("JSON: parsed successfully")
131
+ except Exception as e:
132
+ diagnostics_lines.append(f"JSON parse error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ # Load selected model (safe)
135
  try:
136
+ tokenizer, model = load_model(model_choice)
137
  except Exception as e:
138
+ # If model fails to load, still return diagnostics
139
+ diag = "\n".join(diagnostics_lines + [f"Model load failed: {e}"])
140
+ return "", "", "", "", diag
141
+
142
+ # ---------- Dummy generation (replace with your real prompts) ----------
143
+ # Build a prompt from inputs (very basic)
144
+ user_prompt = (
145
+ "You are an assistant that extracts tasks and entities.\n"
146
+ f"Instructions: {instructions_text}\n"
147
+ f"Context: {context_text}\n"
148
+ "----\n"
149
+ f"TEXT:\n{input_txt[:4000]}\n"
150
+ "----\n"
151
+ f"JSON:\n{json.dumps(parsed_json)[:2000]}\n"
152
+ "Extract:\n- Tasks list\n- Entities list\n- Cleaned text (sanitized)\n- 1-2 line summary\n"
153
  )
154
 
 
 
 
 
 
 
 
 
155
  try:
156
+ inputs = tokenizer(user_prompt, return_tensors="pt").to(DEVICE)
157
+ with torch.no_grad():
158
+ outputs = _model.generate(
159
+ **inputs,
160
+ max_new_tokens=max_new_tokens,
161
+ do_sample=True,
162
+ temperature=temperature,
163
+ top_p=top_p,
164
+ pad_token_id=tokenizer.eos_token_id,
165
+ )
166
+ full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
167
  except Exception as e:
168
+ diag = "\n".join(diagnostics_lines + [f"Inference failed: {e}"])
169
+ return "", "", "", "", diag
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
+ # Very naive post-split (replace with your own structured parsing)
172
+ tasks_out = " Task 1\n Task 2\n(Replace with your parser)"
173
+ entities_out = " Entity A\n• Entity B\n(Replace with your parser)"
174
+ cleaned_out = "Cleaned text here… (Replace with your cleaning pipeline)"
175
+ summary_out = "Short summary here… (Replace with your summarizer)"
176
 
177
+ diagnostics_lines.append("Generation completed successfully.")
178
+ diagnostics = "\n".join(diagnostics_lines)
179
+
180
+ return tasks_out, entities_out, cleaned_out, summary_out, diagnostics
181
 
182
  # =========================
183
+ # UI (Gradio Blocks)
184
  # =========================
185
+ THEME_CSS = """
186
+ /* Global colors: white background, black text */
187
+ :root {
188
+ --body-background-fill: #ffffff !important;
189
+ --body-text-color: #111111 !important;
190
+ --link-text-color: #0b63ce !important; /* blue */
191
+ --shadow-spread: 0px;
192
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
+ /* Ensure all text is readable (black-ish) */
195
+ .gradio-container, .prose, .prose * {
196
+ color: #111111 !important;
197
+ }
198
 
199
+ /* Accent elements in blue (no purple) */
200
+ label, .tabitem .label-wrap, .wrap .label-wrap {
201
+ color: #0b63ce !important;
202
+ }
 
203
 
204
+ /* Cards / Boxes */
205
+ .gr-box, .gr-panel, .gr-group, .gr-accordion {
206
+ border: 1px solid #e5e7eb !important; /* light gray border */
207
+ border-radius: 14px !important;
208
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
+ /* Red run button */
211
+ button#run-btn {
212
+ background: #e11900 !important;
213
+ color: #ffffff !important;
214
+ border: 1px solid #b50f00 !important;
215
+ }
216
+ button#run-btn:hover {
217
+ filter: brightness(0.95);
218
+ }
219
 
220
+ /* Inputs layout polish */
221
+ .input-card {
222
+ padding: 10px;
223
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  """
225
 
226
+ def build_interface() -> gr.Blocks:
227
+ with gr.Blocks(title="Talk2Task Demo", css=THEME_CSS) as demo:
228
+ # 1) MODEL SELECTION (full width) + checklist embedded
229
+ with gr.Group():
230
+ gr.Markdown("### Model & Parameters", elem_id="model-header")
231
+ with gr.Row(equal_height=True):
232
+ model_choice = gr.Dropdown(
233
+ label="Model",
234
+ choices=[
235
+ DEFAULT_MODEL_ID,
236
+ "mistralai/Mistral-7B-Instruct-v0.2",
237
+ "meta-llama/Llama-3.1-8B-Instruct", # if accessible
238
+ ],
239
+ value=DEFAULT_MODEL_ID,
240
+ scale=3
241
+ )
242
+ params_checked = gr.CheckboxGroup(
243
+ label="Options",
244
+ choices=[
245
+ "Default cleaning",
246
+ "Remove PII",
247
+ "Allow 4-bit (if available)",
248
+ "Detect language",
249
+ ],
250
+ value=["Default cleaning"],
251
+ scale=2
252
+ )
253
+ with gr.Row():
254
+ # generation controls (kept compact)
255
+ temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature")
256
+ top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
257
+ max_new_tokens = gr.Slider(32, 1024, value=200, step=8, label="Max new tokens")
258
+
259
+ # 2) SINGLE “BOX” PER TYPE — via Tabs (Paste OR Drag & Drop) — side-by-side
260
+ gr.Markdown("### Input", elem_id="input-header")
261
+ with gr.Row(equal_height=True):
262
+ with gr.Group(elem_classes=["input-card"]):
263
+ gr.Markdown("**TXT Input** (Paste or Drag & Drop)", elem_id="txt-box-title")
264
+ with gr.Tabs():
265
+ with gr.TabItem("Paste"):
266
+ txt_paste = gr.TextArea(
267
+ label="Paste TXT",
268
+ placeholder="Paste raw transcript or text here...",
269
+ lines=12,
270
+ )
271
+ with gr.TabItem("Drag & Drop"):
272
+ txt_upload = gr.File(
273
+ label="Upload .txt file",
274
+ file_types=[".txt"],
275
+ )
276
+
277
+ with gr.Group(elem_classes=["input-card"]):
278
+ gr.Markdown("**JSON Input** (Paste or Drag & Drop)", elem_id="json-box-title")
279
+ with gr.Tabs():
280
+ with gr.TabItem("Paste"):
281
+ json_paste = gr.Code(
282
+ label="Paste JSON",
283
+ language="json",
284
+ value="{\n \"example\": true\n}",
285
+ lines=12,
286
+ )
287
+ with gr.TabItem("Drag & Drop"):
288
+ json_upload = gr.File(
289
+ label="Upload .json file",
290
+ file_types=[".json"],
291
+ )
292
+
293
+ # 3) RUN BUTTON (red), then collapsible Instructions & Context
294
+ run_btn = gr.Button("Run Extraction", elem_id="run-btn", variant="primary")
295
 
 
296
  with gr.Row():
297
+ with gr.Accordion("Instructions (editable)", open=False):
298
+ instructions_text = gr.TextArea(
299
+ label="Instructions",
300
+ value=(
301
+ "Extract tasks, entities, and a short summary. "
302
+ "Apply default cleaning unless unchecked."
303
+ ),
304
+ lines=5,
305
  )
306
+ with gr.Accordion("Context (editable)", open=False):
307
+ context_text = gr.TextArea(
308
+ label="Context",
309
+ value=(
310
+ "Use banking/consulting context if relevant. "
311
+ "Prefer concise actionable phrasing."
312
+ ),
313
+ lines=5,
314
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
 
316
+ # 4) OUTPUT LAYOUT — symmetrical boxes
317
+ gr.Markdown("### Results", elem_id="results-header")
318
+ with gr.Row(equal_height=True):
319
+ tasks_out = gr.TextArea(label="Tasks", lines=10)
320
+ entities_out = gr.TextArea(label="Entities", lines=10)
321
+ with gr.Row(equal_height=True):
322
+ cleaned_out = gr.TextArea(label="Cleaned Text", lines=10)
323
+ summary_out = gr.TextArea(label="Summary", lines=10)
324
+
325
+ gr.Markdown("### Diagnostics", elem_id="diagnostics-header")
326
+ diagnostics = gr.TextArea(label="Diagnostics / Logs", lines=10)
327
+
328
+ # Wire up button
329
+ run_inputs = [
330
+ model_choice, params_checked, instructions_text, context_text,
331
+ txt_paste, txt_upload, json_paste, json_upload,
332
+ max_new_tokens, temperature, top_p
333
+ ]
334
+ run_outputs = [tasks_out, entities_out, cleaned_out, summary_out, diagnostics]
335
+
 
 
 
 
 
 
 
 
336
  run_btn.click(
337
+ fn=run_extraction,
338
+ inputs=run_inputs,
339
+ outputs=run_outputs
 
 
 
 
340
  )
341
 
342
+ return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
 
344
+ demo = build_interface()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
  if __name__ == "__main__":
347
+ # Let Gradio/Spaces choose host & port; this keeps local runs easy too.
348
  demo.launch()