chore: Update TTS dependencies and add MeloTTS support
Browse files- kitt/core/__init__.py +4 -1
- kitt/core/tts.py +29 -9
- kitt/skills/weather.py +1 -1
- main.py +5 -4
kitt/core/__init__.py
CHANGED
|
@@ -6,7 +6,7 @@ from typing import List
|
|
| 6 |
|
| 7 |
import numpy as np
|
| 8 |
import torch
|
| 9 |
-
from TTS.api import TTS
|
| 10 |
|
| 11 |
os.environ["COQUI_TOS_AGREED"] = "1"
|
| 12 |
|
|
@@ -17,6 +17,9 @@ Voice = namedtuple("voice", ["name", "neutral", "angry", "speed"])
|
|
| 17 |
file_full_path = pathlib.Path(os.path.realpath(__file__)).parent
|
| 18 |
|
| 19 |
voices = [
|
|
|
|
|
|
|
|
|
|
| 20 |
Voice(
|
| 21 |
"Attenborough",
|
| 22 |
neutral=f"{file_full_path}/audio/attenborough/neutral.wav",
|
|
|
|
| 6 |
|
| 7 |
import numpy as np
|
| 8 |
import torch
|
| 9 |
+
# from TTS.api import TTS
|
| 10 |
|
| 11 |
os.environ["COQUI_TOS_AGREED"] = "1"
|
| 12 |
|
|
|
|
| 17 |
file_full_path = pathlib.Path(os.path.realpath(__file__)).parent
|
| 18 |
|
| 19 |
voices = [
|
| 20 |
+
Voice(
|
| 21 |
+
"Fast", neutral=None, angry=None, speed=1.0,
|
| 22 |
+
),
|
| 23 |
Voice(
|
| 24 |
"Attenborough",
|
| 25 |
neutral=f"{file_full_path}/audio/attenborough/neutral.wav",
|
kitt/core/tts.py
CHANGED
|
@@ -3,15 +3,21 @@ from replicate import Client
|
|
| 3 |
from loguru import logger
|
| 4 |
from kitt.skills.common import config
|
| 5 |
import torch
|
| 6 |
-
|
|
|
|
| 7 |
from transformers import AutoTokenizer, set_seed
|
| 8 |
import soundfile as sf
|
|
|
|
|
|
|
| 9 |
|
| 10 |
replicate = Client(api_token=config.REPLICATE_API_KEY)
|
| 11 |
|
| 12 |
Voice = namedtuple("voice", ["name", "neutral", "angry", "speed"])
|
| 13 |
|
| 14 |
voices_replicate = [
|
|
|
|
|
|
|
|
|
|
| 15 |
Voice(
|
| 16 |
"Attenborough",
|
| 17 |
neutral="https://zebel.ams3.digitaloceanspaces.com/xtts/short/attenborough-neutral.wav",
|
|
@@ -44,6 +50,7 @@ voices_replicate = [
|
|
| 44 |
),
|
| 45 |
]
|
| 46 |
|
|
|
|
| 47 |
def voice_from_text(voice, voices):
|
| 48 |
for v in voices:
|
| 49 |
if voice == f"{v.name} - Neutral":
|
|
@@ -64,11 +71,7 @@ def speed_from_text(voice, voices):
|
|
| 64 |
def run_tts_replicate(text: str, voice_character: str):
|
| 65 |
voice = voice_from_text(voice_character, voices_replicate)
|
| 66 |
|
| 67 |
-
input = {
|
| 68 |
-
"text": text,
|
| 69 |
-
"speaker": voice,
|
| 70 |
-
"cleanup_voice": True
|
| 71 |
-
}
|
| 72 |
|
| 73 |
output = replicate.run(
|
| 74 |
# "afiaka87/tortoise-tts:e9658de4b325863c4fcdc12d94bb7c9b54cbfe351b7ca1b36860008172b91c71",
|
|
@@ -82,12 +85,13 @@ def run_tts_replicate(text: str, voice_character: str):
|
|
| 82 |
def get_fast_tts():
|
| 83 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 84 |
|
| 85 |
-
model = ParlerTTSForConditionalGeneration.from_pretrained(
|
|
|
|
|
|
|
| 86 |
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-expresso")
|
| 87 |
return model, tokenizer, device
|
| 88 |
|
| 89 |
|
| 90 |
-
|
| 91 |
fast_tts = get_fast_tts()
|
| 92 |
|
| 93 |
|
|
@@ -100,4 +104,20 @@ def run_tts_fast(text: str):
|
|
| 100 |
|
| 101 |
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
|
| 102 |
audio_arr = generation.cpu().numpy().squeeze()
|
| 103 |
-
return model.config.sampling_rate, audio_arr, dict(text=text, voice="Thomas")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from loguru import logger
|
| 4 |
from kitt.skills.common import config
|
| 5 |
import torch
|
| 6 |
+
|
| 7 |
+
# from parler_tts import ParlerTTSForConditionalGeneration
|
| 8 |
from transformers import AutoTokenizer, set_seed
|
| 9 |
import soundfile as sf
|
| 10 |
+
from melo.api import TTS as MeloTTS
|
| 11 |
+
|
| 12 |
|
| 13 |
replicate = Client(api_token=config.REPLICATE_API_KEY)
|
| 14 |
|
| 15 |
Voice = namedtuple("voice", ["name", "neutral", "angry", "speed"])
|
| 16 |
|
| 17 |
voices_replicate = [
|
| 18 |
+
Voice(
|
| 19 |
+
"Fast", neutral=None, angry=None, speed=1.0,
|
| 20 |
+
),
|
| 21 |
Voice(
|
| 22 |
"Attenborough",
|
| 23 |
neutral="https://zebel.ams3.digitaloceanspaces.com/xtts/short/attenborough-neutral.wav",
|
|
|
|
| 50 |
),
|
| 51 |
]
|
| 52 |
|
| 53 |
+
|
| 54 |
def voice_from_text(voice, voices):
|
| 55 |
for v in voices:
|
| 56 |
if voice == f"{v.name} - Neutral":
|
|
|
|
| 71 |
def run_tts_replicate(text: str, voice_character: str):
|
| 72 |
voice = voice_from_text(voice_character, voices_replicate)
|
| 73 |
|
| 74 |
+
input = {"text": text, "speaker": voice, "cleanup_voice": True}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
output = replicate.run(
|
| 77 |
# "afiaka87/tortoise-tts:e9658de4b325863c4fcdc12d94bb7c9b54cbfe351b7ca1b36860008172b91c71",
|
|
|
|
| 85 |
def get_fast_tts():
|
| 86 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 87 |
|
| 88 |
+
model = ParlerTTSForConditionalGeneration.from_pretrained(
|
| 89 |
+
"parler-tts/parler-tts-mini-expresso"
|
| 90 |
+
).to(device)
|
| 91 |
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-expresso")
|
| 92 |
return model, tokenizer, device
|
| 93 |
|
| 94 |
|
|
|
|
| 95 |
fast_tts = get_fast_tts()
|
| 96 |
|
| 97 |
|
|
|
|
| 104 |
|
| 105 |
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
|
| 106 |
audio_arr = generation.cpu().numpy().squeeze()
|
| 107 |
+
return (model.config.sampling_rate, audio_arr), dict(text=text, voice="Thomas")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def load_melo_tts():
|
| 111 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 112 |
+
model = MeloTTS(language="EN", device=device)
|
| 113 |
+
return model
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
melo_tts = load_melo_tts()
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def run_melo_tts(text: str, voice: str):
|
| 120 |
+
speed = 1.0
|
| 121 |
+
speaker_ids = melo_tts.hps.data.spk2id
|
| 122 |
+
audio = melo_tts.tts_to_file(text, speaker_ids["EN-Default"], None, speed=speed)
|
| 123 |
+
return melo_tts.hps.data.sampling_rate, audio
|
kitt/skills/weather.py
CHANGED
|
@@ -129,7 +129,7 @@ def get_forecast(city_name: str = "", when=0, **kwargs):
|
|
| 129 |
number_str = f"in {when-1} days"
|
| 130 |
|
| 131 |
# Generate a sentence for the day's forecast
|
| 132 |
-
forecast_sentence = f"On {date} ({number_str}) in {city_name}, the weather will be {conditions} with a high of {max_temp_c}
|
| 133 |
|
| 134 |
# number = number + 1
|
| 135 |
# Add the sentence to the result
|
|
|
|
| 129 |
number_str = f"in {when-1} days"
|
| 130 |
|
| 131 |
# Generate a sentence for the day's forecast
|
| 132 |
+
forecast_sentence = f"On {date} ({number_str}) in {city_name}, the weather will be {conditions} with a high of {max_temp_c}C and a low of {min_temp_c}C. There's a {chance_of_rain}% chance of rain. "
|
| 133 |
|
| 134 |
# number = number + 1
|
| 135 |
# Add the sentence to the result
|
main.py
CHANGED
|
@@ -8,7 +8,7 @@ import typer
|
|
| 8 |
|
| 9 |
from kitt.skills.common import config, vehicle
|
| 10 |
from kitt.skills.routing import calculate_route
|
| 11 |
-
from kitt.core.tts import run_tts_replicate, run_tts_fast
|
| 12 |
import ollama
|
| 13 |
|
| 14 |
from langchain.tools.base import StructuredTool
|
|
@@ -196,7 +196,7 @@ def run_nexusraven_model(query, voice_character, state):
|
|
| 196 |
|
| 197 |
if type(output_text) == tuple:
|
| 198 |
output_text = output_text[0]
|
| 199 |
-
gr.Info(f"Output text: {output_text}
|
| 200 |
return (
|
| 201 |
output_text,
|
| 202 |
tts_gradio(output_text, voice_character, speaker_embedding_cache)[0],
|
|
@@ -216,11 +216,12 @@ def run_llama3_model(query, voice_character, state):
|
|
| 216 |
functions=functions,
|
| 217 |
backend=state["llm_backend"],
|
| 218 |
)
|
| 219 |
-
gr.Info(f"Output text: {output_text}
|
| 220 |
voice_out = None
|
| 221 |
if state["tts_enabled"]:
|
| 222 |
# voice_out = run_tts_replicate(output_text, voice_character)
|
| 223 |
-
voice_out = run_tts_fast(output_text)[0]
|
|
|
|
| 224 |
# voice_out = tts_gradio(output_text, voice_character, speaker_embedding_cache)[0]
|
| 225 |
return (
|
| 226 |
output_text,
|
|
|
|
| 8 |
|
| 9 |
from kitt.skills.common import config, vehicle
|
| 10 |
from kitt.skills.routing import calculate_route
|
| 11 |
+
from kitt.core.tts import run_tts_replicate, run_tts_fast, run_melo_tts
|
| 12 |
import ollama
|
| 13 |
|
| 14 |
from langchain.tools.base import StructuredTool
|
|
|
|
| 196 |
|
| 197 |
if type(output_text) == tuple:
|
| 198 |
output_text = output_text[0]
|
| 199 |
+
gr.Info(f"Output text: {output_text}\nGenerating voice output...")
|
| 200 |
return (
|
| 201 |
output_text,
|
| 202 |
tts_gradio(output_text, voice_character, speaker_embedding_cache)[0],
|
|
|
|
| 216 |
functions=functions,
|
| 217 |
backend=state["llm_backend"],
|
| 218 |
)
|
| 219 |
+
gr.Info(f"Output text: {output_text}\nGenerating voice output...")
|
| 220 |
voice_out = None
|
| 221 |
if state["tts_enabled"]:
|
| 222 |
# voice_out = run_tts_replicate(output_text, voice_character)
|
| 223 |
+
# voice_out = run_tts_fast(output_text)[0]
|
| 224 |
+
voice_out = run_melo_tts(output_text, voice_character)
|
| 225 |
# voice_out = tts_gradio(output_text, voice_character, speaker_embedding_cache)[0]
|
| 226 |
return (
|
| 227 |
output_text,
|