Naija_MedModel / app(stable2).py
Ephraimmm's picture
Rename app.py to app(stable2).py
cf3f184 verified
"""
NaijaMedModel v5 β€” ZeroGPU (H200) Gradio Space
Lazy model loading: nothing heavy lives in CPU RAM at startup.
Models are loaded directly onto GPU when the first request arrives,
then kept in GPU RAM for subsequent requests (cached in module-level dicts).
"""
import sys
# Patch torchaudio's broken CUDA loader before anything imports it.
# On ZeroGPU the .so exists but libcudart is absent at startup, so we
# intercept torch.ops.load_library to silently swallow that one OSError.
import torch, sys
# Patch torch.load for pyannote compatibility
_orig_load = torch.load
def _patched_load(f, *args, **kwargs):
kwargs["weights_only"] = False
return _orig_load(f, *args, **kwargs)
torch.load = _patched_load
try:
from huggingface_hub.utils import _validators as _hf_v
_orig_smooth = _hf_v.smoothly_deprecate_legacy_arguments
def _patched_smooth(fn_name, kwargs):
if "use_auth_token" in kwargs:
v = kwargs.pop("use_auth_token")
if "token" not in kwargs and v is not None:
kwargs["token"] = v
return _orig_smooth(fn_name, kwargs)
_hf_v.smoothly_deprecate_legacy_arguments = _patched_smooth
except Exception:
pass
# ── Standard imports ───────────────────────────────────────────────────────
import os, tempfile
import numpy as np
import soundfile as sf
from scipy.signal import resample
import gradio as gr
import spaces # must come after torch
# ── Config ─────────────────────────────────────────────────────────────────
ASR_EN_MODEL = "Ephraimmm/asrfinetuned"
ASR_YO_MODEL = "NCAIR1/Yoruba-ASR"
WHISPER_MODEL = "openai/whisper-large-v3"
TRANSLATE_MODEL = "Helsinki-NLP/opus-mt-yo-en"
DIAR_MODEL = "pyannote/speaker-diarization-3.1"
SOAP_MODEL = "Edifon/SOAP_SFT_V1"
HF_TOKEN = os.environ.get("HF_TOKEN", "")
# ── Lazy model cache β€” populated on first GPU call ─────────────────────────
# Nothing is loaded here. Imports of heavy libraries are also deferred.
_models = {}
def _get_models():
"""
Load all models directly onto GPU the first time this is called.
Must only be called from inside a @spaces.GPU function where
libcudart and torchaudio CUDA extensions are available.
"""
if _models:
return _models
from transformers import (
pipeline as hf_pipeline,
AutoProcessor,
AutoModelForSpeechSeq2Seq,
MarianMTModel,
MarianTokenizer,
AutoTokenizer,
)
try:
from transformers import AutoModelForCausalLM
except ImportError:
pass
from pyannote.audio import Pipeline as DiarizationPipeline
import whisper as _whisper
print("⏳ Loading English ASR β†’ GPU…")
_models["asr_en"] = hf_pipeline(
"automatic-speech-recognition",
model=ASR_EN_MODEL,
device="cuda",
token=HF_TOKEN,
)
print("⏳ Loading Yoruba ASR β†’ GPU…")
_models["yo_processor"] = AutoProcessor.from_pretrained(ASR_YO_MODEL)
_models["yo_model"] = AutoModelForSpeechSeq2Seq.from_pretrained(
ASR_YO_MODEL, torch_dtype=torch.float16
).to("cuda")
print("⏳ Loading Whisper large-v3 β†’ GPU…")
_models["whisper_pipe"] = hf_pipeline(
"automatic-speech-recognition",
model=WHISPER_MODEL,
torch_dtype=torch.float16,
device="cuda",
generate_kwargs={"task": "translate", "language": None},
)
print("⏳ Loading Yorubaβ†’English MT β†’ GPU…")
try:
_models["mt_tokenizer"] = MarianTokenizer.from_pretrained(TRANSLATE_MODEL)
_models["mt_model"] = MarianMTModel.from_pretrained(
TRANSLATE_MODEL, torch_dtype=torch.float16
).to("cuda")
except Exception as e:
print(f"⚠️ MT skipped: {e}")
_models["mt_tokenizer"] = None
_models["mt_model"] = None
print("⏳ Loading diarization β†’ GPU…")
diar = DiarizationPipeline.from_pretrained(DIAR_MODEL, use_auth_token=HF_TOKEN)
_models["diar"] = diar.to(torch.device("cuda"))
print("⏳ Loading Whisper tiny lang detector β†’ GPU…")
_models["lang_model"] = _whisper.load_model("tiny", device="cuda")
print("⏳ Loading SOAP model β†’ GPU (bfloat16)…")
_models["soap_processor"] = AutoTokenizer.from_pretrained(SOAP_MODEL, token=HF_TOKEN)
_models["soap_model"] = AutoModelForCausalLM.from_pretrained(
SOAP_MODEL,
torch_dtype=torch.bfloat16,
device_map="cuda",
token=HF_TOKEN,
)
_models["whisper_lib"] = _whisper
print("βœ… All models loaded on GPU.")
return _models
# ── Helper ─────────────────────────────────────────────────────────────────
def _get_chunk(arr, start, end, sr=16000):
return arr[int(start * sr): int(end * sr)]
# ── Single GPU function that does everything ───────────────────────────────
# One @spaces.GPU call per user request avoids the overhead of acquiring
# the GPU slot multiple times and keeps all models in VRAM together.
@spaces.GPU(duration=300) # allow up to 5 min for long consultations
def run_pipeline(wav_path: str):
m = _get_models()
_whisper = m["whisper_lib"]
# ── Diarize ────────────────────────────────────────────────────────
diar_result = m["diar"](wav_path)
segs = [
{"start": t.start, "end": t.end, "speaker": spk}
for t, _, spk in diar_result.itertracks(yield_label=True)
]
if not segs:
return "No speakers detected.", ""
arr, _ = sf.read(wav_path)
arr = arr.astype(np.float32)
# ── Transcribe each segment ────────────────────────────────────────
transcript = []
for seg in segs:
chunk = _get_chunk(arr, seg["start"], seg["end"])
if len(chunk) < 16000 * 0.5:
continue
# Language detection
padded = _whisper.pad_or_trim(chunk)
mel = _whisper.log_mel_spectrogram(padded).to("cuda")
_, probs = m["lang_model"].detect_language(mel)
lang = max(probs, key=probs.get)
if lang == "yo":
# Yoruba ASR β†’ MarianMT translation
inputs = m["yo_processor"](
chunk, sampling_rate=16000, return_tensors="pt"
)
inputs = {k: v.to("cuda", dtype=torch.float16) if v.dtype == torch.float32 else v.to("cuda") for k, v in inputs.items()}
with torch.no_grad():
ids = m["yo_model"].generate(**inputs)
yo_text = m["yo_processor"].batch_decode(
ids, skip_special_tokens=True
)[0].strip()
if m["mt_model"]:
tokens = m["mt_tokenizer"]([yo_text], return_tensors="pt", padding=True)
tokens = {k: v.to("cuda") for k, v in tokens.items()}
with torch.no_grad():
out = m["mt_model"].generate(**tokens)
en_text = m["mt_tokenizer"].batch_decode(
out, skip_special_tokens=True
)[0].strip()
else:
en_text = _run_whisper_translate(m, chunk)
transcript.append({**seg,
"language":"yo", "original_text":yo_text,
"text":en_text, "translated":True})
elif lang != "en":
en_text = _run_whisper_translate(m, chunk)
transcript.append({**seg,
"language":lang, "original_text":en_text,
"text":en_text, "translated":True})
else:
# English ASR
fe = m["asr_en"].feature_extractor
inputs = fe(chunk, sampling_rate=16000, return_tensors="pt",
truncation=True, return_attention_mask=True)
inputs = {k: v.to("cuda") for k, v in inputs.items()}
with torch.no_grad():
ids = m["asr_en"].model.generate(
inputs["input_features"],
attention_mask=inputs["attention_mask"],
generation_config=m["asr_en"].model.generation_config,
language="english",
)
en_text = m["asr_en"].tokenizer.batch_decode(
ids, skip_special_tokens=True
)[0].strip()
transcript.append({**seg,
"language":"en", "original_text":en_text,
"text":en_text, "translated":False})
if not transcript:
return "No speech transcribed.", ""
# ── Format transcript ──────────────────────────────────────────────
sp_ids = sorted({s["speaker"] for s in transcript})
sp_map = {
spk: (["Doctor","Patient"][i] if i < 2 else f"Speaker {i+1}")
for i, spk in enumerate(sp_ids)
}
lines, prev = [], None
for seg in transcript:
label = sp_map[seg["speaker"]]
flag = " βš‘" if seg["translated"] else ""
if label != prev:
lines.append(f"\n{label}:\n {seg['text']}{flag}")
else:
lines.append(f" {seg['text']}{flag}")
prev = label
transcript_text = (
"\n".join(lines)
+ "\n\nβš‘ = translated segment (clinician review recommended)"
)
# ── SOAP note ──────────────────────────────────────────────────────
flat = "\n".join(
f"{sp_map[s['speaker']]}: {s['text']}" for s in transcript
)
tok = m["soap_processor"] # this is now AutoTokenizer
msgs = [
{"role": "system", "content": (
"You are an expert medical professor assisting in the creation of "
"medically accurate SOAP summaries. Please ensure the response follows "
"the structured format: S:, O:, A:, P: without using markdown or special formatting."
)},
{"role": "user", "content":
f"Create a medical SOAP summary of this dialogue.\n### Dialogue:\n{flat}\n"
},
]
encoded = tok.apply_chat_template(
msgs, add_generation_prompt=True, return_tensors="pt"
)
# Force to a plain tensor regardless of what apply_chat_template returns
if hasattr(encoded, "input_ids"):
input_ids = encoded.input_ids.to("cuda")
elif isinstance(encoded, dict):
input_ids = encoded["input_ids"].to("cuda")
else:
input_ids = torch.tensor(encoded).unsqueeze(0).to("cuda")
out_ids = m["soap_model"].generate(
input_ids, max_new_tokens=2048, do_sample=False
)
soap = tok.decode(out_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
# Format as clean markdown
import re
def to_markdown(text):
sections = {"S:": "## 🟦 Subjective", "O:": "## 🟩 Objective",
"A:": "## 🟧 Assessment", "P:": "## πŸŸ₯ Plan"}
for tag, heading in sections.items():
text = re.sub(rf"(?i){re.escape(tag)}", f"\n\n{heading}\n", text)
return text.strip()
return to_markdown(soap)
def _run_whisper_translate(m, chunk):
mdl = m["whisper_pipe"].model
proc = m["whisper_pipe"].feature_extractor
tok = m["whisper_pipe"].tokenizer
inp = proc(chunk, sampling_rate=16000, return_tensors="pt").input_features
inp = inp.to("cuda", dtype=torch.float16)
with torch.no_grad():
ids = mdl.generate(inp, task="translate", language=None, return_timestamps=False)
return tok.batch_decode(ids, skip_special_tokens=True)[0].strip()
# ── Top-level Gradio handler ───────────────────────────────────────────────
def process_audio(audio_filepath):
if audio_filepath is None:
return "Please upload an audio file."
tmp_wav = tempfile.mktemp(suffix=".wav")
try:
arr, sr = sf.read(audio_filepath, always_2d=False)
if sr != 16000:
arr = resample(arr, int(len(arr) * 16000 / sr))
if arr.ndim > 1:
arr = arr.mean(axis=1)
arr = arr.astype(np.float32)
sf.write(tmp_wav, arr, 16000)
except Exception:
import subprocess
subprocess.run(
["ffmpeg","-i",audio_filepath,"-ar","16000","-ac","1",
tmp_wav,"-y","-loglevel","quiet"],
check=True,
)
try:
result = run_pipeline(tmp_wav)
finally:
try:
os.remove(tmp_wav)
except Exception:
pass
return result
# ── Gradio UI ──────────────────────────────────────────────────────────────
with gr.Blocks(
title="NaijaMedModel v5",
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"),
css="""
.gradio-container { max-width: 860px !important; margin: auto; }
#upload-box { border: 2px dashed #4A90D9; border-radius: 12px; }
#soap-out { font-size: 15px; line-height: 1.7; }
footer { display: none !important; }
"""
) as demo:
gr.Markdown("""
# πŸ₯ NaijaMedModel v5
### Bilingual Clinical ASR Β· English + Yoruba β†’ SOAP Note
Upload a doctor-patient consultation recording. Yoruba and mixed segments are
automatically detected, translated, and summarised into a structured SOAP note.
""")
with gr.Row():
audio_in = gr.Audio(
type="filepath",
label="Upload Consultation Audio (MP3 / WAV / AAC / M4A)",
elem_id="upload-box"
)
btn = gr.Button("βš•οΈ Generate SOAP Note", variant="primary", size="lg")
soap_out = gr.Markdown(elem_id="soap-out", value="*Your SOAP note will appear here after processing.*")
gr.Markdown("""
---
<center><sub>⏱ First request takes ~2 min while models load ·
Built with pyannote Β· Whisper Β· MarianMT Β· ZeroGPU</sub></center>
""")
btn.click(fn=process_audio, inputs=audio_in, outputs=soap_out)
if __name__ == "__main__":
demo.launch()