Spaces:
Running
Running
File size: 7,334 Bytes
dbebd1a 93db98b dbebd1a 93db98b dbebd1a 93db98b dbebd1a 93db98b dbebd1a 93db98b dbebd1a 93db98b dbebd1a 93db98b dbebd1a 93db98b dbebd1a 93db98b dbebd1a 93db98b dbebd1a 93db98b dbebd1a 93db98b dbebd1a 93db98b dbebd1a 93db98b dbebd1a 93db98b dbebd1a 93db98b dbebd1a 93db98b dbebd1a 93db98b dbebd1a 93db98b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import gradio as gr
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import torch
import torchaudio
import numpy as np
import av # Ensure you have installed this: pip install av
# --- Configuration and Model Loading ---
model_id = "OvozifyLabs/whisper-small-uz-v1"
# Check for GPU and set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model on device: {device}")
# Load the processor and model (only runs once at startup)
try:
processor = WhisperProcessor.from_pretrained(model_id)
model = WhisperForConditionalGeneration.from_pretrained(model_id).to(device)
except Exception as e:
print(f"Error loading model or processor: {e}")
# Handle the error gracefully if the model cannot be loaded
processor = None
model = None
# --- Audio Loading Helper Function ---
def load_audio_file(file_path):
"""
Loads an audio file (handles M4A, MP3, WAV, etc.) and ensures it is
resampled to 16000 Hz and converted to mono, which Whisper models require.
"""
sr_target = 16000 # Target sampling rate for the Whisper model
if not file_path:
raise FileNotFoundError("Audio file path is empty.")
audio_data_list = []
current_sr = sr_target # Assume target SR initially
try:
# 1. Try torchaudio's built-in loader first (usually handles WAV, FLAC well)
audio, sr = torchaudio.load(file_path)
current_sr = sr
# If torchaudio succeeds, perform necessary post-loading processing
# Resample if needed
if current_sr != sr_target:
if audio.dtype != torch.float32:
audio = audio.float()
resampler = torchaudio.transforms.Resample(orig_freq=current_sr, new_freq=sr_target)
audio = resampler(audio)
current_sr = sr_target
# Convert to mono if necessary (take the mean across channels)
if audio.shape[0] > 1:
audio = torch.mean(audio, dim=0, keepdim=True)
return audio, current_sr
except Exception as torchaudio_e:
# 2. Fallback to using PyAV (FFmpeg wrapper) for formats like M4A, MP3
# print(f"Torchaudio failed. Falling back to PyAV. Error: {torchaudio_e}")
try:
import av
with av.open(file_path) as container:
stream = container.streams.audio[0]
# Set up a resampler to ensure 16kHz float mono output
resampler = av.AudioResampler(
format='fltp', # 32-bit floating point
layout='mono', # Force mono output
rate=sr_target # Target sampling rate 16000 Hz
)
# Decode the audio stream and resample frames
for frame in container.decode(stream):
for resampled_frame in resampler.resample(frame):
# *** FIX APPLIED HERE: Removed 'format' keyword argument ***
# to_ndarray() converts the frame to a NumPy array.
# For a mono stream, [0] selects the single channel's data.
audio_data_list.append(resampled_frame.to_ndarray()[0])
if not audio_data_list:
raise RuntimeError("Could not decode audio frames using PyAV.")
# Concatenate all the 1D NumPy arrays into a single, continuous array
audio_np = np.concatenate(audio_data_list, axis=0)
# Convert the NumPy array back to a PyTorch tensor, ensuring it's 1-channel (mono)
audio = torch.from_numpy(audio_np).unsqueeze(0).float()
return audio, sr_target
except Exception as av_e:
raise RuntimeError(f"Failed to load audio file using both torchaudio and PyAV. Error: {av_e}")
# Note: The main `transcribe_audio` function and the Gradio setup do not need changes.
# Just replace this one function and restart your application.
# --- Post-Loading Processing (Only executes if torchaudio succeeded) ---
# Resample if needed (if torchaudio succeeded but the rate was wrong)
if current_sr != sr_target:
if audio_data.dtype != torch.float32:
audio_data = audio_data.float()
resampler = torchaudio.transforms.Resample(orig_freq=current_sr, new_freq=sr_target)
audio_data = resampler(audio_data)
current_sr = sr_target
# Convert to mono if necessary (take the mean across channels)
if audio_data.shape[0] > 1:
audio_data = torch.mean(audio_data, dim=0, keepdim=True)
return audio_data, current_sr
# --- Transcription Function ---
def transcribe_audio(audio_file_path):
"""
Transcribes an audio file using the pre-loaded Whisper model.
"""
if model is None:
return "Error: Model was not loaded successfully at startup."
if audio_file_path is None:
return "Error: No audio file provided."
try:
# Load audio using the robust loader and get the 16kHz mono tensor
audio, sr = load_audio_file(audio_file_path)
# The processor expects a 1D NumPy array for raw audio input
# audio.squeeze().numpy() converts the (1, N) torch tensor to a (N,) numpy array
inputs = processor(audio.squeeze().numpy(), sampling_rate=sr, return_tensors="pt")
# Move inputs to the appropriate device
input_features = inputs.input_features.to(device)
with torch.no_grad():
# Use generation arguments to specify language and task for the Uz-Small model
predicted_ids = model.generate(
input_features,
forced_decoder_ids=processor.get_decoder_prompt_ids(language="uz", task="transcribe"),
max_length=448 # Use a reasonable max length for speed/resource management
)
# Decode the generated token IDs to get the text transcript
text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return text
except Exception as e:
return f"An error occurred during transcription: {e}"
# --- Gradio Interface Setup ---
# πΌοΈ Interface Description
title = "πΊπΏ Whisper Uz-Small v1: Audio Transcription"
description = "A Gradio demo for the **OvozifyLabs/whisper-small-uz-v1** model for Uzbek ASR. Upload an audio file (M4A, MP3, WAV supported) or record directly."
# π€ Input Component
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="filepath",
label="Input Audio (M4A/MP3/WAV, etc.)"
)
# π Output Component
text_output = gr.Textbox(label="Transcription Result")
# π Create the Interface
demo = gr.Interface(
fn=transcribe_audio,
inputs=audio_input,
outputs=text_output,
title=title,
description=description,
# The 'allow_flagging' argument caused the TypeError and is removed/replaced
# 'flagging_enabled=None' disables the flagging button, which is cleaner
# flagging_enabled=None,
# theme=gr.themes.Soft()
)
# π» Launch the App
if __name__ == "__main__":
demo.launch() |