Spaces:
Running
on
T4
Running
on
T4
Update InferenceInterfaces/ControllableInterface.py
Browse files
InferenceInterfaces/ControllableInterface.py
CHANGED
|
@@ -1,23 +1,31 @@
|
|
| 1 |
import os
|
| 2 |
|
| 3 |
import torch
|
|
|
|
| 4 |
|
| 5 |
from InferenceInterfaces.ToucanTTSInterface import ToucanTTSInterface
|
| 6 |
from Modules.ControllabilityGAN.GAN import GanWrapper
|
| 7 |
-
from Utility.storage_config import MODELS_DIR
|
| 8 |
|
| 9 |
|
| 10 |
class ControllableInterface:
|
| 11 |
|
| 12 |
-
def __init__(self, gpu_id="cpu", available_artificial_voices=
|
| 13 |
if gpu_id == "cpu":
|
| 14 |
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
| 15 |
-
|
|
|
|
|
|
|
| 16 |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
| 17 |
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
self.device = "cuda" if gpu_id != "cpu" else "cpu"
|
| 19 |
-
self.model = ToucanTTSInterface(device=self.device, tts_model_path=
|
| 20 |
-
self.wgan = GanWrapper(
|
| 21 |
self.generated_speaker_embeds = list()
|
| 22 |
self.available_artificial_voices = available_artificial_voices
|
| 23 |
self.current_language = ""
|
|
|
|
| 1 |
import os
|
| 2 |
|
| 3 |
import torch
|
| 4 |
+
from huggingface_hub import hf_hub_download
|
| 5 |
|
| 6 |
from InferenceInterfaces.ToucanTTSInterface import ToucanTTSInterface
|
| 7 |
from Modules.ControllabilityGAN.GAN import GanWrapper
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
class ControllableInterface:
|
| 11 |
|
| 12 |
+
def __init__(self, gpu_id="cpu", available_artificial_voices=50, tts_model_path=None, vocoder_model_path=None, embedding_gan_path=None):
|
| 13 |
if gpu_id == "cpu":
|
| 14 |
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
| 15 |
+
elif gpu_id == "cuda":
|
| 16 |
+
pass
|
| 17 |
+
else: # in this case we hopefully got a number.
|
| 18 |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
| 19 |
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}"
|
| 20 |
+
if tts_model_path is None:
|
| 21 |
+
tts_model_path = hf_hub_download(repo_id="Flux9665/ToucanTTS", filename="ToucanTTS.pt")
|
| 22 |
+
if vocoder_model_path is None:
|
| 23 |
+
vocoder_model_path = hf_hub_download(repo_id="Flux9665/ToucanTTS", filename="Vocoder.pt")
|
| 24 |
+
if embedding_gan_path is None:
|
| 25 |
+
embedding_gan_path = hf_hub_download(repo_id="Flux9665/ToucanTTS", filename="embedding_gan.pt")
|
| 26 |
self.device = "cuda" if gpu_id != "cpu" else "cpu"
|
| 27 |
+
self.model = ToucanTTSInterface(device=self.device, tts_model_path=tts_model_path, vocoder_model_path=vocoder_model_path)
|
| 28 |
+
self.wgan = GanWrapper(embedding_gan_path, num_cached_voices=available_artificial_voices, device=self.device)
|
| 29 |
self.generated_speaker_embeds = list()
|
| 30 |
self.available_artificial_voices = available_artificial_voices
|
| 31 |
self.current_language = ""
|