michaeltangz commited on
Commit
62fccb4
·
1 Parent(s): ae149f3

refactor app.py to streamline flash attention installation and model loading; enhance voice activity detection and transcription accuracy

Browse files
Files changed (1) hide show
  1. app.py +162 -69
app.py CHANGED
@@ -9,35 +9,19 @@ import time
9
  import numpy as np
10
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, WhisperTokenizer, pipeline
11
  import subprocess
12
-
13
- # Try to install flash-attn (optional)
14
- try:
15
- subprocess.run(
16
- "pip install flash-attn --no-build-isolation",
17
- env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
18
- shell=True,
19
- )
20
- except Exception as e:
21
- print(f"Flash attention installation failed: {e}")
22
 
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
24
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
25
  MODEL_NAME = "openai/whisper-large-v3-turbo"
26
 
27
- # Try to load with flash attention, fallback if it fails
28
- try:
29
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
30
- MODEL_NAME, dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="flash_attention_2"
31
- )
32
- print("Loaded with flash_attention_2")
33
- except Exception as e:
34
- print(f"Could not load with flash_attention_2: {e}")
35
- print("Falling back to default attention...")
36
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
37
- MODEL_NAME, dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
38
- )
39
- print("Loaded with default attention")
40
-
41
  model.to(device)
42
 
43
  processor = AutoProcessor.from_pretrained(MODEL_NAME)
@@ -48,11 +32,19 @@ pipe = pipeline(
48
  model=model,
49
  tokenizer=tokenizer,
50
  feature_extractor=processor.feature_extractor,
51
- chunk_length_s=10,
 
52
  device=device,
53
- ignore_warning=True,
54
  )
55
 
 
 
 
 
 
 
 
 
56
  @spaces.GPU
57
  def stream_transcribe(stream, new_chunk):
58
  start_time = time.time()
@@ -64,25 +56,52 @@ def stream_transcribe(stream, new_chunk):
64
  y = y.mean(axis=1)
65
 
66
  y = y.astype(np.float32)
67
- y /= np.max(np.abs(y))
 
 
 
 
 
 
 
68
 
69
  if stream is not None:
70
  stream = np.concatenate([stream, y])
71
  else:
72
  stream = y
73
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  transcription = pipe(
75
- {"sampling_rate": sr, "raw": stream},
76
- generate_kwargs={
77
- "language": "english",
78
- "condition_on_previous_text": False })['text']
 
 
 
 
79
  end_time = time.time()
80
  latency = end_time - start_time
81
 
82
  return stream, transcription, f"{latency:.2f}"
83
  except Exception as e:
84
  print(f"Error during Transcription: {e}")
85
- return stream, str(e), "Error"
 
 
86
 
87
  @spaces.GPU
88
  def transcribe(inputs, previous_transcription):
@@ -90,16 +109,42 @@ def transcribe(inputs, previous_transcription):
90
  try:
91
  filename = f"{uuid.uuid4().hex}.wav"
92
  sample_rate, audio_data = inputs
 
 
 
 
 
 
 
 
 
 
 
 
93
  scipy.io.wavfile.write(filename, sample_rate, audio_data)
94
 
95
- transcription = pipe(filename, generate_kwargs={"language": "english"})["text"]
 
 
 
 
 
 
 
 
96
  previous_transcription += transcription
 
 
 
 
97
 
98
  end_time = time.time()
99
  latency = end_time - start_time
100
  return previous_transcription, f"{latency:.2f}"
101
  except Exception as e:
102
  print(f"Error during Transcription: {e}")
 
 
103
  return previous_transcription, "Error"
104
 
105
  @spaces.GPU
@@ -110,15 +155,28 @@ def translate_and_transcribe(inputs, previous_transcription, target_language):
110
  sample_rate, audio_data = inputs
111
  scipy.io.wavfile.write(filename, sample_rate, audio_data)
112
 
113
- translation = pipe(filename, generate_kwargs={"task": "translate", "language": target_language} )["text"]
 
 
 
 
 
 
 
114
 
115
  previous_transcription += translation
 
 
 
 
116
 
117
  end_time = time.time()
118
  latency = end_time - start_time
119
  return previous_transcription, f"{latency:.2f}"
120
  except Exception as e:
121
  print(f"Error during Translation and Transcription: {e}")
 
 
122
  return previous_transcription, "Error"
123
 
124
  def clear():
@@ -129,7 +187,21 @@ def clear_state():
129
 
130
  with gr.Blocks() as microphone:
131
  with gr.Column():
132
- gr.Markdown(f"# Realtime Whisper Large V3 Turbo: \n Transcribe Audio in Realtime. This Demo uses the Checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers.\n Note: The first token takes about 5 seconds. After that, it works flawlessly.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  with gr.Row():
134
  input_audio_microphone = gr.Audio(streaming=True)
135
  output = gr.Textbox(label="Transcription", value="")
@@ -137,12 +209,25 @@ with gr.Blocks() as microphone:
137
  with gr.Row():
138
  clear_button = gr.Button("Clear Output")
139
  state = gr.State()
140
- input_audio_microphone.stream(stream_transcribe, [state, input_audio_microphone], [state, output, latency_textbox])
 
 
 
 
 
 
 
141
  clear_button.click(clear_state, outputs=[state]).then(clear, outputs=[output])
142
 
143
  with gr.Blocks() as file:
144
  with gr.Column():
145
- gr.Markdown(f"# Realtime Whisper Large V3 Turbo: \n Transcribe Audio in Realtime. This Demo uses the Checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers.\n Note: The first token takes about 5 seconds. After that, it works flawlessly.")
 
 
 
 
 
 
146
  with gr.Row():
147
  input_audio_microphone = gr.Audio(sources="upload", type="numpy")
148
  output = gr.Textbox(label="Transcription", value="")
@@ -154,33 +239,41 @@ with gr.Blocks() as file:
154
  submit_button.click(transcribe, [input_audio_microphone, output], [output, latency_textbox], concurrency_limit=None)
155
  clear_button.click(clear, outputs=[output])
156
 
157
- # with gr.Blocks() as translate:
158
- # with gr.Column():
159
- # gr.Markdown(f"# Realtime Whisper Large V3 Turbo (Translation): \n Transcribe and Translate Audio in Realtime. This Demo uses the Checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers.\n Note: The first token takes about 5 seconds. After that, it works flawlessly.")
160
- # with gr.Row():
161
- # input_audio_microphone = gr.Audio(streaming=True)
162
- # output = gr.Textbox(label="Transcription and Translation", value="")
163
- # latency_textbox = gr.Textbox(label="Latency (seconds)", value="0.0", scale=0)
164
- # target_language_dropdown = gr.Dropdown(
165
- # choices=["english", "french", "hindi", "spanish", "russian"],
166
- # label="Target Language",
167
- # value="<|es|>"
168
- # )
169
- # with gr.Row():
170
- # clear_button = gr.Button("Clear Output")
171
-
172
- # input_audio_microphone.stream(
173
- # translate_and_transcribe,
174
- # [input_audio_microphone, output, target_language_dropdown],
175
- # [output, latency_textbox],
176
- # time_limit=45,
177
- # stream_every=2,
178
- # concurrency_limit=None
179
- # )
180
- # clear_button.click(clear, outputs=[output])
181
-
182
- with gr.Blocks() as demo:
183
- gr.TabbedInterface([microphone, file], ["Microphone", "Transcribe from file"])
184
-
185
- if __name__ == "__main__":
186
- demo.launch(share=False)
 
 
 
 
 
 
 
 
 
9
  import numpy as np
10
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, WhisperTokenizer, pipeline
11
  import subprocess
12
+ subprocess.run(
13
+ "pip install flash-attn --no-build-isolation",
14
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
15
+ shell=True,
16
+ )
 
 
 
 
 
17
 
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ torch_dtype = torch.float16
20
  MODEL_NAME = "openai/whisper-large-v3-turbo"
21
 
22
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
23
+ MODEL_NAME, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="flash_attention_2"
24
+ )
 
 
 
 
 
 
 
 
 
 
 
25
  model.to(device)
26
 
27
  processor = AutoProcessor.from_pretrained(MODEL_NAME)
 
32
  model=model,
33
  tokenizer=tokenizer,
34
  feature_extractor=processor.feature_extractor,
35
+ chunk_length_s=30, # Increased from 10 for better context
36
+ torch_dtype=torch_dtype,
37
  device=device,
 
38
  )
39
 
40
+ # Voice Activity Detection
41
+ def detect_voice_activity(audio, threshold=0.01):
42
+ """Detect if audio contains speech based on energy."""
43
+ if len(audio) == 0:
44
+ return False
45
+ rms = np.sqrt(np.mean(audio**2))
46
+ return rms > threshold
47
+
48
  @spaces.GPU
49
  def stream_transcribe(stream, new_chunk):
50
  start_time = time.time()
 
56
  y = y.mean(axis=1)
57
 
58
  y = y.astype(np.float32)
59
+
60
+ # FIX: Prevent division by zero
61
+ max_val = np.max(np.abs(y))
62
+ if max_val > 0:
63
+ y /= max_val
64
+ else:
65
+ # Silent audio, skip
66
+ return stream, "", "0.00"
67
 
68
  if stream is not None:
69
  stream = np.concatenate([stream, y])
70
  else:
71
  stream = y
72
+
73
+ # FIX: Limit buffer size to prevent memory issues and accumulated silence
74
+ MAX_BUFFER = sr * 30 # 30 seconds maximum
75
+ if len(stream) > MAX_BUFFER:
76
+ stream = stream[-MAX_BUFFER:]
77
+
78
+ # FIX: Check for voice activity before transcribing
79
+ if not detect_voice_activity(stream, threshold=0.01):
80
+ return stream, "", "0.00"
81
+
82
+ # FIX: Require minimum audio length
83
+ if len(stream) < sr * 1.0: # At least 1 second
84
+ return stream, "", "0.00"
85
+
86
+ # FIX: Add anti-hallucination parameters
87
  transcription = pipe(
88
+ {"sampling_rate": sr, "raw": stream},
89
+ generate_kwargs={
90
+ "language": "english",
91
+ "condition_on_previous_text": False, # Prevents hallucinations
92
+ "no_repeat_ngram_size": 3, # Prevents repetitive outputs
93
+ }
94
+ )["text"]
95
+
96
  end_time = time.time()
97
  latency = end_time - start_time
98
 
99
  return stream, transcription, f"{latency:.2f}"
100
  except Exception as e:
101
  print(f"Error during Transcription: {e}")
102
+ import traceback
103
+ traceback.print_exc()
104
+ return stream if stream is not None else np.array([]), "", "Error"
105
 
106
  @spaces.GPU
107
  def transcribe(inputs, previous_transcription):
 
109
  try:
110
  filename = f"{uuid.uuid4().hex}.wav"
111
  sample_rate, audio_data = inputs
112
+
113
+ # Convert to float for VAD check
114
+ audio_float = audio_data.astype(np.float32)
115
+ if audio_data.dtype == np.int16:
116
+ audio_float /= 32768.0
117
+ elif audio_data.dtype == np.int32:
118
+ audio_float /= 2147483648.0
119
+
120
+ # FIX: Check for voice activity before transcribing
121
+ if not detect_voice_activity(audio_float, threshold=0.01):
122
+ return previous_transcription + "\n[No speech detected in audio]", "0.00"
123
+
124
  scipy.io.wavfile.write(filename, sample_rate, audio_data)
125
 
126
+ # FIX: Add anti-hallucination parameters
127
+ transcription = pipe(
128
+ filename,
129
+ generate_kwargs={
130
+ "language": "english",
131
+ "condition_on_previous_text": False,
132
+ }
133
+ )["text"]
134
+
135
  previous_transcription += transcription
136
+
137
+ # Clean up temp file
138
+ if os.path.exists(filename):
139
+ os.remove(filename)
140
 
141
  end_time = time.time()
142
  latency = end_time - start_time
143
  return previous_transcription, f"{latency:.2f}"
144
  except Exception as e:
145
  print(f"Error during Transcription: {e}")
146
+ import traceback
147
+ traceback.print_exc()
148
  return previous_transcription, "Error"
149
 
150
  @spaces.GPU
 
155
  sample_rate, audio_data = inputs
156
  scipy.io.wavfile.write(filename, sample_rate, audio_data)
157
 
158
+ translation = pipe(
159
+ filename,
160
+ generate_kwargs={
161
+ "task": "translate",
162
+ "language": target_language,
163
+ "condition_on_previous_text": False,
164
+ }
165
+ )["text"]
166
 
167
  previous_transcription += translation
168
+
169
+ # Clean up temp file
170
+ if os.path.exists(filename):
171
+ os.remove(filename)
172
 
173
  end_time = time.time()
174
  latency = end_time - start_time
175
  return previous_transcription, f"{latency:.2f}"
176
  except Exception as e:
177
  print(f"Error during Translation and Transcription: {e}")
178
+ import traceback
179
+ traceback.print_exc()
180
  return previous_transcription, "Error"
181
 
182
  def clear():
 
187
 
188
  with gr.Blocks() as microphone:
189
  with gr.Column():
190
+ gr.Markdown(f"""
191
+ # 🎤 Realtime Whisper Large V3 Turbo
192
+
193
+ Transcribe Audio in Realtime with **Voice Activity Detection** to prevent hallucinations.
194
+
195
+ **Model:** [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME})
196
+
197
+ **Features:**
198
+ - Flash Attention 2 for speed
199
+ - Voice Activity Detection (no "oh oh oh" hallucinations)
200
+ - 30-second context window
201
+ - Anti-repetition safeguards
202
+
203
+ **Note:** First transcription takes ~5 seconds. After that, it works flawlessly.
204
+ """)
205
  with gr.Row():
206
  input_audio_microphone = gr.Audio(streaming=True)
207
  output = gr.Textbox(label="Transcription", value="")
 
209
  with gr.Row():
210
  clear_button = gr.Button("Clear Output")
211
  state = gr.State()
212
+ input_audio_microphone.stream(
213
+ stream_transcribe,
214
+ [state, input_audio_microphone],
215
+ [state, output, latency_textbox],
216
+ time_limit=60, # Increased from 30
217
+ stream_every=2,
218
+ concurrency_limit=None
219
+ )
220
  clear_button.click(clear_state, outputs=[state]).then(clear, outputs=[output])
221
 
222
  with gr.Blocks() as file:
223
  with gr.Column():
224
+ gr.Markdown(f"""
225
+ # 🎤 Realtime Whisper Large V3 Turbo
226
+
227
+ Transcribe Audio Files.
228
+
229
+ **Model:** [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME})
230
+ """)
231
  with gr.Row():
232
  input_audio_microphone = gr.Audio(sources="upload", type="numpy")
233
  output = gr.Textbox(label="Transcription", value="")
 
239
  submit_button.click(transcribe, [input_audio_microphone, output], [output, latency_textbox], concurrency_limit=None)
240
  clear_button.click(clear, outputs=[output])
241
 
242
+ with gr.Blocks() as translate:
243
+ with gr.Column():
244
+ gr.Markdown(f"""
245
+ # 🌍 Realtime Whisper Large V3 Turbo (Translation)
246
+
247
+ Transcribe and Translate Audio in Realtime.
248
+
249
+ **Model:** [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME})
250
+
251
+ **Note:** First token takes ~5 seconds. After that, it works flawlessly.
252
+ """)
253
+ with gr.Row():
254
+ input_audio_microphone = gr.Audio(streaming=True)
255
+ output = gr.Textbox(label="Transcription and Translation", value="")
256
+ latency_textbox = gr.Textbox(label="Latency (seconds)", value="0.0", scale=0)
257
+ target_language_dropdown = gr.Dropdown(
258
+ choices=["english", "french", "hindi", "spanish", "russian"],
259
+ label="Target Language",
260
+ value="english"
261
+ )
262
+ with gr.Row():
263
+ clear_button = gr.Button("Clear Output")
264
+
265
+ state = gr.State()
266
+ input_audio_microphone.stream(
267
+ translate_and_transcribe,
268
+ [state, input_audio_microphone, target_language_dropdown],
269
+ [state, output, latency_textbox],
270
+ time_limit=60,
271
+ stream_every=2,
272
+ concurrency_limit=None
273
+ )
274
+ clear_button.click(clear_state, outputs=[state]).then(clear, outputs=[output])
275
+
276
+ with gr.Blocks(theme=gr.themes.Ocean()) as demo:
277
+ gr.TabbedInterface([microphone, file, translate], ["Microphone", "Upload File", "Translation"])
278
+
279
+ demo.launch()