Firdavs222's picture
Update app.py
7e10eb3 verified
import gradio as gr
from transformers import WhisperProcessor, WhisperForConditionalGeneration, GenerationConfig
import torch
import torchaudio
import numpy as np
import av
# --- Configuration and Model Loading ---
model_id = "OvozifyLabs/whisper-small-uz-v1"
# Check for GPU and set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model on device: {device}")
# Load the processor and model (only runs once at startup)
try:
processor = WhisperProcessor.from_pretrained(model_id)
model = WhisperForConditionalGeneration.from_pretrained(model_id).to(device)
except Exception as e:
print(f"Error loading model or processor: {e}")
processor = None
model = None
# --- Audio Loading Helper Function ---
def load_audio_file(file_path):
"""
Loads an audio file (handles M4A, MP3, WAV, etc.) and ensures it is
resampled to 16000 Hz and converted to mono, which Whisper models require.
"""
sr_target = 16000
if not file_path:
raise FileNotFoundError("Audio file path is empty.")
audio_data_list = []
current_sr = sr_target
try:
# Try torchaudio's built-in loader first
audio, sr = torchaudio.load(file_path)
current_sr = sr
# Resample if needed
if current_sr != sr_target:
if audio.dtype != torch.float32:
audio = audio.float()
resampler = torchaudio.transforms.Resample(orig_freq=current_sr, new_freq=sr_target)
audio = resampler(audio)
current_sr = sr_target
# Convert to mono if necessary
if audio.shape[0] > 1:
audio = torch.mean(audio, dim=0, keepdim=True)
return audio, current_sr
except Exception as torchaudio_e:
# Fallback to PyAV for formats like M4A, MP3
try:
import av
with av.open(file_path) as container:
stream = container.streams.audio[0]
resampler = av.AudioResampler(
format='fltp',
layout='mono',
rate=sr_target
)
for frame in container.decode(stream):
for resampled_frame in resampler.resample(frame):
audio_data_list.append(resampled_frame.to_ndarray()[0])
if not audio_data_list:
raise RuntimeError("Could not decode audio frames using PyAV.")
audio_np = np.concatenate(audio_data_list, axis=0)
audio = torch.from_numpy(audio_np).unsqueeze(0).float()
return audio, sr_target
except Exception as av_e:
raise RuntimeError(f"Failed to load audio file using both torchaudio and PyAV. Error: {av_e}")
# --- Audio Chunking Function ---
def chunk_audio(audio_tensor, sampling_rate, chunk_length_s=30, overlap_s=5):
"""
Splits audio into overlapping chunks.
Args:
audio_tensor: torch.Tensor of shape (1, num_samples) - mono audio
sampling_rate: int - sampling rate of the audio
chunk_length_s: float - length of each chunk in seconds
overlap_s: float - overlap between chunks in seconds
Returns:
List of audio chunks (torch.Tensors)
"""
chunk_samples = int(chunk_length_s * sampling_rate)
overlap_samples = int(overlap_s * sampling_rate)
stride = chunk_samples - overlap_samples
audio_length = audio_tensor.shape[1]
chunks = []
# If audio is shorter than chunk length, return as single chunk
if audio_length <= chunk_samples:
return [audio_tensor]
# Split into chunks with overlap
start = 0
while start < audio_length:
end = min(start + chunk_samples, audio_length)
chunk = audio_tensor[:, start:end]
chunks.append(chunk)
# Break if we've reached the end
if end >= audio_length:
break
start += stride
return chunks
# --- Transcription Function ---
def transcribe_audio(audio_file_path, language):
"""
Transcribes an audio file using the pre-loaded Whisper model.
Automatically chunks audio longer than 30 seconds.
"""
if model is None:
return "Error: Model was not loaded successfully at startup."
if audio_file_path is None:
return "Error: No audio file provided."
lang_dict = {
"Uzbek": "uz",
"Russian": "ru",
"English": "en"
}
language = lang_dict[language]
try:
# Load audio using the robust loader
audio, sr = load_audio_file(audio_file_path)
# Calculate audio duration
duration_s = audio.shape[1] / sr
# Check if chunking is needed
if duration_s > 30:
print(f"Audio duration: {duration_s:.2f}s - Chunking into segments...")
chunks = chunk_audio(audio, sr, chunk_length_s=30, overlap_s=5)
# Transcribe each chunk
transcriptions = []
for i, chunk in enumerate(chunks):
print(f"Processing chunk {i+1}/{len(chunks)}...")
inputs = processor(chunk.squeeze().numpy(), sampling_rate=sr, return_tensors="pt")
input_features = inputs.input_features.to(device)
forced_ids = processor.get_decoder_prompt_ids(language=language, task="transcribe")
gen_config = GenerationConfig(
forced_decoder_ids=forced_ids,
max_length=448
)
with torch.no_grad():
predicted_ids = model.generate(
input_features,
generation_config=gen_config
)
text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
transcriptions.append(text)
# Combine all transcriptions
full_transcription = " ".join(transcriptions)
return f"[Audio duration: {duration_s:.2f}s - Processed in {len(chunks)} chunks]\n\n{full_transcription}"
else:
# Process normally for short audio
print(f"Audio duration: {duration_s:.2f}s - Processing as single segment...")
inputs = processor(audio.squeeze().numpy(), sampling_rate=sr, return_tensors="pt")
input_features = inputs.input_features.to(device)
forced_ids = processor.get_decoder_prompt_ids(language=language, task="transcribe")
gen_config = GenerationConfig(
forced_decoder_ids=forced_ids,
max_length=448
)
with torch.no_grad():
predicted_ids = model.generate(
input_features,
generation_config=gen_config
)
text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return text
except Exception as e:
return f"An error occurred during transcription: {e}"
# --- Gradio Interface Setup ---
title = "Whisper Small Uz v1: Multilingual audio transcription"
description = """A Gradio demo for the **OvozifyLabs/whisper-small-uz-v1** model for Uzbek ASR.
Upload an audio file (M4A, MP3, WAV supported) or record directly. """
language_input = gr.Dropdown(
label="Select Language",
choices=["Uzbek", "English", "Russian"],
value="Uzbek"
)
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="filepath",
label="Input Audio (M4A/MP3/WAV, etc.)"
)
text_output = gr.Textbox(label="Transcription Result", lines=6, max_lines=25)
demo = gr.Interface(
fn=transcribe_audio,
inputs=[audio_input, language_input],
outputs=text_output,
title=title,
description=description,
)
if __name__ == "__main__":
demo.launch()