Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -27,14 +27,15 @@ SPACE_CACHE = Path.home() / ".cache" / "huggingface"
|
|
| 27 |
SPACE_CACHE.mkdir(parents=True, exist_ok=True)
|
| 28 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 29 |
|
|
|
|
| 30 |
GEN_CONFIG = GenerationConfig(
|
| 31 |
temperature=0.0,
|
| 32 |
top_p=1.0,
|
| 33 |
do_sample=False,
|
| 34 |
-
max_new_tokens=
|
| 35 |
)
|
| 36 |
|
| 37 |
-
# Official UBS
|
| 38 |
OFFICIAL_LABELS = [
|
| 39 |
"plan_contact",
|
| 40 |
"schedule_meeting",
|
|
@@ -47,98 +48,104 @@ OFFICIAL_LABELS = [
|
|
| 47 |
]
|
| 48 |
OFFICIAL_LABELS_TEXT = "\n".join(OFFICIAL_LABELS)
|
| 49 |
|
| 50 |
-
#
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
"
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
"
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
"update_kyc_activity": [
|
| 73 |
-
"activity update", "economic activity", "employment status",
|
| 74 |
-
"occupation", "job change", "changed jobs", "business activity"
|
| 75 |
-
],
|
| 76 |
-
"update_kyc_origin_of_assets": [
|
| 77 |
-
"source of funds", "origin of assets", "where money comes from",
|
| 78 |
-
"inheritance", "salary", "business income", "asset origin",
|
| 79 |
-
"gifted funds", "proceeds from sale"
|
| 80 |
-
],
|
| 81 |
-
"update_kyc_purpose_of_businessrelation": [
|
| 82 |
-
"purpose of relationship", "why the account", "reason for banking",
|
| 83 |
-
"investment purpose", "relationship purpose", "purpose of the relationship"
|
| 84 |
-
],
|
| 85 |
-
"update_kyc_total_assets": [
|
| 86 |
-
"total assets", "net worth", "assets under ownership",
|
| 87 |
-
"portfolio size", "how much you own", "aggregate assets"
|
| 88 |
-
],
|
| 89 |
}
|
| 90 |
|
| 91 |
-
#
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
r"\b(let'?s\s+)?meet(s|ing)?\b",
|
| 95 |
-
r"\bbook( a)? (time|slot|meeting)\b",
|
| 96 |
-
r"\bschedule( a)? (call|meeting)\b",
|
| 97 |
-
r"\b(next week|tomorrow|this (afternoon|evening|morning))\b",
|
| 98 |
-
r"\bfind a time\b",
|
| 99 |
-
],
|
| 100 |
"plan_contact": [
|
| 101 |
-
|
|
|
|
| 102 |
r"\bfollow\s*up\b",
|
| 103 |
r"\breach out\b",
|
| 104 |
r"\btouch base\b",
|
| 105 |
-
r"\
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
],
|
| 107 |
}
|
| 108 |
|
| 109 |
# =========================
|
| 110 |
-
#
|
| 111 |
# =========================
|
| 112 |
-
SYSTEM_PROMPT = (
|
| 113 |
-
"You are a precise banking assistant that extracts ACTIONABLE TASKS from "
|
| 114 |
-
"client–advisor transcripts. Be conservative with hallucinations but "
|
| 115 |
-
"prioritise RECALL: if unsure and the transcript plausibly implies an "
|
| 116 |
-
"action, include the label and explain briefly.\n\n"
|
| 117 |
-
"Output STRICT JSON only:\n\n"
|
| 118 |
-
"{\n"
|
| 119 |
-
' "labels": ["<Label1>", "..."],\n'
|
| 120 |
-
' "tasks": [\n'
|
| 121 |
-
' {"label": "<Label1>", "explanation": "<why>", "evidence": "<quoted text/snippet>"}\n'
|
| 122 |
-
" ]\n"
|
| 123 |
-
"}\n\n"
|
| 124 |
-
"Rules:\n"
|
| 125 |
-
"- Use ONLY allowed labels supplied to you. Case-insensitive during reasoning, "
|
| 126 |
-
" but output the canonical label text exactly.\n"
|
| 127 |
-
"- If none truly apply, return empty lists.\n"
|
| 128 |
-
"- Keep explanations concise; put the minimal evidence snippet that justifies the task.\n"
|
| 129 |
-
)
|
| 130 |
-
|
| 131 |
USER_PROMPT_TEMPLATE = (
|
| 132 |
-
"Transcript (
|
| 133 |
"```\n{transcript}\n```\n\n"
|
| 134 |
"Allowed Labels (canonical; use only these):\n"
|
| 135 |
"{allowed_labels_list}\n\n"
|
| 136 |
-
"
|
| 137 |
-
"{
|
| 138 |
-
"
|
| 139 |
-
"
|
| 140 |
-
"
|
| 141 |
-
"- Return STRICT JSON only in the exact schema described by the system prompt.\n"
|
| 142 |
)
|
| 143 |
|
| 144 |
# =========================
|
|
@@ -171,14 +178,12 @@ def robust_json_extract(text: str) -> Dict[str, Any]:
|
|
| 171 |
def restrict_to_allowed(pred: Dict[str, Any], allowed: List[str]) -> Dict[str, Any]:
|
| 172 |
out = {"labels": [], "tasks": []}
|
| 173 |
allowed_map = canonicalize_map(allowed)
|
| 174 |
-
# labels
|
| 175 |
filt_labels = []
|
| 176 |
for l in pred.get("labels", []) or []:
|
| 177 |
k = str(l).strip().lower()
|
| 178 |
if k in allowed_map:
|
| 179 |
filt_labels.append(allowed_map[k])
|
| 180 |
filt_labels = normalize_labels(filt_labels)
|
| 181 |
-
# tasks
|
| 182 |
filt_tasks = []
|
| 183 |
for t in pred.get("tasks", []) or []:
|
| 184 |
if not isinstance(t, dict):
|
|
@@ -186,6 +191,11 @@ def restrict_to_allowed(pred: Dict[str, Any], allowed: List[str]) -> Dict[str, A
|
|
| 186 |
k = str(t.get("label", "")).strip().lower()
|
| 187 |
if k in allowed_map:
|
| 188 |
new_t = dict(t); new_t["label"] = allowed_map[k]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
filt_tasks.append(new_t)
|
| 190 |
merged = normalize_labels(list(set(filt_labels) | {tt["label"] for tt in filt_tasks}))
|
| 191 |
out["labels"] = merged
|
|
@@ -193,7 +203,7 @@ def restrict_to_allowed(pred: Dict[str, Any], allowed: List[str]) -> Dict[str, A
|
|
| 193 |
return out
|
| 194 |
|
| 195 |
# =========================
|
| 196 |
-
#
|
| 197 |
# =========================
|
| 198 |
_DISCLAIMER_PATTERNS = [
|
| 199 |
r"(?is)^\s*(?:disclaimer|legal notice|confidentiality notice).+?(?:\n{2,}|$)",
|
|
@@ -206,7 +216,7 @@ _FOOTER_PATTERNS = [
|
|
| 206 |
]
|
| 207 |
_TIMESTAMP_SPEAKER = [
|
| 208 |
r"\[\d{1,2}:\d{2}(:\d{2})?\]", # [00:01] or [00:01:02]
|
| 209 |
-
r"^\s*(advisor|client)\s*:\s*", # Advisor
|
| 210 |
r"^\s*(speaker\s*\d+)\s*:\s*", # Speaker 1:
|
| 211 |
]
|
| 212 |
|
|
@@ -214,7 +224,6 @@ def clean_transcript(text: str) -> str:
|
|
| 214 |
if not text:
|
| 215 |
return text
|
| 216 |
s = text
|
| 217 |
-
# remove timestamps/speaker prefixes line-wise
|
| 218 |
lines = []
|
| 219 |
for ln in s.splitlines():
|
| 220 |
ln2 = ln
|
|
@@ -222,19 +231,15 @@ def clean_transcript(text: str) -> str:
|
|
| 222 |
ln2 = re.sub(pat, "", ln2, flags=re.IGNORECASE)
|
| 223 |
lines.append(ln2)
|
| 224 |
s = "\n".join(lines)
|
| 225 |
-
# remove top disclaimers
|
| 226 |
for pat in _DISCLAIMER_PATTERNS:
|
| 227 |
s = re.sub(pat, "", s).strip()
|
| 228 |
-
# remove trailing footers
|
| 229 |
for pat in _FOOTER_PATTERNS:
|
| 230 |
s = re.sub(pat, "", s)
|
| 231 |
-
# collapse whitespace
|
| 232 |
s = re.sub(r"[ \t]+", " ", s)
|
| 233 |
s = re.sub(r"\n{3,}", "\n\n", s).strip()
|
| 234 |
return s
|
| 235 |
|
| 236 |
def read_text_file_any(file_input) -> str:
|
| 237 |
-
"""Works for gr.File(type='filepath') and raw strings/Path and file-like."""
|
| 238 |
if not file_input:
|
| 239 |
return ""
|
| 240 |
if isinstance(file_input, (str, Path)):
|
|
@@ -268,7 +273,7 @@ def truncate_tokens(tokenizer, text: str, max_tokens: int) -> str:
|
|
| 268 |
return tokenizer.decode(toks[-max_tokens:], skip_special_tokens=True)
|
| 269 |
|
| 270 |
# =========================
|
| 271 |
-
# HF model wrapper
|
| 272 |
# =========================
|
| 273 |
class ModelWrapper:
|
| 274 |
def __init__(self, repo_id: str, hf_token: Optional[str], load_in_4bit: bool):
|
|
@@ -306,7 +311,7 @@ class ModelWrapper:
|
|
| 306 |
|
| 307 |
@torch.inference_mode()
|
| 308 |
def generate(self, system_prompt: str, user_prompt: str) -> str:
|
| 309 |
-
# Build inputs as input_ids=... (avoid **tensor bug)
|
| 310 |
if hasattr(self.tokenizer, "apply_chat_template"):
|
| 311 |
messages = [
|
| 312 |
{"role": "system", "content": system_prompt},
|
|
@@ -351,7 +356,7 @@ def get_model(repo_id: str, hf_token: Optional[str], load_in_4bit: bool) -> Mode
|
|
| 351 |
return _MODEL_CACHE[key]
|
| 352 |
|
| 353 |
# =========================
|
| 354 |
-
#
|
| 355 |
# =========================
|
| 356 |
def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> float:
|
| 357 |
ALLOWED_LABELS = OFFICIAL_LABELS
|
|
@@ -395,62 +400,32 @@ def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> fl
|
|
| 395 |
return float(max(0.0, min(1.0, np.mean(per_sample))))
|
| 396 |
|
| 397 |
# =========================
|
| 398 |
-
#
|
| 399 |
# =========================
|
| 400 |
-
def
|
| 401 |
low = text.lower()
|
| 402 |
-
labels = []
|
| 403 |
-
tasks = []
|
| 404 |
-
|
| 405 |
-
# Regex first
|
| 406 |
for lab in allowed:
|
| 407 |
-
|
| 408 |
-
found = None
|
| 409 |
-
for pat in patterns:
|
| 410 |
m = re.search(pat, low)
|
| 411 |
if m:
|
| 412 |
i = m.start()
|
| 413 |
-
start = max(0, i -
|
| 414 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
break
|
| 416 |
-
if found:
|
| 417 |
-
labels.append(lab)
|
| 418 |
-
tasks.append({
|
| 419 |
-
"label": lab,
|
| 420 |
-
"explanation": "Regex cue matched in transcript.",
|
| 421 |
-
"evidence": found
|
| 422 |
-
})
|
| 423 |
-
|
| 424 |
-
# Keyword contains() as backstop
|
| 425 |
-
for lab in allowed:
|
| 426 |
-
if lab in labels:
|
| 427 |
-
continue
|
| 428 |
-
hits = []
|
| 429 |
-
for kw in LABEL_KEYWORDS.get(lab, []):
|
| 430 |
-
k = kw.lower()
|
| 431 |
-
i = low.find(k)
|
| 432 |
-
if i != -1:
|
| 433 |
-
start = max(0, i - 40); end = min(len(text), i + len(k) + 40)
|
| 434 |
-
hits.append(text[start:end].strip())
|
| 435 |
-
if hits:
|
| 436 |
-
labels.append(lab)
|
| 437 |
-
tasks.append({
|
| 438 |
-
"label": lab,
|
| 439 |
-
"explanation": "Keyword match in transcript.",
|
| 440 |
-
"evidence": hits[0]
|
| 441 |
-
})
|
| 442 |
-
|
| 443 |
return {"labels": normalize_labels(labels), "tasks": tasks}
|
| 444 |
|
| 445 |
# =========================
|
| 446 |
# Inference helpers
|
| 447 |
# =========================
|
| 448 |
-
def
|
| 449 |
-
|
| 450 |
-
for lab in allowed:
|
| 451 |
-
kws = LABEL_KEYWORDS.get(lab, [])
|
| 452 |
-
parts.append(f"- {lab}: " + (", ".join(kws) if kws else "(no default cues)"))
|
| 453 |
-
return "\n".join(parts)
|
| 454 |
|
| 455 |
def warmup_model(model_repo: str, use_4bit: bool, hf_token: str) -> str:
|
| 456 |
t0 = _now_ms()
|
|
@@ -463,12 +438,15 @@ def warmup_model(model_repo: str, use_4bit: bool, hf_token: str) -> str:
|
|
| 463 |
|
| 464 |
def run_single(
|
| 465 |
transcript_text: str,
|
| 466 |
-
transcript_file,
|
| 467 |
gt_json_text: str,
|
| 468 |
-
gt_json_file,
|
| 469 |
use_cleaning: bool,
|
| 470 |
-
|
| 471 |
allowed_labels_text: str,
|
|
|
|
|
|
|
|
|
|
| 472 |
model_repo: str,
|
| 473 |
use_4bit: bool,
|
| 474 |
max_input_tokens: int,
|
|
@@ -477,7 +455,7 @@ def run_single(
|
|
| 477 |
|
| 478 |
t0 = _now_ms()
|
| 479 |
|
| 480 |
-
#
|
| 481 |
raw_text = ""
|
| 482 |
if transcript_file:
|
| 483 |
raw_text = read_text_file_any(transcript_file)
|
|
@@ -487,10 +465,28 @@ def run_single(
|
|
| 487 |
|
| 488 |
text = clean_transcript(raw_text) if use_cleaning else raw_text
|
| 489 |
|
| 490 |
-
# Allowed labels
|
| 491 |
user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()]
|
| 492 |
allowed = normalize_labels(user_allowed or OFFICIAL_LABELS)
|
| 493 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
# Model
|
| 495 |
try:
|
| 496 |
model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
|
|
@@ -501,12 +497,12 @@ def run_single(
|
|
| 501 |
trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
|
| 502 |
|
| 503 |
# Build prompt
|
|
|
|
| 504 |
allowed_list_str = "\n".join(f"- {l}" for l in allowed)
|
| 505 |
-
keyword_ctx = build_keyword_context(allowed)
|
| 506 |
user_prompt = USER_PROMPT_TEMPLATE.format(
|
| 507 |
transcript=trunc,
|
| 508 |
allowed_labels_list=allowed_list_str,
|
| 509 |
-
|
| 510 |
)
|
| 511 |
|
| 512 |
# Token info + prompt preview
|
|
@@ -518,7 +514,7 @@ def run_single(
|
|
| 518 |
# Generate
|
| 519 |
t1 = _now_ms()
|
| 520 |
try:
|
| 521 |
-
out = model.generate(
|
| 522 |
except Exception as e:
|
| 523 |
return "", "", f"Generation error: {e}", "", "", "", prompt_preview_text, token_info_text, ""
|
| 524 |
t2 = _now_ms()
|
|
@@ -526,33 +522,27 @@ def run_single(
|
|
| 526 |
parsed = robust_json_extract(out)
|
| 527 |
filtered = restrict_to_allowed(parsed, allowed)
|
| 528 |
|
| 529 |
-
# Fallback if
|
| 530 |
-
if
|
| 531 |
-
fb =
|
| 532 |
if fb["labels"]:
|
| 533 |
-
|
|
|
|
|
|
|
|
|
|
| 534 |
|
| 535 |
# Diagnostics
|
| 536 |
diag = "\n".join([
|
| 537 |
f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
|
| 538 |
f"Model: {model_repo}",
|
| 539 |
f"Input cleaned: {'Yes' if use_cleaning else 'No'}",
|
| 540 |
-
f"
|
| 541 |
-
f"Tokens (input
|
| 542 |
f"Latency: prep {t1-t0} ms, gen {t2-t1} ms, total {t2-t0} ms",
|
| 543 |
f"Allowed labels: {', '.join(allowed)}",
|
| 544 |
])
|
| 545 |
|
| 546 |
-
#
|
| 547 |
-
context_preview = (
|
| 548 |
-
"### Allowed Labels\n"
|
| 549 |
-
+ "\n".join(f"- {l}" for l in allowed)
|
| 550 |
-
+ "\n\n### Keyword cues per label\n"
|
| 551 |
-
+ keyword_ctx
|
| 552 |
-
)
|
| 553 |
-
instructions_preview = "```\n" + SYSTEM_PROMPT + "\n```"
|
| 554 |
-
|
| 555 |
-
# Summary & JSON
|
| 556 |
labs = filtered.get("labels", [])
|
| 557 |
tasks = filtered.get("tasks", [])
|
| 558 |
summary = "Detected labels:\n" + ("\n".join(f"- {l}" for l in labs) if labs else "(none)")
|
|
@@ -565,7 +555,7 @@ def run_single(
|
|
| 565 |
summary += "\n\nTasks: (none)"
|
| 566 |
json_out = json.dumps(filtered, indent=2, ensure_ascii=False)
|
| 567 |
|
| 568 |
-
#
|
| 569 |
metrics = ""
|
| 570 |
if gt_json_file or (gt_json_text and gt_json_text.strip()):
|
| 571 |
truth_obj = None
|
|
@@ -598,6 +588,10 @@ def run_single(
|
|
| 598 |
else:
|
| 599 |
metrics = "Ground truth JSON missing or invalid; expected {'labels': [...]}."
|
| 600 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 601 |
return summary, json_out, diag, out.strip(), context_preview, instructions_preview, metrics, prompt_preview_text, token_info_text
|
| 602 |
|
| 603 |
# =========================
|
|
@@ -612,9 +606,12 @@ def read_zip_from_path(path: str, exdir: Path) -> List[Path]:
|
|
| 612 |
return [p for p in exdir.rglob("*") if p.is_file()]
|
| 613 |
|
| 614 |
def run_batch(
|
| 615 |
-
zip_path,
|
| 616 |
use_cleaning: bool,
|
| 617 |
-
|
|
|
|
|
|
|
|
|
|
| 618 |
model_repo: str,
|
| 619 |
use_4bit: bool,
|
| 620 |
max_input_tokens: int,
|
|
@@ -625,6 +622,25 @@ def run_batch(
|
|
| 625 |
if not zip_path:
|
| 626 |
return ("No ZIP provided.", "", pd.DataFrame(), "")
|
| 627 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 628 |
work = Path("/tmp/batch")
|
| 629 |
if work.exists():
|
| 630 |
for p in sorted(work.rglob("*"), reverse=True):
|
|
@@ -650,14 +666,15 @@ def run_batch(
|
|
| 650 |
if not stems:
|
| 651 |
return ("No .txt transcripts found in ZIP.", "", pd.DataFrame(), "")
|
| 652 |
|
|
|
|
| 653 |
try:
|
| 654 |
model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
|
| 655 |
except Exception as e:
|
| 656 |
return (f"Model load failed: {e}", "", pd.DataFrame(), "")
|
| 657 |
|
| 658 |
allowed = OFFICIAL_LABELS[:]
|
|
|
|
| 659 |
allowed_list_str = "\n".join(f"- {l}" for l in allowed)
|
| 660 |
-
keyword_ctx = build_keyword_context(allowed)
|
| 661 |
|
| 662 |
y_true, y_pred = [], []
|
| 663 |
rows = []
|
|
@@ -666,25 +683,29 @@ def run_batch(
|
|
| 666 |
for stem in stems:
|
| 667 |
raw = txts[stem].read_text(encoding="utf-8", errors="ignore")
|
| 668 |
text = clean_transcript(raw) if use_cleaning else raw
|
|
|
|
| 669 |
trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
|
| 670 |
|
| 671 |
user_prompt = USER_PROMPT_TEMPLATE.format(
|
| 672 |
transcript=trunc,
|
| 673 |
allowed_labels_list=allowed_list_str,
|
| 674 |
-
|
| 675 |
)
|
| 676 |
|
| 677 |
t0 = _now_ms()
|
| 678 |
-
out = model.generate(
|
| 679 |
t1 = _now_ms()
|
| 680 |
|
| 681 |
parsed = robust_json_extract(out)
|
| 682 |
filtered = restrict_to_allowed(parsed, allowed)
|
| 683 |
|
| 684 |
-
if
|
| 685 |
-
fb =
|
| 686 |
if fb["labels"]:
|
| 687 |
-
|
|
|
|
|
|
|
|
|
|
| 688 |
|
| 689 |
pred_labels = filtered.get("labels", [])
|
| 690 |
y_pred.append(pred_labels)
|
|
@@ -721,8 +742,8 @@ def run_batch(
|
|
| 721 |
f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
|
| 722 |
f"Model: {model_repo}",
|
| 723 |
f"Input cleaned: {'Yes' if use_cleaning else 'No'}",
|
| 724 |
-
f"
|
| 725 |
-
f"Tokens (input
|
| 726 |
f"Batch time: {_now_ms()-t_start} ms",
|
| 727 |
]
|
| 728 |
if have_truth and score is not None:
|
|
@@ -739,7 +760,6 @@ def run_batch(
|
|
| 739 |
]
|
| 740 |
diag_str = "\n".join(diag)
|
| 741 |
|
| 742 |
-
# save CSV for download
|
| 743 |
out_csv = Path("/tmp/batch_results.csv")
|
| 744 |
df.to_csv(out_csv, index=False, encoding="utf-8")
|
| 745 |
return ("Batch done.", diag_str, df, str(out_csv))
|
|
@@ -748,24 +768,26 @@ def run_batch(
|
|
| 748 |
# UI
|
| 749 |
# =========================
|
| 750 |
MODEL_CHOICES = [
|
| 751 |
-
"swiss-ai/Apertus-8B-Instruct-2509",
|
| 752 |
-
"meta-llama/Meta-Llama-3-8B-Instruct",
|
| 753 |
-
"mistralai/Mistral-7B-Instruct-v0.3",
|
| 754 |
]
|
| 755 |
|
|
|
|
| 756 |
custom_css = """
|
| 757 |
:root { --radius: 14px; }
|
| 758 |
-
.gradio-container { font-family: Inter, ui-sans-serif, system-ui; }
|
| 759 |
-
.card { border: 1px solid
|
| 760 |
-
.header { font-weight: 700; font-size: 22px; margin-bottom: 4px; }
|
| 761 |
-
.subtle { color:
|
| 762 |
-
hr.sep { border: none; border-top: 1px solid
|
| 763 |
.gr-button { border-radius: 12px !important; }
|
|
|
|
| 764 |
"""
|
| 765 |
|
| 766 |
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo:
|
| 767 |
-
gr.Markdown("<div class='header'>Talk2Task — Task Extraction (UBS Challenge)</div>")
|
| 768 |
-
gr.Markdown("<div class='subtle'>
|
| 769 |
|
| 770 |
with gr.Tab("Single transcript"):
|
| 771 |
with gr.Row():
|
|
@@ -776,7 +798,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
|
|
| 776 |
file_types=[".txt", ".md", ".json"],
|
| 777 |
type="filepath",
|
| 778 |
)
|
| 779 |
-
text = gr.Textbox(label="Or paste transcript", lines=10)
|
| 780 |
gr.Markdown("<hr class='sep'/>")
|
| 781 |
|
| 782 |
gr.Markdown("<div class='header'>Ground truth JSON (optional)</div>")
|
|
@@ -788,26 +810,22 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
|
|
| 788 |
gt_text = gr.Textbox(label="Or paste ground truth JSON", lines=6, placeholder='{\"labels\": [\"schedule_meeting\"]}')
|
| 789 |
gr.Markdown("</div>") # close card
|
| 790 |
|
| 791 |
-
gr.Markdown("<div class='card'><div class='header'>
|
| 792 |
-
use_cleaning = gr.Checkbox(
|
| 793 |
-
|
| 794 |
-
value=True,
|
| 795 |
-
)
|
| 796 |
-
use_keyword_fallback = gr.Checkbox(
|
| 797 |
-
label="Keyword fallback if model returns empty",
|
| 798 |
-
value=True,
|
| 799 |
-
)
|
| 800 |
gr.Markdown("</div>")
|
| 801 |
|
| 802 |
gr.Markdown("<div class='card'><div class='header'>Allowed labels</div>")
|
| 803 |
-
labels_text = gr.Textbox(
|
| 804 |
-
label="Allowed Labels (one per line)",
|
| 805 |
-
value=OFFICIAL_LABELS_TEXT, # prefilled
|
| 806 |
-
lines=8,
|
| 807 |
-
)
|
| 808 |
reset_btn = gr.Button("Reset to official labels")
|
| 809 |
gr.Markdown("</div>")
|
| 810 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 811 |
with gr.Column(scale=2):
|
| 812 |
gr.Markdown("<div class='card'><div class='header'>Model & run</div>")
|
| 813 |
repo = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
|
|
@@ -830,48 +848,48 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
|
|
| 830 |
with gr.Row():
|
| 831 |
with gr.Column():
|
| 832 |
with gr.Accordion("Instructions used (system prompt)", open=False):
|
| 833 |
-
instr_md = gr.Markdown("```\n" +
|
| 834 |
with gr.Column():
|
| 835 |
-
with gr.Accordion("Context used (
|
| 836 |
context_md = gr.Markdown("")
|
| 837 |
|
| 838 |
-
#
|
| 839 |
def _reset_labels():
|
| 840 |
return OFFICIAL_LABELS_TEXT
|
| 841 |
reset_btn.click(fn=_reset_labels, inputs=None, outputs=labels_text)
|
| 842 |
|
| 843 |
-
#
|
| 844 |
-
warm_btn.click(
|
| 845 |
-
|
| 846 |
-
|
| 847 |
-
|
| 848 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 849 |
|
| 850 |
-
|
| 851 |
-
def _pack_context_md(allowed: str) -> str:
|
| 852 |
-
allowed_list = [ln.strip() for ln in (allowed or OFFICIAL_LABELS_TEXT).splitlines() if ln.strip()]
|
| 853 |
-
ctx = build_keyword_context(allowed_list)
|
| 854 |
-
return "### Allowed Labels\n" + "\n".join(f"- {l}" for l in allowed_list) + "\n\n### Keyword cues per label\n" + ctx
|
| 855 |
|
|
|
|
| 856 |
run_btn.click(
|
| 857 |
fn=run_single,
|
| 858 |
inputs=[
|
| 859 |
-
text, file, gt_text, gt_file, use_cleaning,
|
| 860 |
-
labels_text,
|
|
|
|
| 861 |
],
|
| 862 |
outputs=[summary, json_out, diag, raw, context_md, instr_md, gr.Textbox(visible=False), prompt_preview, token_info],
|
| 863 |
)
|
| 864 |
|
| 865 |
-
# initial context preview
|
| 866 |
-
context_md.value = _pack_context_md(OFFICIAL_LABELS_TEXT)
|
| 867 |
-
|
| 868 |
with gr.Tab("Batch evaluation"):
|
| 869 |
with gr.Row():
|
| 870 |
with gr.Column(scale=3):
|
| 871 |
gr.Markdown("<div class='card'><div class='header'>ZIP input</div>")
|
| 872 |
zip_in = gr.File(label="ZIP with transcripts (.txt) and truths (.json)", file_types=[".zip"], type="filepath")
|
| 873 |
use_cleaning_b = gr.Checkbox(label="Apply default cleaning", value=True)
|
| 874 |
-
|
| 875 |
gr.Markdown("</div>")
|
| 876 |
with gr.Column(scale=2):
|
| 877 |
gr.Markdown("<div class='card'><div class='header'>Model & run</div>")
|
|
@@ -879,6 +897,9 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
|
|
| 879 |
use_4bit_b = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
|
| 880 |
max_tokens_b = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=2048)
|
| 881 |
hf_token_b = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
|
|
|
|
|
|
|
|
|
|
| 882 |
limit_files = gr.Slider(label="Process at most N files (0 = all)", minimum=0, maximum=2000, step=10, value=0)
|
| 883 |
run_batch_btn = gr.Button("Run Batch", variant="primary")
|
| 884 |
gr.Markdown("</div>")
|
|
@@ -893,7 +914,11 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
|
|
| 893 |
|
| 894 |
run_batch_btn.click(
|
| 895 |
fn=run_batch,
|
| 896 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
| 897 |
outputs=[status, diag_b, df_out, csv_out],
|
| 898 |
)
|
| 899 |
|
|
|
|
| 27 |
SPACE_CACHE.mkdir(parents=True, exist_ok=True)
|
| 28 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 29 |
|
| 30 |
+
# Fast, deterministic, compact outputs for lower latency
|
| 31 |
GEN_CONFIG = GenerationConfig(
|
| 32 |
temperature=0.0,
|
| 33 |
top_p=1.0,
|
| 34 |
do_sample=False,
|
| 35 |
+
max_new_tokens=128, # increase if your JSON is getting truncated
|
| 36 |
)
|
| 37 |
|
| 38 |
+
# Official UBS labels (canonical)
|
| 39 |
OFFICIAL_LABELS = [
|
| 40 |
"plan_contact",
|
| 41 |
"schedule_meeting",
|
|
|
|
| 48 |
]
|
| 49 |
OFFICIAL_LABELS_TEXT = "\n".join(OFFICIAL_LABELS)
|
| 50 |
|
| 51 |
+
# =========================
|
| 52 |
+
# Editable defaults (shown in UI)
|
| 53 |
+
# =========================
|
| 54 |
+
DEFAULT_SYSTEM_INSTRUCTIONS = (
|
| 55 |
+
"You extract ACTIONABLE TASKS from client–advisor transcripts. "
|
| 56 |
+
"The transcript may be in German, French, Italian, or English. "
|
| 57 |
+
"Prioritize RECALL: if a label plausibly applies, include it. "
|
| 58 |
+
"Use ONLY the canonical labels provided. "
|
| 59 |
+
"Return STRICT JSON only with keys 'labels' and 'tasks'. "
|
| 60 |
+
"Each task must include 'label', a brief 'explanation', and a short 'evidence' quote from the transcript."
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# Very short, language-agnostic semantics to keep prompt small
|
| 64 |
+
DEFAULT_LABEL_GLOSSARY = {
|
| 65 |
+
"plan_contact": "Commitment to contact later (advisor/client will reach out, follow-up promised).",
|
| 66 |
+
"schedule_meeting": "Scheduling or confirming a meeting/call/appointment (time/date/slot/virtual).",
|
| 67 |
+
"update_contact_info_non_postal": "Change or confirmation of phone/email (non-postal contact details).",
|
| 68 |
+
"update_contact_info_postal_address": "Change or confirmation of postal/residential/mailing address.",
|
| 69 |
+
"update_kyc_activity": "Change/confirmation of occupation, employment status, or economic activity.",
|
| 70 |
+
"update_kyc_origin_of_assets": "Discussion/confirmation of source of funds / origin of assets.",
|
| 71 |
+
"update_kyc_purpose_of_businessrelation": "Purpose of the banking relationship/account usage.",
|
| 72 |
+
"update_kyc_total_assets": "Discussion/confirmation of total assets/net worth.",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
}
|
| 74 |
|
| 75 |
+
# Tiny multilingual fallback rules (optional) to guarantee recall if model is empty.
|
| 76 |
+
# Keep small to avoid false positives and keep maintenance low.
|
| 77 |
+
DEFAULT_FALLBACK_CUES = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
"plan_contact": [
|
| 79 |
+
# EN
|
| 80 |
+
r"\b(get|got|will|we'?ll|i'?ll)\s+back to you\b",
|
| 81 |
r"\bfollow\s*up\b",
|
| 82 |
r"\breach out\b",
|
| 83 |
r"\btouch base\b",
|
| 84 |
+
r"\bcontact (you|me|us)\b",
|
| 85 |
+
# DE
|
| 86 |
+
r"\bin verbindung setzen\b",
|
| 87 |
+
r"\brückmeldung\b",
|
| 88 |
+
r"\bich\s+melde\b|\bwir\s+melden\b",
|
| 89 |
+
r"\bnachfassen\b",
|
| 90 |
+
# FR
|
| 91 |
+
r"\bje vous recontacte\b|\bnous vous recontacterons\b",
|
| 92 |
+
r"\bprendre contact\b|\breprendre contact\b",
|
| 93 |
+
# IT
|
| 94 |
+
r"\bla ricontatter[oò]\b|\bci metteremo in contatto\b",
|
| 95 |
+
r"\btenersi in contatto\b",
|
| 96 |
+
],
|
| 97 |
+
"schedule_meeting": [
|
| 98 |
+
# EN
|
| 99 |
+
r"\b(let'?s\s+)?meet(ing|s)?\b",
|
| 100 |
+
r"\bschedule( a)? (call|meeting|appointment)\b",
|
| 101 |
+
r"\bbook( a)? (slot|time|meeting)\b",
|
| 102 |
+
r"\b(next week|tomorrow|this (afternoon|morning|evening))\b",
|
| 103 |
+
r"\bconfirm( the)? (time|meeting|appointment)\b",
|
| 104 |
+
# DE
|
| 105 |
+
r"\btermin(e|s)?\b|\bvereinbaren\b|\bansetzen\b|\babstimmen\b|\bbesprechung(en)?\b|\bvirtuell(e|en)?\b",
|
| 106 |
+
r"\bnächste(n|r)? woche\b|\b(dienstag|montag|mittwoch|donnerstag|freitag)\b|\bnachmittag|vormittag|morgen\b",
|
| 107 |
+
# FR
|
| 108 |
+
r"\brendez[- ]?vous\b|\bréunion\b|\bfixer\b|\bplanifier\b|\bcalendrier\b|\bse rencontrer\b|\bse voir\b",
|
| 109 |
+
r"\bla semaine prochaine\b|\bdemain\b|\bcet (après-midi|apres-midi|après midi|apres midi|matin|soir)\b",
|
| 110 |
+
# IT
|
| 111 |
+
r"\bappuntamento\b|\briunione\b|\borganizzare\b|\bprogrammare\b|\bincontrarci\b|\bcalendario\b",
|
| 112 |
+
r"\bla prossima settimana\b|\bdomani\b|\b(questo|questa)\s*(pomeriggio|mattina|sera)\b",
|
| 113 |
+
],
|
| 114 |
+
"update_kyc_origin_of_assets": [
|
| 115 |
+
# EN
|
| 116 |
+
r"\bsource of funds\b|\borigin of assets\b|\bproof of (funds|assets)\b",
|
| 117 |
+
# DE
|
| 118 |
+
r"\bvermögensursprung(e|s)?\b|\bherkunft der mittel\b|\bnachweis\b",
|
| 119 |
+
# FR
|
| 120 |
+
r"\borigine des fonds\b|\borigine du patrimoine\b|\bjustificatif(s)?\b",
|
| 121 |
+
# IT
|
| 122 |
+
r"\borigine dei fondi\b|\borigine del patrimonio\b|\bprova dei fondi\b|\bgiustificativo\b",
|
| 123 |
+
],
|
| 124 |
+
"update_kyc_activity": [
|
| 125 |
+
# EN
|
| 126 |
+
r"\bemployment status\b|\boccupation\b|\bjob change\b|\bsalary history\b",
|
| 127 |
+
# DE
|
| 128 |
+
r"\bbeschäftigungsstatus\b|\bberuf\b|\bjobwechsel\b|\bgehaltshistorie\b|\btätigkeit\b",
|
| 129 |
+
# FR
|
| 130 |
+
r"\bstatut professionnel\b|\bprofession\b|\bchangement d'emploi\b|\bhistorique salarial\b|\bactivité\b",
|
| 131 |
+
# IT
|
| 132 |
+
r"\bstato occupazionale\b|\bprofessione\b|\bcambio di lavoro\b|\bstoria salariale\b|\battivit[aà]\b",
|
| 133 |
],
|
| 134 |
}
|
| 135 |
|
| 136 |
# =========================
|
| 137 |
+
# Prompt templates (minimal multilingual)
|
| 138 |
# =========================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
USER_PROMPT_TEMPLATE = (
|
| 140 |
+
"Transcript (may be DE/FR/IT/EN):\n"
|
| 141 |
"```\n{transcript}\n```\n\n"
|
| 142 |
"Allowed Labels (canonical; use only these):\n"
|
| 143 |
"{allowed_labels_list}\n\n"
|
| 144 |
+
"Label Glossary (concise semantics):\n"
|
| 145 |
+
"{glossary}\n\n"
|
| 146 |
+
"Return STRICT JSON ONLY in this exact schema:\n"
|
| 147 |
+
'{\n "labels": ["<Label1>", "..."],\n'
|
| 148 |
+
' "tasks": [{"label": "<Label1>", "explanation": "<why>", "evidence": "<quote>"}]\n}\n'
|
|
|
|
| 149 |
)
|
| 150 |
|
| 151 |
# =========================
|
|
|
|
| 178 |
def restrict_to_allowed(pred: Dict[str, Any], allowed: List[str]) -> Dict[str, Any]:
|
| 179 |
out = {"labels": [], "tasks": []}
|
| 180 |
allowed_map = canonicalize_map(allowed)
|
|
|
|
| 181 |
filt_labels = []
|
| 182 |
for l in pred.get("labels", []) or []:
|
| 183 |
k = str(l).strip().lower()
|
| 184 |
if k in allowed_map:
|
| 185 |
filt_labels.append(allowed_map[k])
|
| 186 |
filt_labels = normalize_labels(filt_labels)
|
|
|
|
| 187 |
filt_tasks = []
|
| 188 |
for t in pred.get("tasks", []) or []:
|
| 189 |
if not isinstance(t, dict):
|
|
|
|
| 191 |
k = str(t.get("label", "")).strip().lower()
|
| 192 |
if k in allowed_map:
|
| 193 |
new_t = dict(t); new_t["label"] = allowed_map[k]
|
| 194 |
+
new_t = {
|
| 195 |
+
"label": new_t["label"],
|
| 196 |
+
"explanation": str(new_t.get("explanation", ""))[:300],
|
| 197 |
+
"evidence": str(new_t.get("evidence", ""))[:300],
|
| 198 |
+
}
|
| 199 |
filt_tasks.append(new_t)
|
| 200 |
merged = normalize_labels(list(set(filt_labels) | {tt["label"] for tt in filt_tasks}))
|
| 201 |
out["labels"] = merged
|
|
|
|
| 203 |
return out
|
| 204 |
|
| 205 |
# =========================
|
| 206 |
+
# Pre-processing
|
| 207 |
# =========================
|
| 208 |
_DISCLAIMER_PATTERNS = [
|
| 209 |
r"(?is)^\s*(?:disclaimer|legal notice|confidentiality notice).+?(?:\n{2,}|$)",
|
|
|
|
| 216 |
]
|
| 217 |
_TIMESTAMP_SPEAKER = [
|
| 218 |
r"\[\d{1,2}:\d{2}(:\d{2})?\]", # [00:01] or [00:01:02]
|
| 219 |
+
r"^\s*(advisor|client|client advisor)\s*:\s*", # Advisor:, Client:
|
| 220 |
r"^\s*(speaker\s*\d+)\s*:\s*", # Speaker 1:
|
| 221 |
]
|
| 222 |
|
|
|
|
| 224 |
if not text:
|
| 225 |
return text
|
| 226 |
s = text
|
|
|
|
| 227 |
lines = []
|
| 228 |
for ln in s.splitlines():
|
| 229 |
ln2 = ln
|
|
|
|
| 231 |
ln2 = re.sub(pat, "", ln2, flags=re.IGNORECASE)
|
| 232 |
lines.append(ln2)
|
| 233 |
s = "\n".join(lines)
|
|
|
|
| 234 |
for pat in _DISCLAIMER_PATTERNS:
|
| 235 |
s = re.sub(pat, "", s).strip()
|
|
|
|
| 236 |
for pat in _FOOTER_PATTERNS:
|
| 237 |
s = re.sub(pat, "", s)
|
|
|
|
| 238 |
s = re.sub(r"[ \t]+", " ", s)
|
| 239 |
s = re.sub(r"\n{3,}", "\n\n", s).strip()
|
| 240 |
return s
|
| 241 |
|
| 242 |
def read_text_file_any(file_input) -> str:
|
|
|
|
| 243 |
if not file_input:
|
| 244 |
return ""
|
| 245 |
if isinstance(file_input, (str, Path)):
|
|
|
|
| 273 |
return tokenizer.decode(toks[-max_tokens:], skip_special_tokens=True)
|
| 274 |
|
| 275 |
# =========================
|
| 276 |
+
# HF model wrapper (main LLM)
|
| 277 |
# =========================
|
| 278 |
class ModelWrapper:
|
| 279 |
def __init__(self, repo_id: str, hf_token: Optional[str], load_in_4bit: bool):
|
|
|
|
| 311 |
|
| 312 |
@torch.inference_mode()
|
| 313 |
def generate(self, system_prompt: str, user_prompt: str) -> str:
|
| 314 |
+
# Build inputs as input_ids=... (avoid earlier **tensor bug)
|
| 315 |
if hasattr(self.tokenizer, "apply_chat_template"):
|
| 316 |
messages = [
|
| 317 |
{"role": "system", "content": system_prompt},
|
|
|
|
| 356 |
return _MODEL_CACHE[key]
|
| 357 |
|
| 358 |
# =========================
|
| 359 |
+
# Evaluation (official weighted score)
|
| 360 |
# =========================
|
| 361 |
def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> float:
|
| 362 |
ALLOWED_LABELS = OFFICIAL_LABELS
|
|
|
|
| 400 |
return float(max(0.0, min(1.0, np.mean(per_sample))))
|
| 401 |
|
| 402 |
# =========================
|
| 403 |
+
# Multilingual fallback (regex on original text)
|
| 404 |
# =========================
|
| 405 |
+
def multilingual_fallback(text: str, allowed: List[str], cues: Dict[str, List[str]]) -> Dict[str, Any]:
|
| 406 |
low = text.lower()
|
| 407 |
+
labels, tasks = [], []
|
|
|
|
|
|
|
|
|
|
| 408 |
for lab in allowed:
|
| 409 |
+
for pat in cues.get(lab, []):
|
|
|
|
|
|
|
| 410 |
m = re.search(pat, low)
|
| 411 |
if m:
|
| 412 |
i = m.start()
|
| 413 |
+
start = max(0, i - 60); end = min(len(text), i + len(m.group(0)) + 60)
|
| 414 |
+
if lab not in labels:
|
| 415 |
+
labels.append(lab)
|
| 416 |
+
tasks.append({
|
| 417 |
+
"label": lab,
|
| 418 |
+
"explanation": "Rule hit (multilingual fallback)",
|
| 419 |
+
"evidence": text[start:end].strip()
|
| 420 |
+
})
|
| 421 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
return {"labels": normalize_labels(labels), "tasks": tasks}
|
| 423 |
|
| 424 |
# =========================
|
| 425 |
# Inference helpers
|
| 426 |
# =========================
|
| 427 |
+
def build_glossary_str(glossary: Dict[str, str], allowed: List[str]) -> str:
|
| 428 |
+
return "\n".join([f"- {lab}: {glossary.get(lab, '')}" for lab in allowed])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
|
| 430 |
def warmup_model(model_repo: str, use_4bit: bool, hf_token: str) -> str:
|
| 431 |
t0 = _now_ms()
|
|
|
|
| 438 |
|
| 439 |
def run_single(
|
| 440 |
transcript_text: str,
|
| 441 |
+
transcript_file,
|
| 442 |
gt_json_text: str,
|
| 443 |
+
gt_json_file,
|
| 444 |
use_cleaning: bool,
|
| 445 |
+
use_fallback: bool,
|
| 446 |
allowed_labels_text: str,
|
| 447 |
+
sys_instructions_text: str,
|
| 448 |
+
glossary_json_text: str,
|
| 449 |
+
fallback_json_text: str,
|
| 450 |
model_repo: str,
|
| 451 |
use_4bit: bool,
|
| 452 |
max_input_tokens: int,
|
|
|
|
| 455 |
|
| 456 |
t0 = _now_ms()
|
| 457 |
|
| 458 |
+
# Load transcript
|
| 459 |
raw_text = ""
|
| 460 |
if transcript_file:
|
| 461 |
raw_text = read_text_file_any(transcript_file)
|
|
|
|
| 465 |
|
| 466 |
text = clean_transcript(raw_text) if use_cleaning else raw_text
|
| 467 |
|
| 468 |
+
# Allowed labels
|
| 469 |
user_allowed = [ln.strip() for ln in (allowed_labels_text or "").splitlines() if ln.strip()]
|
| 470 |
allowed = normalize_labels(user_allowed or OFFICIAL_LABELS)
|
| 471 |
|
| 472 |
+
# Editable configs
|
| 473 |
+
try:
|
| 474 |
+
sys_instructions = (sys_instructions_text or DEFAULT_SYSTEM_INSTRUCTIONS).strip()
|
| 475 |
+
if not sys_instructions:
|
| 476 |
+
sys_instructions = DEFAULT_SYSTEM_INSTRUCTIONS
|
| 477 |
+
except Exception:
|
| 478 |
+
sys_instructions = DEFAULT_SYSTEM_INSTRUCTIONS
|
| 479 |
+
|
| 480 |
+
try:
|
| 481 |
+
label_glossary = json.loads(glossary_json_text) if glossary_json_text else DEFAULT_LABEL_GLOSSARY
|
| 482 |
+
except Exception:
|
| 483 |
+
label_glossary = DEFAULT_LABEL_GLOSSARY
|
| 484 |
+
|
| 485 |
+
try:
|
| 486 |
+
fallback_cues = json.loads(fallback_json_text) if fallback_json_text else DEFAULT_FALLBACK_CUES
|
| 487 |
+
except Exception:
|
| 488 |
+
fallback_cues = DEFAULT_FALLBACK_CUES
|
| 489 |
+
|
| 490 |
# Model
|
| 491 |
try:
|
| 492 |
model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
|
|
|
|
| 497 |
trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
|
| 498 |
|
| 499 |
# Build prompt
|
| 500 |
+
glossary_str = build_glossary_str(label_glossary, allowed)
|
| 501 |
allowed_list_str = "\n".join(f"- {l}" for l in allowed)
|
|
|
|
| 502 |
user_prompt = USER_PROMPT_TEMPLATE.format(
|
| 503 |
transcript=trunc,
|
| 504 |
allowed_labels_list=allowed_list_str,
|
| 505 |
+
glossary=glossary_str,
|
| 506 |
)
|
| 507 |
|
| 508 |
# Token info + prompt preview
|
|
|
|
| 514 |
# Generate
|
| 515 |
t1 = _now_ms()
|
| 516 |
try:
|
| 517 |
+
out = model.generate(sys_instructions, user_prompt)
|
| 518 |
except Exception as e:
|
| 519 |
return "", "", f"Generation error: {e}", "", "", "", prompt_preview_text, token_info_text, ""
|
| 520 |
t2 = _now_ms()
|
|
|
|
| 522 |
parsed = robust_json_extract(out)
|
| 523 |
filtered = restrict_to_allowed(parsed, allowed)
|
| 524 |
|
| 525 |
+
# Fallback (multilingual rules) on original text; merge for recall if enabled
|
| 526 |
+
if use_fallback:
|
| 527 |
+
fb = multilingual_fallback(trunc, allowed, fallback_cues)
|
| 528 |
if fb["labels"]:
|
| 529 |
+
merged_labels = sorted(list(set(filtered.get("labels", [])) | set(fb["labels"])))
|
| 530 |
+
existing = {tt.get("label") for tt in filtered.get("tasks", [])}
|
| 531 |
+
merged_tasks = filtered.get("tasks", []) + [t for t in fb["tasks"] if t["label"] not in existing]
|
| 532 |
+
filtered = {"labels": merged_labels, "tasks": merged_tasks}
|
| 533 |
|
| 534 |
# Diagnostics
|
| 535 |
diag = "\n".join([
|
| 536 |
f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
|
| 537 |
f"Model: {model_repo}",
|
| 538 |
f"Input cleaned: {'Yes' if use_cleaning else 'No'}",
|
| 539 |
+
f"Fallback rules: {'Yes' if use_fallback else 'No'}",
|
| 540 |
+
f"Tokens (input limit): ≤ {max_input_tokens}",
|
| 541 |
f"Latency: prep {t1-t0} ms, gen {t2-t1} ms, total {t2-t0} ms",
|
| 542 |
f"Allowed labels: {', '.join(allowed)}",
|
| 543 |
])
|
| 544 |
|
| 545 |
+
# Summaries
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 546 |
labs = filtered.get("labels", [])
|
| 547 |
tasks = filtered.get("tasks", [])
|
| 548 |
summary = "Detected labels:\n" + ("\n".join(f"- {l}" for l in labs) if labs else "(none)")
|
|
|
|
| 555 |
summary += "\n\nTasks: (none)"
|
| 556 |
json_out = json.dumps(filtered, indent=2, ensure_ascii=False)
|
| 557 |
|
| 558 |
+
# Single-file scoring if GT provided
|
| 559 |
metrics = ""
|
| 560 |
if gt_json_file or (gt_json_text and gt_json_text.strip()):
|
| 561 |
truth_obj = None
|
|
|
|
| 588 |
else:
|
| 589 |
metrics = "Ground truth JSON missing or invalid; expected {'labels': [...]}."
|
| 590 |
|
| 591 |
+
# For UI: show effective context (glossary) and instructions
|
| 592 |
+
context_preview = "### Label Glossary (used)\n" + "\n".join(f"- {k}: {v}" for k, v in label_glossary.items() if k in allowed)
|
| 593 |
+
instructions_preview = "```\n" + sys_instructions + "\n```"
|
| 594 |
+
|
| 595 |
return summary, json_out, diag, out.strip(), context_preview, instructions_preview, metrics, prompt_preview_text, token_info_text
|
| 596 |
|
| 597 |
# =========================
|
|
|
|
| 606 |
return [p for p in exdir.rglob("*") if p.is_file()]
|
| 607 |
|
| 608 |
def run_batch(
|
| 609 |
+
zip_path,
|
| 610 |
use_cleaning: bool,
|
| 611 |
+
use_fallback: bool,
|
| 612 |
+
sys_instructions_text: str,
|
| 613 |
+
glossary_json_text: str,
|
| 614 |
+
fallback_json_text: str,
|
| 615 |
model_repo: str,
|
| 616 |
use_4bit: bool,
|
| 617 |
max_input_tokens: int,
|
|
|
|
| 622 |
if not zip_path:
|
| 623 |
return ("No ZIP provided.", "", pd.DataFrame(), "")
|
| 624 |
|
| 625 |
+
# Editable configs
|
| 626 |
+
try:
|
| 627 |
+
sys_instructions = (sys_instructions_text or DEFAULT_SYSTEM_INSTRUCTIONS).strip()
|
| 628 |
+
if not sys_instructions:
|
| 629 |
+
sys_instructions = DEFAULT_SYSTEM_INSTRUCTIONS
|
| 630 |
+
except Exception:
|
| 631 |
+
sys_instructions = DEFAULT_SYSTEM_INSTRUCTIONS
|
| 632 |
+
|
| 633 |
+
try:
|
| 634 |
+
label_glossary = json.loads(glossary_json_text) if glossary_json_text else DEFAULT_LABEL_GLOSSARY
|
| 635 |
+
except Exception:
|
| 636 |
+
label_glossary = DEFAULT_LABEL_GLOSSARY
|
| 637 |
+
|
| 638 |
+
try:
|
| 639 |
+
fallback_cues = json.loads(fallback_json_text) if fallback_json_text else DEFAULT_FALLBACK_CUES
|
| 640 |
+
except Exception:
|
| 641 |
+
fallback_cues = DEFAULT_FALLBACK_CUES
|
| 642 |
+
|
| 643 |
+
# Prepare workspace
|
| 644 |
work = Path("/tmp/batch")
|
| 645 |
if work.exists():
|
| 646 |
for p in sorted(work.rglob("*"), reverse=True):
|
|
|
|
| 666 |
if not stems:
|
| 667 |
return ("No .txt transcripts found in ZIP.", "", pd.DataFrame(), "")
|
| 668 |
|
| 669 |
+
# Model
|
| 670 |
try:
|
| 671 |
model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
|
| 672 |
except Exception as e:
|
| 673 |
return (f"Model load failed: {e}", "", pd.DataFrame(), "")
|
| 674 |
|
| 675 |
allowed = OFFICIAL_LABELS[:]
|
| 676 |
+
glossary_str = build_glossary_str(label_glossary, allowed)
|
| 677 |
allowed_list_str = "\n".join(f"- {l}" for l in allowed)
|
|
|
|
| 678 |
|
| 679 |
y_true, y_pred = [], []
|
| 680 |
rows = []
|
|
|
|
| 683 |
for stem in stems:
|
| 684 |
raw = txts[stem].read_text(encoding="utf-8", errors="ignore")
|
| 685 |
text = clean_transcript(raw) if use_cleaning else raw
|
| 686 |
+
|
| 687 |
trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
|
| 688 |
|
| 689 |
user_prompt = USER_PROMPT_TEMPLATE.format(
|
| 690 |
transcript=trunc,
|
| 691 |
allowed_labels_list=allowed_list_str,
|
| 692 |
+
glossary=glossary_str,
|
| 693 |
)
|
| 694 |
|
| 695 |
t0 = _now_ms()
|
| 696 |
+
out = model.generate(sys_instructions, user_prompt)
|
| 697 |
t1 = _now_ms()
|
| 698 |
|
| 699 |
parsed = robust_json_extract(out)
|
| 700 |
filtered = restrict_to_allowed(parsed, allowed)
|
| 701 |
|
| 702 |
+
if use_fallback:
|
| 703 |
+
fb = multilingual_fallback(trunc, allowed, fallback_cues)
|
| 704 |
if fb["labels"]:
|
| 705 |
+
merged_labels = sorted(list(set(filtered.get("labels", [])) | set(fb["labels"])))
|
| 706 |
+
existing = {tt.get("label") for tt in filtered.get("tasks", [])}
|
| 707 |
+
merged_tasks = filtered.get("tasks", []) + [t for t in fb["tasks"] if t["label"] not in existing]
|
| 708 |
+
filtered = {"labels": merged_labels, "tasks": merged_tasks}
|
| 709 |
|
| 710 |
pred_labels = filtered.get("labels", [])
|
| 711 |
y_pred.append(pred_labels)
|
|
|
|
| 742 |
f"Device: {DEVICE} (4-bit: {'Yes' if (use_4bit and DEVICE=='cuda') else 'No'})",
|
| 743 |
f"Model: {model_repo}",
|
| 744 |
f"Input cleaned: {'Yes' if use_cleaning else 'No'}",
|
| 745 |
+
f"Fallback rules: {'Yes' if use_fallback else 'No'}",
|
| 746 |
+
f"Tokens (input limit): ≤ {max_input_tokens}",
|
| 747 |
f"Batch time: {_now_ms()-t_start} ms",
|
| 748 |
]
|
| 749 |
if have_truth and score is not None:
|
|
|
|
| 760 |
]
|
| 761 |
diag_str = "\n".join(diag)
|
| 762 |
|
|
|
|
| 763 |
out_csv = Path("/tmp/batch_results.csv")
|
| 764 |
df.to_csv(out_csv, index=False, encoding="utf-8")
|
| 765 |
return ("Batch done.", diag_str, df, str(out_csv))
|
|
|
|
| 768 |
# UI
|
| 769 |
# =========================
|
| 770 |
MODEL_CHOICES = [
|
| 771 |
+
"swiss-ai/Apertus-8B-Instruct-2509", # multilingual
|
| 772 |
+
"meta-llama/Meta-Llama-3-8B-Instruct", # strong generalist
|
| 773 |
+
"mistralai/Mistral-7B-Instruct-v0.3", # light/fast
|
| 774 |
]
|
| 775 |
|
| 776 |
+
# Light, modern UI (white background, neutral accents)
|
| 777 |
custom_css = """
|
| 778 |
:root { --radius: 14px; }
|
| 779 |
+
.gradio-container { font-family: Inter, ui-sans-serif, system-ui; background: #ffffff; color: #111827; }
|
| 780 |
+
.card { border: 1px solid #e5e7eb; border-radius: var(--radius); padding: 14px 16px; background: #ffffff; box-shadow: 0 1px 2px rgba(0,0,0,.03); }
|
| 781 |
+
.header { font-weight: 700; font-size: 22px; margin-bottom: 4px; color: #0f172a; }
|
| 782 |
+
.subtle { color: #475569; font-size: 14px; margin-bottom: 12px; }
|
| 783 |
+
hr.sep { border: none; border-top: 1px solid #e5e7eb; margin: 10px 0 16px; }
|
| 784 |
.gr-button { border-radius: 12px !important; }
|
| 785 |
+
a, .prose a { color: #0ea5e9; }
|
| 786 |
"""
|
| 787 |
|
| 788 |
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo:
|
| 789 |
+
gr.Markdown("<div class='header'>Talk2Task — Multilingual Task Extraction (UBS Challenge)</div>")
|
| 790 |
+
gr.Markdown("<div class='subtle'>Single-pass multilingual extraction (DE/FR/IT/EN) with compact prompts. Optional rule fallback ensures recall. Batch evaluation & scoring included.</div>")
|
| 791 |
|
| 792 |
with gr.Tab("Single transcript"):
|
| 793 |
with gr.Row():
|
|
|
|
| 798 |
file_types=[".txt", ".md", ".json"],
|
| 799 |
type="filepath",
|
| 800 |
)
|
| 801 |
+
text = gr.Textbox(label="Or paste transcript", lines=10, placeholder="Paste transcript in DE/FR/IT/EN…")
|
| 802 |
gr.Markdown("<hr class='sep'/>")
|
| 803 |
|
| 804 |
gr.Markdown("<div class='header'>Ground truth JSON (optional)</div>")
|
|
|
|
| 810 |
gt_text = gr.Textbox(label="Or paste ground truth JSON", lines=6, placeholder='{\"labels\": [\"schedule_meeting\"]}')
|
| 811 |
gr.Markdown("</div>") # close card
|
| 812 |
|
| 813 |
+
gr.Markdown("<div class='card'><div class='header'>Processing options</div>")
|
| 814 |
+
use_cleaning = gr.Checkbox(label="Apply default cleaning (remove disclaimers, timestamps, speakers, footers)", value=True)
|
| 815 |
+
use_fallback = gr.Checkbox(label="Enable multilingual fallback rule layer", value=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 816 |
gr.Markdown("</div>")
|
| 817 |
|
| 818 |
gr.Markdown("<div class='card'><div class='header'>Allowed labels</div>")
|
| 819 |
+
labels_text = gr.Textbox(label="Allowed Labels (one per line)", value=OFFICIAL_LABELS_TEXT, lines=8)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 820 |
reset_btn = gr.Button("Reset to official labels")
|
| 821 |
gr.Markdown("</div>")
|
| 822 |
|
| 823 |
+
gr.Markdown("<div class='card'><div class='header'>Editable instructions & context</div>")
|
| 824 |
+
sys_instr_tb = gr.Textbox(label="System Instructions (editable)", value=DEFAULT_SYSTEM_INSTRUCTIONS, lines=5)
|
| 825 |
+
glossary_tb = gr.Code(label="Label Glossary (JSON; editable)", value=json.dumps(DEFAULT_LABEL_GLOSSARY, indent=2), language="json")
|
| 826 |
+
fallback_tb = gr.Code(label="Fallback Cues (Multilingual, JSON; editable)", value=json.dumps(DEFAULT_FALLBACK_CUES, indent=2), language="json")
|
| 827 |
+
gr.Markdown("</div>")
|
| 828 |
+
|
| 829 |
with gr.Column(scale=2):
|
| 830 |
gr.Markdown("<div class='card'><div class='header'>Model & run</div>")
|
| 831 |
repo = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
|
|
|
|
| 848 |
with gr.Row():
|
| 849 |
with gr.Column():
|
| 850 |
with gr.Accordion("Instructions used (system prompt)", open=False):
|
| 851 |
+
instr_md = gr.Markdown("```\n" + DEFAULT_SYSTEM_INSTRUCTIONS + "\n```")
|
| 852 |
with gr.Column():
|
| 853 |
+
with gr.Accordion("Context used (glossary)", open=True):
|
| 854 |
context_md = gr.Markdown("")
|
| 855 |
|
| 856 |
+
# Reset labels to official
|
| 857 |
def _reset_labels():
|
| 858 |
return OFFICIAL_LABELS_TEXT
|
| 859 |
reset_btn.click(fn=_reset_labels, inputs=None, outputs=labels_text)
|
| 860 |
|
| 861 |
+
# Warm-up
|
| 862 |
+
warm_btn.click(fn=warmup_model, inputs=[repo, use_4bit, hf_token], outputs=diag)
|
| 863 |
+
|
| 864 |
+
# For initial context preview
|
| 865 |
+
def _pack_context_md(glossary_json, allowed_text):
|
| 866 |
+
try:
|
| 867 |
+
glossary = json.loads(glossary_json) if glossary_json else DEFAULT_LABEL_GLOSSARY
|
| 868 |
+
except Exception:
|
| 869 |
+
glossary = DEFAULT_LABEL_GLOSSARY
|
| 870 |
+
allowed_list = [ln.strip() for ln in (allowed_text or OFFICIAL_LABELS_TEXT).splitlines() if ln.strip()]
|
| 871 |
+
return "### Label Glossary (used)\n" + "\n".join(f"- {k}: {glossary.get(k,'')}" for k in allowed_list)
|
| 872 |
|
| 873 |
+
context_md.value = _pack_context_md(json.dumps(DEFAULT_LABEL_GLOSSARY), OFFICIAL_LABELS_TEXT)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 874 |
|
| 875 |
+
# Single run
|
| 876 |
run_btn.click(
|
| 877 |
fn=run_single,
|
| 878 |
inputs=[
|
| 879 |
+
text, file, gt_text, gt_file, use_cleaning, use_fallback,
|
| 880 |
+
labels_text, sys_instr_tb, glossary_tb, fallback_tb,
|
| 881 |
+
repo, use_4bit, max_tokens, hf_token
|
| 882 |
],
|
| 883 |
outputs=[summary, json_out, diag, raw, context_md, instr_md, gr.Textbox(visible=False), prompt_preview, token_info],
|
| 884 |
)
|
| 885 |
|
|
|
|
|
|
|
|
|
|
| 886 |
with gr.Tab("Batch evaluation"):
|
| 887 |
with gr.Row():
|
| 888 |
with gr.Column(scale=3):
|
| 889 |
gr.Markdown("<div class='card'><div class='header'>ZIP input</div>")
|
| 890 |
zip_in = gr.File(label="ZIP with transcripts (.txt) and truths (.json)", file_types=[".zip"], type="filepath")
|
| 891 |
use_cleaning_b = gr.Checkbox(label="Apply default cleaning", value=True)
|
| 892 |
+
use_fallback_b = gr.Checkbox(label="Enable multilingual fallback rule layer", value=True)
|
| 893 |
gr.Markdown("</div>")
|
| 894 |
with gr.Column(scale=2):
|
| 895 |
gr.Markdown("<div class='card'><div class='header'>Model & run</div>")
|
|
|
|
| 897 |
use_4bit_b = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
|
| 898 |
max_tokens_b = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=2048)
|
| 899 |
hf_token_b = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
|
| 900 |
+
sys_instr_tb_b = gr.Textbox(label="System Instructions (editable for batch)", value=DEFAULT_SYSTEM_INSTRUCTIONS, lines=4)
|
| 901 |
+
glossary_tb_b = gr.Code(label="Label Glossary (JSON; editable for batch)", value=json.dumps(DEFAULT_LABEL_GLOSSARY, indent=2), language="json")
|
| 902 |
+
fallback_tb_b = gr.Code(label="Fallback Cues (Multilingual, JSON; editable for batch)", value=json.dumps(DEFAULT_FALLBACK_CUES, indent=2), language="json")
|
| 903 |
limit_files = gr.Slider(label="Process at most N files (0 = all)", minimum=0, maximum=2000, step=10, value=0)
|
| 904 |
run_batch_btn = gr.Button("Run Batch", variant="primary")
|
| 905 |
gr.Markdown("</div>")
|
|
|
|
| 914 |
|
| 915 |
run_batch_btn.click(
|
| 916 |
fn=run_batch,
|
| 917 |
+
inputs=[
|
| 918 |
+
zip_in, use_cleaning_b, use_fallback_b,
|
| 919 |
+
sys_instr_tb_b, glossary_tb_b, fallback_tb_b,
|
| 920 |
+
repo_b, use_4bit_b, max_tokens_b, hf_token_b, limit_files
|
| 921 |
+
],
|
| 922 |
outputs=[status, diag_b, df_out, csv_out],
|
| 923 |
)
|
| 924 |
|