Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
src/f5_tts/infer/utils_infer.py
CHANGED
|
@@ -135,12 +135,10 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
|
|
| 135 |
asr_pipe = None
|
| 136 |
|
| 137 |
|
| 138 |
-
def initialize_asr_pipeline(device=device, dtype=None):
|
| 139 |
if dtype is None:
|
| 140 |
dtype = (
|
| 141 |
-
torch.float16
|
| 142 |
-
if torch.cuda.is_available() and torch.cuda.get_device_properties(device).major >= 6
|
| 143 |
-
else torch.float32
|
| 144 |
)
|
| 145 |
global asr_pipe
|
| 146 |
asr_pipe = pipeline(
|
|
@@ -170,12 +168,10 @@ def transcribe(ref_audio, language=None):
|
|
| 170 |
# load model checkpoint for inference
|
| 171 |
|
| 172 |
|
| 173 |
-
def load_checkpoint(model, ckpt_path, device, dtype=None, use_ema=True):
|
| 174 |
if dtype is None:
|
| 175 |
dtype = (
|
| 176 |
-
torch.float16
|
| 177 |
-
if torch.cuda.is_available() and torch.cuda.get_device_properties(device).major >= 6
|
| 178 |
-
else torch.float32
|
| 179 |
)
|
| 180 |
model = model.to(dtype)
|
| 181 |
|
|
|
|
| 135 |
asr_pipe = None
|
| 136 |
|
| 137 |
|
| 138 |
+
def initialize_asr_pipeline(device: str = device, dtype=None):
|
| 139 |
if dtype is None:
|
| 140 |
dtype = (
|
| 141 |
+
torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
|
|
|
|
|
|
|
| 142 |
)
|
| 143 |
global asr_pipe
|
| 144 |
asr_pipe = pipeline(
|
|
|
|
| 168 |
# load model checkpoint for inference
|
| 169 |
|
| 170 |
|
| 171 |
+
def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
|
| 172 |
if dtype is None:
|
| 173 |
dtype = (
|
| 174 |
+
torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
|
|
|
|
|
|
|
| 175 |
)
|
| 176 |
model = model.to(dtype)
|
| 177 |
|