Ergashbek2004 commited on
Commit
93db98b
Β·
verified Β·
1 Parent(s): 05ac736

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +166 -62
app.py CHANGED
@@ -3,87 +3,191 @@ from transformers import WhisperProcessor, WhisperForConditionalGeneration
3
  import torch
4
  import torchaudio
5
  import numpy as np
 
6
 
7
- # Load model and processor
8
  model_id = "OvozifyLabs/whisper-small-uz-v1"
9
- print("Loading model...")
10
- processor = WhisperProcessor.from_pretrained(model_id)
11
- model = WhisperForConditionalGeneration.from_pretrained(model_id)
12
- print("Model loaded successfully!")
13
 
14
- def transcribe_audio(audio_input):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  """
16
- Transcribe audio file to Uzbek text
17
-
18
- Args:
19
- audio_input: tuple of (sample_rate, audio_data) from Gradio
20
-
21
- Returns:
22
- str: Transcribed text
23
  """
 
 
 
 
 
 
 
 
24
  try:
25
- if audio_input is None:
26
- return "Please upload or record an audio file."
 
 
 
27
 
28
- # Gradio returns (sample_rate, audio_data)
29
- sr, audio = audio_input
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- # Convert to torch tensor if numpy array
32
- if isinstance(audio, np.ndarray):
33
- # Handle stereo audio by taking mean of channels
34
- if len(audio.shape) > 1:
35
- audio = audio.mean(axis=1)
36
- audio = torch.from_numpy(audio).float()
37
 
38
- # Ensure audio is 1D
39
- if len(audio.shape) > 1:
40
- audio = audio.squeeze()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- # Process audio
43
- inputs = processor(
44
- audio,
45
- sampling_rate=sr,
46
- return_tensors="pt"
47
- )
 
 
 
 
 
 
 
 
48
 
49
- # Generate transcription
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  with torch.no_grad():
51
- predicted_ids = model.generate(inputs.input_features)
 
 
 
 
 
52
 
53
- # Decode to text
54
  text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
55
 
56
- return text if text else "No speech detected in audio."
57
-
58
  except Exception as e:
59
- return f"Error during transcription: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- # Create Gradio interface
62
  demo = gr.Interface(
63
  fn=transcribe_audio,
64
- inputs=gr.Audio(
65
- sources=["upload", "microphone"],
66
- type="numpy",
67
- label="Upload Audio or Record"
68
- ),
69
- outputs=gr.Textbox(
70
- label="Transcription (Uzbek)",
71
- lines=5,
72
- placeholder="Your transcribed text will appear here..."
73
- ),
74
- title="πŸŽ™οΈ Uzbek Speech-to-Text",
75
- description="""
76
- Upload an audio file or record your voice to transcribe Uzbek speech to text.
77
- This app uses the Whisper Small model fine-tuned for Uzbek language.
78
- """,
79
- examples=[
80
- # Add example audio files if you have them
81
- # ["example1.wav"],
82
- # ["example2.wav"],
83
- ],
84
- theme=gr.themes.Soft(),
85
- allow_flagging="never"
86
  )
87
 
 
88
  if __name__ == "__main__":
89
- demo.launch(share=True)
 
3
  import torch
4
  import torchaudio
5
  import numpy as np
6
+ import av # Ensure you have installed this: pip install av
7
 
8
+ # --- Configuration and Model Loading ---
9
  model_id = "OvozifyLabs/whisper-small-uz-v1"
 
 
 
 
10
 
11
+ # Check for GPU and set device
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ print(f"Loading model on device: {device}")
15
+
16
+ # Load the processor and model (only runs once at startup)
17
+ try:
18
+ processor = WhisperProcessor.from_pretrained(model_id)
19
+ model = WhisperForConditionalGeneration.from_pretrained(model_id).to(device)
20
+ except Exception as e:
21
+ print(f"Error loading model or processor: {e}")
22
+ # Handle the error gracefully if the model cannot be loaded
23
+ processor = None
24
+ model = None
25
+
26
+
27
+ # --- Audio Loading Helper Function ---
28
+
29
+ def load_audio_file(file_path):
30
  """
31
+ Loads an audio file (handles M4A, MP3, WAV, etc.) and ensures it is
32
+ resampled to 16000 Hz and converted to mono, which Whisper models require.
 
 
 
 
 
33
  """
34
+ sr_target = 16000 # Target sampling rate for the Whisper model
35
+
36
+ if not file_path:
37
+ raise FileNotFoundError("Audio file path is empty.")
38
+
39
+ audio_data_list = []
40
+ current_sr = sr_target # Assume target SR initially
41
+
42
  try:
43
+ # 1. Try torchaudio's built-in loader first (usually handles WAV, FLAC well)
44
+ audio, sr = torchaudio.load(file_path)
45
+ current_sr = sr
46
+
47
+ # If torchaudio succeeds, perform necessary post-loading processing
48
 
49
+ # Resample if needed
50
+ if current_sr != sr_target:
51
+ if audio.dtype != torch.float32:
52
+ audio = audio.float()
53
+
54
+ resampler = torchaudio.transforms.Resample(orig_freq=current_sr, new_freq=sr_target)
55
+ audio = resampler(audio)
56
+ current_sr = sr_target
57
+
58
+ # Convert to mono if necessary (take the mean across channels)
59
+ if audio.shape[0] > 1:
60
+ audio = torch.mean(audio, dim=0, keepdim=True)
61
+
62
+ return audio, current_sr
63
 
64
+ except Exception as torchaudio_e:
65
+ # 2. Fallback to using PyAV (FFmpeg wrapper) for formats like M4A, MP3
66
+ # print(f"Torchaudio failed. Falling back to PyAV. Error: {torchaudio_e}")
 
 
 
67
 
68
+ try:
69
+ import av
70
+ with av.open(file_path) as container:
71
+ stream = container.streams.audio[0]
72
+
73
+ # Set up a resampler to ensure 16kHz float mono output
74
+ resampler = av.AudioResampler(
75
+ format='fltp', # 32-bit floating point
76
+ layout='mono', # Force mono output
77
+ rate=sr_target # Target sampling rate 16000 Hz
78
+ )
79
+
80
+ # Decode the audio stream and resample frames
81
+ for frame in container.decode(stream):
82
+ for resampled_frame in resampler.resample(frame):
83
+ # *** FIX APPLIED HERE: Removed 'format' keyword argument ***
84
+ # to_ndarray() converts the frame to a NumPy array.
85
+ # For a mono stream, [0] selects the single channel's data.
86
+ audio_data_list.append(resampled_frame.to_ndarray()[0])
87
+
88
+
89
+ if not audio_data_list:
90
+ raise RuntimeError("Could not decode audio frames using PyAV.")
91
+
92
+ # Concatenate all the 1D NumPy arrays into a single, continuous array
93
+ audio_np = np.concatenate(audio_data_list, axis=0)
94
+ # Convert the NumPy array back to a PyTorch tensor, ensuring it's 1-channel (mono)
95
+ audio = torch.from_numpy(audio_np).unsqueeze(0).float()
96
+
97
+ return audio, sr_target
98
+
99
+ except Exception as av_e:
100
+ raise RuntimeError(f"Failed to load audio file using both torchaudio and PyAV. Error: {av_e}")
101
+
102
+ # Note: The main `transcribe_audio` function and the Gradio setup do not need changes.
103
+ # Just replace this one function and restart your application.
104
 
105
+ # --- Post-Loading Processing (Only executes if torchaudio succeeded) ---
106
+
107
+ # Resample if needed (if torchaudio succeeded but the rate was wrong)
108
+ if current_sr != sr_target:
109
+ if audio_data.dtype != torch.float32:
110
+ audio_data = audio_data.float()
111
+
112
+ resampler = torchaudio.transforms.Resample(orig_freq=current_sr, new_freq=sr_target)
113
+ audio_data = resampler(audio_data)
114
+ current_sr = sr_target
115
+
116
+ # Convert to mono if necessary (take the mean across channels)
117
+ if audio_data.shape[0] > 1:
118
+ audio_data = torch.mean(audio_data, dim=0, keepdim=True)
119
 
120
+ return audio_data, current_sr
121
+
122
+
123
+ # --- Transcription Function ---
124
+
125
+ def transcribe_audio(audio_file_path):
126
+ """
127
+ Transcribes an audio file using the pre-loaded Whisper model.
128
+ """
129
+ if model is None:
130
+ return "Error: Model was not loaded successfully at startup."
131
+
132
+ if audio_file_path is None:
133
+ return "Error: No audio file provided."
134
+
135
+ try:
136
+ # Load audio using the robust loader and get the 16kHz mono tensor
137
+ audio, sr = load_audio_file(audio_file_path)
138
+
139
+ # The processor expects a 1D NumPy array for raw audio input
140
+ # audio.squeeze().numpy() converts the (1, N) torch tensor to a (N,) numpy array
141
+ inputs = processor(audio.squeeze().numpy(), sampling_rate=sr, return_tensors="pt")
142
+
143
+ # Move inputs to the appropriate device
144
+ input_features = inputs.input_features.to(device)
145
+
146
  with torch.no_grad():
147
+ # Use generation arguments to specify language and task for the Uz-Small model
148
+ predicted_ids = model.generate(
149
+ input_features,
150
+ forced_decoder_ids=processor.get_decoder_prompt_ids(language="uz", task="transcribe"),
151
+ max_length=448 # Use a reasonable max length for speed/resource management
152
+ )
153
 
154
+ # Decode the generated token IDs to get the text transcript
155
  text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
156
 
157
+ return text
158
+
159
  except Exception as e:
160
+ return f"An error occurred during transcription: {e}"
161
+
162
+
163
+ # --- Gradio Interface Setup ---
164
+ # πŸ–ΌοΈ Interface Description
165
+ title = "πŸ‡ΊπŸ‡Ώ Whisper Uz-Small v1: Audio Transcription"
166
+ description = "A Gradio demo for the **OvozifyLabs/whisper-small-uz-v1** model for Uzbek ASR. Upload an audio file (M4A, MP3, WAV supported) or record directly."
167
+
168
+ # 🎀 Input Component
169
+ audio_input = gr.Audio(
170
+ sources=["microphone", "upload"],
171
+ type="filepath",
172
+ label="Input Audio (M4A/MP3/WAV, etc.)"
173
+ )
174
+
175
+ # πŸ“ Output Component
176
+ text_output = gr.Textbox(label="Transcription Result")
177
 
178
+ # πŸš€ Create the Interface
179
  demo = gr.Interface(
180
  fn=transcribe_audio,
181
+ inputs=audio_input,
182
+ outputs=text_output,
183
+ title=title,
184
+ description=description,
185
+ # The 'allow_flagging' argument caused the TypeError and is removed/replaced
186
+ # 'flagging_enabled=None' disables the flagging button, which is cleaner
187
+ # flagging_enabled=None,
188
+ # theme=gr.themes.Soft()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  )
190
 
191
+ # πŸ’» Launch the App
192
  if __name__ == "__main__":
193
+ demo.launch()