michaeltangz commited on
Commit
8f2a46b
·
1 Parent(s): f8af19e

refactor app.py to streamline flash attention installation and model loading; remove fallback mechanisms and enhance transcription parameters

Browse files
Files changed (1) hide show
  1. app.py +46 -174
app.py CHANGED
@@ -9,44 +9,19 @@ import time
9
  import numpy as np
10
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, WhisperTokenizer, pipeline
11
  import subprocess
12
-
13
- # Try to install flash-attn (optional, will fallback if fails)
14
- try:
15
- subprocess.run(
16
- "pip install flash-attn --no-build-isolation",
17
- env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
18
- shell=True,
19
- timeout=60,
20
- )
21
- print("✅ Flash Attention installed")
22
- except Exception as e:
23
- print(f"⚠️ Flash Attention installation failed (will use default): {e}")
24
 
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
27
  MODEL_NAME = "openai/whisper-large-v3-turbo"
28
 
29
- # Try to load model with flash attention, fallback to default if it fails
30
- try:
31
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
32
- MODEL_NAME,
33
- dtype=torch_dtype,
34
- low_cpu_mem_usage=True,
35
- use_safetensors=True,
36
- attn_implementation="flash_attention_2"
37
- )
38
- print("✅ Model loaded with Flash Attention 2")
39
- except Exception as e:
40
- print(f"⚠️ Could not load with Flash Attention 2: {e}")
41
- print("Loading with default attention implementation...")
42
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
43
- MODEL_NAME,
44
- dtype=torch_dtype,
45
- low_cpu_mem_usage=True,
46
- use_safetensors=True
47
- )
48
- print("✅ Model loaded with default attention")
49
-
50
  model.to(device)
51
 
52
  processor = AutoProcessor.from_pretrained(MODEL_NAME)
@@ -57,19 +32,11 @@ pipe = pipeline(
57
  model=model,
58
  tokenizer=tokenizer,
59
  feature_extractor=processor.feature_extractor,
60
- chunk_length_s=30, # Increased from 10 for better context
 
61
  device=device,
62
- ignore_warning=True,
63
  )
64
 
65
- # Voice Activity Detection
66
- def detect_voice_activity(audio, threshold=0.01):
67
- """Detect if audio contains speech based on energy."""
68
- if len(audio) == 0:
69
- return False
70
- rms = np.sqrt(np.mean(audio**2))
71
- return rms > threshold
72
-
73
  @spaces.GPU
74
  def stream_transcribe(stream, new_chunk):
75
  start_time = time.time()
@@ -81,51 +48,23 @@ def stream_transcribe(stream, new_chunk):
81
  y = y.mean(axis=1)
82
 
83
  y = y.astype(np.float32)
84
-
85
- # FIX: Prevent division by zero
86
  max_val = np.max(np.abs(y))
87
  if max_val > 0:
88
  y /= max_val
89
- else:
90
- # Silent audio, skip
91
- return stream, "", "0.00"
92
 
93
  if stream is not None:
94
  stream = np.concatenate([stream, y])
95
  else:
96
  stream = y
97
-
98
- # FIX: Limit buffer size to prevent memory issues and accumulated silence
99
- MAX_BUFFER = sr * 30 # 30 seconds maximum
100
- if len(stream) > MAX_BUFFER:
101
- stream = stream[-MAX_BUFFER:]
102
-
103
- # FIX: Check for voice activity before transcribing
104
- if not detect_voice_activity(stream, threshold=0.01):
105
- return stream, "", "0.00"
106
-
107
- # FIX: Require minimum audio length
108
- if len(stream) < sr * 1.0: # At least 1 second
109
- return stream, "", "0.00"
110
-
111
- # FIX: Add anti-hallucination parameters
112
- transcription = pipe(
113
- {"sampling_rate": sr, "raw": stream},
114
- generate_kwargs={
115
- "language": "english",
116
- "no_repeat_ngram_size": 3, # Prevents repetitive outputs
117
- }
118
- )["text"]
119
-
120
  end_time = time.time()
121
  latency = end_time - start_time
122
 
123
  return stream, transcription, f"{latency:.2f}"
124
  except Exception as e:
125
  print(f"Error during Transcription: {e}")
126
- import traceback
127
- traceback.print_exc()
128
- return stream if stream is not None else np.array([]), "", "Error"
129
 
130
  @spaces.GPU
131
  def transcribe(inputs, previous_transcription):
@@ -133,41 +72,16 @@ def transcribe(inputs, previous_transcription):
133
  try:
134
  filename = f"{uuid.uuid4().hex}.wav"
135
  sample_rate, audio_data = inputs
136
-
137
- # Convert to float for VAD check
138
- audio_float = audio_data.astype(np.float32)
139
- if audio_data.dtype == np.int16:
140
- audio_float /= 32768.0
141
- elif audio_data.dtype == np.int32:
142
- audio_float /= 2147483648.0
143
-
144
- # FIX: Check for voice activity before transcribing
145
- if not detect_voice_activity(audio_float, threshold=0.01):
146
- return previous_transcription + "\n[No speech detected in audio]", "0.00"
147
-
148
  scipy.io.wavfile.write(filename, sample_rate, audio_data)
149
 
150
- # FIX: Add anti-hallucination parameters
151
- transcription = pipe(
152
- filename,
153
- generate_kwargs={
154
- "language": "english",
155
- }
156
- )["text"]
157
-
158
  previous_transcription += transcription
159
-
160
- # Clean up temp file
161
- if os.path.exists(filename):
162
- os.remove(filename)
163
 
164
  end_time = time.time()
165
  latency = end_time - start_time
166
  return previous_transcription, f"{latency:.2f}"
167
  except Exception as e:
168
  print(f"Error during Transcription: {e}")
169
- import traceback
170
- traceback.print_exc()
171
  return previous_transcription, "Error"
172
 
173
  @spaces.GPU
@@ -178,27 +92,15 @@ def translate_and_transcribe(inputs, previous_transcription, target_language):
178
  sample_rate, audio_data = inputs
179
  scipy.io.wavfile.write(filename, sample_rate, audio_data)
180
 
181
- translation = pipe(
182
- filename,
183
- generate_kwargs={
184
- "task": "translate",
185
- "language": target_language,
186
- }
187
- )["text"]
188
 
189
  previous_transcription += translation
190
-
191
- # Clean up temp file
192
- if os.path.exists(filename):
193
- os.remove(filename)
194
 
195
  end_time = time.time()
196
  latency = end_time - start_time
197
  return previous_transcription, f"{latency:.2f}"
198
  except Exception as e:
199
  print(f"Error during Translation and Transcription: {e}")
200
- import traceback
201
- traceback.print_exc()
202
  return previous_transcription, "Error"
203
 
204
  def clear():
@@ -209,21 +111,7 @@ def clear_state():
209
 
210
  with gr.Blocks() as microphone:
211
  with gr.Column():
212
- gr.Markdown(f"""
213
- # 🎤 Realtime Whisper Large V3 Turbo
214
-
215
- Transcribe Audio in Realtime with **Voice Activity Detection** to prevent hallucinations.
216
-
217
- **Model:** [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME})
218
-
219
- **Features:**
220
- - Flash Attention 2 for speed
221
- - Voice Activity Detection (no "oh oh oh" hallucinations)
222
- - 30-second context window
223
- - Anti-repetition safeguards
224
-
225
- **Note:** First transcription takes ~5 seconds. After that, it works flawlessly.
226
- """)
227
  with gr.Row():
228
  input_audio_microphone = gr.Audio(streaming=True)
229
  output = gr.Textbox(label="Transcription", value="")
@@ -231,22 +119,12 @@ with gr.Blocks() as microphone:
231
  with gr.Row():
232
  clear_button = gr.Button("Clear Output")
233
  state = gr.State()
234
- input_audio_microphone.stream(
235
- stream_transcribe,
236
- [state, input_audio_microphone],
237
- [state, output, latency_textbox]
238
- )
239
  clear_button.click(clear_state, outputs=[state]).then(clear, outputs=[output])
240
 
241
  with gr.Blocks() as file:
242
  with gr.Column():
243
- gr.Markdown(f"""
244
- # 🎤 Realtime Whisper Large V3 Turbo
245
-
246
- Transcribe Audio Files.
247
-
248
- **Model:** [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME})
249
- """)
250
  with gr.Row():
251
  input_audio_microphone = gr.Audio(sources="upload", type="numpy")
252
  output = gr.Textbox(label="Transcription", value="")
@@ -258,38 +136,32 @@ with gr.Blocks() as file:
258
  submit_button.click(transcribe, [input_audio_microphone, output], [output, latency_textbox], concurrency_limit=None)
259
  clear_button.click(clear, outputs=[output])
260
 
261
- with gr.Blocks() as translate:
262
- with gr.Column():
263
- gr.Markdown(f"""
264
- # 🌍 Realtime Whisper Large V3 Turbo (Translation)
265
-
266
- Transcribe and Translate Audio in Realtime.
267
-
268
- **Model:** [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME})
269
-
270
- **Note:** First token takes ~5 seconds. After that, it works flawlessly.
271
- """)
272
- with gr.Row():
273
- input_audio_microphone = gr.Audio(streaming=True)
274
- output = gr.Textbox(label="Transcription and Translation", value="")
275
- latency_textbox = gr.Textbox(label="Latency (seconds)", value="0.0", scale=0)
276
- target_language_dropdown = gr.Dropdown(
277
- choices=["english", "french", "hindi", "spanish", "russian"],
278
- label="Target Language",
279
- value="english"
280
- )
281
- with gr.Row():
282
- clear_button = gr.Button("Clear Output")
283
-
284
- state = gr.State()
285
- input_audio_microphone.stream(
286
- translate_and_transcribe,
287
- [state, input_audio_microphone, target_language_dropdown],
288
- [state, output, latency_textbox]
289
- )
290
- clear_button.click(clear_state, outputs=[state]).then(clear, outputs=[output])
291
-
292
- with gr.Blocks() as demo:
293
- gr.TabbedInterface([microphone, file, translate], ["Microphone", "Upload File", "Translation"])
294
 
295
  demo.launch()
 
9
  import numpy as np
10
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, WhisperTokenizer, pipeline
11
  import subprocess
12
+ subprocess.run(
13
+ "pip install flash-attn --no-build-isolation",
14
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
15
+ shell=True,
16
+ )
 
 
 
 
 
 
 
17
 
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ torch_dtype = torch.float16
20
  MODEL_NAME = "openai/whisper-large-v3-turbo"
21
 
22
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
23
+ MODEL_NAME, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="flash_attention_2"
24
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  model.to(device)
26
 
27
  processor = AutoProcessor.from_pretrained(MODEL_NAME)
 
32
  model=model,
33
  tokenizer=tokenizer,
34
  feature_extractor=processor.feature_extractor,
35
+ chunk_length_s=10,
36
+ torch_dtype=torch_dtype,
37
  device=device,
 
38
  )
39
 
 
 
 
 
 
 
 
 
40
  @spaces.GPU
41
  def stream_transcribe(stream, new_chunk):
42
  start_time = time.time()
 
48
  y = y.mean(axis=1)
49
 
50
  y = y.astype(np.float32)
 
 
51
  max_val = np.max(np.abs(y))
52
  if max_val > 0:
53
  y /= max_val
 
 
 
54
 
55
  if stream is not None:
56
  stream = np.concatenate([stream, y])
57
  else:
58
  stream = y
59
+
60
+ transcription = pipe({"sampling_rate": sr, "raw": stream}, generate_kwargs={"condition_on_previous_text": False})["text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  end_time = time.time()
62
  latency = end_time - start_time
63
 
64
  return stream, transcription, f"{latency:.2f}"
65
  except Exception as e:
66
  print(f"Error during Transcription: {e}")
67
+ return stream, e, "Error"
 
 
68
 
69
  @spaces.GPU
70
  def transcribe(inputs, previous_transcription):
 
72
  try:
73
  filename = f"{uuid.uuid4().hex}.wav"
74
  sample_rate, audio_data = inputs
 
 
 
 
 
 
 
 
 
 
 
 
75
  scipy.io.wavfile.write(filename, sample_rate, audio_data)
76
 
77
+ transcription = pipe(filename, generate_kwargs={"condition_on_previous_text": False})["text"]
 
 
 
 
 
 
 
78
  previous_transcription += transcription
 
 
 
 
79
 
80
  end_time = time.time()
81
  latency = end_time - start_time
82
  return previous_transcription, f"{latency:.2f}"
83
  except Exception as e:
84
  print(f"Error during Transcription: {e}")
 
 
85
  return previous_transcription, "Error"
86
 
87
  @spaces.GPU
 
92
  sample_rate, audio_data = inputs
93
  scipy.io.wavfile.write(filename, sample_rate, audio_data)
94
 
95
+ translation = pipe(filename, generate_kwargs={"task": "translate", "language": target_language, "condition_on_previous_text": False})["text"]
 
 
 
 
 
 
96
 
97
  previous_transcription += translation
 
 
 
 
98
 
99
  end_time = time.time()
100
  latency = end_time - start_time
101
  return previous_transcription, f"{latency:.2f}"
102
  except Exception as e:
103
  print(f"Error during Translation and Transcription: {e}")
 
 
104
  return previous_transcription, "Error"
105
 
106
  def clear():
 
111
 
112
  with gr.Blocks() as microphone:
113
  with gr.Column():
114
+ gr.Markdown(f"# Realtime Whisper Large V3 Turbo: \n Transcribe Audio in Realtime. This Demo uses the Checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers.\n Note: The first token takes about 5 seconds. After that, it works flawlessly.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  with gr.Row():
116
  input_audio_microphone = gr.Audio(streaming=True)
117
  output = gr.Textbox(label="Transcription", value="")
 
119
  with gr.Row():
120
  clear_button = gr.Button("Clear Output")
121
  state = gr.State()
122
+ input_audio_microphone.stream(stream_transcribe, [state, input_audio_microphone], [state, output, latency_textbox], time_limit=30, stream_every=2, concurrency_limit=None)
 
 
 
 
123
  clear_button.click(clear_state, outputs=[state]).then(clear, outputs=[output])
124
 
125
  with gr.Blocks() as file:
126
  with gr.Column():
127
+ gr.Markdown(f"# Realtime Whisper Large V3 Turbo: \n Transcribe Audio in Realtime. This Demo uses the Checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers.\n Note: The first token takes about 5 seconds. After that, it works flawlessly.")
 
 
 
 
 
 
128
  with gr.Row():
129
  input_audio_microphone = gr.Audio(sources="upload", type="numpy")
130
  output = gr.Textbox(label="Transcription", value="")
 
136
  submit_button.click(transcribe, [input_audio_microphone, output], [output, latency_textbox], concurrency_limit=None)
137
  clear_button.click(clear, outputs=[output])
138
 
139
+ # with gr.Blocks() as translate:
140
+ # with gr.Column():
141
+ # gr.Markdown(f"# Realtime Whisper Large V3 Turbo (Translation): \n Transcribe and Translate Audio in Realtime. This Demo uses the Checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers.\n Note: The first token takes about 5 seconds. After that, it works flawlessly.")
142
+ # with gr.Row():
143
+ # input_audio_microphone = gr.Audio(streaming=True)
144
+ # output = gr.Textbox(label="Transcription and Translation", value="")
145
+ # latency_textbox = gr.Textbox(label="Latency (seconds)", value="0.0", scale=0)
146
+ # target_language_dropdown = gr.Dropdown(
147
+ # choices=["english", "french", "hindi", "spanish", "russian"],
148
+ # label="Target Language",
149
+ # value="<|es|>"
150
+ # )
151
+ # with gr.Row():
152
+ # clear_button = gr.Button("Clear Output")
153
+
154
+ # input_audio_microphone.stream(
155
+ # translate_and_transcribe,
156
+ # [input_audio_microphone, output, target_language_dropdown],
157
+ # [output, latency_textbox],
158
+ # time_limit=45,
159
+ # stream_every=2,
160
+ # concurrency_limit=None
161
+ # )
162
+ # clear_button.click(clear, outputs=[output])
163
+
164
+ with gr.Blocks(theme=gr.themes.Ocean()) as demo:
165
+ gr.TabbedInterface([microphone, file], ["Microphone", "Transcribe from file"])
 
 
 
 
 
 
166
 
167
  demo.launch()