RishiRP commited on
Commit
2c209e6
·
verified ·
1 Parent(s): 5a71496

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +242 -217
app.py CHANGED
@@ -27,14 +27,15 @@ SPACE_CACHE = Path.home() / ".cache" / "huggingface"
27
  SPACE_CACHE.mkdir(parents=True, exist_ok=True)
28
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
 
 
30
  GEN_CONFIG = GenerationConfig(
31
  temperature=0.0,
32
  top_p=1.0,
33
  do_sample=False,
34
- max_new_tokens=96, # small for speed; adjust if needed
35
  )
36
 
37
- # Official UBS label set (strict)
38
  OFFICIAL_LABELS = [
39
  "plan_contact",
40
  "schedule_meeting",
@@ -47,98 +48,104 @@ OFFICIAL_LABELS = [
47
  ]
48
  OFFICIAL_LABELS_TEXT = "\n".join(OFFICIAL_LABELS)
49
 
50
- # Per-label keyword cues (static prompt context to improve recall)
51
- LABEL_KEYWORDS: Dict[str, List[str]] = {
52
- "plan_contact": [
53
- "call back", "get back to you", "i'll get back", "follow up",
54
- "reach out", "contact later", "check in", "touch base", "remind",
55
- "send a note", "drop you a note", "email you", "ping you"
56
- ],
57
- "schedule_meeting": [
58
- "meet", "let's meet", "meeting", "book a meeting", "set up a meeting",
59
- "schedule a call", "schedule something", "appointment", "calendar",
60
- "time slot", "slot", "next week", "tomorrow", "this afternoon",
61
- "find a time", "set a time", "book time"
62
- ],
63
- "update_contact_info_non_postal": [
64
- "phone change", "new phone", "changed phone", "email change", "new email",
65
- "update contact details", "update mobile", "alternate phone", "alternate email",
66
- "wrong email", "wrong phone", "new mobile"
67
- ],
68
- "update_contact_info_postal_address": [
69
- "moved to", "new address", "postal address", "mailing address",
70
- "change of address", "residential address", "address change"
71
- ],
72
- "update_kyc_activity": [
73
- "activity update", "economic activity", "employment status",
74
- "occupation", "job change", "changed jobs", "business activity"
75
- ],
76
- "update_kyc_origin_of_assets": [
77
- "source of funds", "origin of assets", "where money comes from",
78
- "inheritance", "salary", "business income", "asset origin",
79
- "gifted funds", "proceeds from sale"
80
- ],
81
- "update_kyc_purpose_of_businessrelation": [
82
- "purpose of relationship", "why the account", "reason for banking",
83
- "investment purpose", "relationship purpose", "purpose of the relationship"
84
- ],
85
- "update_kyc_total_assets": [
86
- "total assets", "net worth", "assets under ownership",
87
- "portfolio size", "how much you own", "aggregate assets"
88
- ],
89
  }
90
 
91
- # Regex cues to catch phrasing variants
92
- REGEX_CUES: Dict[str, List[str]] = {
93
- "schedule_meeting": [
94
- r"\b(let'?s\s+)?meet(s|ing)?\b",
95
- r"\bbook( a)? (time|slot|meeting)\b",
96
- r"\bschedule( a)? (call|meeting)\b",
97
- r"\b(next week|tomorrow|this (afternoon|evening|morning))\b",
98
- r"\bfind a time\b",
99
- ],
100
  "plan_contact": [
101
- r"\b(i'?ll|get|got)\s+back to you\b",
 
102
  r"\bfollow\s*up\b",
103
  r"\breach out\b",
104
  r"\btouch base\b",
105
- r"\bping you\b",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  ],
107
  }
108
 
109
  # =========================
110
- # Instructions (concise; concatenated to avoid string issues)
111
  # =========================
112
- SYSTEM_PROMPT = (
113
- "You are a precise banking assistant that extracts ACTIONABLE TASKS from "
114
- "client–advisor transcripts. Be conservative with hallucinations but "
115
- "prioritise RECALL: if unsure and the transcript plausibly implies an "
116
- "action, include the label and explain briefly.\n\n"
117
- "Output STRICT JSON only:\n\n"
118
- "{\n"
119
- ' "labels": ["<Label1>", "..."],\n'
120
- ' "tasks": [\n'
121
- ' {"label": "<Label1>", "explanation": "<why>", "evidence": "<quoted text/snippet>"}\n'
122
- " ]\n"
123
- "}\n\n"
124
- "Rules:\n"
125
- "- Use ONLY allowed labels supplied to you. Case-insensitive during reasoning, "
126
- " but output the canonical label text exactly.\n"
127
- "- If none truly apply, return empty lists.\n"
128
- "- Keep explanations concise; put the minimal evidence snippet that justifies the task.\n"
129
- )
130
-
131
  USER_PROMPT_TEMPLATE = (
132
- "Transcript (cleaned):\n"
133
  "```\n{transcript}\n```\n\n"
134
  "Allowed Labels (canonical; use only these):\n"
135
  "{allowed_labels_list}\n\n"
136
- "Context cues (keywords/phrases that often indicate each label):\n"
137
- "{keyword_context}\n\n"
138
- "Instructions:\n"
139
- "- Identify EVERY concrete task implied by the conversation.\n"
140
- "- Choose ONE label from Allowed Labels for each task (or none if truly inapplicable).\n"
141
- "- Return STRICT JSON only in the exact schema described by the system prompt.\n"
142
  )
143
 
144
  # =========================
@@ -171,14 +178,12 @@ def robust_json_extract(text: str) -> Dict[str, Any]:
171
  def restrict_to_allowed(pred: Dict[str, Any], allowed: List[str]) -> Dict[str, Any]:
172
  out = {"labels": [], "tasks": []}
173
  allowed_map = canonicalize_map(allowed)
174
- # labels
175
  filt_labels = []
176
  for l in pred.get("labels", []) or []:
177
  k = str(l).strip().lower()
178
  if k in allowed_map:
179
  filt_labels.append(allowed_map[k])
180
  filt_labels = normalize_labels(filt_labels)
181
- # tasks
182
  filt_tasks = []
183
  for t in pred.get("tasks", []) or []:
184
  if not isinstance(t, dict):
@@ -186,6 +191,11 @@ def restrict_to_allowed(pred: Dict[str, Any], allowed: List[str]) -> Dict[str, A
186
  k = str(t.get("label", "")).strip().lower()
187
  if k in allowed_map:
188
  new_t = dict(t); new_t["label"] = allowed_map[k]
 
 
 
 
 
189
  filt_tasks.append(new_t)
190
  merged = normalize_labels(list(set(filt_labels) | {tt["label"] for tt in filt_tasks}))
191
  out["labels"] = merged
@@ -193,7 +203,7 @@ def restrict_to_allowed(pred: Dict[str, Any], allowed: List[str]) -> Dict[str, A
193
  return out
194
 
195
  # =========================
196
- # Default pre-processing (toggleable)
197
  # =========================
198
  _DISCLAIMER_PATTERNS = [
199
  r"(?is)^\s*(?:disclaimer|legal notice|confidentiality notice).+?(?:\n{2,}|$)",
@@ -206,7 +216,7 @@ _FOOTER_PATTERNS = [
206
  ]
207
  _TIMESTAMP_SPEAKER = [
208
  r"\[\d{1,2}:\d{2}(:\d{2})?\]", # [00:01] or [00:01:02]
209
- r"^\s*(advisor|client)\s*:\s*", # Advisor: / Client:
210
  r"^\s*(speaker\s*\d+)\s*:\s*", # Speaker 1:
211
  ]
212
 
@@ -214,7 +224,6 @@ def clean_transcript(text: str) -> str:
214
  if not text:
215
  return text
216
  s = text
217
- # remove timestamps/speaker prefixes line-wise
218
  lines = []
219
  for ln in s.splitlines():
220
  ln2 = ln
@@ -222,19 +231,15 @@ def clean_transcript(text: str) -> str:
222
  ln2 = re.sub(pat, "", ln2, flags=re.IGNORECASE)
223
  lines.append(ln2)
224
  s = "\n".join(lines)
225
- # remove top disclaimers
226
  for pat in _DISCLAIMER_PATTERNS:
227
  s = re.sub(pat, "", s).strip()
228
- # remove trailing footers
229
  for pat in _FOOTER_PATTERNS:
230
  s = re.sub(pat, "", s)
231
- # collapse whitespace
232
  s = re.sub(r"[ \t]+", " ", s)
233
  s = re.sub(r"\n{3,}", "\n\n", s).strip()
234
  return s
235
 
236
  def read_text_file_any(file_input) -> str:
237
- """Works for gr.File(type='filepath') and raw strings/Path and file-like."""
238
  if not file_input:
239
  return ""
240
  if isinstance(file_input, (str, Path)):
@@ -268,7 +273,7 @@ def truncate_tokens(tokenizer, text: str, max_tokens: int) -> str:
268
  return tokenizer.decode(toks[-max_tokens:], skip_special_tokens=True)
269
 
270
  # =========================
271
- # HF model wrapper
272
  # =========================
273
  class ModelWrapper:
274
  def __init__(self, repo_id: str, hf_token: Optional[str], load_in_4bit: bool):
@@ -306,7 +311,7 @@ class ModelWrapper:
306
 
307
  @torch.inference_mode()
308
  def generate(self, system_prompt: str, user_prompt: str) -> str:
309
- # Build inputs as input_ids=... (avoid **tensor bug)
310
  if hasattr(self.tokenizer, "apply_chat_template"):
311
  messages = [
312
  {"role": "system", "content": system_prompt},
@@ -351,7 +356,7 @@ def get_model(repo_id: str, hf_token: Optional[str], load_in_4bit: bool) -> Mode
351
  return _MODEL_CACHE[key]
352
 
353
  # =========================
354
- # Official evaluation (from README)
355
  # =========================
356
  def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> float:
357
  ALLOWED_LABELS = OFFICIAL_LABELS
@@ -395,62 +400,32 @@ def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> fl
395
  return float(max(0.0, min(1.0, np.mean(per_sample))))
396
 
397
  # =========================
398
- # Fallback: regex + keywords if model returns empty
399
  # =========================
400
- def keyword_fallback(text: str, allowed: List[str]) -> Dict[str, Any]:
401
  low = text.lower()
402
- labels = []
403
- tasks = []
404
-
405
- # Regex first
406
  for lab in allowed:
407
- patterns = REGEX_CUES.get(lab, [])
408
- found = None
409
- for pat in patterns:
410
  m = re.search(pat, low)
411
  if m:
412
  i = m.start()
413
- start = max(0, i - 40); end = min(len(text), i + len(m.group(0)) + 40)
414
- found = text[start:end].strip()
 
 
 
 
 
 
415
  break
416
- if found:
417
- labels.append(lab)
418
- tasks.append({
419
- "label": lab,
420
- "explanation": "Regex cue matched in transcript.",
421
- "evidence": found
422
- })
423
-
424
- # Keyword contains() as backstop
425
- for lab in allowed:
426
- if lab in labels:
427
- continue
428
- hits = []
429
- for kw in LABEL_KEYWORDS.get(lab, []):
430
- k = kw.lower()
431
- i = low.find(k)
432
- if i != -1:
433
- start = max(0, i - 40); end = min(len(text), i + len(k) + 40)
434
- hits.append(text[start:end].strip())
435
- if hits:
436
- labels.append(lab)
437
- tasks.append({
438
- "label": lab,
439
- "explanation": "Keyword match in transcript.",
440
- "evidence": hits[0]
441
- })
442
-
443
  return {"labels": normalize_labels(labels), "tasks": tasks}
444
 
445
  # =========================
446
  # Inference helpers
447
  # =========================
448
- def build_keyword_context(allowed: List[str]) -> str:
449
- parts = []
450
- for lab in allowed:
451
- kws = LABEL_KEYWORDS.get(lab, [])
452
- parts.append(f"- {lab}: " + (", ".join(kws) if kws else "(no default cues)"))
453
- return "\n".join(parts)
454
 
455
  def warmup_model(model_repo: str, use_4bit: bool, hf_token: str) -> str:
456
  t0 = _now_ms()
@@ -463,12 +438,15 @@ def warmup_model(model_repo: str, use_4bit: bool, hf_token: str) -> str:
463
 
464
  def run_single(
465
  transcript_text: str,
466
- transcript_file, # filepath or file-like
467
  gt_json_text: str,
468
- gt_json_file, # filepath or file-like
469
  use_cleaning: bool,
470
- use_keyword_fallback: bool,
471
  allowed_labels_text: str,
 
 
 
472
  model_repo: str,
473
  use_4bit: bool,
474
  max_input_tokens: int,
@@ -477,7 +455,7 @@ def run_single(
477
 
478
  t0 = _now_ms()
479
 
480
- # Transcript
481
  raw_text = ""
482
  if transcript_file:
483
  raw_text = read_text_file_any(transcript_file)
@@ -487,10 +465,28 @@ def run_single(
487
 
488
  text = clean_transcript(raw_text) if use_cleaning else raw_text
489
 
490
- # Allowed labels (pre-filled defaults)
491
  user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()]
492
  allowed = normalize_labels(user_allowed or OFFICIAL_LABELS)
493
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
494
  # Model
495
  try:
496
  model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
@@ -501,12 +497,12 @@ def run_single(
501
  trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
502
 
503
  # Build prompt
 
504
  allowed_list_str = "\n".join(f"- {l}" for l in allowed)
505
- keyword_ctx = build_keyword_context(allowed)
506
  user_prompt = USER_PROMPT_TEMPLATE.format(
507
  transcript=trunc,
508
  allowed_labels_list=allowed_list_str,
509
- keyword_context=keyword_ctx,
510
  )
511
 
512
  # Token info + prompt preview
@@ -518,7 +514,7 @@ def run_single(
518
  # Generate
519
  t1 = _now_ms()
520
  try:
521
- out = model.generate(SYSTEM_PROMPT, user_prompt)
522
  except Exception as e:
523
  return "", "", f"Generation error: {e}", "", "", "", prompt_preview_text, token_info_text, ""
524
  t2 = _now_ms()
@@ -526,33 +522,27 @@ def run_single(
526
  parsed = robust_json_extract(out)
527
  filtered = restrict_to_allowed(parsed, allowed)
528
 
529
- # Fallback if empty
530
- if use_keyword_fallback and not filtered.get("labels"):
531
- fb = keyword_fallback(trunc, allowed)
532
  if fb["labels"]:
533
- filtered = fb
 
 
 
534
 
535
  # Diagnostics
536
  diag = "\n".join([
537
  f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
538
  f"Model: {model_repo}",
539
  f"Input cleaned: {'Yes' if use_cleaning else 'No'}",
540
- f"Keyword fallback: {'Yes' if use_keyword_fallback else 'No'}",
541
- f"Tokens (input, approx): ≤ {max_input_tokens}",
542
  f"Latency: prep {t1-t0} ms, gen {t2-t1} ms, total {t2-t0} ms",
543
  f"Allowed labels: {', '.join(allowed)}",
544
  ])
545
 
546
- # Context & instructions preview shown in UI
547
- context_preview = (
548
- "### Allowed Labels\n"
549
- + "\n".join(f"- {l}" for l in allowed)
550
- + "\n\n### Keyword cues per label\n"
551
- + keyword_ctx
552
- )
553
- instructions_preview = "```\n" + SYSTEM_PROMPT + "\n```"
554
-
555
- # Summary & JSON
556
  labs = filtered.get("labels", [])
557
  tasks = filtered.get("tasks", [])
558
  summary = "Detected labels:\n" + ("\n".join(f"- {l}" for l in labs) if labs else "(none)")
@@ -565,7 +555,7 @@ def run_single(
565
  summary += "\n\nTasks: (none)"
566
  json_out = json.dumps(filtered, indent=2, ensure_ascii=False)
567
 
568
- # Optional single-file scoring if GT provided
569
  metrics = ""
570
  if gt_json_file or (gt_json_text and gt_json_text.strip()):
571
  truth_obj = None
@@ -598,6 +588,10 @@ def run_single(
598
  else:
599
  metrics = "Ground truth JSON missing or invalid; expected {'labels': [...]}."
600
 
 
 
 
 
601
  return summary, json_out, diag, out.strip(), context_preview, instructions_preview, metrics, prompt_preview_text, token_info_text
602
 
603
  # =========================
@@ -612,9 +606,12 @@ def read_zip_from_path(path: str, exdir: Path) -> List[Path]:
612
  return [p for p in exdir.rglob("*") if p.is_file()]
613
 
614
  def run_batch(
615
- zip_path, # filepath string
616
  use_cleaning: bool,
617
- use_keyword_fallback: bool,
 
 
 
618
  model_repo: str,
619
  use_4bit: bool,
620
  max_input_tokens: int,
@@ -625,6 +622,25 @@ def run_batch(
625
  if not zip_path:
626
  return ("No ZIP provided.", "", pd.DataFrame(), "")
627
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
628
  work = Path("/tmp/batch")
629
  if work.exists():
630
  for p in sorted(work.rglob("*"), reverse=True):
@@ -650,14 +666,15 @@ def run_batch(
650
  if not stems:
651
  return ("No .txt transcripts found in ZIP.", "", pd.DataFrame(), "")
652
 
 
653
  try:
654
  model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
655
  except Exception as e:
656
  return (f"Model load failed: {e}", "", pd.DataFrame(), "")
657
 
658
  allowed = OFFICIAL_LABELS[:]
 
659
  allowed_list_str = "\n".join(f"- {l}" for l in allowed)
660
- keyword_ctx = build_keyword_context(allowed)
661
 
662
  y_true, y_pred = [], []
663
  rows = []
@@ -666,25 +683,29 @@ def run_batch(
666
  for stem in stems:
667
  raw = txts[stem].read_text(encoding="utf-8", errors="ignore")
668
  text = clean_transcript(raw) if use_cleaning else raw
 
669
  trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
670
 
671
  user_prompt = USER_PROMPT_TEMPLATE.format(
672
  transcript=trunc,
673
  allowed_labels_list=allowed_list_str,
674
- keyword_context=keyword_ctx,
675
  )
676
 
677
  t0 = _now_ms()
678
- out = model.generate(SYSTEM_PROMPT, user_prompt)
679
  t1 = _now_ms()
680
 
681
  parsed = robust_json_extract(out)
682
  filtered = restrict_to_allowed(parsed, allowed)
683
 
684
- if use_keyword_fallback and not filtered.get("labels"):
685
- fb = keyword_fallback(trunc, allowed)
686
  if fb["labels"]:
687
- filtered = fb
 
 
 
688
 
689
  pred_labels = filtered.get("labels", [])
690
  y_pred.append(pred_labels)
@@ -721,8 +742,8 @@ def run_batch(
721
  f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
722
  f"Model: {model_repo}",
723
  f"Input cleaned: {'Yes' if use_cleaning else 'No'}",
724
- f"Keyword fallback: {'Yes' if use_keyword_fallback else 'No'}",
725
- f"Tokens (input, approx): ≤ {max_input_tokens}",
726
  f"Batch time: {_now_ms()-t_start} ms",
727
  ]
728
  if have_truth and score is not None:
@@ -739,7 +760,6 @@ def run_batch(
739
  ]
740
  diag_str = "\n".join(diag)
741
 
742
- # save CSV for download
743
  out_csv = Path("/tmp/batch_results.csv")
744
  df.to_csv(out_csv, index=False, encoding="utf-8")
745
  return ("Batch done.", diag_str, df, str(out_csv))
@@ -748,24 +768,26 @@ def run_batch(
748
  # UI
749
  # =========================
750
  MODEL_CHOICES = [
751
- "swiss-ai/Apertus-8B-Instruct-2509",
752
- "meta-llama/Meta-Llama-3-8B-Instruct",
753
- "mistralai/Mistral-7B-Instruct-v0.3",
754
  ]
755
 
 
756
  custom_css = """
757
  :root { --radius: 14px; }
758
- .gradio-container { font-family: Inter, ui-sans-serif, system-ui; }
759
- .card { border: 1px solid rgba(255,255,255,.08); border-radius: var(--radius); padding: 14px 16px; background: rgba(255,255,255,.02); box-shadow: 0 1px 10px rgba(0,0,0,.12) inset; }
760
- .header { font-weight: 700; font-size: 22px; margin-bottom: 4px; }
761
- .subtle { color: rgba(255,255,255,.65); font-size: 14px; margin-bottom: 12px; }
762
- hr.sep { border: none; border-top: 1px solid rgba(255,255,255,.08); margin: 10px 0 16px; }
763
  .gr-button { border-radius: 12px !important; }
 
764
  """
765
 
766
  with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo:
767
- gr.Markdown("<div class='header'>Talk2Task — Task Extraction (UBS Challenge)</div>")
768
- gr.Markdown("<div class='subtle'>False negatives are penalised more than false positives in the official score. This UI biases for recall, shows the exact instructions & context, and supports single or batch evaluation.</div>")
769
 
770
  with gr.Tab("Single transcript"):
771
  with gr.Row():
@@ -776,7 +798,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
776
  file_types=[".txt", ".md", ".json"],
777
  type="filepath",
778
  )
779
- text = gr.Textbox(label="Or paste transcript", lines=10)
780
  gr.Markdown("<hr class='sep'/>")
781
 
782
  gr.Markdown("<div class='header'>Ground truth JSON (optional)</div>")
@@ -788,26 +810,22 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
788
  gt_text = gr.Textbox(label="Or paste ground truth JSON", lines=6, placeholder='{\"labels\": [\"schedule_meeting\"]}')
789
  gr.Markdown("</div>") # close card
790
 
791
- gr.Markdown("<div class='card'><div class='header'>Preprocessing & heuristics</div>")
792
- use_cleaning = gr.Checkbox(
793
- label="Apply default cleaning (remove disclaimers, timestamps, speakers, footers)",
794
- value=True,
795
- )
796
- use_keyword_fallback = gr.Checkbox(
797
- label="Keyword fallback if model returns empty",
798
- value=True,
799
- )
800
  gr.Markdown("</div>")
801
 
802
  gr.Markdown("<div class='card'><div class='header'>Allowed labels</div>")
803
- labels_text = gr.Textbox(
804
- label="Allowed Labels (one per line)",
805
- value=OFFICIAL_LABELS_TEXT, # prefilled
806
- lines=8,
807
- )
808
  reset_btn = gr.Button("Reset to official labels")
809
  gr.Markdown("</div>")
810
 
 
 
 
 
 
 
811
  with gr.Column(scale=2):
812
  gr.Markdown("<div class='card'><div class='header'>Model & run</div>")
813
  repo = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
@@ -830,48 +848,48 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
830
  with gr.Row():
831
  with gr.Column():
832
  with gr.Accordion("Instructions used (system prompt)", open=False):
833
- instr_md = gr.Markdown("```\n" + SYSTEM_PROMPT + "\n```")
834
  with gr.Column():
835
- with gr.Accordion("Context used (allowed labels + keyword cues)", open=True):
836
  context_md = gr.Markdown("")
837
 
838
- # reset button behavior
839
  def _reset_labels():
840
  return OFFICIAL_LABELS_TEXT
841
  reset_btn.click(fn=_reset_labels, inputs=None, outputs=labels_text)
842
 
843
- # warm-up
844
- warm_btn.click(
845
- fn=warmup_model,
846
- inputs=[repo, use_4bit, hf_token],
847
- outputs=diag,
848
- )
 
 
 
 
 
849
 
850
- # single run
851
- def _pack_context_md(allowed: str) -> str:
852
- allowed_list = [ln.strip() for ln in (allowed or OFFICIAL_LABELS_TEXT).splitlines() if ln.strip()]
853
- ctx = build_keyword_context(allowed_list)
854
- return "### Allowed Labels\n" + "\n".join(f"- {l}" for l in allowed_list) + "\n\n### Keyword cues per label\n" + ctx
855
 
 
856
  run_btn.click(
857
  fn=run_single,
858
  inputs=[
859
- text, file, gt_text, gt_file, use_cleaning, use_keyword_fallback,
860
- labels_text, repo, use_4bit, max_tokens, hf_token
 
861
  ],
862
  outputs=[summary, json_out, diag, raw, context_md, instr_md, gr.Textbox(visible=False), prompt_preview, token_info],
863
  )
864
 
865
- # initial context preview
866
- context_md.value = _pack_context_md(OFFICIAL_LABELS_TEXT)
867
-
868
  with gr.Tab("Batch evaluation"):
869
  with gr.Row():
870
  with gr.Column(scale=3):
871
  gr.Markdown("<div class='card'><div class='header'>ZIP input</div>")
872
  zip_in = gr.File(label="ZIP with transcripts (.txt) and truths (.json)", file_types=[".zip"], type="filepath")
873
  use_cleaning_b = gr.Checkbox(label="Apply default cleaning", value=True)
874
- use_keyword_fallback_b = gr.Checkbox(label="Keyword fallback if model returns empty", value=True)
875
  gr.Markdown("</div>")
876
  with gr.Column(scale=2):
877
  gr.Markdown("<div class='card'><div class='header'>Model & run</div>")
@@ -879,6 +897,9 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
879
  use_4bit_b = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
880
  max_tokens_b = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=2048)
881
  hf_token_b = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
 
 
 
882
  limit_files = gr.Slider(label="Process at most N files (0 = all)", minimum=0, maximum=2000, step=10, value=0)
883
  run_batch_btn = gr.Button("Run Batch", variant="primary")
884
  gr.Markdown("</div>")
@@ -893,7 +914,11 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
893
 
894
  run_batch_btn.click(
895
  fn=run_batch,
896
- inputs=[zip_in, use_cleaning_b, use_keyword_fallback_b, repo_b, use_4bit_b, max_tokens_b, hf_token_b, limit_files],
 
 
 
 
897
  outputs=[status, diag_b, df_out, csv_out],
898
  )
899
 
 
27
  SPACE_CACHE.mkdir(parents=True, exist_ok=True)
28
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
 
30
+ # Fast, deterministic, compact outputs for lower latency
31
  GEN_CONFIG = GenerationConfig(
32
  temperature=0.0,
33
  top_p=1.0,
34
  do_sample=False,
35
+ max_new_tokens=128, # increase if your JSON is getting truncated
36
  )
37
 
38
+ # Official UBS labels (canonical)
39
  OFFICIAL_LABELS = [
40
  "plan_contact",
41
  "schedule_meeting",
 
48
  ]
49
  OFFICIAL_LABELS_TEXT = "\n".join(OFFICIAL_LABELS)
50
 
51
+ # =========================
52
+ # Editable defaults (shown in UI)
53
+ # =========================
54
+ DEFAULT_SYSTEM_INSTRUCTIONS = (
55
+ "You extract ACTIONABLE TASKS from client–advisor transcripts. "
56
+ "The transcript may be in German, French, Italian, or English. "
57
+ "Prioritize RECALL: if a label plausibly applies, include it. "
58
+ "Use ONLY the canonical labels provided. "
59
+ "Return STRICT JSON only with keys 'labels' and 'tasks'. "
60
+ "Each task must include 'label', a brief 'explanation', and a short 'evidence' quote from the transcript."
61
+ )
62
+
63
+ # Very short, language-agnostic semantics to keep prompt small
64
+ DEFAULT_LABEL_GLOSSARY = {
65
+ "plan_contact": "Commitment to contact later (advisor/client will reach out, follow-up promised).",
66
+ "schedule_meeting": "Scheduling or confirming a meeting/call/appointment (time/date/slot/virtual).",
67
+ "update_contact_info_non_postal": "Change or confirmation of phone/email (non-postal contact details).",
68
+ "update_contact_info_postal_address": "Change or confirmation of postal/residential/mailing address.",
69
+ "update_kyc_activity": "Change/confirmation of occupation, employment status, or economic activity.",
70
+ "update_kyc_origin_of_assets": "Discussion/confirmation of source of funds / origin of assets.",
71
+ "update_kyc_purpose_of_businessrelation": "Purpose of the banking relationship/account usage.",
72
+ "update_kyc_total_assets": "Discussion/confirmation of total assets/net worth.",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  }
74
 
75
+ # Tiny multilingual fallback rules (optional) to guarantee recall if model is empty.
76
+ # Keep small to avoid false positives and keep maintenance low.
77
+ DEFAULT_FALLBACK_CUES = {
 
 
 
 
 
 
78
  "plan_contact": [
79
+ # EN
80
+ r"\b(get|got|will|we'?ll|i'?ll)\s+back to you\b",
81
  r"\bfollow\s*up\b",
82
  r"\breach out\b",
83
  r"\btouch base\b",
84
+ r"\bcontact (you|me|us)\b",
85
+ # DE
86
+ r"\bin verbindung setzen\b",
87
+ r"\brückmeldung\b",
88
+ r"\bich\s+melde\b|\bwir\s+melden\b",
89
+ r"\bnachfassen\b",
90
+ # FR
91
+ r"\bje vous recontacte\b|\bnous vous recontacterons\b",
92
+ r"\bprendre contact\b|\breprendre contact\b",
93
+ # IT
94
+ r"\bla ricontatter[oò]\b|\bci metteremo in contatto\b",
95
+ r"\btenersi in contatto\b",
96
+ ],
97
+ "schedule_meeting": [
98
+ # EN
99
+ r"\b(let'?s\s+)?meet(ing|s)?\b",
100
+ r"\bschedule( a)? (call|meeting|appointment)\b",
101
+ r"\bbook( a)? (slot|time|meeting)\b",
102
+ r"\b(next week|tomorrow|this (afternoon|morning|evening))\b",
103
+ r"\bconfirm( the)? (time|meeting|appointment)\b",
104
+ # DE
105
+ r"\btermin(e|s)?\b|\bvereinbaren\b|\bansetzen\b|\babstimmen\b|\bbesprechung(en)?\b|\bvirtuell(e|en)?\b",
106
+ r"\bnächste(n|r)? woche\b|\b(dienstag|montag|mittwoch|donnerstag|freitag)\b|\bnachmittag|vormittag|morgen\b",
107
+ # FR
108
+ r"\brendez[- ]?vous\b|\bréunion\b|\bfixer\b|\bplanifier\b|\bcalendrier\b|\bse rencontrer\b|\bse voir\b",
109
+ r"\bla semaine prochaine\b|\bdemain\b|\bcet (après-midi|apres-midi|après midi|apres midi|matin|soir)\b",
110
+ # IT
111
+ r"\bappuntamento\b|\briunione\b|\borganizzare\b|\bprogrammare\b|\bincontrarci\b|\bcalendario\b",
112
+ r"\bla prossima settimana\b|\bdomani\b|\b(questo|questa)\s*(pomeriggio|mattina|sera)\b",
113
+ ],
114
+ "update_kyc_origin_of_assets": [
115
+ # EN
116
+ r"\bsource of funds\b|\borigin of assets\b|\bproof of (funds|assets)\b",
117
+ # DE
118
+ r"\bvermögensursprung(e|s)?\b|\bherkunft der mittel\b|\bnachweis\b",
119
+ # FR
120
+ r"\borigine des fonds\b|\borigine du patrimoine\b|\bjustificatif(s)?\b",
121
+ # IT
122
+ r"\borigine dei fondi\b|\borigine del patrimonio\b|\bprova dei fondi\b|\bgiustificativo\b",
123
+ ],
124
+ "update_kyc_activity": [
125
+ # EN
126
+ r"\bemployment status\b|\boccupation\b|\bjob change\b|\bsalary history\b",
127
+ # DE
128
+ r"\bbeschäftigungsstatus\b|\bberuf\b|\bjobwechsel\b|\bgehaltshistorie\b|\btätigkeit\b",
129
+ # FR
130
+ r"\bstatut professionnel\b|\bprofession\b|\bchangement d'emploi\b|\bhistorique salarial\b|\bactivité\b",
131
+ # IT
132
+ r"\bstato occupazionale\b|\bprofessione\b|\bcambio di lavoro\b|\bstoria salariale\b|\battivit[aà]\b",
133
  ],
134
  }
135
 
136
  # =========================
137
+ # Prompt templates (minimal multilingual)
138
  # =========================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  USER_PROMPT_TEMPLATE = (
140
+ "Transcript (may be DE/FR/IT/EN):\n"
141
  "```\n{transcript}\n```\n\n"
142
  "Allowed Labels (canonical; use only these):\n"
143
  "{allowed_labels_list}\n\n"
144
+ "Label Glossary (concise semantics):\n"
145
+ "{glossary}\n\n"
146
+ "Return STRICT JSON ONLY in this exact schema:\n"
147
+ '{\n "labels": ["<Label1>", "..."],\n'
148
+ ' "tasks": [{"label": "<Label1>", "explanation": "<why>", "evidence": "<quote>"}]\n}\n'
 
149
  )
150
 
151
  # =========================
 
178
  def restrict_to_allowed(pred: Dict[str, Any], allowed: List[str]) -> Dict[str, Any]:
179
  out = {"labels": [], "tasks": []}
180
  allowed_map = canonicalize_map(allowed)
 
181
  filt_labels = []
182
  for l in pred.get("labels", []) or []:
183
  k = str(l).strip().lower()
184
  if k in allowed_map:
185
  filt_labels.append(allowed_map[k])
186
  filt_labels = normalize_labels(filt_labels)
 
187
  filt_tasks = []
188
  for t in pred.get("tasks", []) or []:
189
  if not isinstance(t, dict):
 
191
  k = str(t.get("label", "")).strip().lower()
192
  if k in allowed_map:
193
  new_t = dict(t); new_t["label"] = allowed_map[k]
194
+ new_t = {
195
+ "label": new_t["label"],
196
+ "explanation": str(new_t.get("explanation", ""))[:300],
197
+ "evidence": str(new_t.get("evidence", ""))[:300],
198
+ }
199
  filt_tasks.append(new_t)
200
  merged = normalize_labels(list(set(filt_labels) | {tt["label"] for tt in filt_tasks}))
201
  out["labels"] = merged
 
203
  return out
204
 
205
  # =========================
206
+ # Pre-processing
207
  # =========================
208
  _DISCLAIMER_PATTERNS = [
209
  r"(?is)^\s*(?:disclaimer|legal notice|confidentiality notice).+?(?:\n{2,}|$)",
 
216
  ]
217
  _TIMESTAMP_SPEAKER = [
218
  r"\[\d{1,2}:\d{2}(:\d{2})?\]", # [00:01] or [00:01:02]
219
+ r"^\s*(advisor|client|client advisor)\s*:\s*", # Advisor:, Client:
220
  r"^\s*(speaker\s*\d+)\s*:\s*", # Speaker 1:
221
  ]
222
 
 
224
  if not text:
225
  return text
226
  s = text
 
227
  lines = []
228
  for ln in s.splitlines():
229
  ln2 = ln
 
231
  ln2 = re.sub(pat, "", ln2, flags=re.IGNORECASE)
232
  lines.append(ln2)
233
  s = "\n".join(lines)
 
234
  for pat in _DISCLAIMER_PATTERNS:
235
  s = re.sub(pat, "", s).strip()
 
236
  for pat in _FOOTER_PATTERNS:
237
  s = re.sub(pat, "", s)
 
238
  s = re.sub(r"[ \t]+", " ", s)
239
  s = re.sub(r"\n{3,}", "\n\n", s).strip()
240
  return s
241
 
242
  def read_text_file_any(file_input) -> str:
 
243
  if not file_input:
244
  return ""
245
  if isinstance(file_input, (str, Path)):
 
273
  return tokenizer.decode(toks[-max_tokens:], skip_special_tokens=True)
274
 
275
  # =========================
276
+ # HF model wrapper (main LLM)
277
  # =========================
278
  class ModelWrapper:
279
  def __init__(self, repo_id: str, hf_token: Optional[str], load_in_4bit: bool):
 
311
 
312
  @torch.inference_mode()
313
  def generate(self, system_prompt: str, user_prompt: str) -> str:
314
+ # Build inputs as input_ids=... (avoid earlier **tensor bug)
315
  if hasattr(self.tokenizer, "apply_chat_template"):
316
  messages = [
317
  {"role": "system", "content": system_prompt},
 
356
  return _MODEL_CACHE[key]
357
 
358
  # =========================
359
+ # Evaluation (official weighted score)
360
  # =========================
361
  def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> float:
362
  ALLOWED_LABELS = OFFICIAL_LABELS
 
400
  return float(max(0.0, min(1.0, np.mean(per_sample))))
401
 
402
  # =========================
403
+ # Multilingual fallback (regex on original text)
404
  # =========================
405
+ def multilingual_fallback(text: str, allowed: List[str], cues: Dict[str, List[str]]) -> Dict[str, Any]:
406
  low = text.lower()
407
+ labels, tasks = [], []
 
 
 
408
  for lab in allowed:
409
+ for pat in cues.get(lab, []):
 
 
410
  m = re.search(pat, low)
411
  if m:
412
  i = m.start()
413
+ start = max(0, i - 60); end = min(len(text), i + len(m.group(0)) + 60)
414
+ if lab not in labels:
415
+ labels.append(lab)
416
+ tasks.append({
417
+ "label": lab,
418
+ "explanation": "Rule hit (multilingual fallback)",
419
+ "evidence": text[start:end].strip()
420
+ })
421
  break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
  return {"labels": normalize_labels(labels), "tasks": tasks}
423
 
424
  # =========================
425
  # Inference helpers
426
  # =========================
427
+ def build_glossary_str(glossary: Dict[str, str], allowed: List[str]) -> str:
428
+ return "\n".join([f"- {lab}: {glossary.get(lab, '')}" for lab in allowed])
 
 
 
 
429
 
430
  def warmup_model(model_repo: str, use_4bit: bool, hf_token: str) -> str:
431
  t0 = _now_ms()
 
438
 
439
  def run_single(
440
  transcript_text: str,
441
+ transcript_file,
442
  gt_json_text: str,
443
+ gt_json_file,
444
  use_cleaning: bool,
445
+ use_fallback: bool,
446
  allowed_labels_text: str,
447
+ sys_instructions_text: str,
448
+ glossary_json_text: str,
449
+ fallback_json_text: str,
450
  model_repo: str,
451
  use_4bit: bool,
452
  max_input_tokens: int,
 
455
 
456
  t0 = _now_ms()
457
 
458
+ # Load transcript
459
  raw_text = ""
460
  if transcript_file:
461
  raw_text = read_text_file_any(transcript_file)
 
465
 
466
  text = clean_transcript(raw_text) if use_cleaning else raw_text
467
 
468
+ # Allowed labels
469
  user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()]
470
  allowed = normalize_labels(user_allowed or OFFICIAL_LABELS)
471
 
472
+ # Editable configs
473
+ try:
474
+ sys_instructions = (sys_instructions_text or DEFAULT_SYSTEM_INSTRUCTIONS).strip()
475
+ if not sys_instructions:
476
+ sys_instructions = DEFAULT_SYSTEM_INSTRUCTIONS
477
+ except Exception:
478
+ sys_instructions = DEFAULT_SYSTEM_INSTRUCTIONS
479
+
480
+ try:
481
+ label_glossary = json.loads(glossary_json_text) if glossary_json_text else DEFAULT_LABEL_GLOSSARY
482
+ except Exception:
483
+ label_glossary = DEFAULT_LABEL_GLOSSARY
484
+
485
+ try:
486
+ fallback_cues = json.loads(fallback_json_text) if fallback_json_text else DEFAULT_FALLBACK_CUES
487
+ except Exception:
488
+ fallback_cues = DEFAULT_FALLBACK_CUES
489
+
490
  # Model
491
  try:
492
  model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
 
497
  trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
498
 
499
  # Build prompt
500
+ glossary_str = build_glossary_str(label_glossary, allowed)
501
  allowed_list_str = "\n".join(f"- {l}" for l in allowed)
 
502
  user_prompt = USER_PROMPT_TEMPLATE.format(
503
  transcript=trunc,
504
  allowed_labels_list=allowed_list_str,
505
+ glossary=glossary_str,
506
  )
507
 
508
  # Token info + prompt preview
 
514
  # Generate
515
  t1 = _now_ms()
516
  try:
517
+ out = model.generate(sys_instructions, user_prompt)
518
  except Exception as e:
519
  return "", "", f"Generation error: {e}", "", "", "", prompt_preview_text, token_info_text, ""
520
  t2 = _now_ms()
 
522
  parsed = robust_json_extract(out)
523
  filtered = restrict_to_allowed(parsed, allowed)
524
 
525
+ # Fallback (multilingual rules) on original text; merge for recall if enabled
526
+ if use_fallback:
527
+ fb = multilingual_fallback(trunc, allowed, fallback_cues)
528
  if fb["labels"]:
529
+ merged_labels = sorted(list(set(filtered.get("labels", [])) | set(fb["labels"])))
530
+ existing = {tt.get("label") for tt in filtered.get("tasks", [])}
531
+ merged_tasks = filtered.get("tasks", []) + [t for t in fb["tasks"] if t["label"] not in existing]
532
+ filtered = {"labels": merged_labels, "tasks": merged_tasks}
533
 
534
  # Diagnostics
535
  diag = "\n".join([
536
  f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
537
  f"Model: {model_repo}",
538
  f"Input cleaned: {'Yes' if use_cleaning else 'No'}",
539
+ f"Fallback rules: {'Yes' if use_fallback else 'No'}",
540
+ f"Tokens (input limit): ≤ {max_input_tokens}",
541
  f"Latency: prep {t1-t0} ms, gen {t2-t1} ms, total {t2-t0} ms",
542
  f"Allowed labels: {', '.join(allowed)}",
543
  ])
544
 
545
+ # Summaries
 
 
 
 
 
 
 
 
 
546
  labs = filtered.get("labels", [])
547
  tasks = filtered.get("tasks", [])
548
  summary = "Detected labels:\n" + ("\n".join(f"- {l}" for l in labs) if labs else "(none)")
 
555
  summary += "\n\nTasks: (none)"
556
  json_out = json.dumps(filtered, indent=2, ensure_ascii=False)
557
 
558
+ # Single-file scoring if GT provided
559
  metrics = ""
560
  if gt_json_file or (gt_json_text and gt_json_text.strip()):
561
  truth_obj = None
 
588
  else:
589
  metrics = "Ground truth JSON missing or invalid; expected {'labels': [...]}."
590
 
591
+ # For UI: show effective context (glossary) and instructions
592
+ context_preview = "### Label Glossary (used)\n" + "\n".join(f"- {k}: {v}" for k, v in label_glossary.items() if k in allowed)
593
+ instructions_preview = "```\n" + sys_instructions + "\n```"
594
+
595
  return summary, json_out, diag, out.strip(), context_preview, instructions_preview, metrics, prompt_preview_text, token_info_text
596
 
597
  # =========================
 
606
  return [p for p in exdir.rglob("*") if p.is_file()]
607
 
608
  def run_batch(
609
+ zip_path,
610
  use_cleaning: bool,
611
+ use_fallback: bool,
612
+ sys_instructions_text: str,
613
+ glossary_json_text: str,
614
+ fallback_json_text: str,
615
  model_repo: str,
616
  use_4bit: bool,
617
  max_input_tokens: int,
 
622
  if not zip_path:
623
  return ("No ZIP provided.", "", pd.DataFrame(), "")
624
 
625
+ # Editable configs
626
+ try:
627
+ sys_instructions = (sys_instructions_text or DEFAULT_SYSTEM_INSTRUCTIONS).strip()
628
+ if not sys_instructions:
629
+ sys_instructions = DEFAULT_SYSTEM_INSTRUCTIONS
630
+ except Exception:
631
+ sys_instructions = DEFAULT_SYSTEM_INSTRUCTIONS
632
+
633
+ try:
634
+ label_glossary = json.loads(glossary_json_text) if glossary_json_text else DEFAULT_LABEL_GLOSSARY
635
+ except Exception:
636
+ label_glossary = DEFAULT_LABEL_GLOSSARY
637
+
638
+ try:
639
+ fallback_cues = json.loads(fallback_json_text) if fallback_json_text else DEFAULT_FALLBACK_CUES
640
+ except Exception:
641
+ fallback_cues = DEFAULT_FALLBACK_CUES
642
+
643
+ # Prepare workspace
644
  work = Path("/tmp/batch")
645
  if work.exists():
646
  for p in sorted(work.rglob("*"), reverse=True):
 
666
  if not stems:
667
  return ("No .txt transcripts found in ZIP.", "", pd.DataFrame(), "")
668
 
669
+ # Model
670
  try:
671
  model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
672
  except Exception as e:
673
  return (f"Model load failed: {e}", "", pd.DataFrame(), "")
674
 
675
  allowed = OFFICIAL_LABELS[:]
676
+ glossary_str = build_glossary_str(label_glossary, allowed)
677
  allowed_list_str = "\n".join(f"- {l}" for l in allowed)
 
678
 
679
  y_true, y_pred = [], []
680
  rows = []
 
683
  for stem in stems:
684
  raw = txts[stem].read_text(encoding="utf-8", errors="ignore")
685
  text = clean_transcript(raw) if use_cleaning else raw
686
+
687
  trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
688
 
689
  user_prompt = USER_PROMPT_TEMPLATE.format(
690
  transcript=trunc,
691
  allowed_labels_list=allowed_list_str,
692
+ glossary=glossary_str,
693
  )
694
 
695
  t0 = _now_ms()
696
+ out = model.generate(sys_instructions, user_prompt)
697
  t1 = _now_ms()
698
 
699
  parsed = robust_json_extract(out)
700
  filtered = restrict_to_allowed(parsed, allowed)
701
 
702
+ if use_fallback:
703
+ fb = multilingual_fallback(trunc, allowed, fallback_cues)
704
  if fb["labels"]:
705
+ merged_labels = sorted(list(set(filtered.get("labels", [])) | set(fb["labels"])))
706
+ existing = {tt.get("label") for tt in filtered.get("tasks", [])}
707
+ merged_tasks = filtered.get("tasks", []) + [t for t in fb["tasks"] if t["label"] not in existing]
708
+ filtered = {"labels": merged_labels, "tasks": merged_tasks}
709
 
710
  pred_labels = filtered.get("labels", [])
711
  y_pred.append(pred_labels)
 
742
  f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
743
  f"Model: {model_repo}",
744
  f"Input cleaned: {'Yes' if use_cleaning else 'No'}",
745
+ f"Fallback rules: {'Yes' if use_fallback else 'No'}",
746
+ f"Tokens (input limit): ≤ {max_input_tokens}",
747
  f"Batch time: {_now_ms()-t_start} ms",
748
  ]
749
  if have_truth and score is not None:
 
760
  ]
761
  diag_str = "\n".join(diag)
762
 
 
763
  out_csv = Path("/tmp/batch_results.csv")
764
  df.to_csv(out_csv, index=False, encoding="utf-8")
765
  return ("Batch done.", diag_str, df, str(out_csv))
 
768
  # UI
769
  # =========================
770
  MODEL_CHOICES = [
771
+ "swiss-ai/Apertus-8B-Instruct-2509", # multilingual
772
+ "meta-llama/Meta-Llama-3-8B-Instruct", # strong generalist
773
+ "mistralai/Mistral-7B-Instruct-v0.3", # light/fast
774
  ]
775
 
776
+ # Light, modern UI (white background, neutral accents)
777
  custom_css = """
778
  :root { --radius: 14px; }
779
+ .gradio-container { font-family: Inter, ui-sans-serif, system-ui; background: #ffffff; color: #111827; }
780
+ .card { border: 1px solid #e5e7eb; border-radius: var(--radius); padding: 14px 16px; background: #ffffff; box-shadow: 0 1px 2px rgba(0,0,0,.03); }
781
+ .header { font-weight: 700; font-size: 22px; margin-bottom: 4px; color: #0f172a; }
782
+ .subtle { color: #475569; font-size: 14px; margin-bottom: 12px; }
783
+ hr.sep { border: none; border-top: 1px solid #e5e7eb; margin: 10px 0 16px; }
784
  .gr-button { border-radius: 12px !important; }
785
+ a, .prose a { color: #0ea5e9; }
786
  """
787
 
788
  with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo:
789
+ gr.Markdown("<div class='header'>Talk2Task — Multilingual Task Extraction (UBS Challenge)</div>")
790
+ gr.Markdown("<div class='subtle'>Single-pass multilingual extraction (DE/FR/IT/EN) with compact prompts. Optional rule fallback ensures recall. Batch evaluation & scoring included.</div>")
791
 
792
  with gr.Tab("Single transcript"):
793
  with gr.Row():
 
798
  file_types=[".txt", ".md", ".json"],
799
  type="filepath",
800
  )
801
+ text = gr.Textbox(label="Or paste transcript", lines=10, placeholder="Paste transcript in DE/FR/IT/EN…")
802
  gr.Markdown("<hr class='sep'/>")
803
 
804
  gr.Markdown("<div class='header'>Ground truth JSON (optional)</div>")
 
810
  gt_text = gr.Textbox(label="Or paste ground truth JSON", lines=6, placeholder='{\"labels\": [\"schedule_meeting\"]}')
811
  gr.Markdown("</div>") # close card
812
 
813
+ gr.Markdown("<div class='card'><div class='header'>Processing options</div>")
814
+ use_cleaning = gr.Checkbox(label="Apply default cleaning (remove disclaimers, timestamps, speakers, footers)", value=True)
815
+ use_fallback = gr.Checkbox(label="Enable multilingual fallback rule layer", value=True)
 
 
 
 
 
 
816
  gr.Markdown("</div>")
817
 
818
  gr.Markdown("<div class='card'><div class='header'>Allowed labels</div>")
819
+ labels_text = gr.Textbox(label="Allowed Labels (one per line)", value=OFFICIAL_LABELS_TEXT, lines=8)
 
 
 
 
820
  reset_btn = gr.Button("Reset to official labels")
821
  gr.Markdown("</div>")
822
 
823
+ gr.Markdown("<div class='card'><div class='header'>Editable instructions & context</div>")
824
+ sys_instr_tb = gr.Textbox(label="System Instructions (editable)", value=DEFAULT_SYSTEM_INSTRUCTIONS, lines=5)
825
+ glossary_tb = gr.Code(label="Label Glossary (JSON; editable)", value=json.dumps(DEFAULT_LABEL_GLOSSARY, indent=2), language="json")
826
+ fallback_tb = gr.Code(label="Fallback Cues (Multilingual, JSON; editable)", value=json.dumps(DEFAULT_FALLBACK_CUES, indent=2), language="json")
827
+ gr.Markdown("</div>")
828
+
829
  with gr.Column(scale=2):
830
  gr.Markdown("<div class='card'><div class='header'>Model & run</div>")
831
  repo = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
 
848
  with gr.Row():
849
  with gr.Column():
850
  with gr.Accordion("Instructions used (system prompt)", open=False):
851
+ instr_md = gr.Markdown("```\n" + DEFAULT_SYSTEM_INSTRUCTIONS + "\n```")
852
  with gr.Column():
853
+ with gr.Accordion("Context used (glossary)", open=True):
854
  context_md = gr.Markdown("")
855
 
856
+ # Reset labels to official
857
  def _reset_labels():
858
  return OFFICIAL_LABELS_TEXT
859
  reset_btn.click(fn=_reset_labels, inputs=None, outputs=labels_text)
860
 
861
+ # Warm-up
862
+ warm_btn.click(fn=warmup_model, inputs=[repo, use_4bit, hf_token], outputs=diag)
863
+
864
+ # For initial context preview
865
+ def _pack_context_md(glossary_json, allowed_text):
866
+ try:
867
+ glossary = json.loads(glossary_json) if glossary_json else DEFAULT_LABEL_GLOSSARY
868
+ except Exception:
869
+ glossary = DEFAULT_LABEL_GLOSSARY
870
+ allowed_list = [ln.strip() for ln in (allowed_text or OFFICIAL_LABELS_TEXT).splitlines() if ln.strip()]
871
+ return "### Label Glossary (used)\n" + "\n".join(f"- {k}: {glossary.get(k,'')}" for k in allowed_list)
872
 
873
+ context_md.value = _pack_context_md(json.dumps(DEFAULT_LABEL_GLOSSARY), OFFICIAL_LABELS_TEXT)
 
 
 
 
874
 
875
+ # Single run
876
  run_btn.click(
877
  fn=run_single,
878
  inputs=[
879
+ text, file, gt_text, gt_file, use_cleaning, use_fallback,
880
+ labels_text, sys_instr_tb, glossary_tb, fallback_tb,
881
+ repo, use_4bit, max_tokens, hf_token
882
  ],
883
  outputs=[summary, json_out, diag, raw, context_md, instr_md, gr.Textbox(visible=False), prompt_preview, token_info],
884
  )
885
 
 
 
 
886
  with gr.Tab("Batch evaluation"):
887
  with gr.Row():
888
  with gr.Column(scale=3):
889
  gr.Markdown("<div class='card'><div class='header'>ZIP input</div>")
890
  zip_in = gr.File(label="ZIP with transcripts (.txt) and truths (.json)", file_types=[".zip"], type="filepath")
891
  use_cleaning_b = gr.Checkbox(label="Apply default cleaning", value=True)
892
+ use_fallback_b = gr.Checkbox(label="Enable multilingual fallback rule layer", value=True)
893
  gr.Markdown("</div>")
894
  with gr.Column(scale=2):
895
  gr.Markdown("<div class='card'><div class='header'>Model & run</div>")
 
897
  use_4bit_b = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
898
  max_tokens_b = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=2048)
899
  hf_token_b = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
900
+ sys_instr_tb_b = gr.Textbox(label="System Instructions (editable for batch)", value=DEFAULT_SYSTEM_INSTRUCTIONS, lines=4)
901
+ glossary_tb_b = gr.Code(label="Label Glossary (JSON; editable for batch)", value=json.dumps(DEFAULT_LABEL_GLOSSARY, indent=2), language="json")
902
+ fallback_tb_b = gr.Code(label="Fallback Cues (Multilingual, JSON; editable for batch)", value=json.dumps(DEFAULT_FALLBACK_CUES, indent=2), language="json")
903
  limit_files = gr.Slider(label="Process at most N files (0 = all)", minimum=0, maximum=2000, step=10, value=0)
904
  run_batch_btn = gr.Button("Run Batch", variant="primary")
905
  gr.Markdown("</div>")
 
914
 
915
  run_batch_btn.click(
916
  fn=run_batch,
917
+ inputs=[
918
+ zip_in, use_cleaning_b, use_fallback_b,
919
+ sys_instr_tb_b, glossary_tb_b, fallback_tb_b,
920
+ repo_b, use_4bit_b, max_tokens_b, hf_token_b, limit_files
921
+ ],
922
  outputs=[status, diag_b, df_out, csv_out],
923
  )
924