Spaces:
Sleeping
Sleeping
| """MedASR server for Hugging Face Spaces. | |
| Hosts Google MedASR with **CTC beam search + radiology hotwords** so the | |
| accuracy lifts well above the greedy-decoded floor (~6.6% WER) we hit when | |
| the model ran in the browser. Google's published 4.6% WER number on | |
| RAD-DICT uses beam-8 plus a 6-gram language model; we use beam-8 plus a | |
| ~150-term radiology hotword list (the LM has not been publicly released | |
| yet — when it ships we can drop it into `decoder.decode(lm=...)`). | |
| Endpoints: | |
| GET /health — model + decoder readiness | |
| POST /transcribe — multipart WebM/OGG audio (legacy path) | |
| POST /transcribe-pcm — raw Float32 / Int16 PCM mono @ 16 kHz (preferred) | |
| POST /openai-token — mint OpenAI Realtime ephemeral token | |
| POST /deepgram-token — mint Deepgram ephemeral token | |
| `HF_TOKEN` must be set as a Space secret (Google MedASR is gated). | |
| """ | |
| import json as _json | |
| import logging | |
| import os | |
| import re | |
| import subprocess | |
| import tempfile | |
| import time | |
| import urllib.error | |
| import urllib.request | |
| from io import BytesIO | |
| from typing import Literal | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| import uvicorn | |
| from fastapi import FastAPI, File, Form, HTTPException, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("medasr") | |
| # --------------------------------------------------------------------------- | |
| # Spoken punctuation | |
| # --------------------------------------------------------------------------- | |
| _SPOKEN_PUNCTUATION: list[tuple[re.Pattern, str]] = [ | |
| (re.compile(r"\bopen paren(?:thesis)?\b", re.IGNORECASE), "("), | |
| (re.compile(r"\bclose paren(?:thesis)?\b", re.IGNORECASE), ")"), | |
| (re.compile(r"\bopen bracket\b", re.IGNORECASE), "["), | |
| (re.compile(r"\bclose bracket\b", re.IGNORECASE), "]"), | |
| (re.compile(r"\bnew line\b", re.IGNORECASE), "\n"), | |
| (re.compile(r"\bforward slash\b", re.IGNORECASE), "/"), | |
| (re.compile(r"\bcomma\b", re.IGNORECASE), ","), | |
| (re.compile(r"\bperiod\b", re.IGNORECASE), "."), | |
| (re.compile(r"\bfull stop\b", re.IGNORECASE), "."), | |
| (re.compile(r"\bquestion mark\b", re.IGNORECASE), "?"), | |
| (re.compile(r"\bexclamation (?:mark|point)\b", re.IGNORECASE), "!"), | |
| (re.compile(r"\bcolon\b", re.IGNORECASE), ":"), | |
| (re.compile(r"\bsemicolon\b", re.IGNORECASE), ";"), | |
| (re.compile(r"\bhyphen\b", re.IGNORECASE), "-"), | |
| (re.compile(r"\bdash\b", re.IGNORECASE), " — "), | |
| (re.compile(r"\bslash\b", re.IGNORECASE), "/"), | |
| (re.compile(r"\bplus\b", re.IGNORECASE), "+"), | |
| (re.compile(r"\bampersand\b", re.IGNORECASE), "&"), | |
| ] | |
| def _replace_spoken_punctuation(text: str) -> str: | |
| # MedASR uses special tokens for sentence boundaries during dictation. | |
| # Pyctcdecode emits them as literal text; rewrite to real punctuation. | |
| text = text.replace("</s>", ". ") | |
| text = text.replace("<s>", "") | |
| text = text.replace("<unk>", "") | |
| logger.info("Pre-replace raw text repr: %r", text[:200]) | |
| # MedASR's punctuation tokens may carry the SentencePiece word-prefix "▁", | |
| # and the braces themselves may be ASCII curly or fullwidth variants. | |
| text = re.sub(r"▁?[\{⦃{]▁?([,.:;!?/\-+])▁?[\}⦄}]", r"\1", text) | |
| text = re.sub(r"\{([,.:;!?/\-+])\}", r"\1", text) | |
| for pattern, replacement in _SPOKEN_PUNCTUATION: | |
| text = pattern.sub(replacement, text) | |
| # NB: \s would also match \n that we just inserted via "new line" — use | |
| # [ \t] so the newline survives to the client. | |
| text = re.sub(r"[ \t]+([,.:;!?)\]])", r"\1", text) | |
| text = re.sub(r"([([\[])[ \t]+", r"\1", text) | |
| text = re.sub(r" +", " ", text) | |
| # Trim spaces/tabs/CR only — preserving \n means a segment that's just | |
| # "new line" (now "\n") doesn't get wiped to empty and discarded by the | |
| # client as a zero-length segment. | |
| return re.sub(r"^[ \t\r]+|[ \t\r]+$", "", text) | |
| # --------------------------------------------------------------------------- | |
| # Radiology hotwords. Each term gets a likelihood boost during beam search, | |
| # which is what fixes the "edema -> aa" / "Borderline -> Remaining" failures | |
| # we saw in real dictation. Keep this list focused — too many hotwords pulls | |
| # accuracy back down on common non-medical words. | |
| # --------------------------------------------------------------------------- | |
| RADIOLOGY_HOTWORDS: list[str] = [ | |
| # Cardiac / mediastinum | |
| "cardiomegaly", "borderline", "epicardial", "pericardial", "myocardial", | |
| "mediastinal", "hilar", "perihilar", "aortic", "aorta", "pulmonary", | |
| # Lung | |
| "edema", "effusion", "effusions", "consolidation", "atelectasis", "atelectatic", | |
| "pneumothorax", "pleural", "parenchyma", "parenchymal", "interstitial", | |
| "groundglass", "ground-glass", "opacity", "opacities", "nodule", "nodules", | |
| "mass", "empyema", "pneumonia", "hemothorax", "bronchiectasis", "fibrosis", | |
| "emphysema", "subpleural", "centrilobular", "tree-in-bud", "honeycombing", | |
| "reticulonodular", | |
| # Airways | |
| "tracheal", "trachea", "bronchial", "bronchovascular", "fissural", "fissure", | |
| "peribronchial", "peribronchovascular", | |
| # Chest wall / pleura / diaphragm | |
| "diaphragm", "diaphragmatic", "costophrenic", "subcutaneous", "thoracic", | |
| "subdiaphragmatic", "hemidiaphragm", | |
| # Abdominal anatomy | |
| "abdomen", "abdominal", "pelvis", "pelvic", "retroperitoneum", "retroperitoneal", | |
| "mesenteric", "mesentery", "paraaortic", "periaortic", "porta", "portal", | |
| "hepatic", "splenic", "renal", "adrenal", "pancreatic", "biliary", | |
| "gallbladder", "duodenal", "gastric", "intestinal", "colonic", "rectal", | |
| "uterine", "ovarian", "prostatic", | |
| # Brain / neuro | |
| "intracranial", "extracranial", "subdural", "epidural", "subarachnoid", | |
| "cerebral", "cerebellar", "brainstem", "thalamic", "lentiform", "caudate", | |
| "ventricular", "ventricle", "ventricles", "periventricular", | |
| # Spine | |
| "vertebra", "vertebrae", "vertebral", "lumbar", "cervical", | |
| # Vascular | |
| "stenosis", "occlusion", "occluded", "thrombosis", "embolism", "embolus", | |
| "aneurysm", "dissection", "atherosclerosis", "atherosclerotic", | |
| "calcified", "noncalcified", "calcification", "calcifications", | |
| # MR / CT signal terms | |
| "enhancement", "enhancing", "nonenhancing", "T1", "T2", "FLAIR", "DWI", | |
| "ADC", "STIR", "hypoechoic", "hyperechoic", "isoechoic", "anechoic", | |
| "echogenic", "hypoattenuating", "hyperattenuating", "hypodense", "hyperdense", | |
| "isointense", "hypointense", "hyperintense", "shadowing", | |
| # General findings | |
| "hemorrhage", "infarct", "infarction", "ischemia", "lesion", "lesions", | |
| "lymphadenopathy", "lymph", "hematoma", "fluid", "edematous", "swelling", | |
| # Common verbiage | |
| "unremarkable", "compatible", "consistent", "suggestive", "concerning", | |
| "noted", "demonstrates", "demonstrated", "evidence", "prominent", | |
| ] | |
| # --------------------------------------------------------------------------- | |
| # MedASR model + CTC decoder | |
| # --------------------------------------------------------------------------- | |
| model = None | |
| processor = None | |
| decoder = None # pyctcdecode beam-search decoder | |
| DEVICE = "cpu" # set by load_model() to "cuda" when available | |
| DEFAULT_BEAM_WIDTH = 4 | |
| DEFAULT_HOTWORD_WEIGHT = 5.0 | |
| # Hotwords mined from the 731K-report corpus that weren't in the original | |
| # RADIOLOGY_HOTWORDS list — high-frequency medical terms the decoder | |
| # probably underweights today. Includes specific known-failure terms from | |
| # the offline test set (homologue, aorticopulmonary, intercalated). | |
| CORPUS_HOTWORDS: list[str] = [ | |
| # Bigger anatomy / pathology nouns | |
| "indications", "indication", "abnormality", "abnormalities", "abnormal", | |
| "fracture", "fractures", "narrowing", "thickening", "dilatation", | |
| "enlargement", "enlarged", "compression", "protrusion", "obstruction", | |
| "dislocation", "distortion", "osteoarthritis", "osteophyte", "endplate", | |
| "ligament", "ligaments", "cartilage", "vasculature", "vascular", | |
| "arteries", "pancreas", "adrenals", "mediastinum", "meniscus", | |
| "shoulder", "foramina", "foraminal", "silhouette", "alignment", | |
| "paraspinal", "multilevel", "multiplanar", "multidetector", | |
| "hydronephrosis", "arthropathy", "hypertrophy", "adenopathy", | |
| "microcalcifications", "fibroglandular", "heterogeneously", | |
| # Modifiers + descriptors | |
| "bilateral", "moderate", "anterior", "posterior", "inferior", "sagittal", | |
| "proximal", "scattered", "visualized", "measuring", "diameter", "thickness", | |
| "reformatted", "reconstruction", "modulation", "administered", "supplemental", | |
| "interval", "diagnostic", "intravenous", "suspicious", "malignancy", | |
| "degenerative", "coronary", "sonographic", "ultrasound", "tomosynthesis", | |
| "mammogram", "mammography", "mammographic", "radiograph", "radiology", | |
| "architectural", "migrated", | |
| # Specific high-value terms missed on our offline test set | |
| "homologue", "homologous", "aorticopulmonary", "intercalated", | |
| "modic", "spondylolisthesis", | |
| ] | |
| # Calendar months — the decoder badly mis-hears spoken dates ("july" came | |
| # back as "ul"/"Ja" in testing). Boosting the month names is the cheapest | |
| # lever to try before considering a fine-tune. | |
| DATE_HOTWORDS: list[str] = [ | |
| "january", "february", "march", "april", "may", "june", | |
| "july", "august", "september", "october", "november", "december", | |
| ] | |
| # Merge corpus-mined + date adds into RADIOLOGY_HOTWORDS (defined above), | |
| # dedupe. | |
| _seen = set(RADIOLOGY_HOTWORDS) | |
| for _w in CORPUS_HOTWORDS + DATE_HOTWORDS: | |
| if _w not in _seen: | |
| RADIOLOGY_HOTWORDS.append(_w) | |
| _seen.add(_w) | |
| del _seen, _w | |
| def _patch_lasr_feature_extractor(): | |
| """transformers' Lasr feature extractor changed signatures across versions. | |
| The old `_torch_extract_fbank_features` took no `center` arg; the new | |
| one does. Patch over the mismatch so we run on either version.""" | |
| try: | |
| from transformers.models.lasr.feature_extraction_lasr import LasrFeatureExtractor | |
| import inspect | |
| sig = inspect.signature(LasrFeatureExtractor._torch_extract_fbank_features) | |
| if "center" not in sig.parameters: | |
| _original = LasrFeatureExtractor._torch_extract_fbank_features | |
| def _patched(self, waveform, device="cpu", center=True): | |
| return _original(self, waveform, device) | |
| LasrFeatureExtractor._torch_extract_fbank_features = _patched | |
| logger.info("Applied LasrFeatureExtractor monkey-patch for 'center' arg") | |
| except ImportError: | |
| pass | |
| def _ensure_kenlm(): | |
| """Download radiology.bin from chirag18/radiology-stt-assets if not on | |
| disk. Idempotent — fast no-op when the file is already present (e.g. | |
| after the first cold boot, subsequent restarts hit the persisted layer). | |
| Runs at startup instead of in the Dockerfile so: | |
| 1. Build-time network restrictions don't fail the image. | |
| 2. /health can surface a clear "downloading" vs "ready" status. | |
| 3. The LM file can be hot-swapped on the HF repo without rebuilding.""" | |
| kenlm_path = os.environ.get("KENLM_PATH", "/app/radiology.bin") | |
| url = os.environ.get( | |
| "KENLM_URL", | |
| "https://huggingface.co/chirag18/radiology-stt-assets/resolve/main/radiology.bin", | |
| ) | |
| # Always re-download on startup. The earlier size-check approach was | |
| # flaky (urllib HEAD with HF's xet/CDN redirect chain was unreliable — | |
| # ended up trusting the stale local file). Trading ~1 min of startup | |
| # time for "the LM you uploaded is the LM you serve." | |
| if os.path.exists(kenlm_path): | |
| old_size = os.path.getsize(kenlm_path) / 1048576 | |
| logger.info("Removing stale KenLM at %s (%.1f MB) for fresh download.", | |
| kenlm_path, old_size) | |
| os.remove(kenlm_path) | |
| logger.info("Downloading KenLM from %s ...", url) | |
| import urllib.request | |
| t0 = time.monotonic() | |
| tmp = kenlm_path + ".part" | |
| try: | |
| urllib.request.urlretrieve(url, tmp) | |
| os.replace(tmp, kenlm_path) | |
| except Exception as e: | |
| if os.path.exists(tmp): | |
| os.remove(tmp) | |
| logger.warning("KenLM download failed (%s) — server will fall back to " | |
| "non-LM beam search.", e) | |
| return | |
| size_mb = os.path.getsize(kenlm_path) / 1048576 | |
| logger.info("KenLM downloaded: %.1f MB in %.1fs", size_mb, time.monotonic() - t0) | |
| def _build_decoder(): | |
| """Construct a pyctcdecode beam-search decoder from the model's vocab. | |
| If a KenLM binary is present at the path specified by KENLM_PATH (default | |
| /app/radiology.bin), it's used for shallow fusion at decode time — | |
| boosting candidate transcriptions that contain likely radiology word | |
| sequences. Trained on 731K in-domain reports (~111M words). Tunable | |
| weights via env vars KENLM_ALPHA (LM weight, default 0.5) and | |
| KENLM_BETA (word-insertion bonus, default 1.5). | |
| """ | |
| from pyctcdecode import build_ctcdecoder | |
| # Match the decoder labels to the model's actual CTC output dimension, | |
| # NOT the tokenizer's full vocab — the tokenizer often includes special | |
| # tokens (pad, bos, eos, ...) that aren't part of the CTC head. A label | |
| # count mismatch makes pyctcdecode raise on every decode call. | |
| output_dim = model.config.vocab_size | |
| labels = processor.tokenizer.convert_ids_to_tokens(list(range(output_dim))) | |
| # Pyctcdecode auto-inserts a CTC blank if it doesn't see one it | |
| # recognizes ("", "<pad>", or "<blank>"). MedASR's blank is the | |
| # tokenizer's pad token, often named something else; rename it to "" | |
| # so pyctcdecode treats it as blank instead of growing the vocab by 1. | |
| blank_id = processor.tokenizer.pad_token_id | |
| if blank_id is not None and 0 <= blank_id < len(labels): | |
| labels[blank_id] = "" | |
| logger.info("Decoder labels: %d, blank at id=%s, sample=%s", len(labels), blank_id, labels[:6]) | |
| # Optional KenLM shallow fusion. Setting KENLM_ALPHA=0 (or any value <= 0) | |
| # COMPLETELY bypasses the LM — pyctcdecode's alpha=0 still applies LM- | |
| # related side effects on beam allocation/vocab, so to truly disable we | |
| # call build_ctcdecoder without kenlm_model_path at all. | |
| kenlm_path = os.environ.get("KENLM_PATH", "/app/radiology.bin") | |
| alpha = float(os.environ.get("KENLM_ALPHA", "0.05")) | |
| beta = float(os.environ.get("KENLM_BETA", "1.0")) | |
| if alpha > 0 and os.path.exists(kenlm_path): | |
| size_mb = os.path.getsize(kenlm_path) / 1048576 | |
| logger.info("Loading KenLM (%.0f MB) from %s, alpha=%.2f, beta=%.2f", | |
| size_mb, kenlm_path, alpha, beta) | |
| return build_ctcdecoder(labels, kenlm_model_path=kenlm_path, | |
| alpha=alpha, beta=beta) | |
| if not os.path.exists(kenlm_path): | |
| logger.info("No KenLM at %s — using non-LM beam-search decoder.", kenlm_path) | |
| else: | |
| logger.info("KENLM_ALPHA<=0 — bypassing LM (non-LM beam-search decoder).") | |
| return build_ctcdecoder(labels) | |
| def load_model(): | |
| """Load MedASR weights, build the beam-search decoder.""" | |
| global model, processor, decoder | |
| token = os.environ.get("HF_TOKEN") | |
| if not token: | |
| raise RuntimeError("HF_TOKEN secret required for gated MedASR model") | |
| _patch_lasr_feature_extractor() | |
| from transformers import AutoModelForCTC, AutoProcessor | |
| logger.info("Loading MedASR model...") | |
| processor = AutoProcessor.from_pretrained("google/medasr", token=token) | |
| model = AutoModelForCTC.from_pretrained("google/medasr", token=token) | |
| model.eval() | |
| global DEVICE | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| if DEVICE == "cuda": | |
| model = model.to("cuda") | |
| logger.info("Model moved to CUDA (fp32). GPU=%s", torch.cuda.get_device_name(0)) | |
| else: | |
| torch.set_num_threads(4) | |
| logger.info("Running on CPU (4 threads)") | |
| logger.info("Building CTC beam-search decoder...") | |
| _ensure_kenlm() # downloads the LM if not already on disk | |
| decoder = _build_decoder() | |
| logger.info("MedASR ready (vocab=%d, beam=%d, hotwords=%d).", | |
| len(processor.tokenizer.get_vocab()), DEFAULT_BEAM_WIDTH, | |
| len(RADIOLOGY_HOTWORDS)) | |
| # --------------------------------------------------------------------------- | |
| # Audio in -> logits -> text | |
| # --------------------------------------------------------------------------- | |
| _rescore_lm = None # kenlm.Model loaded lazily for N-best rescoring | |
| def _get_rescore_lm(): | |
| """Lazy-load the KenLM model for N-best rescoring. Independent of the | |
| decoder — we score complete candidate strings rather than influencing | |
| the beam search at every CTC frame (which empirically broke decoding | |
| by eating word-start characters).""" | |
| global _rescore_lm | |
| if _rescore_lm is None: | |
| import kenlm | |
| path = os.environ.get("KENLM_PATH", "/app/radiology.bin") | |
| if os.path.exists(path): | |
| _rescore_lm = kenlm.Model(path) | |
| logger.info("N-best rescoring LM loaded from %s", path) | |
| return _rescore_lm | |
| def _decode_logits(logits_np: np.ndarray) -> str: | |
| """Beam-search decode with radiology hotwords, optionally N-best LM rescore. | |
| When RESCORE_ALPHA > 0 and the LM file exists, we decode the top-N | |
| beam candidates (each is a complete hypothesis the acoustic model | |
| considers plausible), then score each with the radiology KenLM and | |
| pick the combined-best. Sidesteps the shallow-fusion-with-CTC | |
| interference that broke per-frame integration.""" | |
| rescore_alpha = float(os.environ.get("RESCORE_ALPHA", "0")) | |
| rescore_n = int(os.environ.get("RESCORE_N", "8")) | |
| if rescore_alpha > 0: | |
| rescore_lm = _get_rescore_lm() | |
| if rescore_lm is not None: | |
| beams = decoder.decode_beams( | |
| logits_np, | |
| beam_width=max(rescore_n, DEFAULT_BEAM_WIDTH), | |
| hotwords=RADIOLOGY_HOTWORDS, | |
| hotword_weight=DEFAULT_HOTWORD_WEIGHT, | |
| ) | |
| # Each entry: (text, last_word_state, frames, logit_score, lm_score) | |
| # Without LM in decoder, lm_score is ~0. We replace with our own. | |
| best = None | |
| best_combined = -float("inf") | |
| for entry in beams[:rescore_n]: | |
| text = entry[0] | |
| logit_score = entry[3] if len(entry) > 3 else 0.0 | |
| # Score against radiology LM (lowercase, full string). | |
| lm_score = rescore_lm.score(text.lower(), bos=True, eos=True) | |
| combined = logit_score + rescore_alpha * lm_score | |
| if combined > best_combined: | |
| best_combined = combined | |
| best = text | |
| if best is not None: | |
| return _replace_spoken_punctuation(best.strip()) | |
| text = decoder.decode( | |
| logits_np, | |
| beam_width=DEFAULT_BEAM_WIDTH, | |
| hotwords=RADIOLOGY_HOTWORDS, | |
| hotword_weight=DEFAULT_HOTWORD_WEIGHT, | |
| ) | |
| return _replace_spoken_punctuation(text.strip()) | |
| def _samples_to_text(samples: np.ndarray, sample_rate: int) -> str: | |
| if samples.size == 0: | |
| return "" | |
| if sample_rate != 16000: | |
| # Inputs at the wrong rate would silently produce gibberish — resample. | |
| import librosa # only needed if a non-16k client sneaks in | |
| samples = librosa.resample(samples, orig_sr=sample_rate, target_sr=16000) | |
| inputs = processor(samples, sampling_rate=16000, return_tensors="pt", padding=True) | |
| if DEVICE == "cuda": | |
| inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
| with torch.inference_mode(): | |
| logits = model(**inputs).logits | |
| return _decode_logits(logits[0].float().cpu().numpy()) | |
| def convert_to_wav(audio_bytes: bytes) -> bytes: | |
| """Decode container-formatted audio (WebM/OGG/etc.) to 16-kHz mono WAV.""" | |
| with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as src: | |
| src.write(audio_bytes) | |
| src_path = src.name | |
| dst_path = src_path.rsplit(".", 1)[0] + ".wav" | |
| try: | |
| subprocess.run( | |
| ["ffmpeg", "-y", "-i", src_path, "-ar", "16000", "-ac", "1", "-f", "wav", dst_path], | |
| capture_output=True, check=True, timeout=30, | |
| ) | |
| with open(dst_path, "rb") as f: | |
| return f.read() | |
| finally: | |
| for p in (src_path, dst_path): | |
| try: | |
| os.unlink(p) | |
| except OSError: | |
| pass | |
| # --------------------------------------------------------------------------- | |
| # FastAPI | |
| # --------------------------------------------------------------------------- | |
| app = FastAPI(title="MedASR Server") | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) | |
| def startup(): | |
| load_model() | |
| def health(): | |
| kenlm_path = os.environ.get("KENLM_PATH", "/app/radiology.bin") | |
| kenlm_loaded = os.path.exists(kenlm_path) | |
| return { | |
| "status": "ok", | |
| "model_loaded": model is not None, | |
| "decoder_ready": decoder is not None, | |
| "beam_width": DEFAULT_BEAM_WIDTH, | |
| "hotwords": len(RADIOLOGY_HOTWORDS), | |
| "kenlm_loaded": kenlm_loaded, | |
| "kenlm_size_mb": round(os.path.getsize(kenlm_path) / 1048576, 1) if kenlm_loaded else 0, | |
| "kenlm_alpha": float(os.environ.get("KENLM_ALPHA", "0.5")) if kenlm_loaded else None, | |
| "kenlm_beta": float(os.environ.get("KENLM_BETA", "1.5")) if kenlm_loaded else None, | |
| } | |
| async def transcribe_audio(audio: UploadFile = File(...)): | |
| """Legacy endpoint: accepts WebM/OGG via FormData. Decodes via ffmpeg | |
| then runs the same beam-search pipeline as /transcribe-pcm.""" | |
| contents = await audio.read() | |
| if len(contents) == 0: | |
| raise HTTPException(400, "Empty audio file") | |
| if len(contents) > 20 * 1024 * 1024: | |
| raise HTTPException(413, "Audio too large (max 20 MB)") | |
| t0 = time.monotonic() | |
| wav_bytes = convert_to_wav(contents) | |
| waveform, sr = sf.read(BytesIO(wav_bytes), dtype="float32") | |
| if waveform.ndim > 1: | |
| waveform = waveform.mean(axis=1) | |
| text = _samples_to_text(waveform, sr) | |
| elapsed = time.monotonic() - t0 | |
| logger.info("Transcribed (webm) in %.2fs: '%s'", elapsed, text[:100]) | |
| return {"text": text, "duration_seconds": round(elapsed, 2)} | |
| async def transcribe_pcm( | |
| audio: UploadFile = File(...), | |
| sample_rate: int = Form(16000), | |
| pcm_format: Literal["float32", "int16"] = Form("float32"), | |
| ): | |
| """Preferred endpoint: accepts raw mono PCM at 16 kHz. The browser | |
| sends the bytes of a Float32Array directly — no ffmpeg, no encoder | |
| overhead, no transcoder lossiness. Per-segment latency is dominated | |
| by the model forward pass (~80–300 ms for typical sentence audio | |
| on the Space's CPU).""" | |
| contents = await audio.read() | |
| if len(contents) == 0: | |
| raise HTTPException(400, "Empty audio") | |
| if len(contents) > 32 * 1024 * 1024: | |
| raise HTTPException(413, "PCM too large (max 32 MB)") | |
| if pcm_format == "int16": | |
| samples = np.frombuffer(contents, dtype=np.int16).astype(np.float32) / 32768.0 | |
| else: | |
| samples = np.frombuffer(contents, dtype=np.float32).copy() | |
| t0 = time.monotonic() | |
| text = _samples_to_text(samples, sample_rate) | |
| elapsed = time.monotonic() - t0 | |
| logger.info("Transcribed (pcm, %d samples @%d Hz) in %.2fs: '%s'", | |
| samples.size, sample_rate, elapsed, text[:100]) | |
| return {"text": text, "duration_seconds": round(elapsed, 2), "samples": int(samples.size)} | |
| # --------------------------------------------------------------------------- | |
| # OpenAI Realtime ephemeral token minter (unchanged) | |
| # --------------------------------------------------------------------------- | |
| OPENAI_TRANSCRIPTION_PROMPT = ( | |
| "Medical radiology dictation. Common terms include: lungs, chest, CT, MRI, " | |
| "X-ray, ultrasound, contrast, lesion, mass, nodule, opacity, consolidation, " | |
| "effusion, pneumothorax, atelectasis, lymphadenopathy, hilar, mediastinal, " | |
| "pulmonary, parenchymal, abdomen, pelvis, liver, spleen, kidney, hydronephrosis, " | |
| "cyst, fracture, displacement, alignment, vertebrae, lumbar, thoracic, cervical, " | |
| "spine, brain, intracranial, hemorrhage, infarct, edema, stenosis, occlusion, " | |
| "calcification, enhancement." | |
| ) | |
| def openai_token(): | |
| api_key = os.environ.get("OPENAI_API_KEY") | |
| if not api_key: | |
| raise HTTPException(500, "OPENAI_API_KEY not configured on the Space") | |
| body = _json.dumps({ | |
| "input_audio_format": "pcm16", | |
| "input_audio_transcription": { | |
| "model": "gpt-4o-mini-transcribe", | |
| "language": "en", | |
| "prompt": OPENAI_TRANSCRIPTION_PROMPT, | |
| }, | |
| "input_audio_noise_reduction": {"type": "near_field"}, | |
| "turn_detection": { | |
| "type": "server_vad", | |
| "threshold": 0.4, | |
| "prefix_padding_ms": 200, | |
| "silence_duration_ms": 180, | |
| }, | |
| }).encode("utf-8") | |
| req = urllib.request.Request( | |
| "https://api.openai.com/v1/realtime/transcription_sessions", | |
| data=body, | |
| headers={ | |
| "Authorization": f"Bearer {api_key}", | |
| "Content-Type": "application/json", | |
| "OpenAI-Beta": "realtime=v1", | |
| }, | |
| method="POST", | |
| ) | |
| try: | |
| with urllib.request.urlopen(req, timeout=10) as resp: | |
| return _json.loads(resp.read().decode("utf-8")) | |
| except urllib.error.HTTPError as e: | |
| detail = e.read().decode("utf-8", errors="replace") | |
| raise HTTPException(e.code, f"OpenAI error: {detail}") | |
| except Exception as e: | |
| raise HTTPException(500, f"OpenAI request failed: {e}") | |
| def deepgram_token(): | |
| api_key = os.environ.get("DEEPGRAM_API_KEY") | |
| if not api_key: | |
| raise HTTPException(500, "DEEPGRAM_API_KEY not configured on the Space") | |
| body = _json.dumps({"ttl_seconds": 30}).encode("utf-8") | |
| req = urllib.request.Request( | |
| "https://api.deepgram.com/v1/auth/grant", | |
| data=body, | |
| headers={ | |
| "Authorization": f"Token {api_key}", | |
| "Content-Type": "application/json", | |
| }, | |
| method="POST", | |
| ) | |
| try: | |
| with urllib.request.urlopen(req, timeout=10) as resp: | |
| return _json.loads(resp.read().decode("utf-8")) | |
| except urllib.error.HTTPError as e: | |
| detail = e.read().decode("utf-8", errors="replace") | |
| raise HTTPException(e.code, f"Deepgram error: {detail}") | |
| except Exception as e: | |
| raise HTTPException(500, f"Deepgram request failed: {e}") | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |