Spaces:
Sleeping
Sleeping
File size: 38,470 Bytes
7ad44fe 52300cf 7ad44fe 52300cf 7ad44fe fc5ec37 52300cf 7ad44fe 52300cf 1279cfd 52300cf 7ad44fe 52300cf fc5ec37 52300cf e52390b b03b022 810b118 e62dce0 52300cf d23bfc4 52300cf d23bfc4 52300cf fc5ec37 02757e2 fc5ec37 02757e2 0816810 02757e2 fc5ec37 02757e2 fc5ec37 02757e2 fc5ec37 02757e2 fc5ec37 02757e2 fc5ec37 52300cf 7ad44fe 52300cf 7ad44fe c5fe792 7ad44fe 046871c 52300cf 4134cad ce538a0 a67f0ce ce538a0 a67f0ce 4134cad 52300cf 7ad44fe 52300cf 7ad44fe 52300cf 7ad44fe 52300cf d1fc078 cd70be3 98cb0ee cd70be3 98cb0ee cd70be3 d1fc078 7ad44fe 810b118 7ad44fe 5ab66c5 d6abe69 810b118 98cb0ee 810b118 98cb0ee 810b118 98cb0ee 7ad44fe 52300cf 7ad44fe 52300cf c5fe792 ea3dae0 c5fe792 7ad44fe d1fc078 7ad44fe 1f324e5 7ad44fe 1f324e5 7ad44fe c5fe792 ea3dae0 7ad44fe 8f824ff 52300cf 7ad44fe 52300cf 7ad44fe 810b118 7ad44fe 810b118 7ad44fe 52300cf 7ad44fe 52300cf 7ad44fe 52300cf 7ad44fe 52300cf 7ad44fe 52300cf 7ad44fe 52300cf 1279cfd fc5ec37 58bb624 fc5ec37 1279cfd fc5ec37 1279cfd 7007888 7ad44fe 7007888 031189e 7007888 9fbb166 52300cf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 | """MedASR server for Hugging Face Spaces.
Hosts Google MedASR with **CTC beam search + radiology hotwords** so the
accuracy lifts well above the greedy-decoded floor (~6.6% WER) we hit when
the model ran in the browser. Google's published 4.6% WER number on
RAD-DICT uses beam-8 plus a 6-gram language model; we use beam-8 plus a
~150-term radiology hotword list (the LM has not been publicly released
yet β when it ships we can drop it into `decoder.decode(lm=...)`).
Endpoints:
GET /health β model + decoder readiness
POST /transcribe β multipart WebM/OGG audio (legacy path)
POST /transcribe-pcm β raw Float32 / Int16 PCM mono @ 16 kHz (preferred)
POST /openai-token β mint OpenAI Realtime ephemeral token
POST /deepgram-token β mint Deepgram ephemeral token
`HF_TOKEN` must be set as a Space secret (Google MedASR is gated).
"""
import json as _json
import logging
import os
import re
import subprocess
import tempfile
import time
import urllib.error
import urllib.request
from datetime import date as _date
from io import BytesIO
from typing import Literal
import numpy as np
import requests
import soundfile as sf
import torch
import uvicorn
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("medasr")
_MONTHS = (
"January", "February", "March", "April", "May", "June",
"July", "August", "September", "October", "November", "December",
)
# ---------------------------------------------------------------------------
# Spoken punctuation
# ---------------------------------------------------------------------------
_SPOKEN_PUNCTUATION: list[tuple[re.Pattern, str]] = [
(re.compile(r"\bopen paren(?:thesis)?\b", re.IGNORECASE), "("),
(re.compile(r"\bclose paren(?:thesis)?\b", re.IGNORECASE), ")"),
(re.compile(r"\bopen bracket\b", re.IGNORECASE), "["),
(re.compile(r"\bclose bracket\b", re.IGNORECASE), "]"),
(re.compile(r"\bnew line\b", re.IGNORECASE), "\n"),
(re.compile(r"\bforward slash\b", re.IGNORECASE), "/"),
(re.compile(r"\bcomma\b", re.IGNORECASE), ","),
(re.compile(r"\bperiod\b", re.IGNORECASE), "."),
(re.compile(r"\bfull stop\b", re.IGNORECASE), "."),
(re.compile(r"\bquestion mark\b", re.IGNORECASE), "?"),
(re.compile(r"\bexclamation (?:mark|point)\b", re.IGNORECASE), "!"),
(re.compile(r"\bcolon\b", re.IGNORECASE), ":"),
(re.compile(r"\bsemicolon\b", re.IGNORECASE), ";"),
(re.compile(r"\bhyphen\b", re.IGNORECASE), "-"),
(re.compile(r"\bdash\b", re.IGNORECASE), " β "),
(re.compile(r"\bslash\b", re.IGNORECASE), "/"),
(re.compile(r"\bplus\b", re.IGNORECASE), "+"),
(re.compile(r"\bampersand\b", re.IGNORECASE), "&"),
]
def _replace_spoken_punctuation(text: str) -> str:
# MedASR uses special tokens for sentence boundaries during dictation.
# Pyctcdecode emits them as literal text; rewrite to real punctuation.
text = text.replace("</s>", ". ")
text = text.replace("<s>", "")
text = text.replace("<unk>", "")
logger.info("Pre-replace raw text repr: %r", text[:200])
# MedASR's punctuation tokens may carry the SentencePiece word-prefix "β",
# and the braces themselves may be ASCII curly or fullwidth variants.
text = re.sub(r"β?[\{β¦ο½]β?([,.:;!?/\-+])β?[\}β¦ο½]", r"\1", text)
text = re.sub(r"\{([,.:;!?/\-+])\}", r"\1", text)
for pattern, replacement in _SPOKEN_PUNCTUATION:
text = pattern.sub(replacement, text)
# NB: \s would also match \n that we just inserted via "new line" β use
# [ \t] so the newline survives to the client.
text = re.sub(r"[ \t]+([,.:;!?)\]])", r"\1", text)
text = re.sub(r"([([\[])[ \t]+", r"\1", text)
text = re.sub(r" +", " ", text)
# Trim spaces/tabs/CR only β preserving \n means a segment that's just
# "new line" (now "\n") doesn't get wiped to empty and discarded by the
# client as a zero-length segment.
return re.sub(r"^[ \t\r]+|[ \t\r]+$", "", text)
# ---------------------------------------------------------------------------
# Date-only structured extraction
# ---------------------------------------------------------------------------
_DATE_CUE_RE = re.compile(
r"\b(?:date|dated|prior|comparison|compared|performed|exam|examination|"
r"study|from|since|seen|imaged)\b",
re.IGNORECASE,
)
_MONTH_RE = re.compile(r"\b(?:" + "|".join(_MONTHS) + r"|Jan|Feb|Mar|Apr|Jun|"
r"Jul|Aug|Sep|Sept|Oct|Nov|Dec)\.?\b", re.IGNORECASE)
_NUMERIC_DATE_RE = re.compile(r"\b\d{1,2}[\/.-]\d{1,2}[\/.-]\d{2,4}\b")
_YEAR_RE = re.compile(r"\b(?:19|20)\d{2}\b")
_DAY_HINT_RE = re.compile(
r"\b(?:\d{1,2}(?:st|nd|rd|th)?|first|second|third|fourth|fifth|sixth|"
r"seventh|eighth|ninth|tenth|eleventh|twelfth|thirteenth|fourteenth|"
r"fifteenth|sixteenth|seventeenth|eighteenth|nineteenth|twentieth|"
r"twenty\s+(?:one|two|three|four|five|six|seven|eight|nine|first|second|"
r"third|fourth|fifth|sixth|seventh|eighth|ninth)|thirtieth|"
r"thirty\s+first)\b",
re.IGNORECASE,
)
_YEAR_WORD_RE = re.compile(
r"\b(?:two\s+thousand|twenty\s+(?:twenty|nineteen|eighteen|seventeen|"
r"sixteen|fifteen|fourteen|thirteen|twelve|eleven|ten|nine|eight|seven|"
r"six|five|four|three|two|one)|nineteen\s+\w+)\b",
re.IGNORECASE,
)
def _looks_date_like_for_extractor(text: str) -> bool:
"""Gate the slower structured extractor to likely date-bearing segments."""
if not text:
return False
month_hit = bool(_MONTH_RE.search(text))
return (
bool(_NUMERIC_DATE_RE.search(text))
or bool(_YEAR_RE.search(text) and _DATE_CUE_RE.search(text))
or bool(_YEAR_WORD_RE.search(text) and _DATE_CUE_RE.search(text))
or bool(month_hit and (_DATE_CUE_RE.search(text) or _YEAR_RE.search(text) or _DAY_HINT_RE.search(text)))
)
def _safe_json_object(text: str) -> dict:
text = (text or "").strip()
if not text:
return {}
if text.startswith("```"):
text = re.sub(r"^```(?:json)?\s*", "", text, flags=re.IGNORECASE)
text = re.sub(r"\s*```$", "", text)
try:
obj = _json.loads(text)
return obj if isinstance(obj, dict) else {}
except Exception:
match = re.search(r"\{.*\}", text, flags=re.DOTALL)
if not match:
return {}
try:
obj = _json.loads(match.group(0))
return obj if isinstance(obj, dict) else {}
except Exception:
return {}
def _valid_ddmmyyyy(value: str) -> bool:
m = re.fullmatch(r"(\d{2})/(\d{2})/((?:19|20)\d{2})", value or "")
if not m:
return False
dd, mm, yy = map(int, m.groups())
try:
_date(yy, mm, dd)
return True
except ValueError:
return False
def _numeric_only_date_is_ambiguous(text: str) -> bool:
"""Reject slash-only dates like 05/07/2025 unless a month name is present."""
if _MONTH_RE.search(text or ""):
return False
for a, b, _ in re.findall(r"\b(\d{1,2})[\/.-](\d{1,2})[\/.-](\d{2,4})\b", text or ""):
if 1 <= int(a) <= 12 and 1 <= int(b) <= 12:
return True
return False
def _has_explicit_year_hint(text: str) -> bool:
text = re.sub(r"[-ββ]", " ", text or "")
return (
bool(_YEAR_RE.search(text))
or bool(_YEAR_WORD_RE.search(text))
or bool(_MONTH_RE.search(text) and re.search(r"\bof\s+['β]?\d{2}\b", text, re.IGNORECASE))
)
def _has_explicit_day_hint(text: str) -> bool:
if _NUMERIC_DATE_RE.search(text or ""):
return True
# Month-year phrases such as "May of '17" and "July of 24" contain a
# two-digit year, not a spoken day. Strip those before looking for day
# hints so the structured extractor cannot promote them to full dates.
scrubbed = re.sub(r"\bof\s+['β]?\d{2}\b", " ", text or "", flags=re.IGNORECASE)
scrubbed = re.sub(r"\bof\s+(?:twenty|nineteen)\s+\w+\b", " ", scrubbed, flags=re.IGNORECASE)
scrubbed = re.sub(r"['β]\d{2}\b", " ", scrubbed)
return bool(_DAY_HINT_RE.search(scrubbed))
def _validate_date_candidate(candidate: dict, transcript: str) -> dict | None:
if not isinstance(candidate, dict):
return None
if not candidate.get("has_date", False):
return None
if candidate.get("ambiguous", False):
return None
try:
confidence = float(candidate.get("confidence", 0))
except (TypeError, ValueError):
confidence = 0
if confidence < float(os.environ.get("OPENAI_DATE_CONFIDENCE_MIN", "0.70")):
return None
normalized = str(candidate.get("normalized") or "").strip()
granularity = str(candidate.get("granularity") or "").strip().lower()
source_text = str(candidate.get("source_text") or "").strip()
if not normalized:
return None
if _numeric_only_date_is_ambiguous(transcript):
return None
if _valid_ddmmyyyy(normalized):
if not (_has_explicit_day_hint(transcript) and _has_explicit_year_hint(transcript)):
return None
granularity = "day"
elif re.fullmatch(r"(?:" + "|".join(_MONTHS) + r")\s+(?:19|20)\d{2}", normalized, re.IGNORECASE):
if not _has_explicit_year_hint(transcript):
return None
mon, yy = normalized.split()
normalized = f"{mon.capitalize()} {yy}"
granularity = "month"
elif re.fullmatch(r"(?:19|20)\d{2}", normalized):
if not _has_explicit_year_hint(transcript):
return None
granularity = "year"
else:
return None
return {
"normalized": normalized,
"granularity": granularity,
"source_text": source_text,
"confidence": round(confidence, 3),
}
def _extract_structured_date(transcript: str, api_key: str) -> dict | None:
if not _looks_date_like_for_extractor(transcript):
return None
prompt = (
"You extract dates from short radiology dictation ASR text. Return JSON "
"only. Do not rewrite medical terminology. Do not infer a missing day, "
"month, or year. If a date is ambiguous, return has_date false or "
"ambiguous true. Output full dates as DD/MM/YYYY, month-level dates as "
"'Month YYYY', and year-only dates as 'YYYY'. Treat slash-only dates "
"where both first numbers are <= 12 as ambiguous unless a month name is "
"also present in the text. Never default a missing day to 01 and never "
"default a missing year to the current year."
)
schema = (
'{"has_date": boolean, "normalized": string|null, "granularity": '
'"day|month|year|null", "source_text": string|null, '
'"ambiguous": boolean, "confidence": number}'
)
try:
r = requests.post(
"https://api.openai.com/v1/chat/completions",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
json={
"model": os.environ.get("OPENAI_DATE_EXTRACT_MODEL", "gpt-4o-mini"),
"temperature": 0,
"response_format": {"type": "json_object"},
"messages": [
{"role": "system", "content": prompt + "\nJSON shape: " + schema},
{"role": "user", "content": transcript[:1000]},
],
"max_tokens": 180,
},
timeout=20,
)
except requests.RequestException as e:
logger.warning("Structured date extractor request failed: %s", e)
return None
if not r.ok:
logger.warning("Structured date extractor OpenAI error %s: %s",
r.status_code, r.text[:500])
return None
try:
content = r.json()["choices"][0]["message"]["content"]
except Exception:
logger.warning("Structured date extractor returned unexpected payload")
return None
return _validate_date_candidate(_safe_json_object(content), transcript)
# ---------------------------------------------------------------------------
# Radiology hotwords. Each term gets a likelihood boost during beam search,
# which is what fixes the "edema -> aa" / "Borderline -> Remaining" failures
# we saw in real dictation. Keep this list focused β too many hotwords pulls
# accuracy back down on common non-medical words.
# ---------------------------------------------------------------------------
RADIOLOGY_HOTWORDS: list[str] = [
# Cardiac / mediastinum
"cardiomegaly", "borderline", "epicardial", "pericardial", "myocardial",
"mediastinal", "hilar", "perihilar", "aortic", "aorta", "pulmonary",
# Lung
"edema", "effusion", "effusions", "consolidation", "atelectasis", "atelectatic",
"pneumothorax", "pleural", "parenchyma", "parenchymal", "interstitial",
"groundglass", "ground-glass", "opacity", "opacities", "nodule", "nodules",
"mass", "empyema", "pneumonia", "hemothorax", "bronchiectasis", "fibrosis",
"emphysema", "subpleural", "centrilobular", "tree-in-bud", "honeycombing",
"reticulonodular",
# Airways
"tracheal", "trachea", "bronchial", "bronchovascular", "fissural", "fissure",
"peribronchial", "peribronchovascular",
# Chest wall / pleura / diaphragm
"diaphragm", "diaphragmatic", "costophrenic", "subcutaneous", "thoracic",
"subdiaphragmatic", "hemidiaphragm",
# Abdominal anatomy
"abdomen", "abdominal", "pelvis", "pelvic", "retroperitoneum", "retroperitoneal",
"mesenteric", "mesentery", "paraaortic", "periaortic", "porta", "portal",
"hepatic", "splenic", "renal", "adrenal", "pancreatic", "biliary",
"gallbladder", "duodenal", "gastric", "intestinal", "colonic", "rectal",
"uterine", "ovarian", "prostatic",
# Brain / neuro
"intracranial", "extracranial", "subdural", "epidural", "subarachnoid",
"cerebral", "cerebellar", "brainstem", "thalamic", "lentiform", "caudate",
"ventricular", "ventricle", "ventricles", "periventricular",
# Spine
"vertebra", "vertebrae", "vertebral", "lumbar", "cervical",
# Vascular
"stenosis", "occlusion", "occluded", "thrombosis", "embolism", "embolus",
"aneurysm", "dissection", "atherosclerosis", "atherosclerotic",
"calcified", "noncalcified", "calcification", "calcifications",
# MR / CT signal terms
"enhancement", "enhancing", "nonenhancing", "T1", "T2", "FLAIR", "DWI",
"ADC", "STIR", "hypoechoic", "hyperechoic", "isoechoic", "anechoic",
"echogenic", "hypoattenuating", "hyperattenuating", "hypodense", "hyperdense",
"isointense", "hypointense", "hyperintense", "shadowing",
# General findings
"hemorrhage", "infarct", "infarction", "ischemia", "lesion", "lesions",
"lymphadenopathy", "lymph", "hematoma", "fluid", "edematous", "swelling",
# Common verbiage
"unremarkable", "compatible", "consistent", "suggestive", "concerning",
"noted", "demonstrates", "demonstrated", "evidence", "prominent",
]
# ---------------------------------------------------------------------------
# MedASR model + CTC decoder
# ---------------------------------------------------------------------------
model = None
processor = None
decoder = None # pyctcdecode beam-search decoder
DEVICE = "cpu" # set by load_model() to "cuda" when available
DEFAULT_BEAM_WIDTH = 4
DEFAULT_HOTWORD_WEIGHT = 5.0
# Hotwords mined from the 731K-report corpus that weren't in the original
# RADIOLOGY_HOTWORDS list β high-frequency medical terms the decoder
# probably underweights today. Includes specific known-failure terms from
# the offline test set (homologue, aorticopulmonary, intercalated).
CORPUS_HOTWORDS: list[str] = [
# Bigger anatomy / pathology nouns
"indications", "indication", "abnormality", "abnormalities", "abnormal",
"fracture", "fractures", "narrowing", "thickening", "dilatation",
"enlargement", "enlarged", "compression", "protrusion", "obstruction",
"dislocation", "distortion", "osteoarthritis", "osteophyte", "endplate",
"ligament", "ligaments", "cartilage", "vasculature", "vascular",
"arteries", "pancreas", "adrenals", "mediastinum", "meniscus",
"shoulder", "foramina", "foraminal", "silhouette", "alignment",
"paraspinal", "multilevel", "multiplanar", "multidetector",
"hydronephrosis", "arthropathy", "hypertrophy", "adenopathy",
"microcalcifications", "fibroglandular", "heterogeneously",
# Modifiers + descriptors
"bilateral", "moderate", "anterior", "posterior", "inferior", "sagittal",
"proximal", "scattered", "visualized", "measuring", "diameter", "thickness",
"reformatted", "reconstruction", "modulation", "administered", "supplemental",
"interval", "diagnostic", "intravenous", "suspicious", "malignancy",
"degenerative", "coronary", "sonographic", "ultrasound", "tomosynthesis",
"mammogram", "mammography", "mammographic", "radiograph", "radiology",
"architectural", "migrated",
# Specific high-value terms missed on our offline test set
"homologue", "homologous", "aorticopulmonary", "intercalated",
"modic", "spondylolisthesis",
]
# Calendar months β the decoder badly mis-hears spoken dates ("july" came
# back as "ul"/"Ja" in testing). Boosting the month names is the cheapest
# lever to try before considering a fine-tune.
DATE_HOTWORDS: list[str] = [
"january", "february", "march", "april", "may", "june",
"july", "august", "september", "october", "november", "december",
]
# Merge corpus-mined + date adds into RADIOLOGY_HOTWORDS (defined above),
# dedupe.
_seen = set(RADIOLOGY_HOTWORDS)
for _w in CORPUS_HOTWORDS + DATE_HOTWORDS:
if _w not in _seen:
RADIOLOGY_HOTWORDS.append(_w)
_seen.add(_w)
del _seen, _w
def _patch_lasr_feature_extractor():
"""transformers' Lasr feature extractor changed signatures across versions.
The old `_torch_extract_fbank_features` took no `center` arg; the new
one does. Patch over the mismatch so we run on either version."""
try:
from transformers.models.lasr.feature_extraction_lasr import LasrFeatureExtractor
import inspect
sig = inspect.signature(LasrFeatureExtractor._torch_extract_fbank_features)
if "center" not in sig.parameters:
_original = LasrFeatureExtractor._torch_extract_fbank_features
def _patched(self, waveform, device="cpu", center=True):
return _original(self, waveform, device)
LasrFeatureExtractor._torch_extract_fbank_features = _patched
logger.info("Applied LasrFeatureExtractor monkey-patch for 'center' arg")
except ImportError:
pass
def _ensure_kenlm():
"""Download radiology.bin from chirag18/radiology-stt-assets if not on
disk. Idempotent β fast no-op when the file is already present (e.g.
after the first cold boot, subsequent restarts hit the persisted layer).
Runs at startup instead of in the Dockerfile so:
1. Build-time network restrictions don't fail the image.
2. /health can surface a clear "downloading" vs "ready" status.
3. The LM file can be hot-swapped on the HF repo without rebuilding."""
kenlm_path = os.environ.get("KENLM_PATH", "/app/radiology.bin")
url = os.environ.get(
"KENLM_URL",
"https://huggingface.co/chirag18/radiology-stt-assets/resolve/main/radiology.bin",
)
# Always re-download on startup. The earlier size-check approach was
# flaky (urllib HEAD with HF's xet/CDN redirect chain was unreliable β
# ended up trusting the stale local file). Trading ~1 min of startup
# time for "the LM you uploaded is the LM you serve."
if os.path.exists(kenlm_path):
old_size = os.path.getsize(kenlm_path) / 1048576
logger.info("Removing stale KenLM at %s (%.1f MB) for fresh download.",
kenlm_path, old_size)
os.remove(kenlm_path)
logger.info("Downloading KenLM from %s ...", url)
import urllib.request
t0 = time.monotonic()
tmp = kenlm_path + ".part"
try:
urllib.request.urlretrieve(url, tmp)
os.replace(tmp, kenlm_path)
except Exception as e:
if os.path.exists(tmp):
os.remove(tmp)
logger.warning("KenLM download failed (%s) β server will fall back to "
"non-LM beam search.", e)
return
size_mb = os.path.getsize(kenlm_path) / 1048576
logger.info("KenLM downloaded: %.1f MB in %.1fs", size_mb, time.monotonic() - t0)
def _build_decoder():
"""Construct a pyctcdecode beam-search decoder from the model's vocab.
If a KenLM binary is present at the path specified by KENLM_PATH (default
/app/radiology.bin), it's used for shallow fusion at decode time β
boosting candidate transcriptions that contain likely radiology word
sequences. Trained on 731K in-domain reports (~111M words). Tunable
weights via env vars KENLM_ALPHA (LM weight, default 0.5) and
KENLM_BETA (word-insertion bonus, default 1.5).
"""
from pyctcdecode import build_ctcdecoder
# Match the decoder labels to the model's actual CTC output dimension,
# NOT the tokenizer's full vocab β the tokenizer often includes special
# tokens (pad, bos, eos, ...) that aren't part of the CTC head. A label
# count mismatch makes pyctcdecode raise on every decode call.
output_dim = model.config.vocab_size
labels = processor.tokenizer.convert_ids_to_tokens(list(range(output_dim)))
# Pyctcdecode auto-inserts a CTC blank if it doesn't see one it
# recognizes ("", "<pad>", or "<blank>"). MedASR's blank is the
# tokenizer's pad token, often named something else; rename it to ""
# so pyctcdecode treats it as blank instead of growing the vocab by 1.
blank_id = processor.tokenizer.pad_token_id
if blank_id is not None and 0 <= blank_id < len(labels):
labels[blank_id] = ""
logger.info("Decoder labels: %d, blank at id=%s, sample=%s", len(labels), blank_id, labels[:6])
# Optional KenLM shallow fusion. Setting KENLM_ALPHA=0 (or any value <= 0)
# COMPLETELY bypasses the LM β pyctcdecode's alpha=0 still applies LM-
# related side effects on beam allocation/vocab, so to truly disable we
# call build_ctcdecoder without kenlm_model_path at all.
kenlm_path = os.environ.get("KENLM_PATH", "/app/radiology.bin")
alpha = float(os.environ.get("KENLM_ALPHA", "0.05"))
beta = float(os.environ.get("KENLM_BETA", "1.0"))
if alpha > 0 and os.path.exists(kenlm_path):
size_mb = os.path.getsize(kenlm_path) / 1048576
logger.info("Loading KenLM (%.0f MB) from %s, alpha=%.2f, beta=%.2f",
size_mb, kenlm_path, alpha, beta)
return build_ctcdecoder(labels, kenlm_model_path=kenlm_path,
alpha=alpha, beta=beta)
if not os.path.exists(kenlm_path):
logger.info("No KenLM at %s β using non-LM beam-search decoder.", kenlm_path)
else:
logger.info("KENLM_ALPHA<=0 β bypassing LM (non-LM beam-search decoder).")
return build_ctcdecoder(labels)
def load_model():
"""Load MedASR weights, build the beam-search decoder."""
global model, processor, decoder
token = os.environ.get("HF_TOKEN")
if not token:
raise RuntimeError("HF_TOKEN secret required for gated MedASR model")
_patch_lasr_feature_extractor()
from transformers import AutoModelForCTC, AutoProcessor
logger.info("Loading MedASR model...")
processor = AutoProcessor.from_pretrained("google/medasr", token=token)
model = AutoModelForCTC.from_pretrained("google/medasr", token=token)
model.eval()
global DEVICE
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if DEVICE == "cuda":
model = model.to("cuda")
logger.info("Model moved to CUDA (fp32). GPU=%s", torch.cuda.get_device_name(0))
else:
torch.set_num_threads(4)
logger.info("Running on CPU (4 threads)")
logger.info("Building CTC beam-search decoder...")
_ensure_kenlm() # downloads the LM if not already on disk
decoder = _build_decoder()
logger.info("MedASR ready (vocab=%d, beam=%d, hotwords=%d).",
len(processor.tokenizer.get_vocab()), DEFAULT_BEAM_WIDTH,
len(RADIOLOGY_HOTWORDS))
# ---------------------------------------------------------------------------
# Audio in -> logits -> text
# ---------------------------------------------------------------------------
_rescore_lm = None # kenlm.Model loaded lazily for N-best rescoring
def _get_rescore_lm():
"""Lazy-load the KenLM model for N-best rescoring. Independent of the
decoder β we score complete candidate strings rather than influencing
the beam search at every CTC frame (which empirically broke decoding
by eating word-start characters)."""
global _rescore_lm
if _rescore_lm is None:
import kenlm
path = os.environ.get("KENLM_PATH", "/app/radiology.bin")
if os.path.exists(path):
_rescore_lm = kenlm.Model(path)
logger.info("N-best rescoring LM loaded from %s", path)
return _rescore_lm
def _decode_logits(logits_np: np.ndarray) -> str:
"""Beam-search decode with radiology hotwords, optionally N-best LM rescore.
When RESCORE_ALPHA > 0 and the LM file exists, we decode the top-N
beam candidates (each is a complete hypothesis the acoustic model
considers plausible), then score each with the radiology KenLM and
pick the combined-best. Sidesteps the shallow-fusion-with-CTC
interference that broke per-frame integration."""
rescore_alpha = float(os.environ.get("RESCORE_ALPHA", "0"))
rescore_n = int(os.environ.get("RESCORE_N", "8"))
if rescore_alpha > 0:
rescore_lm = _get_rescore_lm()
if rescore_lm is not None:
beams = decoder.decode_beams(
logits_np,
beam_width=max(rescore_n, DEFAULT_BEAM_WIDTH),
hotwords=RADIOLOGY_HOTWORDS,
hotword_weight=DEFAULT_HOTWORD_WEIGHT,
)
# Each entry: (text, last_word_state, frames, logit_score, lm_score)
# Without LM in decoder, lm_score is ~0. We replace with our own.
best = None
best_combined = -float("inf")
for entry in beams[:rescore_n]:
text = entry[0]
logit_score = entry[3] if len(entry) > 3 else 0.0
# Score against radiology LM (lowercase, full string).
lm_score = rescore_lm.score(text.lower(), bos=True, eos=True)
combined = logit_score + rescore_alpha * lm_score
if combined > best_combined:
best_combined = combined
best = text
if best is not None:
return _replace_spoken_punctuation(best.strip())
text = decoder.decode(
logits_np,
beam_width=DEFAULT_BEAM_WIDTH,
hotwords=RADIOLOGY_HOTWORDS,
hotword_weight=DEFAULT_HOTWORD_WEIGHT,
)
return _replace_spoken_punctuation(text.strip())
def _samples_to_text(samples: np.ndarray, sample_rate: int) -> str:
if samples.size == 0:
return ""
if sample_rate != 16000:
# Inputs at the wrong rate would silently produce gibberish β resample.
import librosa # only needed if a non-16k client sneaks in
samples = librosa.resample(samples, orig_sr=sample_rate, target_sr=16000)
inputs = processor(samples, sampling_rate=16000, return_tensors="pt", padding=True)
if DEVICE == "cuda":
inputs = {k: v.to("cuda") for k, v in inputs.items()}
with torch.inference_mode():
logits = model(**inputs).logits
return _decode_logits(logits[0].float().cpu().numpy())
def convert_to_wav(audio_bytes: bytes) -> bytes:
"""Decode container-formatted audio (WebM/OGG/etc.) to 16-kHz mono WAV."""
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as src:
src.write(audio_bytes)
src_path = src.name
dst_path = src_path.rsplit(".", 1)[0] + ".wav"
try:
subprocess.run(
["ffmpeg", "-y", "-i", src_path, "-ar", "16000", "-ac", "1", "-f", "wav", dst_path],
capture_output=True, check=True, timeout=30,
)
with open(dst_path, "rb") as f:
return f.read()
finally:
for p in (src_path, dst_path):
try:
os.unlink(p)
except OSError:
pass
# ---------------------------------------------------------------------------
# FastAPI
# ---------------------------------------------------------------------------
app = FastAPI(title="MedASR Server")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
@app.on_event("startup")
def startup():
load_model()
@app.get("/health")
def health():
kenlm_path = os.environ.get("KENLM_PATH", "/app/radiology.bin")
kenlm_loaded = os.path.exists(kenlm_path)
return {
"status": "ok",
"model_loaded": model is not None,
"decoder_ready": decoder is not None,
"beam_width": DEFAULT_BEAM_WIDTH,
"hotwords": len(RADIOLOGY_HOTWORDS),
"kenlm_loaded": kenlm_loaded,
"kenlm_size_mb": round(os.path.getsize(kenlm_path) / 1048576, 1) if kenlm_loaded else 0,
"kenlm_alpha": float(os.environ.get("KENLM_ALPHA", "0.5")) if kenlm_loaded else None,
"kenlm_beta": float(os.environ.get("KENLM_BETA", "1.5")) if kenlm_loaded else None,
}
@app.post("/transcribe")
async def transcribe_audio(audio: UploadFile = File(...)):
"""Legacy endpoint: accepts WebM/OGG via FormData. Decodes via ffmpeg
then runs the same beam-search pipeline as /transcribe-pcm."""
contents = await audio.read()
if len(contents) == 0:
raise HTTPException(400, "Empty audio file")
if len(contents) > 20 * 1024 * 1024:
raise HTTPException(413, "Audio too large (max 20 MB)")
t0 = time.monotonic()
wav_bytes = convert_to_wav(contents)
waveform, sr = sf.read(BytesIO(wav_bytes), dtype="float32")
if waveform.ndim > 1:
waveform = waveform.mean(axis=1)
text = _samples_to_text(waveform, sr)
elapsed = time.monotonic() - t0
logger.info("Transcribed (webm) in %.2fs: '%s'", elapsed, text[:100])
return {"text": text, "duration_seconds": round(elapsed, 2)}
@app.post("/transcribe-pcm")
async def transcribe_pcm(
audio: UploadFile = File(...),
sample_rate: int = Form(16000),
pcm_format: Literal["float32", "int16"] = Form("float32"),
):
"""Preferred endpoint: accepts raw mono PCM at 16 kHz. The browser
sends the bytes of a Float32Array directly β no ffmpeg, no encoder
overhead, no transcoder lossiness. Per-segment latency is dominated
by the model forward pass (~80β300 ms for typical sentence audio
on the Space's CPU)."""
contents = await audio.read()
if len(contents) == 0:
raise HTTPException(400, "Empty audio")
if len(contents) > 32 * 1024 * 1024:
raise HTTPException(413, "PCM too large (max 32 MB)")
if pcm_format == "int16":
samples = np.frombuffer(contents, dtype=np.int16).astype(np.float32) / 32768.0
else:
samples = np.frombuffer(contents, dtype=np.float32).copy()
t0 = time.monotonic()
text = _samples_to_text(samples, sample_rate)
elapsed = time.monotonic() - t0
logger.info("Transcribed (pcm, %d samples @%d Hz) in %.2fs: '%s'",
samples.size, sample_rate, elapsed, text[:100])
return {"text": text, "duration_seconds": round(elapsed, 2), "samples": int(samples.size)}
@app.post("/transcribe-date-pcm")
async def transcribe_date_pcm(
audio: UploadFile = File(...),
sample_rate: int = Form(16000),
pcm_format: Literal["float32", "int16"] = Form("float32"),
):
"""Experimental date-only sidecar transcription.
This endpoint deliberately does NOT replace /transcribe-pcm. It exists to
test a safer architecture: keep MedASR as the radiology transcript, but ask
a general ASR model to identify date phrases from the same short audio
segment. The client/server can then replace only a high-confidence date
span, leaving all radiology wording and measurements from MedASR untouched.
"""
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
raise HTTPException(500, "OPENAI_API_KEY not configured on the Space")
contents = await audio.read()
if len(contents) == 0:
raise HTTPException(400, "Empty audio")
if len(contents) > 32 * 1024 * 1024:
raise HTTPException(413, "PCM too large (max 32 MB)")
if pcm_format == "int16":
samples = np.frombuffer(contents, dtype=np.int16).astype(np.float32) / 32768.0
else:
samples = np.frombuffer(contents, dtype=np.float32).copy()
wav = BytesIO()
sf.write(wav, samples, sample_rate, format="WAV", subtype="PCM_16")
wav.seek(0)
prompt = (
"Radiology dictation. Transcribe dates carefully. Preserve month names, "
"days, and four-digit years. Examples: July twenty fifth twenty twenty "
"six; February fifth two thousand eighteen; prior study dated March 12 2025."
)
t0 = time.monotonic()
try:
r = requests.post(
"https://api.openai.com/v1/audio/transcriptions",
headers={"Authorization": f"Bearer {api_key}"},
data={
"model": os.environ.get("OPENAI_TRANSCRIBE_MODEL",
"gpt-4o-mini-transcribe"),
"response_format": "json",
"prompt": prompt,
},
files={"file": ("segment.wav", wav.getvalue(), "audio/wav")},
timeout=60,
)
except requests.RequestException as e:
raise HTTPException(502, f"OpenAI transcription request failed: {e}")
elapsed = time.monotonic() - t0
if not r.ok:
raise HTTPException(r.status_code, r.text[:1000])
data = r.json()
text = (data.get("text") or "").strip()
structured_date = None
if os.environ.get("OPENAI_DATE_EXTRACTOR", "0").lower() in {"1", "true", "on", "yes"}:
structured_date = _extract_structured_date(text, api_key)
logger.info("OpenAI date sidecar (pcm, %d samples @%d Hz) in %.2fs: '%s'",
samples.size, sample_rate, elapsed, text[:100])
return {
"text": text,
"date": structured_date,
"duration_seconds": round(elapsed, 2),
"samples": int(samples.size),
}
@app.post("/extract-date-text")
async def extract_date_text(payload: dict):
"""Date-only JSON extractor for offline evaluation of cached sidecar text."""
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
raise HTTPException(500, "OPENAI_API_KEY not configured on the Space")
text = str((payload or {}).get("text") or "").strip()
if not text:
raise HTTPException(400, "Missing text")
return {"text": text, "date": _extract_structured_date(text, api_key)}
# ---------------------------------------------------------------------------
# OpenAI Realtime ephemeral token minter (unchanged)
# ---------------------------------------------------------------------------
OPENAI_TRANSCRIPTION_PROMPT = (
"Medical radiology dictation. Common terms include: lungs, chest, CT, MRI, "
"X-ray, ultrasound, contrast, lesion, mass, nodule, opacity, consolidation, "
"effusion, pneumothorax, atelectasis, lymphadenopathy, hilar, mediastinal, "
"pulmonary, parenchymal, abdomen, pelvis, liver, spleen, kidney, hydronephrosis, "
"cyst, fracture, displacement, alignment, vertebrae, lumbar, thoracic, cervical, "
"spine, brain, intracranial, hemorrhage, infarct, edema, stenosis, occlusion, "
"calcification, enhancement."
)
@app.post("/openai-token")
@app.get("/openai-token")
def openai_token():
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
raise HTTPException(500, "OPENAI_API_KEY not configured on the Space")
body = _json.dumps({
"input_audio_format": "pcm16",
"input_audio_transcription": {
"model": "gpt-4o-mini-transcribe",
"language": "en",
"prompt": OPENAI_TRANSCRIPTION_PROMPT,
},
"input_audio_noise_reduction": {"type": "near_field"},
"turn_detection": {
"type": "server_vad",
"threshold": 0.4,
"prefix_padding_ms": 200,
"silence_duration_ms": 180,
},
}).encode("utf-8")
req = urllib.request.Request(
"https://api.openai.com/v1/realtime/transcription_sessions",
data=body,
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"OpenAI-Beta": "realtime=v1",
},
method="POST",
)
try:
with urllib.request.urlopen(req, timeout=10) as resp:
return _json.loads(resp.read().decode("utf-8"))
except urllib.error.HTTPError as e:
detail = e.read().decode("utf-8", errors="replace")
raise HTTPException(e.code, f"OpenAI error: {detail}")
except Exception as e:
raise HTTPException(500, f"OpenAI request failed: {e}")
@app.post("/deepgram-token")
@app.get("/deepgram-token")
def deepgram_token():
api_key = os.environ.get("DEEPGRAM_API_KEY")
if not api_key:
raise HTTPException(500, "DEEPGRAM_API_KEY not configured on the Space")
body = _json.dumps({"ttl_seconds": 30}).encode("utf-8")
req = urllib.request.Request(
"https://api.deepgram.com/v1/auth/grant",
data=body,
headers={
"Authorization": f"Token {api_key}",
"Content-Type": "application/json",
},
method="POST",
)
try:
with urllib.request.urlopen(req, timeout=10) as resp:
return _json.loads(resp.read().decode("utf-8"))
except urllib.error.HTTPError as e:
detail = e.read().decode("utf-8", errors="replace")
raise HTTPException(e.code, f"Deepgram error: {detail}")
except Exception as e:
raise HTTPException(500, f"Deepgram request failed: {e}")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)
|