Update f5_tts/infer/utils_infer.py
Browse files
f5_tts/infer/utils_infer.py
CHANGED
|
@@ -138,10 +138,7 @@ asr_pipe = None
|
|
| 138 |
def initialize_asr_pipeline(device: str = device, dtype=None):
|
| 139 |
if dtype is None:
|
| 140 |
if "cuda" in device:
|
| 141 |
-
|
| 142 |
-
dtype = torch.float16
|
| 143 |
-
else:
|
| 144 |
-
dtype = torch.float32
|
| 145 |
else:
|
| 146 |
dtype = torch.float32
|
| 147 |
global asr_pipe
|
|
@@ -175,10 +172,7 @@ def transcribe(ref_audio, language=None):
|
|
| 175 |
def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
|
| 176 |
if dtype is None:
|
| 177 |
if "cuda" in device:
|
| 178 |
-
|
| 179 |
-
dtype = torch.float16
|
| 180 |
-
else:
|
| 181 |
-
dtype = torch.float32
|
| 182 |
else:
|
| 183 |
dtype = torch.float32
|
| 184 |
model = model.to(dtype)
|
|
|
|
| 138 |
def initialize_asr_pipeline(device: str = device, dtype=None):
|
| 139 |
if dtype is None:
|
| 140 |
if "cuda" in device:
|
| 141 |
+
dtype = torch.float32
|
|
|
|
|
|
|
|
|
|
| 142 |
else:
|
| 143 |
dtype = torch.float32
|
| 144 |
global asr_pipe
|
|
|
|
| 172 |
def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
|
| 173 |
if dtype is None:
|
| 174 |
if "cuda" in device:
|
| 175 |
+
dtype = torch.float32
|
|
|
|
|
|
|
|
|
|
| 176 |
else:
|
| 177 |
dtype = torch.float32
|
| 178 |
model = model.to(dtype)
|