File size: 7,334 Bytes
dbebd1a
 
 
 
 
93db98b
dbebd1a
93db98b
dbebd1a
 
93db98b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbebd1a
93db98b
 
dbebd1a
93db98b
 
 
 
 
 
 
 
dbebd1a
93db98b
 
 
 
 
dbebd1a
93db98b
 
 
 
 
 
 
 
 
 
 
 
 
 
dbebd1a
93db98b
 
 
dbebd1a
93db98b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbebd1a
93db98b
 
 
 
 
 
 
 
 
 
 
 
 
 
dbebd1a
93db98b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbebd1a
93db98b
 
 
 
 
 
dbebd1a
93db98b
dbebd1a
 
93db98b
 
dbebd1a
93db98b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbebd1a
93db98b
dbebd1a
 
93db98b
 
 
 
 
 
 
 
dbebd1a
 
93db98b
dbebd1a
93db98b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import gradio as gr
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import torch
import torchaudio
import numpy as np
import av # Ensure you have installed this: pip install av

# --- Configuration and Model Loading ---
model_id = "OvozifyLabs/whisper-small-uz-v1"

# Check for GPU and set device
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Loading model on device: {device}")

# Load the processor and model (only runs once at startup)
try:
    processor = WhisperProcessor.from_pretrained(model_id)
    model = WhisperForConditionalGeneration.from_pretrained(model_id).to(device)
except Exception as e:
    print(f"Error loading model or processor: {e}")
    # Handle the error gracefully if the model cannot be loaded
    processor = None
    model = None


# --- Audio Loading Helper Function ---

def load_audio_file(file_path):
    """
    Loads an audio file (handles M4A, MP3, WAV, etc.) and ensures it is 
    resampled to 16000 Hz and converted to mono, which Whisper models require.
    """
    sr_target = 16000 # Target sampling rate for the Whisper model
    
    if not file_path:
        raise FileNotFoundError("Audio file path is empty.")

    audio_data_list = []
    current_sr = sr_target # Assume target SR initially

    try:
        # 1. Try torchaudio's built-in loader first (usually handles WAV, FLAC well)
        audio, sr = torchaudio.load(file_path)
        current_sr = sr
        
        # If torchaudio succeeds, perform necessary post-loading processing
        
        # Resample if needed
        if current_sr != sr_target:
            if audio.dtype != torch.float32:
                audio = audio.float()
                
            resampler = torchaudio.transforms.Resample(orig_freq=current_sr, new_freq=sr_target)
            audio = resampler(audio)
            current_sr = sr_target

        # Convert to mono if necessary (take the mean across channels)
        if audio.shape[0] > 1:
            audio = torch.mean(audio, dim=0, keepdim=True)
            
        return audio, current_sr
        
    except Exception as torchaudio_e:
        # 2. Fallback to using PyAV (FFmpeg wrapper) for formats like M4A, MP3
        # print(f"Torchaudio failed. Falling back to PyAV. Error: {torchaudio_e}")
        
        try:
            import av
            with av.open(file_path) as container:
                stream = container.streams.audio[0]
                
                # Set up a resampler to ensure 16kHz float mono output
                resampler = av.AudioResampler(
                    format='fltp',       # 32-bit floating point
                    layout='mono',       # Force mono output
                    rate=sr_target       # Target sampling rate 16000 Hz
                )
                
                # Decode the audio stream and resample frames
                for frame in container.decode(stream):
                    for resampled_frame in resampler.resample(frame):
                        # *** FIX APPLIED HERE: Removed 'format' keyword argument ***
                        # to_ndarray() converts the frame to a NumPy array.
                        # For a mono stream, [0] selects the single channel's data.
                        audio_data_list.append(resampled_frame.to_ndarray()[0]) 
                
                
            if not audio_data_list:
                raise RuntimeError("Could not decode audio frames using PyAV.")
            
            # Concatenate all the 1D NumPy arrays into a single, continuous array
            audio_np = np.concatenate(audio_data_list, axis=0)
            # Convert the NumPy array back to a PyTorch tensor, ensuring it's 1-channel (mono)
            audio = torch.from_numpy(audio_np).unsqueeze(0).float()
            
            return audio, sr_target
            
        except Exception as av_e:
            raise RuntimeError(f"Failed to load audio file using both torchaudio and PyAV. Error: {av_e}")

# Note: The main `transcribe_audio` function and the Gradio setup do not need changes.
# Just replace this one function and restart your application.
        
    # --- Post-Loading Processing (Only executes if torchaudio succeeded) ---
    
    # Resample if needed (if torchaudio succeeded but the rate was wrong)
    if current_sr != sr_target:
        if audio_data.dtype != torch.float32:
            audio_data = audio_data.float()
            
        resampler = torchaudio.transforms.Resample(orig_freq=current_sr, new_freq=sr_target)
        audio_data = resampler(audio_data)
        current_sr = sr_target

    # Convert to mono if necessary (take the mean across channels)
    if audio_data.shape[0] > 1:
        audio_data = torch.mean(audio_data, dim=0, keepdim=True)
        
    return audio_data, current_sr


# --- Transcription Function ---

def transcribe_audio(audio_file_path):
    """
    Transcribes an audio file using the pre-loaded Whisper model.
    """
    if model is None:
        return "Error: Model was not loaded successfully at startup."
        
    if audio_file_path is None:
        return "Error: No audio file provided."

    try:
        # Load audio using the robust loader and get the 16kHz mono tensor
        audio, sr = load_audio_file(audio_file_path)

        # The processor expects a 1D NumPy array for raw audio input
        # audio.squeeze().numpy() converts the (1, N) torch tensor to a (N,) numpy array
        inputs = processor(audio.squeeze().numpy(), sampling_rate=sr, return_tensors="pt")
        
        # Move inputs to the appropriate device
        input_features = inputs.input_features.to(device)

        with torch.no_grad():
            # Use generation arguments to specify language and task for the Uz-Small model
            predicted_ids = model.generate(
                input_features, 
                forced_decoder_ids=processor.get_decoder_prompt_ids(language="uz", task="transcribe"),
                max_length=448 # Use a reasonable max length for speed/resource management
            )
        
        # Decode the generated token IDs to get the text transcript
        text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
        
        return text

    except Exception as e:
        return f"An error occurred during transcription: {e}"


# --- Gradio Interface Setup ---
# πŸ–ΌοΈ Interface Description
title = "πŸ‡ΊπŸ‡Ώ Whisper Uz-Small v1: Audio Transcription"
description = "A Gradio demo for the **OvozifyLabs/whisper-small-uz-v1** model for Uzbek ASR. Upload an audio file (M4A, MP3, WAV supported) or record directly."

# 🎀 Input Component
audio_input = gr.Audio(
    sources=["microphone", "upload"], 
    type="filepath", 
    label="Input Audio (M4A/MP3/WAV, etc.)"
)

# πŸ“ Output Component
text_output = gr.Textbox(label="Transcription Result")

# πŸš€ Create the Interface
demo = gr.Interface(
    fn=transcribe_audio,
    inputs=audio_input,
    outputs=text_output,
    title=title,
    description=description,
    # The 'allow_flagging' argument caused the TypeError and is removed/replaced
    # 'flagging_enabled=None' disables the flagging button, which is cleaner
    # flagging_enabled=None, 
    # theme=gr.themes.Soft()
)

# πŸ’» Launch the App
if __name__ == "__main__":
    demo.launch()