Ergashbek2004's picture
Update app.py
93db98b verified
raw
history blame
7.33 kB
import gradio as gr
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import torch
import torchaudio
import numpy as np
import av # Ensure you have installed this: pip install av
# --- Configuration and Model Loading ---
model_id = "OvozifyLabs/whisper-small-uz-v1"
# Check for GPU and set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model on device: {device}")
# Load the processor and model (only runs once at startup)
try:
processor = WhisperProcessor.from_pretrained(model_id)
model = WhisperForConditionalGeneration.from_pretrained(model_id).to(device)
except Exception as e:
print(f"Error loading model or processor: {e}")
# Handle the error gracefully if the model cannot be loaded
processor = None
model = None
# --- Audio Loading Helper Function ---
def load_audio_file(file_path):
"""
Loads an audio file (handles M4A, MP3, WAV, etc.) and ensures it is
resampled to 16000 Hz and converted to mono, which Whisper models require.
"""
sr_target = 16000 # Target sampling rate for the Whisper model
if not file_path:
raise FileNotFoundError("Audio file path is empty.")
audio_data_list = []
current_sr = sr_target # Assume target SR initially
try:
# 1. Try torchaudio's built-in loader first (usually handles WAV, FLAC well)
audio, sr = torchaudio.load(file_path)
current_sr = sr
# If torchaudio succeeds, perform necessary post-loading processing
# Resample if needed
if current_sr != sr_target:
if audio.dtype != torch.float32:
audio = audio.float()
resampler = torchaudio.transforms.Resample(orig_freq=current_sr, new_freq=sr_target)
audio = resampler(audio)
current_sr = sr_target
# Convert to mono if necessary (take the mean across channels)
if audio.shape[0] > 1:
audio = torch.mean(audio, dim=0, keepdim=True)
return audio, current_sr
except Exception as torchaudio_e:
# 2. Fallback to using PyAV (FFmpeg wrapper) for formats like M4A, MP3
# print(f"Torchaudio failed. Falling back to PyAV. Error: {torchaudio_e}")
try:
import av
with av.open(file_path) as container:
stream = container.streams.audio[0]
# Set up a resampler to ensure 16kHz float mono output
resampler = av.AudioResampler(
format='fltp', # 32-bit floating point
layout='mono', # Force mono output
rate=sr_target # Target sampling rate 16000 Hz
)
# Decode the audio stream and resample frames
for frame in container.decode(stream):
for resampled_frame in resampler.resample(frame):
# *** FIX APPLIED HERE: Removed 'format' keyword argument ***
# to_ndarray() converts the frame to a NumPy array.
# For a mono stream, [0] selects the single channel's data.
audio_data_list.append(resampled_frame.to_ndarray()[0])
if not audio_data_list:
raise RuntimeError("Could not decode audio frames using PyAV.")
# Concatenate all the 1D NumPy arrays into a single, continuous array
audio_np = np.concatenate(audio_data_list, axis=0)
# Convert the NumPy array back to a PyTorch tensor, ensuring it's 1-channel (mono)
audio = torch.from_numpy(audio_np).unsqueeze(0).float()
return audio, sr_target
except Exception as av_e:
raise RuntimeError(f"Failed to load audio file using both torchaudio and PyAV. Error: {av_e}")
# Note: The main `transcribe_audio` function and the Gradio setup do not need changes.
# Just replace this one function and restart your application.
# --- Post-Loading Processing (Only executes if torchaudio succeeded) ---
# Resample if needed (if torchaudio succeeded but the rate was wrong)
if current_sr != sr_target:
if audio_data.dtype != torch.float32:
audio_data = audio_data.float()
resampler = torchaudio.transforms.Resample(orig_freq=current_sr, new_freq=sr_target)
audio_data = resampler(audio_data)
current_sr = sr_target
# Convert to mono if necessary (take the mean across channels)
if audio_data.shape[0] > 1:
audio_data = torch.mean(audio_data, dim=0, keepdim=True)
return audio_data, current_sr
# --- Transcription Function ---
def transcribe_audio(audio_file_path):
"""
Transcribes an audio file using the pre-loaded Whisper model.
"""
if model is None:
return "Error: Model was not loaded successfully at startup."
if audio_file_path is None:
return "Error: No audio file provided."
try:
# Load audio using the robust loader and get the 16kHz mono tensor
audio, sr = load_audio_file(audio_file_path)
# The processor expects a 1D NumPy array for raw audio input
# audio.squeeze().numpy() converts the (1, N) torch tensor to a (N,) numpy array
inputs = processor(audio.squeeze().numpy(), sampling_rate=sr, return_tensors="pt")
# Move inputs to the appropriate device
input_features = inputs.input_features.to(device)
with torch.no_grad():
# Use generation arguments to specify language and task for the Uz-Small model
predicted_ids = model.generate(
input_features,
forced_decoder_ids=processor.get_decoder_prompt_ids(language="uz", task="transcribe"),
max_length=448 # Use a reasonable max length for speed/resource management
)
# Decode the generated token IDs to get the text transcript
text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return text
except Exception as e:
return f"An error occurred during transcription: {e}"
# --- Gradio Interface Setup ---
# πŸ–ΌοΈ Interface Description
title = "πŸ‡ΊπŸ‡Ώ Whisper Uz-Small v1: Audio Transcription"
description = "A Gradio demo for the **OvozifyLabs/whisper-small-uz-v1** model for Uzbek ASR. Upload an audio file (M4A, MP3, WAV supported) or record directly."
# 🎀 Input Component
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="filepath",
label="Input Audio (M4A/MP3/WAV, etc.)"
)
# πŸ“ Output Component
text_output = gr.Textbox(label="Transcription Result")
# πŸš€ Create the Interface
demo = gr.Interface(
fn=transcribe_audio,
inputs=audio_input,
outputs=text_output,
title=title,
description=description,
# The 'allow_flagging' argument caused the TypeError and is removed/replaced
# 'flagging_enabled=None' disables the flagging button, which is cleaner
# flagging_enabled=None,
# theme=gr.themes.Soft()
)
# πŸ’» Launch the App
if __name__ == "__main__":
demo.launch()