Anjan9320 commited on
Commit
2b92217
·
verified ·
1 Parent(s): 989d1ac

Update f5_tts/infer/utils_infer.py

Browse files
Files changed (1) hide show
  1. f5_tts/infer/utils_infer.py +2 -8
f5_tts/infer/utils_infer.py CHANGED
@@ -138,10 +138,7 @@ asr_pipe = None
138
  def initialize_asr_pipeline(device: str = device, dtype=None):
139
  if dtype is None:
140
  if "cuda" in device:
141
- if torch.cuda.get_device_properties(device).major >= 6:
142
- dtype = torch.float16
143
- else:
144
- dtype = torch.float32
145
  else:
146
  dtype = torch.float32
147
  global asr_pipe
@@ -175,10 +172,7 @@ def transcribe(ref_audio, language=None):
175
  def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
176
  if dtype is None:
177
  if "cuda" in device:
178
- if torch.cuda.get_device_properties(device).major >= 6:
179
- dtype = torch.float16
180
- else:
181
- dtype = torch.float32
182
  else:
183
  dtype = torch.float32
184
  model = model.to(dtype)
 
138
  def initialize_asr_pipeline(device: str = device, dtype=None):
139
  if dtype is None:
140
  if "cuda" in device:
141
+ dtype = torch.float32
 
 
 
142
  else:
143
  dtype = torch.float32
144
  global asr_pipe
 
172
  def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
173
  if dtype is None:
174
  if "cuda" in device:
175
+ dtype = torch.float32
 
 
 
176
  else:
177
  dtype = torch.float32
178
  model = model.to(dtype)