|
|
import os |
|
|
import torch |
|
|
import gradio as gr |
|
|
import librosa |
|
|
import numpy as np |
|
|
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ASR_MODEL_ID = "openai/whisper-large-v3" |
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
processor = AutoProcessor.from_pretrained( |
|
|
ASR_MODEL_ID, |
|
|
use_auth_token=HF_TOKEN |
|
|
) |
|
|
|
|
|
model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
|
ASR_MODEL_ID, |
|
|
torch_dtype=DTYPE, |
|
|
low_cpu_mem_usage=True, |
|
|
use_safetensors=True, |
|
|
use_auth_token=HF_TOKEN |
|
|
).to(DEVICE) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess_audio(audio): |
|
|
if audio is None: |
|
|
return None |
|
|
|
|
|
sr, speech = audio |
|
|
|
|
|
|
|
|
if speech.ndim > 1: |
|
|
speech = np.mean(speech, axis=1) |
|
|
|
|
|
|
|
|
speech = speech.astype(np.float32) |
|
|
|
|
|
|
|
|
rms = np.sqrt(np.mean(speech ** 2)) |
|
|
if rms > 0: |
|
|
speech = speech / rms |
|
|
|
|
|
|
|
|
speech, _ = librosa.effects.trim(speech, top_db=25) |
|
|
|
|
|
|
|
|
if sr != 16000: |
|
|
speech = librosa.resample(speech, orig_sr=sr, target_sr=16000).astype(np.float32) |
|
|
|
|
|
return speech |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def transcribe_audio(audio): |
|
|
speech = preprocess_audio(audio) |
|
|
|
|
|
if speech is None or len(speech) < 16000: |
|
|
return "Audio too short or unclear. Please speak clearly and try again." |
|
|
|
|
|
inputs = processor( |
|
|
speech, |
|
|
sampling_rate=16000, |
|
|
return_tensors="pt" |
|
|
) |
|
|
inputs = {k: v.to(DEVICE) for k, v in inputs.items()} |
|
|
|
|
|
MAX_DECODER_TOKENS = 448 |
|
|
START_TOKENS = 4 |
|
|
max_new_tokens = MAX_DECODER_TOKENS - START_TOKENS |
|
|
|
|
|
with torch.no_grad(): |
|
|
generated_ids = model.generate( |
|
|
**inputs, |
|
|
task="transcribe", |
|
|
language="yo", |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=0.0, |
|
|
no_repeat_ngram_size=3 |
|
|
) |
|
|
|
|
|
|
|
|
text = processor.batch_decode( |
|
|
generated_ids, |
|
|
skip_special_tokens=True |
|
|
)[0].strip() |
|
|
|
|
|
if len(text.split()) < 2: |
|
|
return "Speech unclear. Please repeat slowly in Yoruba." |
|
|
|
|
|
return text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=transcribe_audio, |
|
|
inputs=gr.Audio( |
|
|
sources=["microphone", "upload"], |
|
|
type="numpy", |
|
|
label="Speak clearly or upload audio in Yoruba" |
|
|
), |
|
|
outputs=gr.Textbox(label="Transcription"), |
|
|
title="Yoruba ASR (Whisper)", |
|
|
description="Speech-to-text system that transcribes only Yoruba" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(share=True) |
|
|
|