michaeltangz commited on
Commit
6fdae11
·
1 Parent(s): 7b34cad

refactor app.py to streamline flash-attn installation and model loading; update requirements.txt to remove unnecessary dependencies

Browse files
Files changed (3) hide show
  1. .gitattributes +0 -1
  2. app.py +56 -33
  3. requirements.txt +0 -11
.gitattributes CHANGED
@@ -25,7 +25,6 @@
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
  *.wasm filter=lfs diff=lfs merge=lfs -text
 
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
28
  *.tflite filter=lfs diff=lfs merge=lfs -text
29
  *.tgz filter=lfs diff=lfs merge=lfs -text
30
  *.wasm filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -9,39 +9,19 @@ import time
9
  import numpy as np
10
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, WhisperTokenizer, pipeline
11
  import subprocess
12
-
13
- # Install flash-attn if possible, but don't fail if it doesn't work
14
- try:
15
- subprocess.run(
16
- "pip install flash-attn --no-build-isolation",
17
- env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
18
- shell=True,
19
- )
20
- except Exception as e:
21
- print(f"Flash attention installation skipped: {e}")
22
 
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
24
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
25
  MODEL_NAME = "openai/whisper-large-v3-turbo"
26
 
27
- # Load model with flash attention if available
28
- try:
29
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
30
- MODEL_NAME,
31
- torch_dtype=torch_dtype,
32
- low_cpu_mem_usage=True,
33
- use_safetensors=True,
34
- attn_implementation="flash_attention_2"
35
- )
36
- except Exception as e:
37
- print(f"Could not load with flash_attention_2, falling back to default: {e}")
38
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
39
- MODEL_NAME,
40
- torch_dtype=torch_dtype,
41
- low_cpu_mem_usage=True,
42
- use_safetensors=True
43
- )
44
-
45
  model.to(device)
46
 
47
  processor = AutoProcessor.from_pretrained(MODEL_NAME)
@@ -52,7 +32,7 @@ pipe = pipeline(
52
  model=model,
53
  tokenizer=tokenizer,
54
  feature_extractor=processor.feature_extractor,
55
- chunk_length_s=30,
56
  torch_dtype=torch_dtype,
57
  device=device,
58
  )
@@ -82,7 +62,7 @@ def stream_transcribe(stream, new_chunk):
82
  return stream, transcription, f"{latency:.2f}"
83
  except Exception as e:
84
  print(f"Error during Transcription: {e}")
85
- return stream, str(e), "Error"
86
 
87
  @spaces.GPU
88
  def transcribe(inputs, previous_transcription):
@@ -102,6 +82,25 @@ def transcribe(inputs, previous_transcription):
102
  print(f"Error during Transcription: {e}")
103
  return previous_transcription, "Error"
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  def clear():
106
  return ""
107
 
@@ -135,8 +134,32 @@ with gr.Blocks() as file:
135
  submit_button.click(transcribe, [input_audio_microphone, output], [output, latency_textbox], concurrency_limit=None)
136
  clear_button.click(clear, outputs=[output])
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  with gr.Blocks(theme=gr.themes.Ocean()) as demo:
139
  gr.TabbedInterface([microphone, file], ["Microphone", "Transcribe from file"])
140
 
141
- if __name__ == "__main__":
142
- demo.launch(share=False)
 
9
  import numpy as np
10
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, WhisperTokenizer, pipeline
11
  import subprocess
12
+ subprocess.run(
13
+ "pip install flash-attn --no-build-isolation",
14
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
15
+ shell=True,
16
+ )
 
 
 
 
 
17
 
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ torch_dtype = torch.float16
20
  MODEL_NAME = "openai/whisper-large-v3-turbo"
21
 
22
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
23
+ MODEL_NAME, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="flash_attention_2"
24
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  model.to(device)
26
 
27
  processor = AutoProcessor.from_pretrained(MODEL_NAME)
 
32
  model=model,
33
  tokenizer=tokenizer,
34
  feature_extractor=processor.feature_extractor,
35
+ chunk_length_s=10,
36
  torch_dtype=torch_dtype,
37
  device=device,
38
  )
 
62
  return stream, transcription, f"{latency:.2f}"
63
  except Exception as e:
64
  print(f"Error during Transcription: {e}")
65
+ return stream, e, "Error"
66
 
67
  @spaces.GPU
68
  def transcribe(inputs, previous_transcription):
 
82
  print(f"Error during Transcription: {e}")
83
  return previous_transcription, "Error"
84
 
85
+ @spaces.GPU
86
+ def translate_and_transcribe(inputs, previous_transcription, target_language):
87
+ start_time = time.time()
88
+ try:
89
+ filename = f"{uuid.uuid4().hex}.wav"
90
+ sample_rate, audio_data = inputs
91
+ scipy.io.wavfile.write(filename, sample_rate, audio_data)
92
+
93
+ translation = pipe(filename, generate_kwargs={"task": "translate", "language": target_language} )["text"]
94
+
95
+ previous_transcription += translation
96
+
97
+ end_time = time.time()
98
+ latency = end_time - start_time
99
+ return previous_transcription, f"{latency:.2f}"
100
+ except Exception as e:
101
+ print(f"Error during Translation and Transcription: {e}")
102
+ return previous_transcription, "Error"
103
+
104
  def clear():
105
  return ""
106
 
 
134
  submit_button.click(transcribe, [input_audio_microphone, output], [output, latency_textbox], concurrency_limit=None)
135
  clear_button.click(clear, outputs=[output])
136
 
137
+ # with gr.Blocks() as translate:
138
+ # with gr.Column():
139
+ # gr.Markdown(f"# Realtime Whisper Large V3 Turbo (Translation): \n Transcribe and Translate Audio in Realtime. This Demo uses the Checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers.\n Note: The first token takes about 5 seconds. After that, it works flawlessly.")
140
+ # with gr.Row():
141
+ # input_audio_microphone = gr.Audio(streaming=True)
142
+ # output = gr.Textbox(label="Transcription and Translation", value="")
143
+ # latency_textbox = gr.Textbox(label="Latency (seconds)", value="0.0", scale=0)
144
+ # target_language_dropdown = gr.Dropdown(
145
+ # choices=["english", "french", "hindi", "spanish", "russian"],
146
+ # label="Target Language",
147
+ # value="<|es|>"
148
+ # )
149
+ # with gr.Row():
150
+ # clear_button = gr.Button("Clear Output")
151
+
152
+ # input_audio_microphone.stream(
153
+ # translate_and_transcribe,
154
+ # [input_audio_microphone, output, target_language_dropdown],
155
+ # [output, latency_textbox],
156
+ # time_limit=45,
157
+ # stream_every=2,
158
+ # concurrency_limit=None
159
+ # )
160
+ # clear_button.click(clear, outputs=[output])
161
+
162
  with gr.Blocks(theme=gr.themes.Ocean()) as demo:
163
  gr.TabbedInterface([microphone, file], ["Microphone", "Transcribe from file"])
164
 
165
+ demo.launch()
 
requirements.txt CHANGED
@@ -1,14 +1,3 @@
1
- torch==2.6.0
2
- gradio==4.44.1
3
- numpy==1.24.3
4
- spaces>=0.20.0
5
- accelerate>=0.24.0
6
- safetensors>=0.4.0
7
- sentencepiece>=0.1.99
8
- protobuf>=3.20.0
9
- webrtcvad
10
- librosa
11
- flash-attn
12
  transformers
13
  scipy
14
  accelerate
 
 
 
 
 
 
 
 
 
 
 
 
1
  transformers
2
  scipy
3
  accelerate