Anjan9320 commited on
Commit
d23f87e
·
verified ·
1 Parent(s): 879ec61

Update f5_tts/infer/utils_infer.py

Browse files
Files changed (1) hide show
  1. f5_tts/infer/utils_infer.py +2 -6
f5_tts/infer/utils_infer.py CHANGED
@@ -138,9 +138,7 @@ asr_pipe = None
138
  def initialize_asr_pipeline(device: str = device, dtype=None):
139
  if dtype is None:
140
  if "cuda" in device:
141
- if torch.cuda.is_bf16_supported():
142
- dtype = torch.bfloat16
143
- elif torch.cuda.get_device_properties(device).major >= 6:
144
  dtype = torch.float16
145
  else:
146
  dtype = torch.float32
@@ -177,9 +175,7 @@ def transcribe(ref_audio, language=None):
177
  def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
178
  if dtype is None:
179
  if "cuda" in device:
180
- if torch.cuda.is_bf16_supported():
181
- dtype = torch.bfloat16
182
- elif torch.cuda.get_device_properties(device).major >= 6:
183
  dtype = torch.float16
184
  else:
185
  dtype = torch.float32
 
138
  def initialize_asr_pipeline(device: str = device, dtype=None):
139
  if dtype is None:
140
  if "cuda" in device:
141
+ if torch.cuda.get_device_properties(device).major >= 6:
 
 
142
  dtype = torch.float16
143
  else:
144
  dtype = torch.float32
 
175
  def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
176
  if dtype is None:
177
  if "cuda" in device:
178
+ if torch.cuda.get_device_properties(device).major >= 6:
 
 
179
  dtype = torch.float16
180
  else:
181
  dtype = torch.float32