Spaces:
Running on Zero
Running on Zero
| """ | |
| NaijaMedModel v5 β ZeroGPU (H200) Gradio Space | |
| Lazy model loading: nothing heavy lives in CPU RAM at startup. | |
| Models are loaded directly onto GPU when the first request arrives, | |
| then kept in GPU RAM for subsequent requests (cached in module-level dicts). | |
| """ | |
| import sys | |
| # Patch torchaudio's broken CUDA loader before anything imports it. | |
| # On ZeroGPU the .so exists but libcudart is absent at startup, so we | |
| # intercept torch.ops.load_library to silently swallow that one OSError. | |
| import torch, sys | |
| # Patch torch.load for pyannote compatibility | |
| _orig_load = torch.load | |
| def _patched_load(f, *args, **kwargs): | |
| kwargs["weights_only"] = False | |
| return _orig_load(f, *args, **kwargs) | |
| torch.load = _patched_load | |
| try: | |
| from huggingface_hub.utils import _validators as _hf_v | |
| _orig_smooth = _hf_v.smoothly_deprecate_legacy_arguments | |
| def _patched_smooth(fn_name, kwargs): | |
| if "use_auth_token" in kwargs: | |
| v = kwargs.pop("use_auth_token") | |
| if "token" not in kwargs and v is not None: | |
| kwargs["token"] = v | |
| return _orig_smooth(fn_name, kwargs) | |
| _hf_v.smoothly_deprecate_legacy_arguments = _patched_smooth | |
| except Exception: | |
| pass | |
| # ββ Standard imports βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| import os, tempfile | |
| import numpy as np | |
| import soundfile as sf | |
| from scipy.signal import resample | |
| import gradio as gr | |
| import spaces # must come after torch | |
| # ββ Config βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ASR_EN_MODEL = "Ephraimmm/asrfinetuned" | |
| ASR_YO_MODEL = "NCAIR1/Yoruba-ASR" | |
| WHISPER_MODEL = "openai/whisper-large-v3" | |
| TRANSLATE_MODEL = "Helsinki-NLP/opus-mt-yo-en" | |
| DIAR_MODEL = "pyannote/speaker-diarization-3.1" | |
| SOAP_MODEL = "Edifon/SOAP_SFT_V1" | |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
| # ββ Lazy model cache β populated on first GPU call βββββββββββββββββββββββββ | |
| # Nothing is loaded here. Imports of heavy libraries are also deferred. | |
| _models = {} | |
| def _get_models(): | |
| """ | |
| Load all models directly onto GPU the first time this is called. | |
| Must only be called from inside a @spaces.GPU function where | |
| libcudart and torchaudio CUDA extensions are available. | |
| """ | |
| if _models: | |
| return _models | |
| from transformers import ( | |
| pipeline as hf_pipeline, | |
| AutoProcessor, | |
| AutoModelForSpeechSeq2Seq, | |
| MarianMTModel, | |
| MarianTokenizer, | |
| AutoTokenizer, | |
| ) | |
| try: | |
| from transformers import AutoModelForCausalLM | |
| except ImportError: | |
| pass | |
| from pyannote.audio import Pipeline as DiarizationPipeline | |
| import whisper as _whisper | |
| print("β³ Loading English ASR β GPUβ¦") | |
| _models["asr_en"] = hf_pipeline( | |
| "automatic-speech-recognition", | |
| model=ASR_EN_MODEL, | |
| device="cuda", | |
| token=HF_TOKEN, | |
| ) | |
| print("β³ Loading Yoruba ASR β GPUβ¦") | |
| _models["yo_processor"] = AutoProcessor.from_pretrained(ASR_YO_MODEL) | |
| _models["yo_model"] = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| ASR_YO_MODEL, torch_dtype=torch.float16 | |
| ).to("cuda") | |
| print("β³ Loading Whisper large-v3 β GPUβ¦") | |
| _models["whisper_pipe"] = hf_pipeline( | |
| "automatic-speech-recognition", | |
| model=WHISPER_MODEL, | |
| torch_dtype=torch.float16, | |
| device="cuda", | |
| generate_kwargs={"task": "translate", "language": None}, | |
| ) | |
| print("β³ Loading YorubaβEnglish MT β GPUβ¦") | |
| try: | |
| _models["mt_tokenizer"] = MarianTokenizer.from_pretrained(TRANSLATE_MODEL) | |
| _models["mt_model"] = MarianMTModel.from_pretrained( | |
| TRANSLATE_MODEL, torch_dtype=torch.float16 | |
| ).to("cuda") | |
| except Exception as e: | |
| print(f"β οΈ MT skipped: {e}") | |
| _models["mt_tokenizer"] = None | |
| _models["mt_model"] = None | |
| print("β³ Loading diarization β GPUβ¦") | |
| diar = DiarizationPipeline.from_pretrained(DIAR_MODEL, use_auth_token=HF_TOKEN) | |
| _models["diar"] = diar.to(torch.device("cuda")) | |
| print("β³ Loading Whisper tiny lang detector β GPUβ¦") | |
| _models["lang_model"] = _whisper.load_model("tiny", device="cuda") | |
| print("β³ Loading SOAP model β GPU (bfloat16)β¦") | |
| _models["soap_processor"] = AutoTokenizer.from_pretrained(SOAP_MODEL, token=HF_TOKEN) | |
| _models["soap_model"] = AutoModelForCausalLM.from_pretrained( | |
| SOAP_MODEL, | |
| torch_dtype=torch.bfloat16, | |
| device_map="cuda", | |
| token=HF_TOKEN, | |
| ) | |
| _models["whisper_lib"] = _whisper | |
| print("β All models loaded on GPU.") | |
| return _models | |
| # ββ Helper βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _get_chunk(arr, start, end, sr=16000): | |
| return arr[int(start * sr): int(end * sr)] | |
| # ββ Single GPU function that does everything βββββββββββββββββββββββββββββββ | |
| # One @spaces.GPU call per user request avoids the overhead of acquiring | |
| # the GPU slot multiple times and keeps all models in VRAM together. | |
| # allow up to 5 min for long consultations | |
| def run_pipeline(wav_path: str): | |
| m = _get_models() | |
| _whisper = m["whisper_lib"] | |
| # ββ Diarize ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| diar_result = m["diar"](wav_path) | |
| segs = [ | |
| {"start": t.start, "end": t.end, "speaker": spk} | |
| for t, _, spk in diar_result.itertracks(yield_label=True) | |
| ] | |
| if not segs: | |
| return "No speakers detected.", "" | |
| arr, _ = sf.read(wav_path) | |
| arr = arr.astype(np.float32) | |
| # ββ Transcribe each segment ββββββββββββββββββββββββββββββββββββββββ | |
| transcript = [] | |
| for seg in segs: | |
| chunk = _get_chunk(arr, seg["start"], seg["end"]) | |
| if len(chunk) < 16000 * 0.5: | |
| continue | |
| # Language detection | |
| padded = _whisper.pad_or_trim(chunk) | |
| mel = _whisper.log_mel_spectrogram(padded).to("cuda") | |
| _, probs = m["lang_model"].detect_language(mel) | |
| lang = max(probs, key=probs.get) | |
| if lang == "yo": | |
| # Yoruba ASR β MarianMT translation | |
| inputs = m["yo_processor"]( | |
| chunk, sampling_rate=16000, return_tensors="pt" | |
| ) | |
| inputs = {k: v.to("cuda", dtype=torch.float16) if v.dtype == torch.float32 else v.to("cuda") for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| ids = m["yo_model"].generate(**inputs) | |
| yo_text = m["yo_processor"].batch_decode( | |
| ids, skip_special_tokens=True | |
| )[0].strip() | |
| if m["mt_model"]: | |
| tokens = m["mt_tokenizer"]([yo_text], return_tensors="pt", padding=True) | |
| tokens = {k: v.to("cuda") for k, v in tokens.items()} | |
| with torch.no_grad(): | |
| out = m["mt_model"].generate(**tokens) | |
| en_text = m["mt_tokenizer"].batch_decode( | |
| out, skip_special_tokens=True | |
| )[0].strip() | |
| else: | |
| en_text = _run_whisper_translate(m, chunk) | |
| transcript.append({**seg, | |
| "language":"yo", "original_text":yo_text, | |
| "text":en_text, "translated":True}) | |
| elif lang != "en": | |
| en_text = _run_whisper_translate(m, chunk) | |
| transcript.append({**seg, | |
| "language":lang, "original_text":en_text, | |
| "text":en_text, "translated":True}) | |
| else: | |
| # English ASR | |
| fe = m["asr_en"].feature_extractor | |
| inputs = fe(chunk, sampling_rate=16000, return_tensors="pt", | |
| truncation=True, return_attention_mask=True) | |
| inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| ids = m["asr_en"].model.generate( | |
| inputs["input_features"], | |
| attention_mask=inputs["attention_mask"], | |
| generation_config=m["asr_en"].model.generation_config, | |
| language="english", | |
| ) | |
| en_text = m["asr_en"].tokenizer.batch_decode( | |
| ids, skip_special_tokens=True | |
| )[0].strip() | |
| transcript.append({**seg, | |
| "language":"en", "original_text":en_text, | |
| "text":en_text, "translated":False}) | |
| if not transcript: | |
| return "No speech transcribed.", "" | |
| # ββ Format transcript ββββββββββββββββββββββββββββββββββββββββββββββ | |
| sp_ids = sorted({s["speaker"] for s in transcript}) | |
| sp_map = { | |
| spk: (["Doctor","Patient"][i] if i < 2 else f"Speaker {i+1}") | |
| for i, spk in enumerate(sp_ids) | |
| } | |
| lines, prev = [], None | |
| for seg in transcript: | |
| label = sp_map[seg["speaker"]] | |
| flag = " β" if seg["translated"] else "" | |
| if label != prev: | |
| lines.append(f"\n{label}:\n {seg['text']}{flag}") | |
| else: | |
| lines.append(f" {seg['text']}{flag}") | |
| prev = label | |
| transcript_text = ( | |
| "\n".join(lines) | |
| + "\n\nβ = translated segment (clinician review recommended)" | |
| ) | |
| # ββ SOAP note ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| flat = "\n".join( | |
| f"{sp_map[s['speaker']]}: {s['text']}" for s in transcript | |
| ) | |
| tok = m["soap_processor"] # this is now AutoTokenizer | |
| msgs = [ | |
| {"role": "system", "content": ( | |
| "You are an expert medical professor assisting in the creation of " | |
| "medically accurate SOAP summaries. Please ensure the response follows " | |
| "the structured format: S:, O:, A:, P: without using markdown or special formatting." | |
| )}, | |
| {"role": "user", "content": | |
| f"Create a medical SOAP summary of this dialogue.\n### Dialogue:\n{flat}\n" | |
| }, | |
| ] | |
| encoded = tok.apply_chat_template( | |
| msgs, add_generation_prompt=True, return_tensors="pt" | |
| ) | |
| # Force to a plain tensor regardless of what apply_chat_template returns | |
| if hasattr(encoded, "input_ids"): | |
| input_ids = encoded.input_ids.to("cuda") | |
| elif isinstance(encoded, dict): | |
| input_ids = encoded["input_ids"].to("cuda") | |
| else: | |
| input_ids = torch.tensor(encoded).unsqueeze(0).to("cuda") | |
| out_ids = m["soap_model"].generate( | |
| input_ids, max_new_tokens=2048, do_sample=False | |
| ) | |
| soap = tok.decode(out_ids[0][input_ids.shape[1]:], skip_special_tokens=True) | |
| # Format as clean markdown | |
| import re | |
| def to_markdown(text): | |
| sections = {"S:": "## π¦ Subjective", "O:": "## π© Objective", | |
| "A:": "## π§ Assessment", "P:": "## π₯ Plan"} | |
| for tag, heading in sections.items(): | |
| text = re.sub(rf"(?i){re.escape(tag)}", f"\n\n{heading}\n", text) | |
| return text.strip() | |
| return to_markdown(soap) | |
| def _run_whisper_translate(m, chunk): | |
| mdl = m["whisper_pipe"].model | |
| proc = m["whisper_pipe"].feature_extractor | |
| tok = m["whisper_pipe"].tokenizer | |
| inp = proc(chunk, sampling_rate=16000, return_tensors="pt").input_features | |
| inp = inp.to("cuda", dtype=torch.float16) | |
| with torch.no_grad(): | |
| ids = mdl.generate(inp, task="translate", language=None, return_timestamps=False) | |
| return tok.batch_decode(ids, skip_special_tokens=True)[0].strip() | |
| # ββ Top-level Gradio handler βββββββββββββββββββββββββββββββββββββββββββββββ | |
| def process_audio(audio_filepath): | |
| if audio_filepath is None: | |
| return "Please upload an audio file." | |
| tmp_wav = tempfile.mktemp(suffix=".wav") | |
| try: | |
| arr, sr = sf.read(audio_filepath, always_2d=False) | |
| if sr != 16000: | |
| arr = resample(arr, int(len(arr) * 16000 / sr)) | |
| if arr.ndim > 1: | |
| arr = arr.mean(axis=1) | |
| arr = arr.astype(np.float32) | |
| sf.write(tmp_wav, arr, 16000) | |
| except Exception: | |
| import subprocess | |
| subprocess.run( | |
| ["ffmpeg","-i",audio_filepath,"-ar","16000","-ac","1", | |
| tmp_wav,"-y","-loglevel","quiet"], | |
| check=True, | |
| ) | |
| try: | |
| result = run_pipeline(tmp_wav) | |
| finally: | |
| try: | |
| os.remove(tmp_wav) | |
| except Exception: | |
| pass | |
| return result | |
| # ββ Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks( | |
| title="NaijaMedModel v5", | |
| theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), | |
| css=""" | |
| .gradio-container { max-width: 860px !important; margin: auto; } | |
| #upload-box { border: 2px dashed #4A90D9; border-radius: 12px; } | |
| #soap-out { font-size: 15px; line-height: 1.7; } | |
| footer { display: none !important; } | |
| """ | |
| ) as demo: | |
| gr.Markdown(""" | |
| # π₯ NaijaMedModel v5 | |
| ### Bilingual Clinical ASR Β· English + Yoruba β SOAP Note | |
| Upload a doctor-patient consultation recording. Yoruba and mixed segments are | |
| automatically detected, translated, and summarised into a structured SOAP note. | |
| """) | |
| with gr.Row(): | |
| audio_in = gr.Audio( | |
| type="filepath", | |
| label="Upload Consultation Audio (MP3 / WAV / AAC / M4A)", | |
| elem_id="upload-box" | |
| ) | |
| btn = gr.Button("βοΈ Generate SOAP Note", variant="primary", size="lg") | |
| soap_out = gr.Markdown(elem_id="soap-out", value="*Your SOAP note will appear here after processing.*") | |
| gr.Markdown(""" | |
| --- | |
| <center><sub>β± First request takes ~2 min while models load Β· | |
| Built with pyannote Β· Whisper Β· MarianMT Β· ZeroGPU</sub></center> | |
| """) | |
| btn.click(fn=process_audio, inputs=audio_in, outputs=soap_out) | |
| if __name__ == "__main__": | |
| demo.launch() |