mimishanmi commited on
Commit
fd226a9
·
verified ·
1 Parent(s): 0953b49

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -103
app.py CHANGED
@@ -1,17 +1,13 @@
1
-
2
-
3
  import os
4
  import asyncio
5
  import whisper
6
  import gradio as gr
7
  import torch
8
- import shutil
9
  import logging
10
  from pathlib import Path
11
  import ffmpeg
12
  import re
13
- import threading
14
- from tqdm.notebook import tqdm
15
  from cryptography.fernet import Fernet
16
  from pyannote.audio import Pipeline
17
  from pyannote.core import Segment
@@ -19,35 +15,78 @@ import numpy as np
19
  import sounddevice as sd
20
  import soundfile as sf
21
  import time
 
22
 
23
- # --- Configuration ---
24
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
25
  logger = logging.getLogger(__name__)
26
 
27
  TEMP_FOLDER = 'temp/'
28
- SUPPORTED_FORMATS = ['.mp3', '.wav', '.aac', '.flac', '.ogg', '.m4a',
29
- '.mp4', '.avi', '.mov', '.mkv', '.webm', '.3gp']
30
- MAX_AUDIO_LENGTH = 600 # 10 minutes in seconds
31
- DIARIZATION_MODEL = "pyannote/speaker-diarization@2.1"
32
- HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN") # Get your Hugging Face auth token
33
 
34
- # --- Encryption ---
35
- # ... (Code for generate_key, encrypt_file, and decrypt_file remains the same)
36
 
37
- # --- File Handling ---
38
- # ... (Code for create_folders, is_supported_format, convert_to_wav, and delete_temp_file remains the same)
 
 
 
39
 
40
- # --- Whisper Model Cache ---
41
- class WhisperModelCache:
42
- # ... (Code for WhisperModelCache, including efficient model loading, remains the same)
 
 
 
 
 
 
 
 
 
43
 
44
- # --- Transcription and Diarization ---
45
- async def transcribe_audio(audio_path, language, task='transcribe',
46
- initial_prompt=None, temperature=0.5,
47
- num_speakers=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  try:
49
  model = WhisperModelCache.get_instance().load_model()
50
-
51
  result = await asyncio.to_thread(
52
  model.transcribe,
53
  audio_path,
@@ -56,42 +95,49 @@ async def transcribe_audio(audio_path, language, task='transcribe',
56
  initial_prompt=initial_prompt,
57
  temperature=temperature,
58
  )
59
-
60
  if num_speakers > 1:
61
  diarization = await perform_diarization(audio_path, num_speakers)
62
  result['text'] = apply_diarization(result, diarization)
63
-
64
  return result['text']
65
-
66
  except Exception as e:
67
  logger.error(f"Error transcribing {audio_path}: {str(e)}")
68
  return f"Error during transcription: {str(e)}"
69
 
70
  async def perform_diarization(audio_path, num_speakers):
71
- """Performs speaker diarization using Pyannote Audio."""
72
- pipeline = Pipeline.from_pretrained(DIARIZATION_MODEL, use_auth_token=HF_AUTH_TOKEN)
73
- diarization = pipeline(audio_path, num_speakers=num_speakers)
74
- return diarization
75
 
76
  def apply_diarization(whisper_result, diarization):
77
- """Applies speaker labels from diarization to Whisper segments."""
78
  speaker_segments = []
79
  for turn, _, speaker in diarization.itertracks(yield_label=True):
80
  speaker_segments.append((turn.start, turn.end, speaker))
81
-
82
  diarized_text = ""
83
  for segment in whisper_result['segments']:
84
- start, end, text = segment['start'], segment['end'], segment['text']
85
- speaker = next((s_label for s_start, s_end, s_label in speaker_segments
86
- if Segment(start, end).intersects(Segment(s_start, s_end))), "Unknown")
 
 
 
 
 
 
 
87
  diarized_text += f"[{speaker}]: {text}\n"
88
-
89
  return diarized_text
90
 
91
- # --- Anonymization ---
92
- # ... (Code for anonymize_text remains the same)
 
 
 
 
93
 
94
- # --- Real-Time Transcription ---
95
  class RealTimeTranscriber:
96
  def __init__(self, language, task, initial_prompt, temperature):
97
  self.language = language
@@ -99,19 +145,16 @@ class RealTimeTranscriber:
99
  self.initial_prompt = initial_prompt
100
  self.temperature = temperature
101
  self.model = WhisperModelCache.get_instance().load_model()
102
- self.audio_buffer = np.array([], dtype=np.float32)
103
  self.is_recording = False
104
  self.transcription = ""
105
- self.chunk_duration = 2 # Process audio in 2-second chunks
106
 
107
  async def start_recording(self):
108
  self.is_recording = True
109
  threading.Thread(target=self._record_audio, daemon=True).start()
110
  while self.is_recording:
111
- await asyncio.sleep(self.chunk_duration)
112
- if len(self.audio_buffer) >= self.chunk_duration * 16000:
113
- audio_chunk = self.audio_buffer[:int(self.chunk_duration * 16000)]
114
- self.audio_buffer = self.audio_buffer[int(self.chunk_duration * 16000):]
115
  result = await asyncio.to_thread(
116
  self.model.transcribe,
117
  audio_chunk,
@@ -121,6 +164,7 @@ class RealTimeTranscriber:
121
  temperature=self.temperature
122
  )
123
  self.transcription += result['text'] + " "
 
124
  return self.transcription
125
 
126
  def stop_recording(self):
@@ -134,12 +178,10 @@ class RealTimeTranscriber:
134
  def _audio_callback(self, indata, frames, time, status):
135
  if status:
136
  logger.warning(f"Audio callback status: {status}")
137
- self.audio_buffer = np.append(self.audio_buffer, np.frombuffer(indata, dtype=np.float32))
 
138
 
139
- # --- Main Processing Function ---
140
- async def process_audio(file, language, task, anonymize,
141
- initial_prompt, temperature,
142
- encryption_key, num_speakers):
143
  try:
144
  if not file:
145
  return "Error: Please upload an audio or video file."
@@ -147,7 +189,6 @@ async def process_audio(file, language, task, anonymize,
147
  if not is_supported_format(file):
148
  return f"Error: Unsupported file format: {file.name}"
149
 
150
- # --- ENCRYPTION ---
151
  if encryption_key:
152
  try:
153
  encrypt_file(encryption_key.encode(), file.name)
@@ -156,38 +197,24 @@ async def process_audio(file, language, task, anonymize,
156
  logger.error(f"Encryption failed: {str(e)}")
157
  return f"Error: Encryption failed: {str(e)}"
158
 
159
- # Convert to WAV (if necessary)
160
- temp_audio_path = convert_to_wav(file.name) if not file.name.lower().endswith('.wav') else file.name
161
  if not temp_audio_path:
162
  return f"Error: Failed to convert {file.name} to WAV format."
163
 
164
- # Check audio length
165
- probe = ffmpeg.probe(temp_audio_path)
166
- audio_duration = float(probe['format']['duration'])
167
- if audio_duration > MAX_AUDIO_LENGTH:
168
- return f"Error: Audio file is too long. Maximum duration is {MAX_AUDIO_LENGTH} seconds."
169
-
170
- # Transcribe (with progress bar)
171
- with tqdm(total=100, desc="Transcribing", unit="%", position=0, leave=True) as pbar:
172
- transcription = await transcribe_audio(
173
- temp_audio_path,
174
- language,
175
- task=task,
176
- initial_prompt=initial_prompt,
177
- temperature=temperature,
178
- num_speakers=num_speakers,
179
- progress_bar=pbar # Pass the progress bar to transcribe_audio
180
- )
181
-
182
- # Clean up the temporary WAV file (if it was converted)
183
- if temp_audio_path != file.name:
184
- delete_temp_file(temp_audio_path)
185
-
186
- # Anonymize if selected
187
  if anonymize:
188
  transcription = anonymize_text(transcription)
189
 
190
- # --- DECRYPTION ---
191
  if encryption_key:
192
  try:
193
  decrypt_file(encryption_key.encode(), file.name)
@@ -202,12 +229,11 @@ async def process_audio(file, language, task, anonymize,
202
  logger.error(f"Error processing audio: {e}")
203
  return f"Error: {str(e)}"
204
 
205
- # --- Gradio UI ---
206
  def create_ui():
207
  languages = {
208
  "en": "English", "es": "Spanish", "fr": "French", "de": "German", "it": "Italian",
209
  "pt": "Portuguese", "nl": "Dutch", "ru": "Russian", "zh": "Chinese", "ja": "Japanese",
210
- "ko": "Korean", "ar": "Arabic", "he": "Hebrew", "hi": "Hindi", "bn": "Bengali", "ur": "Urdu",
211
  "te": "Telugu", "ta": "Tamil", "mr": "Marathi", "gu": "Gujarati", "kn": "Kannada"
212
  }
213
 
@@ -215,9 +241,9 @@ def create_ui():
215
  gr.Markdown(
216
  """
217
  # 🎙️ Advanced Whisper Transcription App
218
-
219
  Transcribe or translate your audio and video files with ease, now with real-time processing!
220
-
221
  ## Features:
222
  - Support for multiple audio and video formats
223
  - Speaker diarization for multi-speaker audio
@@ -226,7 +252,7 @@ def create_ui():
226
  - File encryption for enhanced security
227
  """
228
  )
229
-
230
  with gr.Tabs():
231
  with gr.TabItem("File Upload"):
232
  with gr.Row():
@@ -265,16 +291,16 @@ def create_ui():
265
  )
266
  encryption_key = gr.Textbox(label="Encryption Key (Optional)", type="password")
267
  process_button = gr.Button("Process Audio", variant="primary")
268
-
269
  with gr.Column(scale=3):
270
  output_text = gr.Textbox(label="Transcription Output", lines=20)
271
-
272
  process_button.click(
273
  fn=process_audio,
274
  inputs=[file_input, language_dropdown, task_dropdown, anonymize_checkbox, prompt_input, temperature_slider, encryption_key, num_speakers],
275
  outputs=output_text
276
  )
277
-
278
  with gr.TabItem("Real-time Transcription"):
279
  with gr.Row():
280
  with gr.Column(scale=2):
@@ -302,26 +328,17 @@ def create_ui():
302
  )
303
  rt_start_button = gr.Button("Start Real-time Transcription", variant="primary")
304
  rt_stop_button = gr.Button("Stop Transcription", variant="secondary")
305
-
306
  with gr.Column(scale=3):
307
  rt_output_text = gr.Textbox(label="Real-time Transcription Output", lines=20)
308
 
309
- rt_transcriber = None # Store the transcriber object to stop it later
310
-
311
  async def start_real_time_transcription(language, task, prompt, temperature):
312
- global rt_transcriber
313
- rt_transcriber = RealTimeTranscriber(language, task, prompt, temperature)
314
- transcription = await rt_transcriber.start_recording()
315
  return transcription
316
 
317
  def stop_real_time_transcription():
318
- global rt_transcriber
319
- if rt_transcriber is not None:
320
- rt_transcriber.stop_recording()
321
- rt_transcriber = None
322
- return "Transcription stopped."
323
- else:
324
- return "No active transcription."
325
 
326
  rt_start_button.click(
327
  fn=start_real_time_transcription,
@@ -346,16 +363,15 @@ def create_ui():
346
  - Click "Process Audio" and wait for the results.
347
  3. For Real-time Transcription:
348
  - Select the language and task.
349
- - Optionally, provide a prompt and adjust the temperature.
350
- - Click "Start Real-time Transcription". Speak into your microphone.
351
- - Click "Stop Transcription" to end the process.
352
  """
353
  )
354
 
355
  return interface
356
 
357
- # --- Main Execution ---
358
  if __name__ == "__main__":
359
  create_folders()
360
  iface = create_ui()
361
- iface.queue().launch(debug=True)
 
 
 
1
  import os
2
  import asyncio
3
  import whisper
4
  import gradio as gr
5
  import torch
 
6
  import logging
7
  from pathlib import Path
8
  import ffmpeg
9
  import re
10
+ from tqdm import tqdm
 
11
  from cryptography.fernet import Fernet
12
  from pyannote.audio import Pipeline
13
  from pyannote.core import Segment
 
15
  import sounddevice as sd
16
  import soundfile as sf
17
  import time
18
+ import threading
19
 
 
20
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
21
  logger = logging.getLogger(__name__)
22
 
23
  TEMP_FOLDER = 'temp/'
24
+ SUPPORTED_FORMATS = ['.mp3', '.wav', '.aac', '.flac', '.ogg', '.m4a', '.mp4', '.avi', '.mov', '.mkv', '.webm']
25
+ MAX_AUDIO_LENGTH = 600
 
 
 
26
 
27
+ class WhisperModelCache:
28
+ _instance = None
29
 
30
+ @staticmethod
31
+ def get_instance():
32
+ if WhisperModelCache._instance is None:
33
+ WhisperModelCache._instance = WhisperModelCache()
34
+ return WhisperModelCache._instance
35
 
36
+ def __init__(self):
37
+ self.model = None
38
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+
40
+ def load_model(self, model_size="medium"):
41
+ if self.model is None:
42
+ logger.info(f"Loading Whisper model: {model_size} on {self.device}")
43
+ self.model = whisper.load_model(model_size, device=self.device)
44
+ return self.model
45
+
46
+ def create_folders():
47
+ Path(TEMP_FOLDER).mkdir(exist_ok=True)
48
 
49
+ def is_supported_format(file):
50
+ return file is not None and any(file.name.lower().endswith(ext) for ext in SUPPORTED_FORMATS)
51
+
52
+ def convert_to_wav(original_file_path):
53
+ output_path = os.path.join(TEMP_FOLDER, os.path.splitext(os.path.basename(original_file_path))[0] + '.wav')
54
+ try:
55
+ (
56
+ ffmpeg
57
+ .input(original_file_path)
58
+ .output(output_path, acodec='pcm_s16le', ac=1, ar='16k')
59
+ .overwrite_output()
60
+ .run(capture_stdout=True, capture_stderr=True)
61
+ )
62
+ return output_path
63
+ except ffmpeg.Error as e:
64
+ logger.error(f'Error converting {original_file_path}: {e.stderr.decode()}')
65
+ return None
66
+
67
+ def generate_key():
68
+ return Fernet.generate_key()
69
+
70
+ def encrypt_file(key, filename):
71
+ f = Fernet(key)
72
+ with open(filename, "rb") as file:
73
+ original_data = file.read()
74
+ encrypted_data = f.encrypt(original_data)
75
+ with open(filename, "wb") as file:
76
+ file.write(encrypted_data)
77
+
78
+ def decrypt_file(key, filename):
79
+ f = Fernet(key)
80
+ with open(filename, "rb") as file:
81
+ encrypted_data = file.read()
82
+ decrypted_data = f.decrypt(encrypted_data)
83
+ with open(filename, "wb") as file:
84
+ file.write(decrypted_data)
85
+
86
+ async def transcribe_audio(audio_path, language, task='transcribe', initial_prompt=None, temperature=0.5, num_speakers=1):
87
  try:
88
  model = WhisperModelCache.get_instance().load_model()
89
+
90
  result = await asyncio.to_thread(
91
  model.transcribe,
92
  audio_path,
 
95
  initial_prompt=initial_prompt,
96
  temperature=temperature,
97
  )
98
+
99
  if num_speakers > 1:
100
  diarization = await perform_diarization(audio_path, num_speakers)
101
  result['text'] = apply_diarization(result, diarization)
102
+
103
  return result['text']
 
104
  except Exception as e:
105
  logger.error(f"Error transcribing {audio_path}: {str(e)}")
106
  return f"Error during transcription: {str(e)}"
107
 
108
  async def perform_diarization(audio_path, num_speakers):
109
+ pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",
110
+ use_auth_token="YOUR_HF_AUTH_TOKEN")
111
+ return pipeline(audio_path, num_speakers=num_speakers)
 
112
 
113
  def apply_diarization(whisper_result, diarization):
 
114
  speaker_segments = []
115
  for turn, _, speaker in diarization.itertracks(yield_label=True):
116
  speaker_segments.append((turn.start, turn.end, speaker))
117
+
118
  diarized_text = ""
119
  for segment in whisper_result['segments']:
120
+ start_time = segment['start']
121
+ end_time = segment['end']
122
+ text = segment['text']
123
+
124
+ speaker = "Unknown"
125
+ for s_start, s_end, s_label in speaker_segments:
126
+ if Segment(start_time, end_time).intersects(Segment(s_start, s_end)):
127
+ speaker = s_label
128
+ break
129
+
130
  diarized_text += f"[{speaker}]: {text}\n"
131
+
132
  return diarized_text
133
 
134
+ def anonymize_text(text):
135
+ text = re.sub(r'\b[A-Z][a-z]+ [A-Z][a-z]+\b|\S+@\S+|\d{3}[-.]?\d{3}[-.]?\d{4}',
136
+ lambda m: '[NAME]' if re.match(r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', m.group()) else
137
+ '[EMAIL]' if '@' in m.group() else '[PHONE]',
138
+ text)
139
+ return text
140
 
 
141
  class RealTimeTranscriber:
142
  def __init__(self, language, task, initial_prompt, temperature):
143
  self.language = language
 
145
  self.initial_prompt = initial_prompt
146
  self.temperature = temperature
147
  self.model = WhisperModelCache.get_instance().load_model()
148
+ self.audio_queue = asyncio.Queue()
149
  self.is_recording = False
150
  self.transcription = ""
 
151
 
152
  async def start_recording(self):
153
  self.is_recording = True
154
  threading.Thread(target=self._record_audio, daemon=True).start()
155
  while self.is_recording:
156
+ audio_chunk = await self.audio_queue.get()
157
+ if audio_chunk is not None:
 
 
158
  result = await asyncio.to_thread(
159
  self.model.transcribe,
160
  audio_chunk,
 
164
  temperature=self.temperature
165
  )
166
  self.transcription += result['text'] + " "
167
+ await asyncio.sleep(0.1)
168
  return self.transcription
169
 
170
  def stop_recording(self):
 
178
  def _audio_callback(self, indata, frames, time, status):
179
  if status:
180
  logger.warning(f"Audio callback status: {status}")
181
+ audio_chunk = np.frombuffer(indata, dtype=np.float32)
182
+ asyncio.run_coroutine_threadsafe(self.audio_queue.put(audio_chunk), asyncio.get_event_loop())
183
 
184
+ async def process_audio(file, language, task, anonymize, initial_prompt, temperature, encryption_key, num_speakers):
 
 
 
185
  try:
186
  if not file:
187
  return "Error: Please upload an audio or video file."
 
189
  if not is_supported_format(file):
190
  return f"Error: Unsupported file format: {file.name}"
191
 
 
192
  if encryption_key:
193
  try:
194
  encrypt_file(encryption_key.encode(), file.name)
 
197
  logger.error(f"Encryption failed: {str(e)}")
198
  return f"Error: Encryption failed: {str(e)}"
199
 
200
+ temp_audio_path = convert_to_wav(file.name)
 
201
  if not temp_audio_path:
202
  return f"Error: Failed to convert {file.name} to WAV format."
203
 
204
+ transcription = await transcribe_audio(
205
+ temp_audio_path,
206
+ language,
207
+ task=task,
208
+ initial_prompt=initial_prompt,
209
+ temperature=temperature,
210
+ num_speakers=num_speakers
211
+ )
212
+
213
+ os.remove(temp_audio_path)
214
+
 
 
 
 
 
 
 
 
 
 
 
 
215
  if anonymize:
216
  transcription = anonymize_text(transcription)
217
 
 
218
  if encryption_key:
219
  try:
220
  decrypt_file(encryption_key.encode(), file.name)
 
229
  logger.error(f"Error processing audio: {e}")
230
  return f"Error: {str(e)}"
231
 
 
232
  def create_ui():
233
  languages = {
234
  "en": "English", "es": "Spanish", "fr": "French", "de": "German", "it": "Italian",
235
  "pt": "Portuguese", "nl": "Dutch", "ru": "Russian", "zh": "Chinese", "ja": "Japanese",
236
+ "ko": "Korean", "ar": "Arabic", "hi": "Hindi", "bn": "Bengali", "ur": "Urdu",
237
  "te": "Telugu", "ta": "Tamil", "mr": "Marathi", "gu": "Gujarati", "kn": "Kannada"
238
  }
239
 
 
241
  gr.Markdown(
242
  """
243
  # 🎙️ Advanced Whisper Transcription App
244
+
245
  Transcribe or translate your audio and video files with ease, now with real-time processing!
246
+
247
  ## Features:
248
  - Support for multiple audio and video formats
249
  - Speaker diarization for multi-speaker audio
 
252
  - File encryption for enhanced security
253
  """
254
  )
255
+
256
  with gr.Tabs():
257
  with gr.TabItem("File Upload"):
258
  with gr.Row():
 
291
  )
292
  encryption_key = gr.Textbox(label="Encryption Key (Optional)", type="password")
293
  process_button = gr.Button("Process Audio", variant="primary")
294
+
295
  with gr.Column(scale=3):
296
  output_text = gr.Textbox(label="Transcription Output", lines=20)
297
+
298
  process_button.click(
299
  fn=process_audio,
300
  inputs=[file_input, language_dropdown, task_dropdown, anonymize_checkbox, prompt_input, temperature_slider, encryption_key, num_speakers],
301
  outputs=output_text
302
  )
303
+
304
  with gr.TabItem("Real-time Transcription"):
305
  with gr.Row():
306
  with gr.Column(scale=2):
 
328
  )
329
  rt_start_button = gr.Button("Start Real-time Transcription", variant="primary")
330
  rt_stop_button = gr.Button("Stop Transcription", variant="secondary")
331
+
332
  with gr.Column(scale=3):
333
  rt_output_text = gr.Textbox(label="Real-time Transcription Output", lines=20)
334
 
 
 
335
  async def start_real_time_transcription(language, task, prompt, temperature):
336
+ transcriber = RealTimeTranscriber(language, task, prompt, temperature)
337
+ transcription = await transcriber.start_recording()
 
338
  return transcription
339
 
340
  def stop_real_time_transcription():
341
+ return "Transcription stopped."
 
 
 
 
 
 
342
 
343
  rt_start_button.click(
344
  fn=start_real_time_transcription,
 
363
  - Click "Process Audio" and wait for the results.
364
  3. For Real-time Transcription:
365
  - Select the language and task.
366
+ - Optionally, provide an initial prompt and adjust the temperature.
367
+ - Click "Start Real-time Transcription" and speak into your microphone.
368
+ - Click "Stop Transcription" when you're done.
369
  """
370
  )
371
 
372
  return interface
373
 
 
374
  if __name__ == "__main__":
375
  create_folders()
376
  iface = create_ui()
377
+ iface.launch()