Spaces:
Runtime error
Runtime error
| import os, io, requests | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| import gradio as gr | |
| import spaces | |
| from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq | |
| # ---------------- CONFIG ---------------- | |
| MODEL_CATALOG = { | |
| "Iban (ASR)": { | |
| "repo_id": "mds04/iban_transcription_model", # ← exact model repo ID | |
| "language": "iban", | |
| }, | |
| "Bukar Sadong (ASR)": { | |
| "repo_id": "mds04/bukar_sadong_transcription", # ← exact model repo ID | |
| "language": "bukar-sadong", | |
| }, | |
| } | |
| DEFAULT_MODEL = "Iban (ASR)" | |
| DEFAULT_FORCE_LANG = True | |
| DEFAULT_MAX_TOKENS = 256 | |
| # ---------------------------------------- | |
| HF_TOKEN = os.getenv("HF_TOKEN", None) # add secret if repos are private | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float32 # use float32 for Whisper stability on GPU Zero | |
| # lazy cache so both models aren't loaded at once | |
| _MODEL_CACHE: dict[str, tuple[AutoProcessor, AutoModelForSpeechSeq2Seq]] = {} | |
| def _load_bundle(model_key: str): | |
| """Load processor+model for the selected language (cached).""" | |
| if model_key in _MODEL_CACHE: | |
| return _MODEL_CACHE[model_key] | |
| info = MODEL_CATALOG[model_key] | |
| proc = AutoProcessor.from_pretrained(info["repo_id"], token=HF_TOKEN) | |
| mdl = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| info["repo_id"], token=HF_TOKEN, torch_dtype=dtype | |
| ).to(device).eval() | |
| _MODEL_CACHE[model_key] = (proc, mdl) | |
| print(f"Loaded model: {model_key}") | |
| return _MODEL_CACHE[model_key] | |
| def _resample_to_16k(x: np.ndarray, sr: int) -> np.ndarray: | |
| """Lightweight linear resampler to 16 kHz.""" | |
| if sr == 16000: | |
| return x.astype(np.float32) | |
| duration = x.shape[0] / sr | |
| t_old = np.linspace(0, duration, num=x.shape[0], endpoint=False) | |
| t_new = np.linspace(0, duration, num=int(duration * 16000), endpoint=False) | |
| return np.interp(t_new, t_old, x).astype(np.float32) | |
| def _read_audio_bytes(path_or_url: str) -> bytes: | |
| """Accept local path or remote URL.""" | |
| if path_or_url.startswith("http://") or path_or_url.startswith("https://"): | |
| r = requests.get(path_or_url, timeout=30) | |
| r.raise_for_status() | |
| return r.content | |
| with open(path_or_url, "rb") as f: | |
| return f.read() | |
| def _load_audio_16k(input_obj) -> np.ndarray: | |
| """Normalize any Gradio/URL/local input to mono 16 kHz float32 array.""" | |
| if isinstance(input_obj, dict) and "path" in input_obj: | |
| path_or_url = input_obj["path"] | |
| elif isinstance(input_obj, str): | |
| path_or_url = input_obj | |
| else: | |
| raise ValueError("Unsupported audio input format") | |
| raw = _read_audio_bytes(path_or_url) | |
| data, sr = sf.read(io.BytesIO(raw)) | |
| if data.ndim == 2: | |
| data = data.mean(axis=1) | |
| return _resample_to_16k(data, sr) | |
| # ----------- GPU ZERO HANDLER ----------- | |
| def transcribe(model_choice, audio_input, force_lang, max_tokens): | |
| """Main ASR inference function.""" | |
| if not audio_input: | |
| return "Please upload or record audio." | |
| processor, model = _load_bundle(model_choice) | |
| audio = _load_audio_16k(audio_input) | |
| inputs = processor(audio=audio, sampling_rate=16000, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| gen_kwargs = dict(max_new_tokens=int(max_tokens), do_sample=False) | |
| # Force language if Whisper-style processor supports it | |
| if force_lang and hasattr(processor, "get_decoder_prompt_ids"): | |
| try: | |
| lang = MODEL_CATALOG[model_choice]["language"] | |
| gen_kwargs["forced_decoder_ids"] = processor.get_decoder_prompt_ids( | |
| language=lang, task="transcribe" | |
| ) | |
| except Exception as e: | |
| print("Language forcing skipped:", e) | |
| with torch.no_grad(): | |
| ids = model.generate(**inputs, **gen_kwargs) | |
| text = processor.batch_decode(ids, skip_special_tokens=True)[0] | |
| return text | |
| # --------------------------------------- | |
| # ---------- BUILD GRADIO UI ------------ | |
| with gr.Blocks(title="Iban & Bukar Sadong ASR") as demo: | |
| gr.Markdown("## Iban & Bukar Sadong Transcription\nUpload or record audio, choose model, and transcribe.") | |
| model_choice = gr.Dropdown( | |
| choices=list(MODEL_CATALOG.keys()), | |
| value=DEFAULT_MODEL, | |
| label="Model" | |
| ) | |
| audio_in = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio") | |
| with gr.Row(): | |
| force_lang = gr.Checkbox( | |
| value=DEFAULT_FORCE_LANG, | |
| label="Force model’s language prompt" | |
| ) | |
| max_tokens = gr.Slider( | |
| 64, 512, value=DEFAULT_MAX_TOKENS, step=16, | |
| label="Max new tokens" | |
| ) | |
| out = gr.Textbox(label="Transcription", lines=4) | |
| btn = gr.Button("Transcribe") | |
| btn.click( | |
| fn=transcribe, | |
| inputs=[model_choice, audio_in, force_lang, max_tokens], | |
| outputs=out | |
| ) | |
| demo.queue() | |
| if __name__ == "__main__": | |
| demo.launch() | |