Spaces:
Runtime error
Runtime error
Commit
·
ca02c10
1
Parent(s):
29d19bd
repo. change
Browse files
app.py
CHANGED
|
@@ -49,8 +49,8 @@ speed = 1.0
|
|
| 49 |
# fix_duration = 27 # None or float (duration in seconds)
|
| 50 |
fix_duration = None
|
| 51 |
|
| 52 |
-
def load_model(exp_name, model_cls, model_cfg, ckpt_step):
|
| 53 |
-
checkpoint = torch.load(str(cached_path(f"hf://SWivid/{exp_name}/model_{ckpt_step}.pt")), map_location=device)
|
| 54 |
vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
|
| 55 |
model = CFM(
|
| 56 |
transformer=model_cls(
|
|
@@ -79,13 +79,13 @@ def load_model(exp_name, model_cls, model_cfg, ckpt_step):
|
|
| 79 |
F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
| 80 |
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
| 81 |
|
| 82 |
-
F5TTS_ema_model, F5TTS_base_model = load_model("
|
| 83 |
-
E2TTS_ema_model, E2TTS_base_model = load_model("
|
| 84 |
|
| 85 |
@spaces.GPU
|
| 86 |
def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress = gr.Progress()):
|
| 87 |
print(gen_text)
|
| 88 |
-
if model.predict(gen_text)['toxicity'] > 0.8:
|
| 89 |
print("Flagged for toxicity:", gen_text)
|
| 90 |
raise gr.Error("Your text was flagged for toxicity, please try again with a different text.")
|
| 91 |
gr.Info("Converting audio...")
|
|
|
|
| 49 |
# fix_duration = 27 # None or float (duration in seconds)
|
| 50 |
fix_duration = None
|
| 51 |
|
| 52 |
+
def load_model(exp_name, model_cls, model_cfg, ckpt_step,repoid):
|
| 53 |
+
checkpoint = torch.load(str(cached_path(f"hf://SWivid/{repoid}/{exp_name}/model_{ckpt_step}.pt")), map_location=device)
|
| 54 |
vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
|
| 55 |
model = CFM(
|
| 56 |
transformer=model_cls(
|
|
|
|
| 79 |
F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
| 80 |
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
| 81 |
|
| 82 |
+
F5TTS_ema_model, F5TTS_base_model = load_model("F5TTS_Base", DiT, F5TTS_model_cfg, 1200000, "F5-TTS")
|
| 83 |
+
E2TTS_ema_model, E2TTS_base_model = load_model("E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000, "E2-TTS")
|
| 84 |
|
| 85 |
@spaces.GPU
|
| 86 |
def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress = gr.Progress()):
|
| 87 |
print(gen_text)
|
| 88 |
+
if model.predict(gen_text)['toxicity'] > 0.8:
|
| 89 |
print("Flagged for toxicity:", gen_text)
|
| 90 |
raise gr.Error("Your text was flagged for toxicity, please try again with a different text.")
|
| 91 |
gr.Info("Converting audio...")
|