dvalle08 commited on
Commit
11c8a27
·
1 Parent(s): af17212

Add Pocket TTS plugin with configuration settings and integration into LLMFactory

Browse files
pyproject.toml CHANGED
@@ -14,4 +14,5 @@ dependencies = [
14
  "livekit-agents[silero,turn-detector]~=1.3",
15
  "livekit-plugins-noise-cancellation~=0.2",
16
  "langgraph>=1.0.8",
 
17
  ]
 
14
  "livekit-agents[silero,turn-detector]~=1.3",
15
  "livekit-plugins-noise-cancellation~=0.2",
16
  "langgraph>=1.0.8",
17
+ "pydantic-settings>=2.12.0",
18
  ]
src/agent/llm_factory.py CHANGED
@@ -3,9 +3,9 @@ from typing import Any, Union
3
 
4
  from huggingface_hub import InferenceClient
5
  from transformers import pipeline
6
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFacePipeline
7
 
8
- from kokoro import KPipeline
9
  import torch
10
  from langchain_nvidia_ai_endpoints import ChatNVIDIA
11
 
@@ -32,40 +32,40 @@ class LLMFactory:
32
  max_completion_tokens=max_tokens,
33
  )
34
 
35
- @staticmethod
36
- def create_huggingface_llm(
37
- model_id: str,
38
- provider: str = "auto",
39
- temperature: float = settings.llm.LLM_TEMPERATURE,
40
- max_tokens: int = settings.llm.LLM_MAX_TOKENS,
41
- run_local: bool = False,
42
- ) -> ChatHuggingFace:
43
- if run_local:
44
- logger.info(f"Initializing local HuggingFace LLM: {model_id}")
45
- llm = HuggingFacePipeline.from_model_id(
46
- model_id=model_id,
47
- task="text-generation",
48
- pipeline_kwargs={
49
- "temperature": temperature,
50
- "max_new_tokens": max_tokens,
51
- },
52
- )
53
- return ChatHuggingFace(llm=llm)
54
-
55
- token = (settings.llm.HF_TOKEN or "").strip()
56
- if not token:
57
- raise ValueError("HF_TOKEN must be set to use the HuggingFace LLM provider.")
58
-
59
- logger.info(f"Initializing HuggingFace LLM: {model_id} via provider={provider}")
60
-
61
- llm = HuggingFaceEndpoint(
62
- repo_id=model_id,
63
- provider=provider,
64
- huggingfacehub_api_token=token,
65
- temperature=temperature,
66
- max_new_tokens=max_tokens,
67
- )
68
- return ChatHuggingFace(llm=llm)
69
 
70
  @staticmethod
71
  def create_huggingface_stt(
@@ -126,3 +126,38 @@ class LLMFactory:
126
  logger.info(f"Initializing Moonshine ONNX STT: {model_size}")
127
  from src.plugins.moonshine_stt import MoonshineSTT
128
  return MoonshineSTT(model_size=model_size, language=language)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  from huggingface_hub import InferenceClient
5
  from transformers import pipeline
6
+ #from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFacePipeline
7
 
8
+ #from kokoro import KPipeline
9
  import torch
10
  from langchain_nvidia_ai_endpoints import ChatNVIDIA
11
 
 
32
  max_completion_tokens=max_tokens,
33
  )
34
 
35
+ # @staticmethod
36
+ # def create_huggingface_llm(
37
+ # model_id: str,
38
+ # provider: str = "auto",
39
+ # temperature: float = settings.llm.LLM_TEMPERATURE,
40
+ # max_tokens: int = settings.llm.LLM_MAX_TOKENS,
41
+ # run_local: bool = False,
42
+ # ) -> ChatHuggingFace:
43
+ # if run_local:
44
+ # logger.info(f"Initializing local HuggingFace LLM: {model_id}")
45
+ # llm = HuggingFacePipeline.from_model_id(
46
+ # model_id=model_id,
47
+ # task="text-generation",
48
+ # pipeline_kwargs={
49
+ # "temperature": temperature,
50
+ # "max_new_tokens": max_tokens,
51
+ # },
52
+ # )
53
+ # return ChatHuggingFace(llm=llm)
54
+
55
+ # token = (settings.llm.HF_TOKEN or "").strip()
56
+ # if not token:
57
+ # raise ValueError("HF_TOKEN must be set to use the HuggingFace LLM provider.")
58
+
59
+ # logger.info(f"Initializing HuggingFace LLM: {model_id} via provider={provider}")
60
+
61
+ # llm = HuggingFaceEndpoint(
62
+ # repo_id=model_id,
63
+ # provider=provider,
64
+ # huggingfacehub_api_token=token,
65
+ # temperature=temperature,
66
+ # max_new_tokens=max_tokens,
67
+ # )
68
+ # return ChatHuggingFace(llm=llm)
69
 
70
  @staticmethod
71
  def create_huggingface_stt(
 
126
  logger.info(f"Initializing Moonshine ONNX STT: {model_size}")
127
  from src.plugins.moonshine_stt import MoonshineSTT
128
  return MoonshineSTT(model_size=model_size, language=language)
129
+
130
+ @staticmethod
131
+ def create_pocket_tts(
132
+ voice: str | None = None,
133
+ temperature: float | None = None,
134
+ lsd_decode_steps: int | None = None,
135
+ ) -> "PocketTTS":
136
+ """Initialize Pocket TTS plugin.
137
+
138
+ Args:
139
+ voice: Voice name (alba, marius, etc.) or path to audio file.
140
+ If None, uses settings.voice.POCKET_TTS_VOICE
141
+ temperature: Sampling temperature (0.0-2.0).
142
+ If None, uses settings.voice.POCKET_TTS_TEMPERATURE
143
+ lsd_decode_steps: LSD decoding steps for quality.
144
+ If None, uses settings.voice.POCKET_TTS_LSD_DECODE_STEPS
145
+
146
+ Returns:
147
+ PocketTTS plugin instance
148
+ """
149
+ from src.plugins.pocket_tts import PocketTTS
150
+
151
+ if voice is None:
152
+ voice = settings.voice.POCKET_TTS_VOICE
153
+ if temperature is None:
154
+ temperature = settings.voice.POCKET_TTS_TEMPERATURE
155
+ if lsd_decode_steps is None:
156
+ lsd_decode_steps = settings.voice.POCKET_TTS_LSD_DECODE_STEPS
157
+
158
+ logger.info(f"Initializing Pocket TTS: voice={voice}, temp={temperature}, lsd_steps={lsd_decode_steps}")
159
+ return PocketTTS(
160
+ voice=voice,
161
+ temperature=temperature,
162
+ lsd_decode_steps=lsd_decode_steps,
163
+ )
src/core/settings.py CHANGED
@@ -69,6 +69,23 @@ class VoiceSettings(CoreSettings):
69
  description="Moonshine model size: tiny, base, or small"
70
  )
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  class LLMSettings(CoreSettings):
74
  NVIDIA_API_KEY: Optional[str] = Field(default=None)
 
69
  description="Moonshine model size: tiny, base, or small"
70
  )
71
 
72
+ # TTS (Text-to-Speech) Settings - Pocket TTS
73
+ POCKET_TTS_VOICE: str = Field(
74
+ default="alba",
75
+ description="Default voice (alba, marius, javert, jean, fantine, cosette, eponine, azelma) or path to audio file"
76
+ )
77
+ POCKET_TTS_TEMPERATURE: float = Field(
78
+ default=0.7,
79
+ ge=0.0,
80
+ le=2.0,
81
+ description="Sampling temperature for generation"
82
+ )
83
+ POCKET_TTS_LSD_DECODE_STEPS: int = Field(
84
+ default=1,
85
+ ge=1,
86
+ description="LSD decoding steps (higher = better quality, slower)"
87
+ )
88
+
89
 
90
  class LLMSettings(CoreSettings):
91
  NVIDIA_API_KEY: Optional[str] = Field(default=None)
src/plugins/pocket_tts/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Pocket TTS Plugin - Local text-to-speech using Kyutai's pocket-tts."""
2
+
3
+ from .tts import PocketTTS
4
+
5
+ __all__ = ["PocketTTS"]
src/plugins/pocket_tts/tts.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tts.py - Pocket TTS Plugin for LiveKit Agents
2
+ from __future__ import annotations
3
+
4
+ import asyncio
5
+ import logging
6
+ import uuid
7
+ from typing import Any
8
+
9
+ import numpy as np
10
+ import torch
11
+ from pocket_tts import TTSModel
12
+
13
+ from livekit.agents import tts
14
+ from livekit.agents.types import APIConnectOptions, DEFAULT_API_CONNECT_OPTIONS
15
+
16
+ from src.core.logger import logger
17
+
18
+ # Reduce verbosity of pocket_tts library to avoid console spam
19
+ logging.getLogger("pocket_tts").setLevel(logging.WARNING)
20
+ logging.getLogger("pocket_tts.models.tts_model").setLevel(logging.WARNING)
21
+ logging.getLogger("pocket_tts.utils.utils").setLevel(logging.WARNING)
22
+ logging.getLogger("pocket_tts.conditioners.text").setLevel(logging.WARNING)
23
+
24
+
25
+ class PocketTTS(tts.TTS):
26
+ def __init__(
27
+ self,
28
+ *,
29
+ voice: str = "alba",
30
+ temperature: float = 0.7,
31
+ lsd_decode_steps: int = 1,
32
+ ) -> None:
33
+ """Initialize Pocket TTS plugin.
34
+
35
+ Args:
36
+ voice: Voice name (alba, marius, javert, jean, fantine, cosette, eponine, azelma)
37
+ or path to audio file for custom voice
38
+ temperature: Sampling temperature (0.0-2.0)
39
+ lsd_decode_steps: LSD decoding steps (higher = better quality, slower)
40
+ """
41
+ super().__init__(
42
+ capabilities=tts.TTSCapabilities(streaming=True, aligned_transcript=False),
43
+ sample_rate=24000,
44
+ num_channels=1,
45
+ )
46
+
47
+ self._voice = voice
48
+ self._temperature = temperature
49
+ self._lsd_decode_steps = lsd_decode_steps
50
+
51
+ try:
52
+ logger.info(f"Loading Pocket TTS model: temp={temperature}, lsd_steps={lsd_decode_steps}")
53
+ self._model = TTSModel.load_model(
54
+ temp=temperature,
55
+ lsd_decode_steps=lsd_decode_steps,
56
+ )
57
+ logger.info("Pocket TTS model loaded successfully")
58
+
59
+ logger.info(f"Loading voice state: {voice}")
60
+ self._voice_state = self._model.get_state_for_audio_prompt(voice, truncate=True)
61
+ logger.info(f"Voice state loaded for: {voice}")
62
+
63
+ except FileNotFoundError as e:
64
+ raise ValueError(f"Failed to load voice '{voice}': {e}") from e
65
+ except Exception as e:
66
+ logger.warning(f"Failed to load voice '{voice}': {e}, falling back to 'alba'")
67
+ try:
68
+ self._voice = "alba"
69
+ self._voice_state = self._model.get_state_for_audio_prompt("alba", truncate=True)
70
+ logger.info("Fallback to 'alba' voice successful")
71
+ except Exception as fallback_error:
72
+ raise ValueError(f"Failed to load Pocket TTS model: {fallback_error}") from fallback_error
73
+
74
+ @property
75
+ def model(self) -> str:
76
+ return "pocket-tts"
77
+
78
+ @property
79
+ def provider(self) -> str:
80
+ return "kyutai"
81
+
82
+ def synthesize(
83
+ self,
84
+ text: str,
85
+ *,
86
+ conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
87
+ ) -> tts.ChunkedStream:
88
+ """Synthesize text to speech using batch generation.
89
+
90
+ Args:
91
+ text: Text to synthesize
92
+ conn_options: API connection options
93
+
94
+ Returns:
95
+ ChunkedStream for batch synthesis
96
+ """
97
+ return self._synthesize_with_stream(text, conn_options=conn_options)
98
+
99
+ def stream(
100
+ self,
101
+ *,
102
+ conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
103
+ ) -> tts.SynthesizeStream:
104
+ """Create a streaming synthesis stream.
105
+
106
+ Args:
107
+ conn_options: API connection options
108
+
109
+ Returns:
110
+ PocketSynthesizeStream for progressive synthesis
111
+ """
112
+ return PocketSynthesizeStream(
113
+ tts=self,
114
+ conn_options=conn_options,
115
+ )
116
+
117
+
118
+ class PocketSynthesizeStream(tts.SynthesizeStream):
119
+ def __init__(
120
+ self,
121
+ *,
122
+ tts: PocketTTS,
123
+ conn_options: APIConnectOptions,
124
+ ) -> None:
125
+ """Initialize streaming synthesis stream.
126
+
127
+ Args:
128
+ tts: PocketTTS instance
129
+ conn_options: API connection options
130
+ """
131
+ super().__init__(tts=tts, conn_options=conn_options)
132
+ self._tts = tts
133
+
134
+ async def _run(self, output_emitter: tts.AudioEmitter) -> None:
135
+ """Process input stream and generate audio progressively.
136
+
137
+ Args:
138
+ output_emitter: Audio emitter for pushing generated audio
139
+ """
140
+ request_id = str(uuid.uuid4())
141
+ segment_id = str(uuid.uuid4())
142
+
143
+ output_emitter.initialize(
144
+ request_id=request_id,
145
+ sample_rate=24000,
146
+ num_channels=1,
147
+ mime_type="audio/pcm",
148
+ stream=True,
149
+ )
150
+ output_emitter.start_segment(segment_id=segment_id)
151
+
152
+ text_buffer = ""
153
+
154
+ async for data in self._input_ch:
155
+ if isinstance(data, self._FlushSentinel):
156
+ if text_buffer.strip():
157
+ await self._synthesize_segment(text_buffer, output_emitter)
158
+ text_buffer = ""
159
+ output_emitter.end_segment()
160
+
161
+ segment_id = str(uuid.uuid4())
162
+ output_emitter.start_segment(segment_id=segment_id)
163
+ continue
164
+
165
+ text_buffer += data
166
+
167
+ if text_buffer.strip():
168
+ await self._synthesize_segment(text_buffer, output_emitter)
169
+
170
+ output_emitter.end_segment()
171
+
172
+ async def _synthesize_segment(
173
+ self,
174
+ text: str,
175
+ output_emitter: tts.AudioEmitter,
176
+ ) -> None:
177
+ """Synthesize a text segment and push audio chunks to emitter.
178
+
179
+ Args:
180
+ text: Text segment to synthesize
181
+ output_emitter: Audio emitter for pushing generated audio
182
+ """
183
+ try:
184
+ def _generate_and_push() -> None:
185
+ for audio_chunk in self._tts._model.generate_audio_stream(
186
+ self._tts._voice_state,
187
+ text,
188
+ copy_state=True,
189
+ ):
190
+ audio_bytes = self._tensor_to_pcm_bytes(audio_chunk)
191
+ output_emitter.push(audio_bytes)
192
+
193
+ await asyncio.to_thread(_generate_and_push)
194
+
195
+ except Exception as e:
196
+ logger.error(f"Error synthesizing segment: {e}")
197
+ raise
198
+
199
+ def _tensor_to_pcm_bytes(self, audio_tensor: torch.Tensor) -> bytes:
200
+ """Convert audio tensor to PCM bytes.
201
+
202
+ Args:
203
+ audio_tensor: Audio tensor with shape [samples] or [channels, samples]
204
+
205
+ Returns:
206
+ PCM audio bytes (int16)
207
+ """
208
+ if audio_tensor.ndim > 1:
209
+ audio_tensor = audio_tensor.mean(dim=0)
210
+
211
+ audio_int16 = (audio_tensor.clamp(-1.0, 1.0) * 32767.0).short()
212
+
213
+ return audio_int16.cpu().numpy().tobytes()
testing/livekit_custom.py CHANGED
@@ -24,6 +24,7 @@ from huggingface_hub import InferenceClient
24
  import io
25
  import wave
26
  from src.plugins.moonshine_stt import MoonshineSTT
 
27
 
28
  load_dotenv(".env")
29
 
@@ -76,7 +77,7 @@ async def my_agent(ctx: agents.JobContext):
76
  session = AgentSession(
77
  stt=MoonshineSTT(model_id="UsefulSensors/moonshine-streaming-medium"),
78
  llm=langchain.LLMAdapter(create_nvidia_workflow()),
79
- tts="cartesia/sonic-3:9626c31c-bec5-4cca-baa8-f8ba9e84c8bc",
80
  vad=silero.VAD.load(),
81
  turn_detection=MultilingualModel(),
82
  )
 
24
  import io
25
  import wave
26
  from src.plugins.moonshine_stt import MoonshineSTT
27
+ from src.agent.llm_factory import LLMFactory
28
 
29
  load_dotenv(".env")
30
 
 
77
  session = AgentSession(
78
  stt=MoonshineSTT(model_id="UsefulSensors/moonshine-streaming-medium"),
79
  llm=langchain.LLMAdapter(create_nvidia_workflow()),
80
+ tts=LLMFactory.create_pocket_tts(voice="alba"),
81
  vad=silero.VAD.load(),
82
  turn_detection=MultilingualModel(),
83
  )
uv.lock CHANGED
@@ -2088,6 +2088,7 @@ dependencies = [
2088
  { name = "livekit-agents", extra = ["silero", "turn-detector"] },
2089
  { name = "livekit-plugins-noise-cancellation" },
2090
  { name = "nemo-toolkit", extra = ["asr"] },
 
2091
  { name = "python-dotenv" },
2092
  { name = "torch" },
2093
  { name = "torchaudio" },
@@ -2101,6 +2102,7 @@ requires-dist = [
2101
  { name = "livekit-agents", extras = ["silero", "turn-detector"], specifier = "~=1.3" },
2102
  { name = "livekit-plugins-noise-cancellation", specifier = "~=0.2" },
2103
  { name = "nemo-toolkit", extras = ["asr"] },
 
2104
  { name = "python-dotenv", specifier = ">=1.2.1" },
2105
  { name = "torch", specifier = "==2.10.0" },
2106
  { name = "torchaudio", specifier = "==2.10.0" },
@@ -2650,6 +2652,20 @@ wheels = [
2650
  { url = "https://files.pythonhosted.org/packages/2f/02/8559b1f26ee0d502c74f9cca5c0d2fd97e967e083e006bbbb4e97f3a043a/pydantic_core-2.41.5-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:d3a978c4f57a597908b7e697229d996d77a6d3c94901e9edee593adada95ce1a", size = 2147009, upload-time = "2025-11-04T13:43:23.286Z" },
2651
  ]
2652
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2653
  [[package]]
2654
  name = "pydub"
2655
  version = "0.25.1"
 
2088
  { name = "livekit-agents", extra = ["silero", "turn-detector"] },
2089
  { name = "livekit-plugins-noise-cancellation" },
2090
  { name = "nemo-toolkit", extra = ["asr"] },
2091
+ { name = "pydantic-settings" },
2092
  { name = "python-dotenv" },
2093
  { name = "torch" },
2094
  { name = "torchaudio" },
 
2102
  { name = "livekit-agents", extras = ["silero", "turn-detector"], specifier = "~=1.3" },
2103
  { name = "livekit-plugins-noise-cancellation", specifier = "~=0.2" },
2104
  { name = "nemo-toolkit", extras = ["asr"] },
2105
+ { name = "pydantic-settings", specifier = ">=2.12.0" },
2106
  { name = "python-dotenv", specifier = ">=1.2.1" },
2107
  { name = "torch", specifier = "==2.10.0" },
2108
  { name = "torchaudio", specifier = "==2.10.0" },
 
2652
  { url = "https://files.pythonhosted.org/packages/2f/02/8559b1f26ee0d502c74f9cca5c0d2fd97e967e083e006bbbb4e97f3a043a/pydantic_core-2.41.5-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:d3a978c4f57a597908b7e697229d996d77a6d3c94901e9edee593adada95ce1a", size = 2147009, upload-time = "2025-11-04T13:43:23.286Z" },
2653
  ]
2654
 
2655
+ [[package]]
2656
+ name = "pydantic-settings"
2657
+ version = "2.12.0"
2658
+ source = { registry = "https://pypi.org/simple" }
2659
+ dependencies = [
2660
+ { name = "pydantic" },
2661
+ { name = "python-dotenv" },
2662
+ { name = "typing-inspection" },
2663
+ ]
2664
+ sdist = { url = "https://files.pythonhosted.org/packages/43/4b/ac7e0aae12027748076d72a8764ff1c9d82ca75a7a52622e67ed3f765c54/pydantic_settings-2.12.0.tar.gz", hash = "sha256:005538ef951e3c2a68e1c08b292b5f2e71490def8589d4221b95dab00dafcfd0", size = 194184, upload-time = "2025-11-10T14:25:47.013Z" }
2665
+ wheels = [
2666
+ { url = "https://files.pythonhosted.org/packages/c1/60/5d4751ba3f4a40a6891f24eec885f51afd78d208498268c734e256fb13c4/pydantic_settings-2.12.0-py3-none-any.whl", hash = "sha256:fddb9fd99a5b18da837b29710391e945b1e30c135477f484084ee513adb93809", size = 51880, upload-time = "2025-11-10T14:25:45.546Z" },
2667
+ ]
2668
+
2669
  [[package]]
2670
  name = "pydub"
2671
  version = "0.25.1"