spaceb / app.py
BissakaAI's picture
Update app.py
bd45f35 verified
import os
import torch
import gradio as gr
import librosa
import numpy as np
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
# ----------------------------
# Config
# ----------------------------
ASR_MODEL_ID = "openai/whisper-large-v3"
HF_TOKEN = os.getenv("HF_TOKEN")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
# ----------------------------
# Load Whisper
# ----------------------------
processor = AutoProcessor.from_pretrained(
ASR_MODEL_ID,
use_auth_token=HF_TOKEN
)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
ASR_MODEL_ID,
torch_dtype=DTYPE,
low_cpu_mem_usage=True,
use_safetensors=True,
use_auth_token=HF_TOKEN
).to(DEVICE)
model.eval()
# ----------------------------
# Audio preprocessing
# ----------------------------
def preprocess_audio(audio):
if audio is None:
return None
sr, speech = audio
# Stereo β†’ mono
if speech.ndim > 1:
speech = np.mean(speech, axis=1)
# Convert to float32
speech = speech.astype(np.float32)
# Normalize volume
rms = np.sqrt(np.mean(speech ** 2))
if rms > 0:
speech = speech / rms
# Trim silence
speech, _ = librosa.effects.trim(speech, top_db=25)
# Force 16kHz
if sr != 16000:
speech = librosa.resample(speech, orig_sr=sr, target_sr=16000).astype(np.float32)
return speech
# ----------------------------
# Transcription
# ----------------------------
def transcribe_audio(audio):
speech = preprocess_audio(audio)
if speech is None or len(speech) < 16000:
return "Audio too short or unclear. Please speak clearly and try again."
inputs = processor(
speech,
sampling_rate=16000,
return_tensors="pt"
)
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
MAX_DECODER_TOKENS = 448
START_TOKENS = 4
max_new_tokens = MAX_DECODER_TOKENS - START_TOKENS # 444
with torch.no_grad():
generated_ids = model.generate(
**inputs,
task="transcribe",
language="yo",
max_new_tokens=max_new_tokens,
temperature=0.0,
no_repeat_ngram_size=3
)
text = processor.batch_decode(
generated_ids,
skip_special_tokens=True
)[0].strip()
if len(text.split()) < 2:
return "Speech unclear. Please repeat slowly in Yoruba."
return text
# ----------------------------
# Gradio UI
# ----------------------------
demo = gr.Interface(
fn=transcribe_audio,
inputs=gr.Audio(
sources=["microphone", "upload"],
type="numpy",
label="Speak clearly or upload audio in Yoruba"
),
outputs=gr.Textbox(label="Transcription"),
title="Yoruba ASR (Whisper)",
description="Speech-to-text system that transcribes only Yoruba"
)
if __name__ == "__main__":
demo.launch(share=True)