mds04's picture
Update app.py
a7a6eef verified
import os, io, requests
import numpy as np
import soundfile as sf
import torch
import gradio as gr
import spaces
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
# ---------------- CONFIG ----------------
MODEL_CATALOG = {
"Iban (ASR)": {
"repo_id": "mds04/iban_transcription_model", # ← exact model repo ID
"language": "iban",
},
"Bukar Sadong (ASR)": {
"repo_id": "mds04/bukar_sadong_transcription", # ← exact model repo ID
"language": "bukar-sadong",
},
}
DEFAULT_MODEL = "Iban (ASR)"
DEFAULT_FORCE_LANG = True
DEFAULT_MAX_TOKENS = 256
# ----------------------------------------
HF_TOKEN = os.getenv("HF_TOKEN", None) # add secret if repos are private
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float32 # use float32 for Whisper stability on GPU Zero
# lazy cache so both models aren't loaded at once
_MODEL_CACHE: dict[str, tuple[AutoProcessor, AutoModelForSpeechSeq2Seq]] = {}
def _load_bundle(model_key: str):
"""Load processor+model for the selected language (cached)."""
if model_key in _MODEL_CACHE:
return _MODEL_CACHE[model_key]
info = MODEL_CATALOG[model_key]
proc = AutoProcessor.from_pretrained(info["repo_id"], token=HF_TOKEN)
mdl = AutoModelForSpeechSeq2Seq.from_pretrained(
info["repo_id"], token=HF_TOKEN, torch_dtype=dtype
).to(device).eval()
_MODEL_CACHE[model_key] = (proc, mdl)
print(f"Loaded model: {model_key}")
return _MODEL_CACHE[model_key]
def _resample_to_16k(x: np.ndarray, sr: int) -> np.ndarray:
"""Lightweight linear resampler to 16 kHz."""
if sr == 16000:
return x.astype(np.float32)
duration = x.shape[0] / sr
t_old = np.linspace(0, duration, num=x.shape[0], endpoint=False)
t_new = np.linspace(0, duration, num=int(duration * 16000), endpoint=False)
return np.interp(t_new, t_old, x).astype(np.float32)
def _read_audio_bytes(path_or_url: str) -> bytes:
"""Accept local path or remote URL."""
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
r = requests.get(path_or_url, timeout=30)
r.raise_for_status()
return r.content
with open(path_or_url, "rb") as f:
return f.read()
def _load_audio_16k(input_obj) -> np.ndarray:
"""Normalize any Gradio/URL/local input to mono 16 kHz float32 array."""
if isinstance(input_obj, dict) and "path" in input_obj:
path_or_url = input_obj["path"]
elif isinstance(input_obj, str):
path_or_url = input_obj
else:
raise ValueError("Unsupported audio input format")
raw = _read_audio_bytes(path_or_url)
data, sr = sf.read(io.BytesIO(raw))
if data.ndim == 2:
data = data.mean(axis=1)
return _resample_to_16k(data, sr)
# ----------- GPU ZERO HANDLER -----------
@spaces.GPU
def transcribe(model_choice, audio_input, force_lang, max_tokens):
"""Main ASR inference function."""
if not audio_input:
return "Please upload or record audio."
processor, model = _load_bundle(model_choice)
audio = _load_audio_16k(audio_input)
inputs = processor(audio=audio, sampling_rate=16000, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
gen_kwargs = dict(max_new_tokens=int(max_tokens), do_sample=False)
# Force language if Whisper-style processor supports it
if force_lang and hasattr(processor, "get_decoder_prompt_ids"):
try:
lang = MODEL_CATALOG[model_choice]["language"]
gen_kwargs["forced_decoder_ids"] = processor.get_decoder_prompt_ids(
language=lang, task="transcribe"
)
except Exception as e:
print("Language forcing skipped:", e)
with torch.no_grad():
ids = model.generate(**inputs, **gen_kwargs)
text = processor.batch_decode(ids, skip_special_tokens=True)[0]
return text
# ---------------------------------------
# ---------- BUILD GRADIO UI ------------
with gr.Blocks(title="Iban & Bukar Sadong ASR") as demo:
gr.Markdown("## Iban & Bukar Sadong Transcription\nUpload or record audio, choose model, and transcribe.")
model_choice = gr.Dropdown(
choices=list(MODEL_CATALOG.keys()),
value=DEFAULT_MODEL,
label="Model"
)
audio_in = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio")
with gr.Row():
force_lang = gr.Checkbox(
value=DEFAULT_FORCE_LANG,
label="Force model’s language prompt"
)
max_tokens = gr.Slider(
64, 512, value=DEFAULT_MAX_TOKENS, step=16,
label="Max new tokens"
)
out = gr.Textbox(label="Transcription", lines=4)
btn = gr.Button("Transcribe")
btn.click(
fn=transcribe,
inputs=[model_choice, audio_in, force_lang, max_tokens],
outputs=out
)
demo.queue()
if __name__ == "__main__":
demo.launch()