michaeltangz commited on
Commit
574825e
·
1 Parent(s): 6cd2d8c

fix app.py to correct dtype parameter usage in model initialization and pipeline; remove redundant torch_dtype argument

Browse files
Files changed (1) hide show
  1. app.py +3 -5
app.py CHANGED
@@ -9,12 +9,12 @@ import numpy as np
9
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, WhisperTokenizer, pipeline
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
- torch_dtype = torch.float16
13
  MODEL_NAME = "openai/whisper-large-v3-turbo"
14
 
15
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
16
  MODEL_NAME,
17
- torch_dtype=torch_dtype,
18
  low_cpu_mem_usage=True,
19
  use_safetensors=True,
20
  attn_implementation="sdpa"
@@ -30,7 +30,6 @@ pipe = pipeline(
30
  tokenizer=tokenizer,
31
  feature_extractor=processor.feature_extractor,
32
  chunk_length_s=10,
33
- torch_dtype=torch_dtype,
34
  device=device,
35
  ignore_warning=True,
36
  )
@@ -100,8 +99,7 @@ with gr.Blocks() as microphone:
100
  input_audio_microphone.stream(
101
  stream_transcribe,
102
  inputs=[state, input_audio_microphone],
103
- outputs=[state, output, latency_textbox],
104
- stream_every=2
105
  )
106
  clear_button.click(clear_state, outputs=[state]).then(clear, outputs=[output])
107
 
 
9
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, WhisperTokenizer, pipeline
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
13
  MODEL_NAME = "openai/whisper-large-v3-turbo"
14
 
15
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
16
  MODEL_NAME,
17
+ dtype=torch_dtype,
18
  low_cpu_mem_usage=True,
19
  use_safetensors=True,
20
  attn_implementation="sdpa"
 
30
  tokenizer=tokenizer,
31
  feature_extractor=processor.feature_extractor,
32
  chunk_length_s=10,
 
33
  device=device,
34
  ignore_warning=True,
35
  )
 
99
  input_audio_microphone.stream(
100
  stream_transcribe,
101
  inputs=[state, input_audio_microphone],
102
+ outputs=[state, output, latency_textbox]
 
103
  )
104
  clear_button.click(clear_state, outputs=[state]).then(clear, outputs=[output])
105