michaeltangz commited on
Commit
721ab04
·
1 Parent(s): 20eeccd

refactor app.py to remove flash attention installation logic and simplify attention implementation; enhance error handling in transcription functions

Browse files
Files changed (1) hide show
  1. app.py +12 -65
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import spaces
2
  import torch
3
  import gradio as gr
4
- import tempfile
5
  import os
6
  import uuid
7
  import scipy.io.wavfile
@@ -13,25 +12,12 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
13
  torch_dtype = torch.float16
14
  MODEL_NAME = "openai/whisper-large-v3-turbo"
15
 
16
- # Try to use flash attention, fall back to sdpa if not available
17
- try:
18
- import subprocess
19
- subprocess.run(
20
- "pip install flash-attn --no-build-isolation",
21
- env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
22
- shell=True,
23
- )
24
- from flash_attn import flash_attn_func
25
- attn_implementation = "flash_attention_2"
26
- except Exception:
27
- attn_implementation = "sdpa" # Use PyTorch's scaled dot product attention
28
-
29
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
30
  MODEL_NAME,
31
  torch_dtype=torch_dtype,
32
  low_cpu_mem_usage=True,
33
  use_safetensors=True,
34
- attn_implementation=attn_implementation
35
  )
36
  model.to(device)
37
 
@@ -46,6 +32,7 @@ pipe = pipeline(
46
  chunk_length_s=10,
47
  torch_dtype=torch_dtype,
48
  device=device,
 
49
  )
50
 
51
  @spaces.GPU
@@ -54,7 +41,6 @@ def stream_transcribe(stream, new_chunk):
54
  try:
55
  sr, y = new_chunk
56
 
57
- # Convert to mono if stereo
58
  if y.ndim > 1:
59
  y = y.mean(axis=1)
60
 
@@ -75,7 +61,7 @@ def stream_transcribe(stream, new_chunk):
75
  return stream, transcription, f"{latency:.2f}"
76
  except Exception as e:
77
  print(f"Error during Transcription: {e}")
78
- return stream, e, "Error"
79
 
80
  @spaces.GPU
81
  def transcribe(inputs, previous_transcription):
@@ -95,25 +81,6 @@ def transcribe(inputs, previous_transcription):
95
  print(f"Error during Transcription: {e}")
96
  return previous_transcription, "Error"
97
 
98
- @spaces.GPU
99
- def translate_and_transcribe(inputs, previous_transcription, target_language):
100
- start_time = time.time()
101
- try:
102
- filename = f"{uuid.uuid4().hex}.wav"
103
- sample_rate, audio_data = inputs
104
- scipy.io.wavfile.write(filename, sample_rate, audio_data)
105
-
106
- translation = pipe(filename, generate_kwargs={"task": "translate", "language": target_language, "condition_on_previous_text": False})["text"]
107
-
108
- previous_transcription += translation
109
-
110
- end_time = time.time()
111
- latency = end_time - start_time
112
- return previous_transcription, f"{latency:.2f}"
113
- except Exception as e:
114
- print(f"Error during Translation and Transcription: {e}")
115
- return previous_transcription, "Error"
116
-
117
  def clear():
118
  return ""
119
 
@@ -122,7 +89,7 @@ def clear_state():
122
 
123
  with gr.Blocks() as microphone:
124
  with gr.Column():
125
- gr.Markdown(f"# Realtime Whisper Large V3 Turbo: \n Transcribe Audio in Realtime. This Demo uses the Checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers.\n Note: The first token takes about 5 seconds. After that, it works flawlessly.")
126
  with gr.Row():
127
  input_audio_microphone = gr.Audio(streaming=True)
128
  output = gr.Textbox(label="Transcription", value="")
@@ -130,12 +97,17 @@ with gr.Blocks() as microphone:
130
  with gr.Row():
131
  clear_button = gr.Button("Clear Output")
132
  state = gr.State()
133
- input_audio_microphone.stream(stream_transcribe, [state, input_audio_microphone], [state, output, latency_textbox], time_limit=30, stream_every=2, concurrency_limit=None)
 
 
 
 
 
134
  clear_button.click(clear_state, outputs=[state]).then(clear, outputs=[output])
135
 
136
  with gr.Blocks() as file:
137
  with gr.Column():
138
- gr.Markdown(f"# Realtime Whisper Large V3 Turbo: \n Transcribe Audio in Realtime. This Demo uses the Checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers.\n Note: The first token takes about 5 seconds. After that, it works flawlessly.")
139
  with gr.Row():
140
  input_audio_microphone = gr.Audio(sources="upload", type="numpy")
141
  output = gr.Textbox(label="Transcription", value="")
@@ -144,34 +116,9 @@ with gr.Blocks() as file:
144
  submit_button = gr.Button("Submit")
145
  clear_button = gr.Button("Clear Output")
146
 
147
- submit_button.click(transcribe, [input_audio_microphone, output], [output, latency_textbox], concurrency_limit=None)
148
  clear_button.click(clear, outputs=[output])
149
 
150
- # with gr.Blocks() as translate:
151
- # with gr.Column():
152
- # gr.Markdown(f"# Realtime Whisper Large V3 Turbo (Translation): \n Transcribe and Translate Audio in Realtime. This Demo uses the Checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers.\n Note: The first token takes about 5 seconds. After that, it works flawlessly.")
153
- # with gr.Row():
154
- # input_audio_microphone = gr.Audio(streaming=True)
155
- # output = gr.Textbox(label="Transcription and Translation", value="")
156
- # latency_textbox = gr.Textbox(label="Latency (seconds)", value="0.0", scale=0)
157
- # target_language_dropdown = gr.Dropdown(
158
- # choices=["english", "french", "hindi", "spanish", "russian"],
159
- # label="Target Language",
160
- # value="<|es|>"
161
- # )
162
- # with gr.Row():
163
- # clear_button = gr.Button("Clear Output")
164
-
165
- # input_audio_microphone.stream(
166
- # translate_and_transcribe,
167
- # [input_audio_microphone, output, target_language_dropdown],
168
- # [output, latency_textbox],
169
- # time_limit=45,
170
- # stream_every=2,
171
- # concurrency_limit=None
172
- # )
173
- # clear_button.click(clear, outputs=[output])
174
-
175
  with gr.Blocks(theme=gr.themes.Ocean()) as demo:
176
  gr.TabbedInterface([microphone, file], ["Microphone", "Transcribe from file"])
177
 
 
1
  import spaces
2
  import torch
3
  import gradio as gr
 
4
  import os
5
  import uuid
6
  import scipy.io.wavfile
 
12
  torch_dtype = torch.float16
13
  MODEL_NAME = "openai/whisper-large-v3-turbo"
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
16
  MODEL_NAME,
17
  torch_dtype=torch_dtype,
18
  low_cpu_mem_usage=True,
19
  use_safetensors=True,
20
+ attn_implementation="sdpa"
21
  )
22
  model.to(device)
23
 
 
32
  chunk_length_s=10,
33
  torch_dtype=torch_dtype,
34
  device=device,
35
+ ignore_warning=True,
36
  )
37
 
38
  @spaces.GPU
 
41
  try:
42
  sr, y = new_chunk
43
 
 
44
  if y.ndim > 1:
45
  y = y.mean(axis=1)
46
 
 
61
  return stream, transcription, f"{latency:.2f}"
62
  except Exception as e:
63
  print(f"Error during Transcription: {e}")
64
+ return stream, str(e), "Error"
65
 
66
  @spaces.GPU
67
  def transcribe(inputs, previous_transcription):
 
81
  print(f"Error during Transcription: {e}")
82
  return previous_transcription, "Error"
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  def clear():
85
  return ""
86
 
 
89
 
90
  with gr.Blocks() as microphone:
91
  with gr.Column():
92
+ gr.Markdown(f"# Realtime Whisper Large V3 Turbo\nTranscribe Audio in Realtime. This Demo uses the Checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers.")
93
  with gr.Row():
94
  input_audio_microphone = gr.Audio(streaming=True)
95
  output = gr.Textbox(label="Transcription", value="")
 
97
  with gr.Row():
98
  clear_button = gr.Button("Clear Output")
99
  state = gr.State()
100
+ input_audio_microphone.stream(
101
+ stream_transcribe,
102
+ inputs=[state, input_audio_microphone],
103
+ outputs=[state, output, latency_textbox],
104
+ stream_every=2
105
+ )
106
  clear_button.click(clear_state, outputs=[state]).then(clear, outputs=[output])
107
 
108
  with gr.Blocks() as file:
109
  with gr.Column():
110
+ gr.Markdown(f"# Realtime Whisper Large V3 Turbo\nTranscribe Audio in Realtime. This Demo uses the Checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers.")
111
  with gr.Row():
112
  input_audio_microphone = gr.Audio(sources="upload", type="numpy")
113
  output = gr.Textbox(label="Transcription", value="")
 
116
  submit_button = gr.Button("Submit")
117
  clear_button = gr.Button("Clear Output")
118
 
119
+ submit_button.click(transcribe, inputs=[input_audio_microphone, output], outputs=[output, latency_textbox])
120
  clear_button.click(clear, outputs=[output])
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  with gr.Blocks(theme=gr.themes.Ocean()) as demo:
123
  gr.TabbedInterface([microphone, file], ["Microphone", "Transcribe from file"])
124