Anjan9320 commited on
Commit
d630904
·
verified ·
1 Parent(s): b7e41a0

Update f5_tts/infer/utils_infer.py

Browse files
Files changed (1) hide show
  1. f5_tts/infer/utils_infer.py +18 -13
f5_tts/infer/utils_infer.py CHANGED
@@ -137,12 +137,15 @@ asr_pipe = None
137
 
138
  def initialize_asr_pipeline(device: str = device, dtype=None):
139
  if dtype is None:
140
- dtype = (
141
- torch.float16
142
- if "cuda" in device
143
- and torch.cuda.get_device_properties(device).major >= 6
144
- else torch.float32
145
- )
 
 
 
146
  global asr_pipe
147
  asr_pipe = pipeline(
148
  "automatic-speech-recognition",
@@ -173,13 +176,15 @@ def transcribe(ref_audio, language=None):
173
 
174
  def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
175
  if dtype is None:
176
- #dtype = torch.float32
177
- dtype = (
178
- torch.float16
179
- if "cuda" in device
180
- and torch.cuda.get_device_properties(device).major >= 6
181
- else torch.float32
182
- )
 
 
183
  model = model.to(dtype)
184
 
185
  ckpt_type = ckpt_path.split(".")[-1]
 
137
 
138
  def initialize_asr_pipeline(device: str = device, dtype=None):
139
  if dtype is None:
140
+ if "cuda" in device:
141
+ if torch.cuda.is_bf16_supported():
142
+ dtype = torch.bfloat16
143
+ elif torch.cuda.get_device_properties(device).major >= 6:
144
+ dtype = torch.float16
145
+ else:
146
+ dtype = torch.float32
147
+ else:
148
+ dtype = torch.float32
149
  global asr_pipe
150
  asr_pipe = pipeline(
151
  "automatic-speech-recognition",
 
176
 
177
  def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
178
  if dtype is None:
179
+ if "cuda" in device:
180
+ if torch.cuda.is_bf16_supported():
181
+ dtype = torch.bfloat16
182
+ elif torch.cuda.get_device_properties(device).major >= 6:
183
+ dtype = torch.float16
184
+ else:
185
+ dtype = torch.float32
186
+ else:
187
+ dtype = torch.float32
188
  model = model.to(dtype)
189
 
190
  ckpt_type = ckpt_path.split(".")[-1]