Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -28,10 +28,10 @@ SPACE_CACHE.mkdir(parents=True, exist_ok=True)
|
|
| 28 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 29 |
|
| 30 |
GEN_CONFIG = GenerationConfig(
|
| 31 |
-
temperature=0.
|
| 32 |
top_p=1.0,
|
| 33 |
do_sample=False,
|
| 34 |
-
max_new_tokens=
|
| 35 |
)
|
| 36 |
|
| 37 |
# Official UBS label set (strict)
|
|
@@ -48,7 +48,7 @@ OFFICIAL_LABELS = [
|
|
| 48 |
OFFICIAL_LABELS_TEXT = "\n".join(OFFICIAL_LABELS)
|
| 49 |
|
| 50 |
# Per-label keyword cues (static prompt context to improve recall)
|
| 51 |
-
LABEL_KEYWORDS = {
|
| 52 |
"plan_contact": [
|
| 53 |
"call back", "get back to you", "i'll get back", "follow up",
|
| 54 |
"reach out", "contact later", "check in", "touch base", "remind",
|
|
@@ -87,10 +87,9 @@ LABEL_KEYWORDS = {
|
|
| 87 |
"portfolio size", "how much you own", "aggregate assets"
|
| 88 |
],
|
| 89 |
}
|
| 90 |
-
|
| 91 |
-
# Regex cues
|
| 92 |
-
|
| 93 |
-
REGEX_CUES = {
|
| 94 |
"schedule_meeting": [
|
| 95 |
r"\b(let'?s\s+)?meet(s|ing)?\b",
|
| 96 |
r"\bbook( a)? (time|slot|meeting)\b",
|
|
@@ -99,17 +98,16 @@ REGEX_CUES = {
|
|
| 99 |
r"\bfind a time\b",
|
| 100 |
],
|
| 101 |
"plan_contact": [
|
| 102 |
-
r"\b(
|
| 103 |
r"\bfollow\s*up\b",
|
| 104 |
r"\breach out\b",
|
| 105 |
r"\btouch base\b",
|
| 106 |
r"\bping you\b",
|
| 107 |
],
|
| 108 |
-
# Add more regexes for other labels if useful
|
| 109 |
}
|
| 110 |
|
| 111 |
# =========================
|
| 112 |
-
# Instructions (
|
| 113 |
# =========================
|
| 114 |
SYSTEM_PROMPT = (
|
| 115 |
"You are a precise banking assistant that extracts ACTIONABLE TASKS from "
|
|
@@ -308,7 +306,7 @@ class ModelWrapper:
|
|
| 308 |
|
| 309 |
@torch.inference_mode()
|
| 310 |
def generate(self, system_prompt: str, user_prompt: str) -> str:
|
| 311 |
-
# Build inputs as input_ids=... (avoid **tensor bug
|
| 312 |
if hasattr(self.tokenizer, "apply_chat_template"):
|
| 313 |
messages = [
|
| 314 |
{"role": "system", "content": system_prompt},
|
|
@@ -397,19 +395,41 @@ def evaluate_predictions(y_true: List[List[str]], y_pred: List[List[str]]) -> fl
|
|
| 397 |
return float(max(0.0, min(1.0, np.mean(per_sample))))
|
| 398 |
|
| 399 |
# =========================
|
| 400 |
-
# Fallback:
|
| 401 |
# =========================
|
| 402 |
def keyword_fallback(text: str, allowed: List[str]) -> Dict[str, Any]:
|
| 403 |
low = text.lower()
|
| 404 |
labels = []
|
| 405 |
tasks = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
for lab in allowed:
|
|
|
|
|
|
|
| 407 |
hits = []
|
| 408 |
for kw in LABEL_KEYWORDS.get(lab, []):
|
| 409 |
k = kw.lower()
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
i = low.find(k)
|
| 413 |
start = max(0, i - 40); end = min(len(text), i + len(k) + 40)
|
| 414 |
hits.append(text[start:end].strip())
|
| 415 |
if hits:
|
|
@@ -419,6 +439,7 @@ def keyword_fallback(text: str, allowed: List[str]) -> Dict[str, Any]:
|
|
| 419 |
"explanation": "Keyword match in transcript.",
|
| 420 |
"evidence": hits[0]
|
| 421 |
})
|
|
|
|
| 422 |
return {"labels": normalize_labels(labels), "tasks": tasks}
|
| 423 |
|
| 424 |
# =========================
|
|
@@ -431,6 +452,15 @@ def build_keyword_context(allowed: List[str]) -> str:
|
|
| 431 |
parts.append(f"- {lab}: " + (", ".join(kws) if kws else "(no default cues)"))
|
| 432 |
return "\n".join(parts)
|
| 433 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
def run_single(
|
| 435 |
transcript_text: str,
|
| 436 |
transcript_file, # filepath or file-like
|
|
@@ -443,7 +473,7 @@ def run_single(
|
|
| 443 |
use_4bit: bool,
|
| 444 |
max_input_tokens: int,
|
| 445 |
hf_token: str,
|
| 446 |
-
) -> Tuple[str, str, str, str, str, str, str]:
|
| 447 |
|
| 448 |
t0 = _now_ms()
|
| 449 |
|
|
@@ -453,7 +483,7 @@ def run_single(
|
|
| 453 |
raw_text = read_text_file_any(transcript_file)
|
| 454 |
raw_text = (raw_text or transcript_text or "").strip()
|
| 455 |
if not raw_text:
|
| 456 |
-
return "", "", "No transcript provided.", "", "", "", ""
|
| 457 |
|
| 458 |
text = clean_transcript(raw_text) if use_cleaning else raw_text
|
| 459 |
|
|
@@ -465,7 +495,7 @@ def run_single(
|
|
| 465 |
try:
|
| 466 |
model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
|
| 467 |
except Exception as e:
|
| 468 |
-
return "", "", f"Model load failed: {e}", "", "", "", ""
|
| 469 |
|
| 470 |
# Truncate
|
| 471 |
trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
|
|
@@ -479,12 +509,18 @@ def run_single(
|
|
| 479 |
keyword_context=keyword_ctx,
|
| 480 |
)
|
| 481 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 482 |
# Generate
|
| 483 |
t1 = _now_ms()
|
| 484 |
try:
|
| 485 |
out = model.generate(SYSTEM_PROMPT, user_prompt)
|
| 486 |
except Exception as e:
|
| 487 |
-
return "", "", f"Generation error: {e}", "", "", "", ""
|
| 488 |
t2 = _now_ms()
|
| 489 |
|
| 490 |
parsed = robust_json_extract(out)
|
|
@@ -531,7 +567,6 @@ def run_single(
|
|
| 531 |
|
| 532 |
# Optional single-file scoring if GT provided
|
| 533 |
metrics = ""
|
| 534 |
-
true_labels = None
|
| 535 |
if gt_json_file or (gt_json_text and gt_json_text.strip()):
|
| 536 |
truth_obj = None
|
| 537 |
if gt_json_file:
|
|
@@ -563,7 +598,7 @@ def run_single(
|
|
| 563 |
else:
|
| 564 |
metrics = "Ground truth JSON missing or invalid; expected {'labels': [...]}."
|
| 565 |
|
| 566 |
-
return summary, json_out, diag, out.strip(), context_preview, instructions_preview, metrics
|
| 567 |
|
| 568 |
# =========================
|
| 569 |
# Batch mode (ZIP with transcripts + truths)
|
|
@@ -725,7 +760,6 @@ custom_css = """
|
|
| 725 |
.header { font-weight: 700; font-size: 22px; margin-bottom: 4px; }
|
| 726 |
.subtle { color: rgba(255,255,255,.65); font-size: 14px; margin-bottom: 12px; }
|
| 727 |
hr.sep { border: none; border-top: 1px solid rgba(255,255,255,.08); margin: 10px 0 16px; }
|
| 728 |
-
.accordion-title { font-weight: 600; }
|
| 729 |
.gr-button { border-radius: 12px !important; }
|
| 730 |
"""
|
| 731 |
|
|
@@ -736,7 +770,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
|
|
| 736 |
with gr.Tab("Single transcript"):
|
| 737 |
with gr.Row():
|
| 738 |
with gr.Column(scale=3):
|
| 739 |
-
gr.Markdown("<div class='card'><div class='header'>Transcript</div>"
|
| 740 |
file = gr.File(
|
| 741 |
label="Drag & drop transcript (.txt / .md / .json)",
|
| 742 |
file_types=[".txt", ".md", ".json"],
|
|
@@ -745,7 +779,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
|
|
| 745 |
text = gr.Textbox(label="Or paste transcript", lines=10)
|
| 746 |
gr.Markdown("<hr class='sep'/>")
|
| 747 |
|
| 748 |
-
gr.Markdown("<div class='header'>Ground truth JSON (optional)</div>"
|
| 749 |
gt_file = gr.File(
|
| 750 |
label="Upload ground truth JSON (expects {'labels': [...]})",
|
| 751 |
file_types=[".json"],
|
|
@@ -754,7 +788,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
|
|
| 754 |
gt_text = gr.Textbox(label="Or paste ground truth JSON", lines=6, placeholder='{\"labels\": [\"schedule_meeting\"]}')
|
| 755 |
gr.Markdown("</div>") # close card
|
| 756 |
|
| 757 |
-
gr.Markdown("<div class='card'><div class='header'>Preprocessing & heuristics</div>"
|
| 758 |
use_cleaning = gr.Checkbox(
|
| 759 |
label="Apply default cleaning (remove disclaimers, timestamps, speakers, footers)",
|
| 760 |
value=True,
|
|
@@ -765,7 +799,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
|
|
| 765 |
)
|
| 766 |
gr.Markdown("</div>")
|
| 767 |
|
| 768 |
-
gr.Markdown("<div class='card'><div class='header'>Allowed labels</div>"
|
| 769 |
labels_text = gr.Textbox(
|
| 770 |
label="Allowed Labels (one per line)",
|
| 771 |
value=OFFICIAL_LABELS_TEXT, # prefilled
|
|
@@ -775,25 +809,28 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
|
|
| 775 |
gr.Markdown("</div>")
|
| 776 |
|
| 777 |
with gr.Column(scale=2):
|
| 778 |
-
gr.Markdown("<div class='card'><div class='header'>Model & run</div>"
|
| 779 |
repo = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
|
| 780 |
use_4bit = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
|
| 781 |
-
max_tokens = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=
|
| 782 |
hf_token = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
|
|
|
|
| 783 |
run_btn = gr.Button("Run Extraction", variant="primary")
|
| 784 |
gr.Markdown("</div>")
|
| 785 |
|
| 786 |
-
gr.Markdown("<div class='card'><div class='header'>Outputs</div>"
|
| 787 |
summary = gr.Textbox(label="Summary", lines=12)
|
| 788 |
json_out = gr.Code(label="Strict JSON Output", language="json")
|
| 789 |
diag = gr.Textbox(label="Diagnostics", lines=8)
|
| 790 |
raw = gr.Textbox(label="Raw Model Output", lines=8)
|
|
|
|
|
|
|
| 791 |
gr.Markdown("</div>")
|
| 792 |
|
| 793 |
with gr.Row():
|
| 794 |
with gr.Column():
|
| 795 |
with gr.Accordion("Instructions used (system prompt)", open=False):
|
| 796 |
-
instr_md = gr.Markdown("")
|
| 797 |
with gr.Column():
|
| 798 |
with gr.Accordion("Context used (allowed labels + keyword cues)", open=True):
|
| 799 |
context_md = gr.Markdown("")
|
|
@@ -803,6 +840,13 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
|
|
| 803 |
return OFFICIAL_LABELS_TEXT
|
| 804 |
reset_btn.click(fn=_reset_labels, inputs=None, outputs=labels_text)
|
| 805 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 806 |
# single run
|
| 807 |
def _pack_context_md(allowed: str) -> str:
|
| 808 |
allowed_list = [ln.strip() for ln in (allowed or OFFICIAL_LABELS_TEXT).splitlines() if ln.strip()]
|
|
@@ -815,33 +859,32 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo
|
|
| 815 |
text, file, gt_text, gt_file, use_cleaning, use_keyword_fallback,
|
| 816 |
labels_text, repo, use_4bit, max_tokens, hf_token
|
| 817 |
],
|
| 818 |
-
outputs=[summary, json_out, diag, raw, context_md, instr_md, gr.Textbox(visible=False)],
|
| 819 |
)
|
| 820 |
|
| 821 |
-
#
|
| 822 |
-
instr_md.value = "```\n" + SYSTEM_PROMPT + "\n```"
|
| 823 |
context_md.value = _pack_context_md(OFFICIAL_LABELS_TEXT)
|
| 824 |
|
| 825 |
with gr.Tab("Batch evaluation"):
|
| 826 |
with gr.Row():
|
| 827 |
with gr.Column(scale=3):
|
| 828 |
-
gr.Markdown("<div class='card'><div class='header'>ZIP input</div>"
|
| 829 |
zip_in = gr.File(label="ZIP with transcripts (.txt) and truths (.json)", file_types=[".zip"], type="filepath")
|
| 830 |
use_cleaning_b = gr.Checkbox(label="Apply default cleaning", value=True)
|
| 831 |
use_keyword_fallback_b = gr.Checkbox(label="Keyword fallback if model returns empty", value=True)
|
| 832 |
gr.Markdown("</div>")
|
| 833 |
with gr.Column(scale=2):
|
| 834 |
-
gr.Markdown("<div class='card'><div class='header'>Model & run</div>"
|
| 835 |
repo_b = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
|
| 836 |
use_4bit_b = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
|
| 837 |
-
max_tokens_b = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=
|
| 838 |
hf_token_b = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
|
| 839 |
limit_files = gr.Slider(label="Process at most N files (0 = all)", minimum=0, maximum=2000, step=10, value=0)
|
| 840 |
run_batch_btn = gr.Button("Run Batch", variant="primary")
|
| 841 |
gr.Markdown("</div>")
|
| 842 |
|
| 843 |
with gr.Row():
|
| 844 |
-
gr.Markdown("<div class='card'><div class='header'>Batch outputs</div>"
|
| 845 |
status = gr.Textbox(label="Status", lines=1)
|
| 846 |
diag_b = gr.Textbox(label="Batch diagnostics & metrics", lines=12)
|
| 847 |
df_out = gr.Dataframe(label="Per-file results (TP/FP/FN, latency)", interactive=False)
|
|
|
|
| 28 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 29 |
|
| 30 |
GEN_CONFIG = GenerationConfig(
|
| 31 |
+
temperature=0.0,
|
| 32 |
top_p=1.0,
|
| 33 |
do_sample=False,
|
| 34 |
+
max_new_tokens=96, # small for speed; adjust if needed
|
| 35 |
)
|
| 36 |
|
| 37 |
# Official UBS label set (strict)
|
|
|
|
| 48 |
OFFICIAL_LABELS_TEXT = "\n".join(OFFICIAL_LABELS)
|
| 49 |
|
| 50 |
# Per-label keyword cues (static prompt context to improve recall)
|
| 51 |
+
LABEL_KEYWORDS: Dict[str, List[str]] = {
|
| 52 |
"plan_contact": [
|
| 53 |
"call back", "get back to you", "i'll get back", "follow up",
|
| 54 |
"reach out", "contact later", "check in", "touch base", "remind",
|
|
|
|
| 87 |
"portfolio size", "how much you own", "aggregate assets"
|
| 88 |
],
|
| 89 |
}
|
| 90 |
+
|
| 91 |
+
# Regex cues to catch phrasing variants
|
| 92 |
+
REGEX_CUES: Dict[str, List[str]] = {
|
|
|
|
| 93 |
"schedule_meeting": [
|
| 94 |
r"\b(let'?s\s+)?meet(s|ing)?\b",
|
| 95 |
r"\bbook( a)? (time|slot|meeting)\b",
|
|
|
|
| 98 |
r"\bfind a time\b",
|
| 99 |
],
|
| 100 |
"plan_contact": [
|
| 101 |
+
r"\b(i'?ll|get|got)\s+back to you\b",
|
| 102 |
r"\bfollow\s*up\b",
|
| 103 |
r"\breach out\b",
|
| 104 |
r"\btouch base\b",
|
| 105 |
r"\bping you\b",
|
| 106 |
],
|
|
|
|
| 107 |
}
|
| 108 |
|
| 109 |
# =========================
|
| 110 |
+
# Instructions (concise; concatenated to avoid string issues)
|
| 111 |
# =========================
|
| 112 |
SYSTEM_PROMPT = (
|
| 113 |
"You are a precise banking assistant that extracts ACTIONABLE TASKS from "
|
|
|
|
| 306 |
|
| 307 |
@torch.inference_mode()
|
| 308 |
def generate(self, system_prompt: str, user_prompt: str) -> str:
|
| 309 |
+
# Build inputs as input_ids=... (avoid **tensor bug)
|
| 310 |
if hasattr(self.tokenizer, "apply_chat_template"):
|
| 311 |
messages = [
|
| 312 |
{"role": "system", "content": system_prompt},
|
|
|
|
| 395 |
return float(max(0.0, min(1.0, np.mean(per_sample))))
|
| 396 |
|
| 397 |
# =========================
|
| 398 |
+
# Fallback: regex + keywords if model returns empty
|
| 399 |
# =========================
|
| 400 |
def keyword_fallback(text: str, allowed: List[str]) -> Dict[str, Any]:
|
| 401 |
low = text.lower()
|
| 402 |
labels = []
|
| 403 |
tasks = []
|
| 404 |
+
|
| 405 |
+
# Regex first
|
| 406 |
+
for lab in allowed:
|
| 407 |
+
patterns = REGEX_CUES.get(lab, [])
|
| 408 |
+
found = None
|
| 409 |
+
for pat in patterns:
|
| 410 |
+
m = re.search(pat, low)
|
| 411 |
+
if m:
|
| 412 |
+
i = m.start()
|
| 413 |
+
start = max(0, i - 40); end = min(len(text), i + len(m.group(0)) + 40)
|
| 414 |
+
found = text[start:end].strip()
|
| 415 |
+
break
|
| 416 |
+
if found:
|
| 417 |
+
labels.append(lab)
|
| 418 |
+
tasks.append({
|
| 419 |
+
"label": lab,
|
| 420 |
+
"explanation": "Regex cue matched in transcript.",
|
| 421 |
+
"evidence": found
|
| 422 |
+
})
|
| 423 |
+
|
| 424 |
+
# Keyword contains() as backstop
|
| 425 |
for lab in allowed:
|
| 426 |
+
if lab in labels:
|
| 427 |
+
continue
|
| 428 |
hits = []
|
| 429 |
for kw in LABEL_KEYWORDS.get(lab, []):
|
| 430 |
k = kw.lower()
|
| 431 |
+
i = low.find(k)
|
| 432 |
+
if i != -1:
|
|
|
|
| 433 |
start = max(0, i - 40); end = min(len(text), i + len(k) + 40)
|
| 434 |
hits.append(text[start:end].strip())
|
| 435 |
if hits:
|
|
|
|
| 439 |
"explanation": "Keyword match in transcript.",
|
| 440 |
"evidence": hits[0]
|
| 441 |
})
|
| 442 |
+
|
| 443 |
return {"labels": normalize_labels(labels), "tasks": tasks}
|
| 444 |
|
| 445 |
# =========================
|
|
|
|
| 452 |
parts.append(f"- {lab}: " + (", ".join(kws) if kws else "(no default cues)"))
|
| 453 |
return "\n".join(parts)
|
| 454 |
|
| 455 |
+
def warmup_model(model_repo: str, use_4bit: bool, hf_token: str) -> str:
|
| 456 |
+
t0 = _now_ms()
|
| 457 |
+
try:
|
| 458 |
+
model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
|
| 459 |
+
_ = model.generate("Return JSON only.", '{"labels": [], "tasks": []}')
|
| 460 |
+
return f"Warm-up complete in {_now_ms() - t0} ms."
|
| 461 |
+
except Exception as e:
|
| 462 |
+
return f"Warm-up failed: {e}"
|
| 463 |
+
|
| 464 |
def run_single(
|
| 465 |
transcript_text: str,
|
| 466 |
transcript_file, # filepath or file-like
|
|
|
|
| 473 |
use_4bit: bool,
|
| 474 |
max_input_tokens: int,
|
| 475 |
hf_token: str,
|
| 476 |
+
) -> Tuple[str, str, str, str, str, str, str, str, str]:
|
| 477 |
|
| 478 |
t0 = _now_ms()
|
| 479 |
|
|
|
|
| 483 |
raw_text = read_text_file_any(transcript_file)
|
| 484 |
raw_text = (raw_text or transcript_text or "").strip()
|
| 485 |
if not raw_text:
|
| 486 |
+
return "", "", "No transcript provided.", "", "", "", "", "", ""
|
| 487 |
|
| 488 |
text = clean_transcript(raw_text) if use_cleaning else raw_text
|
| 489 |
|
|
|
|
| 495 |
try:
|
| 496 |
model = get_model(model_repo, (hf_token or "").strip() or None, use_4bit)
|
| 497 |
except Exception as e:
|
| 498 |
+
return "", "", f"Model load failed: {e}", "", "", "", "", "", ""
|
| 499 |
|
| 500 |
# Truncate
|
| 501 |
trunc = truncate_tokens(model.tokenizer, text, max_input_tokens)
|
|
|
|
| 509 |
keyword_context=keyword_ctx,
|
| 510 |
)
|
| 511 |
|
| 512 |
+
# Token info + prompt preview
|
| 513 |
+
transcript_tokens = len(model.tokenizer(trunc, add_special_tokens=False)["input_ids"])
|
| 514 |
+
prompt_tokens = len(model.tokenizer(user_prompt, add_special_tokens=False)["input_ids"])
|
| 515 |
+
token_info_text = f"Transcript tokens: {transcript_tokens} | Prompt tokens: {prompt_tokens}"
|
| 516 |
+
prompt_preview_text = "```\n" + user_prompt[:4000] + ("\n... (truncated)" if len(user_prompt) > 4000 else "") + "\n```"
|
| 517 |
+
|
| 518 |
# Generate
|
| 519 |
t1 = _now_ms()
|
| 520 |
try:
|
| 521 |
out = model.generate(SYSTEM_PROMPT, user_prompt)
|
| 522 |
except Exception as e:
|
| 523 |
+
return "", "", f"Generation error: {e}", "", "", "", prompt_preview_text, token_info_text, ""
|
| 524 |
t2 = _now_ms()
|
| 525 |
|
| 526 |
parsed = robust_json_extract(out)
|
|
|
|
| 567 |
|
| 568 |
# Optional single-file scoring if GT provided
|
| 569 |
metrics = ""
|
|
|
|
| 570 |
if gt_json_file or (gt_json_text and gt_json_text.strip()):
|
| 571 |
truth_obj = None
|
| 572 |
if gt_json_file:
|
|
|
|
| 598 |
else:
|
| 599 |
metrics = "Ground truth JSON missing or invalid; expected {'labels': [...]}."
|
| 600 |
|
| 601 |
+
return summary, json_out, diag, out.strip(), context_preview, instructions_preview, metrics, prompt_preview_text, token_info_text
|
| 602 |
|
| 603 |
# =========================
|
| 604 |
# Batch mode (ZIP with transcripts + truths)
|
|
|
|
| 760 |
.header { font-weight: 700; font-size: 22px; margin-bottom: 4px; }
|
| 761 |
.subtle { color: rgba(255,255,255,.65); font-size: 14px; margin-bottom: 12px; }
|
| 762 |
hr.sep { border: none; border-top: 1px solid rgba(255,255,255,.08); margin: 10px 0 16px; }
|
|
|
|
| 763 |
.gr-button { border-radius: 12px !important; }
|
| 764 |
"""
|
| 765 |
|
|
|
|
| 770 |
with gr.Tab("Single transcript"):
|
| 771 |
with gr.Row():
|
| 772 |
with gr.Column(scale=3):
|
| 773 |
+
gr.Markdown("<div class='card'><div class='header'>Transcript</div>")
|
| 774 |
file = gr.File(
|
| 775 |
label="Drag & drop transcript (.txt / .md / .json)",
|
| 776 |
file_types=[".txt", ".md", ".json"],
|
|
|
|
| 779 |
text = gr.Textbox(label="Or paste transcript", lines=10)
|
| 780 |
gr.Markdown("<hr class='sep'/>")
|
| 781 |
|
| 782 |
+
gr.Markdown("<div class='header'>Ground truth JSON (optional)</div>")
|
| 783 |
gt_file = gr.File(
|
| 784 |
label="Upload ground truth JSON (expects {'labels': [...]})",
|
| 785 |
file_types=[".json"],
|
|
|
|
| 788 |
gt_text = gr.Textbox(label="Or paste ground truth JSON", lines=6, placeholder='{\"labels\": [\"schedule_meeting\"]}')
|
| 789 |
gr.Markdown("</div>") # close card
|
| 790 |
|
| 791 |
+
gr.Markdown("<div class='card'><div class='header'>Preprocessing & heuristics</div>")
|
| 792 |
use_cleaning = gr.Checkbox(
|
| 793 |
label="Apply default cleaning (remove disclaimers, timestamps, speakers, footers)",
|
| 794 |
value=True,
|
|
|
|
| 799 |
)
|
| 800 |
gr.Markdown("</div>")
|
| 801 |
|
| 802 |
+
gr.Markdown("<div class='card'><div class='header'>Allowed labels</div>")
|
| 803 |
labels_text = gr.Textbox(
|
| 804 |
label="Allowed Labels (one per line)",
|
| 805 |
value=OFFICIAL_LABELS_TEXT, # prefilled
|
|
|
|
| 809 |
gr.Markdown("</div>")
|
| 810 |
|
| 811 |
with gr.Column(scale=2):
|
| 812 |
+
gr.Markdown("<div class='card'><div class='header'>Model & run</div>")
|
| 813 |
repo = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
|
| 814 |
use_4bit = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
|
| 815 |
+
max_tokens = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=2048)
|
| 816 |
hf_token = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
|
| 817 |
+
warm_btn = gr.Button("Warm up model (load & compile kernels)")
|
| 818 |
run_btn = gr.Button("Run Extraction", variant="primary")
|
| 819 |
gr.Markdown("</div>")
|
| 820 |
|
| 821 |
+
gr.Markdown("<div class='card'><div class='header'>Outputs</div>")
|
| 822 |
summary = gr.Textbox(label="Summary", lines=12)
|
| 823 |
json_out = gr.Code(label="Strict JSON Output", language="json")
|
| 824 |
diag = gr.Textbox(label="Diagnostics", lines=8)
|
| 825 |
raw = gr.Textbox(label="Raw Model Output", lines=8)
|
| 826 |
+
prompt_preview = gr.Code(label="Prompt preview (user prompt sent)", language="markdown")
|
| 827 |
+
token_info = gr.Textbox(label="Token counts (transcript / prompt)", lines=2)
|
| 828 |
gr.Markdown("</div>")
|
| 829 |
|
| 830 |
with gr.Row():
|
| 831 |
with gr.Column():
|
| 832 |
with gr.Accordion("Instructions used (system prompt)", open=False):
|
| 833 |
+
instr_md = gr.Markdown("```\n" + SYSTEM_PROMPT + "\n```")
|
| 834 |
with gr.Column():
|
| 835 |
with gr.Accordion("Context used (allowed labels + keyword cues)", open=True):
|
| 836 |
context_md = gr.Markdown("")
|
|
|
|
| 840 |
return OFFICIAL_LABELS_TEXT
|
| 841 |
reset_btn.click(fn=_reset_labels, inputs=None, outputs=labels_text)
|
| 842 |
|
| 843 |
+
# warm-up
|
| 844 |
+
warm_btn.click(
|
| 845 |
+
fn=warmup_model,
|
| 846 |
+
inputs=[repo, use_4bit, hf_token],
|
| 847 |
+
outputs=diag,
|
| 848 |
+
)
|
| 849 |
+
|
| 850 |
# single run
|
| 851 |
def _pack_context_md(allowed: str) -> str:
|
| 852 |
allowed_list = [ln.strip() for ln in (allowed or OFFICIAL_LABELS_TEXT).splitlines() if ln.strip()]
|
|
|
|
| 859 |
text, file, gt_text, gt_file, use_cleaning, use_keyword_fallback,
|
| 860 |
labels_text, repo, use_4bit, max_tokens, hf_token
|
| 861 |
],
|
| 862 |
+
outputs=[summary, json_out, diag, raw, context_md, instr_md, gr.Textbox(visible=False), prompt_preview, token_info],
|
| 863 |
)
|
| 864 |
|
| 865 |
+
# initial context preview
|
|
|
|
| 866 |
context_md.value = _pack_context_md(OFFICIAL_LABELS_TEXT)
|
| 867 |
|
| 868 |
with gr.Tab("Batch evaluation"):
|
| 869 |
with gr.Row():
|
| 870 |
with gr.Column(scale=3):
|
| 871 |
+
gr.Markdown("<div class='card'><div class='header'>ZIP input</div>")
|
| 872 |
zip_in = gr.File(label="ZIP with transcripts (.txt) and truths (.json)", file_types=[".zip"], type="filepath")
|
| 873 |
use_cleaning_b = gr.Checkbox(label="Apply default cleaning", value=True)
|
| 874 |
use_keyword_fallback_b = gr.Checkbox(label="Keyword fallback if model returns empty", value=True)
|
| 875 |
gr.Markdown("</div>")
|
| 876 |
with gr.Column(scale=2):
|
| 877 |
+
gr.Markdown("<div class='card'><div class='header'>Model & run</div>")
|
| 878 |
repo_b = gr.Dropdown(label="Model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
|
| 879 |
use_4bit_b = gr.Checkbox(label="Use 4-bit (GPU only)", value=True)
|
| 880 |
+
max_tokens_b = gr.Slider(label="Max input tokens", minimum=1024, maximum=8192, step=512, value=2048)
|
| 881 |
hf_token_b = gr.Textbox(label="HF_TOKEN (only for gated models)", type="password", value=os.environ.get("HF_TOKEN",""))
|
| 882 |
limit_files = gr.Slider(label="Process at most N files (0 = all)", minimum=0, maximum=2000, step=10, value=0)
|
| 883 |
run_batch_btn = gr.Button("Run Batch", variant="primary")
|
| 884 |
gr.Markdown("</div>")
|
| 885 |
|
| 886 |
with gr.Row():
|
| 887 |
+
gr.Markdown("<div class='card'><div class='header'>Batch outputs</div>")
|
| 888 |
status = gr.Textbox(label="Status", lines=1)
|
| 889 |
diag_b = gr.Textbox(label="Batch diagnostics & metrics", lines=12)
|
| 890 |
df_out = gr.Dataframe(label="Per-file results (TP/FP/FN, latency)", interactive=False)
|