dvalle08 commited on
Commit
9af190b
·
1 Parent(s): ffad511

Refactor Open Voice Agent: Transition to Hatch for build system, restructure agent components, and streamline conversation graph

Browse files
.gitignore CHANGED
@@ -119,7 +119,9 @@ venv.bak/
119
  dev/
120
  nvidia_services/cache/asr/
121
  nvidia_services/cache/tts/
122
- test/
 
 
123
  # Spyder project settings
124
  .spyderproject
125
  .spyproject
 
119
  dev/
120
  nvidia_services/cache/asr/
121
  nvidia_services/cache/tts/
122
+ .claude/
123
+ .cursor/
124
+ .pytest_cache/
125
  # Spyder project settings
126
  .spyderproject
127
  .spyproject
main.py DELETED
@@ -1,84 +0,0 @@
1
- import argparse
2
- import multiprocessing
3
- import subprocess
4
- import sys
5
-
6
- from src.core.logger import logger
7
-
8
-
9
- def run_api():
10
- logger.info("Starting FastAPI server...")
11
- import uvicorn
12
- from src.core.settings import settings
13
-
14
- uvicorn.run(
15
- "src.api.main:app",
16
- host=settings.api.API_HOST,
17
- port=settings.api.API_PORT,
18
- workers=settings.api.API_WORKERS,
19
- reload=True,
20
- log_level="info",
21
- )
22
-
23
-
24
- def run_streamlit():
25
- logger.info("Starting Streamlit UI...")
26
- subprocess.run([
27
- sys.executable,
28
- "-m",
29
- "streamlit",
30
- "run",
31
- "src/streamlit_app.py",
32
- "--server.port=8501",
33
- "--server.address=localhost",
34
- ])
35
-
36
-
37
- def run_both():
38
- logger.info("Starting both FastAPI server and Streamlit UI...")
39
-
40
- api_process = multiprocessing.Process(target=run_api, name="FastAPI")
41
- streamlit_process = multiprocessing.Process(target=run_streamlit, name="Streamlit")
42
-
43
- try:
44
- api_process.start()
45
- streamlit_process.start()
46
-
47
- api_process.join()
48
- streamlit_process.join()
49
-
50
- except KeyboardInterrupt:
51
- logger.info("Shutting down...")
52
- api_process.terminate()
53
- streamlit_process.terminate()
54
- api_process.join()
55
- streamlit_process.join()
56
- logger.info("Shutdown complete")
57
-
58
-
59
- def main():
60
- parser = argparse.ArgumentParser(
61
- description="Open Voice Agent - Real-time AI voice conversations"
62
- )
63
- parser.add_argument(
64
- "mode",
65
- choices=["api", "streamlit", "both"],
66
- default="both",
67
- nargs="?",
68
- help="Run mode: 'api' (FastAPI server), 'streamlit' (UI), or 'both' (default)",
69
- )
70
-
71
- args = parser.parse_args()
72
-
73
- logger.info(f"Starting Open Voice Agent in '{args.mode}' mode...")
74
-
75
- if args.mode == "api":
76
- run_api()
77
- elif args.mode == "streamlit":
78
- run_streamlit()
79
- else:
80
- run_both()
81
-
82
-
83
- if __name__ == "__main__":
84
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pyproject.toml CHANGED
@@ -1,3 +1,10 @@
 
 
 
 
 
 
 
1
  [project]
2
  name = "open-voice-agent"
3
  version = "0.1.0"
 
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [tool.hatch.build.targets.wheel]
6
+ packages = ["src"]
7
+
8
  [project]
9
  name = "open-voice-agent"
10
  version = "0.1.0"
src/agent/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
- """Conversation agent using LangGraph."""
2
 
3
- from src.agent.graph import create_conversation_graph
4
- from src.agent.state import ConversationState
5
 
6
- __all__ = ["create_conversation_graph", "ConversationState"]
 
1
+ """LiveKit voice agent using LangGraph."""
2
 
3
+ from src.agent.graph import create_graph
4
+ from src.agent.agent import Assistant
5
 
6
+ __all__ = ["create_graph", "Assistant"]
testing/livekit_custom.py → src/agent/agent.py RENAMED
@@ -1,64 +1,14 @@
1
- from __future__ import annotations
2
-
3
- from pathlib import Path
4
- import sys
5
-
6
- project_root = Path(__file__).resolve().parents[1]
7
- if str(project_root) not in sys.path:
8
- sys.path.insert(0, str(project_root))
9
-
10
- from dotenv import load_dotenv
11
-
12
  from livekit import agents, rtc
13
- from livekit.agents import AgentServer,AgentSession, Agent, room_io
14
  from livekit.plugins import noise_cancellation, silero
15
  from livekit.plugins.turn_detector.multilingual import MultilingualModel
16
-
17
- import os
18
- from langgraph.graph import StateGraph, MessagesState, START, END
19
- from langchain_nvidia_ai_endpoints import ChatNVIDIA
20
  from livekit.plugins import langchain
21
 
22
- from livekit.agents import stt, tts
23
- from huggingface_hub import InferenceClient
24
- import io
25
- import wave
26
  from src.plugins.moonshine_stt import MoonshineSTT
27
- from src.agent.llm_factory import LLMFactory
28
-
29
- load_dotenv(".env")
30
-
31
- # Simple LangGraph workflow with NVIDIA LLM
32
- def create_nvidia_workflow():
33
- """Create a simple LangGraph workflow with NVIDIA ChatNVIDIA"""
34
-
35
- # Initialize NVIDIA LLM
36
- nvidia_llm = ChatNVIDIA(
37
- model="meta/llama-3.1-8b-instruct",
38
- api_key=os.getenv("NVIDIA_API_KEY"),
39
- temperature=0.7,
40
- max_tokens=150
41
- )
42
-
43
- # Define the conversation node
44
- def call_model(state: MessagesState):
45
- """Simple node that calls the NVIDIA LLM"""
46
- response = nvidia_llm.invoke(state["messages"])
47
- return {"messages": [response]}
48
-
49
- # Build the graph
50
- workflow = StateGraph(MessagesState)
51
-
52
- # Add the single node
53
- workflow.add_node("agent", call_model)
54
-
55
- # Define the flow: START -> agent -> END
56
- workflow.add_edge(START, "agent")
57
- workflow.add_edge("agent", END)
58
-
59
- # Compile and return
60
- return workflow.compile()
61
-
62
 
63
 
64
  class Assistant(Agent):
@@ -70,32 +20,36 @@ class Assistant(Agent):
70
  You are curious, friendly, and have a sense of humor.""",
71
  )
72
 
 
73
  server = AgentServer()
74
 
 
75
  @server.rtc_session()
76
- async def my_agent(ctx: agents.JobContext):
77
  session = AgentSession(
78
- stt=MoonshineSTT(model_id="UsefulSensors/moonshine-streaming-medium"),
79
- llm=langchain.LLMAdapter(create_nvidia_workflow()),
80
- tts=LLMFactory.create_pocket_tts(voice="alba"),
 
 
 
 
81
  vad=silero.VAD.load(),
82
  turn_detection=MultilingualModel(),
83
  )
84
-
85
  await session.start(
86
  room=ctx.room,
87
  agent=Assistant(),
88
  room_options=room_io.RoomOptions(
89
  audio_input=room_io.AudioInputOptions(
90
- noise_cancellation=lambda params: noise_cancellation.BVCTelephony() if params.participant.kind == rtc.ParticipantKind.PARTICIPANT_KIND_SIP else noise_cancellation.BVC(),
 
 
91
  ),
92
  ),
93
  )
94
-
95
- await session.generate_reply(
96
- instructions="Greet the user and offer your assistance."
97
- )
98
 
99
 
100
  if __name__ == "__main__":
101
- agents.cli.run_app(server)
 
 
 
 
 
 
 
 
 
 
 
 
1
  from livekit import agents, rtc
2
+ from livekit.agents import AgentServer, AgentSession, Agent, room_io
3
  from livekit.plugins import noise_cancellation, silero
4
  from livekit.plugins.turn_detector.multilingual import MultilingualModel
 
 
 
 
5
  from livekit.plugins import langchain
6
 
7
+ from src.agent.graph import create_graph
 
 
 
8
  from src.plugins.moonshine_stt import MoonshineSTT
9
+ from src.plugins.pocket_tts import PocketTTS
10
+ from src.core.settings import settings
11
+ from src.core.logger import logger
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  class Assistant(Agent):
 
20
  You are curious, friendly, and have a sense of humor.""",
21
  )
22
 
23
+
24
  server = AgentServer()
25
 
26
+
27
  @server.rtc_session()
28
+ async def session_handler(ctx: agents.JobContext) -> None:
29
  session = AgentSession(
30
+ stt=MoonshineSTT(model_id=settings.voice.MOONSHINE_MODEL_ID),
31
+ llm=langchain.LLMAdapter(create_graph()),
32
+ tts=PocketTTS(
33
+ voice=settings.voice.POCKET_TTS_VOICE,
34
+ temperature=settings.voice.POCKET_TTS_TEMPERATURE,
35
+ lsd_decode_steps=settings.voice.POCKET_TTS_LSD_DECODE_STEPS,
36
+ ),
37
  vad=silero.VAD.load(),
38
  turn_detection=MultilingualModel(),
39
  )
 
40
  await session.start(
41
  room=ctx.room,
42
  agent=Assistant(),
43
  room_options=room_io.RoomOptions(
44
  audio_input=room_io.AudioInputOptions(
45
+ noise_cancellation=lambda params: noise_cancellation.BVCTelephony()
46
+ if params.participant.kind == rtc.ParticipantKind.PARTICIPANT_KIND_SIP
47
+ else noise_cancellation.BVC(),
48
  ),
49
  ),
50
  )
51
+ await session.generate_reply(instructions="Greet the user and offer your assistance.")
 
 
 
52
 
53
 
54
  if __name__ == "__main__":
55
+ agents.cli.run_app(server)
src/agent/graph.py CHANGED
@@ -1,117 +1,23 @@
1
- from typing import Literal, Optional
 
2
 
3
- from langchain_core.language_models import BaseLanguageModel
4
- from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
5
- from langgraph.graph import StateGraph, END
6
- from langgraph.checkpoint.memory import MemorySaver
7
 
8
- from src.agent.llm_factory import LLMFactory
9
- from src.agent.prompts import get_system_prompt
10
- from src.agent.state import ConversationState
11
- from src.core.logger import logger
12
 
 
 
 
 
 
 
 
 
13
 
14
- def process_user_input(state: ConversationState) -> ConversationState:
15
- transcript = state.get("current_transcript", "").strip()
16
-
17
- if not transcript:
18
- logger.debug("No transcript to process")
19
- return state
20
-
21
- logger.info(f"Processing user input: {transcript}")
22
-
23
- messages = state.get("messages", [])
24
-
25
- if not messages:
26
- messages.append(SystemMessage(content=get_system_prompt()))
27
-
28
- messages.append(HumanMessage(content=transcript))
29
-
30
- return {
31
- **state,
32
- "messages": messages,
33
- "current_transcript": "",
34
- }
35
-
36
-
37
- def generate_response(state: ConversationState, llm: Optional[BaseLanguageModel] = None) -> ConversationState:
38
- if llm is None:
39
- llm = LLMFactory.create_llm()
40
-
41
- messages = state.get("messages", [])
42
-
43
- if not messages:
44
- logger.warning("No messages to generate response from")
45
- return state
46
-
47
- logger.info("Generating AI response...")
48
-
49
- try:
50
- response = llm.invoke(messages)
51
-
52
- if hasattr(response, "content"):
53
- content = response.content
54
- else:
55
- content = str(response)
56
-
57
- logger.info(f"Generated response: {content[:100]}...")
58
-
59
- messages.append(AIMessage(content=content))
60
-
61
- return {
62
- **state,
63
- "messages": messages,
64
- }
65
-
66
- except Exception as e:
67
- logger.error(f"Error generating response: {e}")
68
- fallback = "I'm sorry, I encountered an error. Could you please repeat that?"
69
- messages.append(AIMessage(content=fallback))
70
- return {
71
- **state,
72
- "messages": messages,
73
- }
74
-
75
-
76
- def should_respond(state: ConversationState) -> Literal["generate", "wait"]:
77
- turn_active = state.get("turn_active", False)
78
- current_transcript = state.get("current_transcript", "").strip()
79
-
80
- if turn_active:
81
- logger.debug("Turn still active, waiting...")
82
- return "wait"
83
-
84
- if current_transcript:
85
- logger.debug("Turn complete with transcript, generating response")
86
- return "generate"
87
-
88
- logger.debug("No action needed, waiting...")
89
- return "wait"
90
-
91
 
92
- def create_conversation_graph() -> StateGraph:
93
- logger.info("Creating conversation graph...")
94
-
95
- workflow = StateGraph(ConversationState)
96
-
97
- workflow.add_node("process_input", process_user_input)
98
- workflow.add_node("generate_response", generate_response)
99
-
100
- workflow.set_entry_point("process_input")
101
-
102
- workflow.add_conditional_edges(
103
- "process_input",
104
- should_respond,
105
- {
106
- "generate": "generate_response",
107
- "wait": END,
108
- },
109
- )
110
-
111
- workflow.add_edge("generate_response", END)
112
-
113
- memory = MemorySaver()
114
- graph = workflow.compile(checkpointer=memory)
115
-
116
- logger.info("Conversation graph created successfully")
117
- return graph
 
1
+ from langchain_nvidia_ai_endpoints import ChatNVIDIA
2
+ from langgraph.graph import StateGraph, MessagesState, START, END
3
 
4
+ from src.core.settings import settings
 
 
 
5
 
 
 
 
 
6
 
7
+ def create_graph():
8
+ """Create a single-node LangGraph workflow using NVIDIA ChatNVIDIA."""
9
+ llm = ChatNVIDIA(
10
+ model=settings.llm.NVIDIA_MODEL,
11
+ api_key=settings.llm.NVIDIA_API_KEY,
12
+ temperature=settings.llm.LLM_TEMPERATURE,
13
+ max_tokens=settings.llm.LLM_MAX_TOKENS,
14
+ )
15
 
16
+ def call_model(state: MessagesState) -> dict:
17
+ return {"messages": [llm.invoke(state["messages"])]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ workflow = StateGraph(MessagesState)
20
+ workflow.add_node("agent", call_model)
21
+ workflow.add_edge(START, "agent")
22
+ workflow.add_edge("agent", END)
23
+ return workflow.compile()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/agent/llm_factory.py DELETED
@@ -1,163 +0,0 @@
1
-
2
- from typing import Any, Union
3
-
4
- from huggingface_hub import InferenceClient
5
- from transformers import pipeline
6
- #from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFacePipeline
7
-
8
- #from kokoro import KPipeline
9
- import torch
10
- from langchain_nvidia_ai_endpoints import ChatNVIDIA
11
-
12
- from src.core.logger import logger
13
- from src.core.settings import settings
14
-
15
-
16
- class LLMFactory:
17
- @staticmethod
18
- def create_nvidia_llm(
19
- model: str = settings.llm.NVIDIA_MODEL,
20
- temperature: float = settings.llm.LLM_TEMPERATURE,
21
- max_tokens: int = settings.llm.LLM_MAX_TOKENS,
22
- ) -> ChatNVIDIA:
23
- logger.info(f"Initializing NVIDIA LLM: {model}")
24
-
25
- if not settings.llm.NVIDIA_API_KEY:
26
- raise ValueError("NVIDIA_API_KEY must be set to use the NVIDIA LLM provider.")
27
-
28
- return ChatNVIDIA(
29
- model=model,
30
- api_key=settings.llm.NVIDIA_API_KEY,
31
- temperature=temperature,
32
- max_completion_tokens=max_tokens,
33
- )
34
-
35
- # @staticmethod
36
- # def create_huggingface_llm(
37
- # model_id: str,
38
- # provider: str = "auto",
39
- # temperature: float = settings.llm.LLM_TEMPERATURE,
40
- # max_tokens: int = settings.llm.LLM_MAX_TOKENS,
41
- # run_local: bool = False,
42
- # ) -> ChatHuggingFace:
43
- # if run_local:
44
- # logger.info(f"Initializing local HuggingFace LLM: {model_id}")
45
- # llm = HuggingFacePipeline.from_model_id(
46
- # model_id=model_id,
47
- # task="text-generation",
48
- # pipeline_kwargs={
49
- # "temperature": temperature,
50
- # "max_new_tokens": max_tokens,
51
- # },
52
- # )
53
- # return ChatHuggingFace(llm=llm)
54
-
55
- # token = (settings.llm.HF_TOKEN or "").strip()
56
- # if not token:
57
- # raise ValueError("HF_TOKEN must be set to use the HuggingFace LLM provider.")
58
-
59
- # logger.info(f"Initializing HuggingFace LLM: {model_id} via provider={provider}")
60
-
61
- # llm = HuggingFaceEndpoint(
62
- # repo_id=model_id,
63
- # provider=provider,
64
- # huggingfacehub_api_token=token,
65
- # temperature=temperature,
66
- # max_new_tokens=max_tokens,
67
- # )
68
- # return ChatHuggingFace(llm=llm)
69
-
70
- @staticmethod
71
- def create_huggingface_stt(
72
- model_id: str | None = None, run_local: bool = False
73
- ) -> Union[InferenceClient, Any]:
74
- if run_local:
75
- logger.info(f"Initializing local HuggingFace STT: {model_id or 'default'}")
76
- return pipeline("automatic-speech-recognition", model=model_id)
77
-
78
- token = (settings.llm.HF_TOKEN or "").strip()
79
- if not token:
80
- raise ValueError("HF_TOKEN must be set to use the HuggingFace STT provider.")
81
-
82
- logger.info(f"Initializing HuggingFace STT: {model_id or 'default'}")
83
-
84
- return InferenceClient(model=model_id, token=token)
85
-
86
- @staticmethod
87
- def create_huggingface_tts(
88
- model_id: str | None = None, run_local: bool = False
89
- ) -> Union[InferenceClient, Any]:
90
- if run_local:
91
- logger.info(f"Initializing local HuggingFace TTS: {model_id or 'default'}")
92
- return pipeline("text-to-speech", model=model_id)
93
-
94
- token = (settings.llm.HF_TOKEN or "").strip()
95
- if not token:
96
- raise ValueError("HF_TOKEN must be set to use the HuggingFace TTS provider.")
97
-
98
- logger.info(f"Initializing HuggingFace TTS: {model_id or 'default'}")
99
-
100
- return InferenceClient(model=model_id, token=token)
101
-
102
- @staticmethod
103
- def create_kokoro_tts(lang_code: str = "a") -> Any:
104
- if KPipeline is None:
105
- raise ImportError(
106
- "kokoro library not found. Please install it (pip install kokoro>=0.9.4) to use Kokoro TTS."
107
- )
108
-
109
- logger.info(f"Initializing Kokoro TTS Pipeline with lang_code: {lang_code}")
110
- return KPipeline(lang_code=lang_code, repo_id="hexgrad/Kokoro-82M")
111
-
112
- @staticmethod
113
- def create_moonshine_stt(
114
- model_size: str = "base",
115
- language: str = "en",
116
- ) -> "MoonshineSTT":
117
- """Initialize Moonshine ONNX STT plugin.
118
-
119
- Args:
120
- model_size: "tiny" (26MB) or "base" (57MB), or language variants (e.g., "base-es", "tiny-ar")
121
- language: Currently only "en" supported
122
-
123
- Returns:
124
- MoonshineSTT plugin instance
125
- """
126
- logger.info(f"Initializing Moonshine ONNX STT: {model_size}")
127
- from src.plugins.moonshine_stt import MoonshineSTT
128
- return MoonshineSTT(model_size=model_size, language=language)
129
-
130
- @staticmethod
131
- def create_pocket_tts(
132
- voice: str | None = None,
133
- temperature: float | None = None,
134
- lsd_decode_steps: int | None = None,
135
- ) -> "PocketTTS":
136
- """Initialize Pocket TTS plugin.
137
-
138
- Args:
139
- voice: Voice name (alba, marius, etc.) or path to audio file.
140
- If None, uses settings.voice.POCKET_TTS_VOICE
141
- temperature: Sampling temperature (0.0-2.0).
142
- If None, uses settings.voice.POCKET_TTS_TEMPERATURE
143
- lsd_decode_steps: LSD decoding steps for quality.
144
- If None, uses settings.voice.POCKET_TTS_LSD_DECODE_STEPS
145
-
146
- Returns:
147
- PocketTTS plugin instance
148
- """
149
- from src.plugins.pocket_tts import PocketTTS
150
-
151
- if voice is None:
152
- voice = settings.voice.POCKET_TTS_VOICE
153
- if temperature is None:
154
- temperature = settings.voice.POCKET_TTS_TEMPERATURE
155
- if lsd_decode_steps is None:
156
- lsd_decode_steps = settings.voice.POCKET_TTS_LSD_DECODE_STEPS
157
-
158
- logger.info(f"Initializing Pocket TTS: voice={voice}, temp={temperature}, lsd_steps={lsd_decode_steps}")
159
- return PocketTTS(
160
- voice=voice,
161
- temperature=temperature,
162
- lsd_decode_steps=lsd_decode_steps,
163
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/agent/prompts.py DELETED
@@ -1,41 +0,0 @@
1
- from typing import Any, Optional
2
- from enum import Enum
3
-
4
-
5
- class PromptVersion(str, Enum):
6
- V1 = "v1"
7
- DEFAULT = "v1"
8
-
9
-
10
- class PromptTemplate:
11
- def __init__(self, template: str, version: PromptVersion = PromptVersion.DEFAULT):
12
- self.template = template
13
- self.version = version
14
-
15
- def render(self, **kwargs: Any) -> str:
16
- return self.template.format(**kwargs)
17
-
18
-
19
- SYSTEM_PROMPT_V1 = """You are a helpful AI voice assistant. You engage in natural, conversational dialogue with users.
20
-
21
- Guidelines:
22
- - Keep responses concise and natural for voice interaction
23
- - Be friendly and engaging
24
- - Ask clarifying questions when needed
25
- - Acknowledge what the user says before responding
26
- - Keep your responses focused and to the point (2-3 sentences typically)
27
- """
28
-
29
- SYSTEM_PROMPTS = {
30
- PromptVersion.V1: PromptTemplate(SYSTEM_PROMPT_V1, PromptVersion.V1),
31
- }
32
-
33
-
34
- def get_system_prompt(version: Optional[PromptVersion] = None) -> str:
35
- version = version or PromptVersion.DEFAULT
36
- return SYSTEM_PROMPTS[version].render()
37
-
38
-
39
- def get_custom_prompt(template: str, **context: Any) -> str:
40
- prompt = PromptTemplate(template)
41
- return prompt.render(**context)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/agent/state.py DELETED
@@ -1,10 +0,0 @@
1
- from typing import Any, TypedDict
2
-
3
- from langchain_core.messages import BaseMessage
4
-
5
-
6
- class ConversationState(TypedDict):
7
- messages: list[BaseMessage]
8
- current_transcript: str
9
- context: dict[str, Any]
10
- turn_active: bool
 
 
 
 
 
 
 
 
 
 
 
src/api/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- """FastAPI application and WebSocket handlers."""
2
-
3
- from src.api.main import app
4
-
5
- __all__ = ["app"]
 
 
 
 
 
 
src/api/main.py DELETED
@@ -1,69 +0,0 @@
1
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect
2
- from fastapi.middleware.cors import CORSMiddleware
3
-
4
- from src.api.websocket import VoiceWebSocketHandler
5
- from src.core.logger import logger
6
- from src.core.settings import settings
7
-
8
- app = FastAPI(
9
- title="Open Voice Agent API",
10
- description="Real-time voice conversation agent with WebSocket support",
11
- version="0.1.0",
12
- )
13
-
14
- app.add_middleware(
15
- CORSMiddleware,
16
- allow_origins=settings.api.API_CORS_ORIGINS,
17
- allow_credentials=True,
18
- allow_methods=["*"],
19
- allow_headers=["*"],
20
- )
21
-
22
-
23
- @app.get("/health")
24
- async def health_check():
25
- return {
26
- "status": "healthy",
27
- "service": "open-voice-agent",
28
- "version": "0.1.0",
29
- }
30
-
31
-
32
- @app.websocket("/ws/voice")
33
- async def websocket_voice_endpoint(websocket: WebSocket):
34
- handler = VoiceWebSocketHandler(websocket)
35
-
36
- try:
37
- await handler.connect()
38
- await handler.handle_conversation()
39
- except WebSocketDisconnect:
40
- logger.info("Client disconnected")
41
- except Exception as e:
42
- logger.error(f"WebSocket error: {e}", exc_info=True)
43
- await handler.send_error(str(e))
44
- finally:
45
- await handler.disconnect()
46
-
47
-
48
- @app.on_event("startup")
49
- async def startup_event():
50
- logger.info("Starting Open Voice Agent API...")
51
- logger.info(f"Voice provider: {settings.voice.VOICE_PROVIDER}")
52
- logger.info(f"LLM provider: {settings.llm.LLM_PROVIDER}")
53
-
54
-
55
- @app.on_event("shutdown")
56
- async def shutdown_event():
57
- logger.info("Shutting down Open Voice Agent API...")
58
-
59
-
60
- if __name__ == "__main__":
61
- import uvicorn
62
-
63
- uvicorn.run(
64
- "src.api.main:app",
65
- host=settings.api.API_HOST,
66
- port=settings.api.API_PORT,
67
- workers=settings.api.API_WORKERS,
68
- reload=True,
69
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/core/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from src.core.settings import settings
2
+ from src.core.logger import logger
3
+
4
+ __all__ = ["settings", "logger"]
src/core/settings.py CHANGED
@@ -1,6 +1,6 @@
1
  import json
2
  from pathlib import Path
3
- from typing import Any, Optional
4
 
5
  from pydantic import Field, ValidationError
6
  from pydantic_settings import BaseSettings, SettingsConfigDict
@@ -15,10 +15,10 @@ load_dotenv(ENV_FILE, override=True)
15
  logger.info(f"Loaded environment from: {ENV_FILE}")
16
 
17
 
18
- def mask_sensitive_data(data: dict[str, Any]) -> dict[str, Any]:
19
  masked = {}
20
  sensitive_keys = ["key", "token", "secret", "password"]
21
-
22
  for key, value in data.items():
23
  if isinstance(value, dict):
24
  masked[key] = mask_sensitive_data(value)
@@ -31,7 +31,7 @@ def mask_sensitive_data(data: dict[str, Any]) -> dict[str, Any]:
31
  masked[key] = f"{value[:4]}...{value[-4:]}"
32
  else:
33
  masked[key] = value
34
-
35
  return masked
36
 
37
 
@@ -46,80 +46,48 @@ class CoreSettings(BaseSettings):
46
 
47
 
48
  class VoiceSettings(CoreSettings):
49
- VOICE_PROVIDER: str = Field(default="nvidia")
50
-
51
- NVIDIA_VOICE_LANGUAGE: str = Field(default="en-US")
52
- NVIDIA_VOICE_NAME: str = Field(default="Magpie-Multilingual.EN-US.Aria")
53
- NVIDIA_TTS_MODEL: str = Field(default="magpie-tts-multilingual")
54
- NVIDIA_TTS_ENDPOINT: str = Field(default="")
55
-
56
- SAMPLE_RATE_OUTPUT: int = Field(default=48000, gt=0)
57
- CHUNK_DURATION_MS: int = Field(default=80, gt=0)
58
-
59
- VAD_THRESHOLD: float = Field(default=0.5, ge=0.0, le=1.0)
60
- VAD_HORIZON_INDEX: int = Field(default=2, ge=0)
61
-
62
- # STT (Speech-to-Text) Settings
63
- STT_PROVIDER: str = Field(
64
- default="moonshine",
65
- description="STT provider (moonshine, assemblyai, etc)"
66
- )
67
- MOONSHINE_MODEL_SIZE: str = Field(
68
- default="small",
69
- description="Moonshine model size: tiny, base, or small"
70
  )
71
-
72
- # TTS (Text-to-Speech) Settings - Pocket TTS
73
  POCKET_TTS_VOICE: str = Field(
74
  default="alba",
75
- description="Default voice (alba, marius, javert, jean, fantine, cosette, eponine, azelma) or path to audio file"
76
  )
 
77
  POCKET_TTS_TEMPERATURE: float = Field(
78
  default=0.7,
79
  ge=0.0,
80
  le=2.0,
81
- description="Sampling temperature for generation"
82
  )
83
  POCKET_TTS_LSD_DECODE_STEPS: int = Field(
84
  default=1,
85
  ge=1,
86
- description="LSD decoding steps (higher = better quality, slower)"
87
  )
88
 
89
 
90
  class LLMSettings(CoreSettings):
91
  NVIDIA_API_KEY: Optional[str] = Field(default=None)
92
  NVIDIA_MODEL: str = Field(default="meta/llama-3.1-8b-instruct")
93
- NVIDIA_BASE_URL: str = Field(default="https://integrate.api.nvidia.com/v1")
94
-
95
- HF_TOKEN: Optional[str] = Field(default=None)
96
 
97
  LLM_TEMPERATURE: float = Field(default=0.7, ge=0.0, le=2.0)
98
  LLM_MAX_TOKENS: int = Field(default=1024, gt=0)
99
 
100
 
101
- class APISettings(CoreSettings):
102
- API_HOST: str = Field(default="0.0.0.0")
103
- API_PORT: int = Field(default=8000, gt=0, lt=65536)
104
- API_WORKERS: int = Field(default=1, gt=0)
105
- API_CORS_ORIGINS: list[str] = Field(
106
- default=["http://localhost:8501", "http://localhost:3000"]
107
- )
108
-
109
-
110
  class Settings(CoreSettings):
111
  voice: VoiceSettings = Field(default_factory=VoiceSettings)
112
  llm: LLMSettings = Field(default_factory=LLMSettings)
113
- api: APISettings = Field(default_factory=APISettings)
114
 
115
 
116
  try:
117
  settings = Settings()
118
-
119
  settings_dict = settings.model_dump()
120
  masked_settings = mask_sensitive_data(settings_dict)
121
  logger.info(f"Settings loaded: {json.dumps(masked_settings, indent=2)}")
122
-
123
  except ValidationError as e:
124
  logger.exception(f"Error validating settings: {e.json()}")
125
  raise
 
1
  import json
2
  from pathlib import Path
3
+ from typing import Optional
4
 
5
  from pydantic import Field, ValidationError
6
  from pydantic_settings import BaseSettings, SettingsConfigDict
 
15
  logger.info(f"Loaded environment from: {ENV_FILE}")
16
 
17
 
18
+ def mask_sensitive_data(data: dict) -> dict:
19
  masked = {}
20
  sensitive_keys = ["key", "token", "secret", "password"]
21
+
22
  for key, value in data.items():
23
  if isinstance(value, dict):
24
  masked[key] = mask_sensitive_data(value)
 
31
  masked[key] = f"{value[:4]}...{value[-4:]}"
32
  else:
33
  masked[key] = value
34
+
35
  return masked
36
 
37
 
 
46
 
47
 
48
  class VoiceSettings(CoreSettings):
49
+ MOONSHINE_MODEL_ID: str = Field(
50
+ default="usefulsensors/moonshine-streaming-medium",
51
+ description="Moonshine model size: tiny, base, or small",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  )
 
 
53
  POCKET_TTS_VOICE: str = Field(
54
  default="alba",
55
+ description="Default voice (alba, marius, javert, jean, fantine, cosette, eponine, azelma) or path to audio file",
56
  )
57
+ SAMPLE_RATE_OUTPUT: int = Field(default=48000, gt=0)
58
  POCKET_TTS_TEMPERATURE: float = Field(
59
  default=0.7,
60
  ge=0.0,
61
  le=2.0,
62
+ description="Sampling temperature for generation",
63
  )
64
  POCKET_TTS_LSD_DECODE_STEPS: int = Field(
65
  default=1,
66
  ge=1,
67
+ description="LSD decoding steps (higher = better quality, slower)",
68
  )
69
 
70
 
71
  class LLMSettings(CoreSettings):
72
  NVIDIA_API_KEY: Optional[str] = Field(default=None)
73
  NVIDIA_MODEL: str = Field(default="meta/llama-3.1-8b-instruct")
 
 
 
74
 
75
  LLM_TEMPERATURE: float = Field(default=0.7, ge=0.0, le=2.0)
76
  LLM_MAX_TOKENS: int = Field(default=1024, gt=0)
77
 
78
 
 
 
 
 
 
 
 
 
 
79
  class Settings(CoreSettings):
80
  voice: VoiceSettings = Field(default_factory=VoiceSettings)
81
  llm: LLMSettings = Field(default_factory=LLMSettings)
 
82
 
83
 
84
  try:
85
  settings = Settings()
86
+
87
  settings_dict = settings.model_dump()
88
  masked_settings = mask_sensitive_data(settings_dict)
89
  logger.info(f"Settings loaded: {json.dumps(masked_settings, indent=2)}")
90
+
91
  except ValidationError as e:
92
  logger.exception(f"Error validating settings: {e.json()}")
93
  raise
src/models/__init__.py DELETED
@@ -1 +0,0 @@
1
- """Models package for voice agents and data structures."""
 
 
src/models/voice/__init__.py DELETED
@@ -1,18 +0,0 @@
1
- """Voice provider interfaces and implementations."""
2
-
3
- from src.models.voice.base import BaseVoiceProvider, VoiceProviderConfig
4
- from src.models.voice.types import (
5
- AudioFormat,
6
- VADInfo,
7
- VoiceMessage,
8
- TranscriptionResult,
9
- )
10
-
11
- __all__ = [
12
- "BaseVoiceProvider",
13
- "VoiceProviderConfig",
14
- "AudioFormat",
15
- "VADInfo",
16
- "VoiceMessage",
17
- "TranscriptionResult",
18
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/voice/base.py DELETED
@@ -1,53 +0,0 @@
1
- from abc import ABC, abstractmethod
2
- from typing import AsyncIterator, Optional
3
-
4
- from pydantic import BaseModel
5
-
6
- from src.models.voice.types import TranscriptionResult, VADInfo
7
-
8
-
9
- class VoiceProviderConfig(BaseModel):
10
- provider_name: str
11
- sample_rate_input: int = 24000
12
- sample_rate_output: int = 48000
13
- chunk_duration_ms: int = 80
14
-
15
-
16
- class BaseVoiceProvider(ABC):
17
- def __init__(self, config: VoiceProviderConfig):
18
- self.config = config
19
- self._connected = False
20
-
21
- @abstractmethod
22
- async def connect(self) -> None:
23
- pass
24
-
25
- @abstractmethod
26
- async def disconnect(self) -> None:
27
- pass
28
-
29
- @abstractmethod
30
- async def text_to_speech(
31
- self, text: str, stream: bool = True
32
- ) -> AsyncIterator[bytes]:
33
- pass
34
-
35
- async def speech_to_text(
36
- self, audio_stream: AsyncIterator[bytes]
37
- ) -> AsyncIterator[TranscriptionResult]:
38
- raise NotImplementedError("Speech-to-text not supported by this provider")
39
-
40
- @abstractmethod
41
- async def get_vad_info(self) -> Optional[VADInfo]:
42
- pass
43
-
44
- @property
45
- def is_connected(self) -> bool:
46
- return self._connected
47
-
48
- async def __aenter__(self):
49
- await self.connect()
50
- return self
51
-
52
- async def __aexit__(self, exc_type, exc_val, exc_tb):
53
- await self.disconnect()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/voice/types.py DELETED
@@ -1,43 +0,0 @@
1
- from dataclasses import dataclass
2
- from enum import Enum
3
- from typing import Optional
4
-
5
-
6
- class AudioFormat(str, Enum):
7
- PCM = "pcm"
8
- WAV = "wav"
9
- OPUS = "opus"
10
- ULAW_8000 = "ulaw_8000"
11
- ALAW_8000 = "alaw_8000"
12
- PCM_16000 = "pcm_16000"
13
- PCM_24000 = "pcm_24000"
14
-
15
-
16
- @dataclass
17
- class VoiceMessage:
18
- type: str
19
- content: str | bytes
20
- timestamp: Optional[float] = None
21
- metadata: Optional[dict] = None
22
-
23
-
24
- @dataclass
25
- class VADInfo:
26
- inactivity_prob: float
27
- horizon_s: float
28
- step_idx: int
29
- total_duration_s: float
30
-
31
- @property
32
- def is_turn_complete(self, threshold: float = 0.5) -> bool:
33
- return self.inactivity_prob > threshold
34
-
35
-
36
- @dataclass
37
- class TranscriptionResult:
38
- text: str
39
- start_s: float
40
- stop_s: Optional[float] = None
41
- is_final: bool = False
42
- confidence: Optional[float] = None
43
- stream_id: Optional[int] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/plugins/pocket_tts/tts.py CHANGED
@@ -199,7 +199,6 @@ class PocketSynthesizeStream(tts.SynthesizeStream):
199
  ):
200
  audio_bytes = self._tensor_to_pcm_bytes(audio_chunk)
201
  chunks.append(audio_bytes)
202
- logger.debug(f"Generated chunk: {len(audio_bytes)} bytes")
203
  logger.info(f"Total chunks generated: {len(chunks)}")
204
  return chunks
205
 
@@ -210,10 +209,6 @@ class PocketSynthesizeStream(tts.SynthesizeStream):
210
 
211
  # Push raw PCM bytes to the emitter
212
  for i, chunk in enumerate(audio_chunks):
213
- num_samples = len(chunk) // 2 # int16 = 2 bytes per sample
214
- logger.debug(
215
- f"Pushing chunk {i+1}/{len(audio_chunks)}: {len(chunk)} bytes ({num_samples} samples @ {self._tts._output_sample_rate}Hz)"
216
- )
217
  output_emitter.push(chunk)
218
 
219
  logger.info(f"Successfully pushed all {len(audio_chunks)} chunks")
 
199
  ):
200
  audio_bytes = self._tensor_to_pcm_bytes(audio_chunk)
201
  chunks.append(audio_bytes)
 
202
  logger.info(f"Total chunks generated: {len(chunks)}")
203
  return chunks
204
 
 
209
 
210
  # Push raw PCM bytes to the emitter
211
  for i, chunk in enumerate(audio_chunks):
 
 
 
 
212
  output_emitter.push(chunk)
213
 
214
  logger.info(f"Successfully pushed all {len(audio_chunks)} chunks")
src/streamlit_app.py DELETED
@@ -1,288 +0,0 @@
1
- """Streamlit UI for voice agent with audio I/O."""
2
-
3
- import asyncio
4
- import base64
5
- import json
6
- from threading import Thread
7
- from queue import Queue
8
- from typing import Optional
9
-
10
- import streamlit as st
11
- import websockets
12
-
13
- from src.core.logger import logger
14
-
15
- # Page configuration
16
- st.set_page_config(
17
- page_title="Open Voice Agent",
18
- page_icon="🎙️",
19
- layout="wide",
20
- )
21
-
22
- # Initialize session state
23
- if "messages" not in st.session_state:
24
- st.session_state.messages = []
25
- if "ws_connected" not in st.session_state:
26
- st.session_state.ws_connected = False
27
- if "current_transcript" not in st.session_state:
28
- st.session_state.current_transcript = ""
29
- if "processing" not in st.session_state:
30
- st.session_state.processing = False
31
- if "response_queue" not in st.session_state:
32
- st.session_state.response_queue = Queue()
33
-
34
-
35
- def send_audio_to_websocket(ws_url: str, audio_data: str, response_queue: Queue):
36
- """Send audio to WebSocket and receive responses in background thread.
37
-
38
- Args:
39
- ws_url: WebSocket URL
40
- audio_data: Base64 encoded audio data
41
- response_queue: Queue to put responses
42
- """
43
- async def communicate():
44
- try:
45
- async with websockets.connect(ws_url) as websocket:
46
- # Send audio message
47
- await websocket.send(json.dumps({
48
- "type": "audio",
49
- "data": audio_data
50
- }))
51
- logger.info("Audio sent to WebSocket")
52
-
53
- # Send end turn signal
54
- await websocket.send(json.dumps({"type": "end_turn"}))
55
- logger.info("End turn signal sent")
56
-
57
- # Receive responses
58
- while True:
59
- try:
60
- message = await asyncio.wait_for(websocket.recv(), timeout=30.0)
61
- response = json.loads(message)
62
- response_queue.put(response)
63
-
64
- # Stop if response is complete
65
- if response.get("type") == "response_complete":
66
- logger.info("Response complete received")
67
- break
68
-
69
- except asyncio.TimeoutError:
70
- logger.warning("WebSocket receive timeout")
71
- break
72
- except Exception as e:
73
- logger.error(f"Error receiving message: {e}")
74
- response_queue.put({"type": "error", "message": str(e)})
75
- break
76
-
77
- except Exception as e:
78
- logger.error(f"WebSocket error: {e}")
79
- response_queue.put({"type": "error", "message": str(e)})
80
-
81
- # Run async function
82
- asyncio.run(communicate())
83
-
84
-
85
- # Title and description
86
- st.title("🎙️ Open Voice Agent")
87
- st.markdown(
88
- """
89
- Voice conversation with AI using NVIDIA API for speech synthesis
90
- and LangGraph for conversation management.
91
- """
92
- )
93
-
94
- # Sidebar configuration
95
- with st.sidebar:
96
- st.header("⚙️ Settings")
97
-
98
- # WebSocket connection settings
99
- ws_url = st.text_input(
100
- "WebSocket URL",
101
- value="ws://localhost:8000/ws/voice",
102
- help="URL of the FastAPI WebSocket server",
103
- )
104
-
105
- # Connection status
106
- status_color = "🟢" if st.session_state.ws_connected else "🔴"
107
- st.markdown(f"**Status:** {status_color} {'Connected' if st.session_state.ws_connected else 'Disconnected'}")
108
- st.markdown("**Voice Provider:** NVIDIA API (configured on the server)")
109
-
110
- # Test connection button
111
- if st.button("🔍 Test Connection"):
112
- try:
113
- import requests
114
- response = requests.get("http://localhost:8000/health", timeout=2)
115
- if response.status_code == 200:
116
- st.success("✅ Server is running!")
117
- st.session_state.ws_connected = True
118
- else:
119
- st.error("❌ Server not responding correctly")
120
- st.session_state.ws_connected = False
121
- except Exception as e:
122
- st.error(f"❌ Cannot connect to server: {e}")
123
- st.session_state.ws_connected = False
124
-
125
- st.divider()
126
-
127
- if st.button("🗑️ Clear Conversation"):
128
- st.session_state.messages = []
129
- st.session_state.current_transcript = ""
130
- st.rerun()
131
-
132
- # Download conversation
133
- if st.session_state.messages:
134
- transcript = "\n\n".join(
135
- [f"{msg['role'].upper()}: {msg['content']}" for msg in st.session_state.messages]
136
- )
137
- st.download_button(
138
- label="📥 Download Transcript",
139
- data=transcript,
140
- file_name="conversation_transcript.txt",
141
- mime="text/plain",
142
- )
143
-
144
- # Main content area
145
- col1, col2 = st.columns([2, 1])
146
-
147
- with col1:
148
- st.subheader("💬 Conversation")
149
-
150
- # Chat container
151
- chat_container = st.container(height=500)
152
-
153
- with chat_container:
154
- # Display messages
155
- for message in st.session_state.messages:
156
- with st.chat_message(message["role"]):
157
- st.write(message["content"])
158
-
159
- # Display current transcript being captured
160
- if st.session_state.current_transcript:
161
- with st.chat_message("user"):
162
- st.write(f"*{st.session_state.current_transcript}...*")
163
-
164
- with col2:
165
- st.subheader("🎤 Audio Controls")
166
-
167
- # Check if server is running
168
- if not st.session_state.ws_connected:
169
- st.warning("⚠️ Please test connection first")
170
-
171
- # Audio recorder using built-in st.audio_input
172
- st.markdown("**Record your message:**")
173
-
174
- audio_data = st.audio_input(
175
- "Click to start/stop recording",
176
- key="audio_input",
177
- help="Click to start recording, click again to stop"
178
- )
179
-
180
- if audio_data is not None:
181
- st.success("✓ Audio recorded!")
182
-
183
- # Show audio player
184
- st.audio(audio_data)
185
-
186
- # Send button
187
- if st.button("📤 Send Audio", disabled=st.session_state.processing):
188
- if not st.session_state.ws_connected:
189
- st.error("❌ Not connected to server")
190
- else:
191
- st.session_state.processing = True
192
-
193
- # Get audio bytes
194
- audio_bytes = audio_data.getvalue()
195
-
196
- # Encode to base64
197
- encoded_audio = base64.b64encode(audio_bytes).decode('utf-8')
198
-
199
- # Show processing indicator
200
- with st.spinner("🔊 Processing audio..."):
201
- # Start WebSocket communication in background thread
202
- thread = Thread(
203
- target=send_audio_to_websocket,
204
- args=(ws_url, encoded_audio, st.session_state.response_queue),
205
- daemon=True
206
- )
207
- thread.start()
208
-
209
- # Wait for thread to complete (with timeout)
210
- thread.join(timeout=30)
211
-
212
- # Process responses from queue
213
- transcript_text = ""
214
- response_text = ""
215
- audio_chunks = []
216
-
217
- while not st.session_state.response_queue.empty():
218
- response = st.session_state.response_queue.get()
219
- msg_type = response.get("type")
220
-
221
- if msg_type == "transcript":
222
- transcript_text += " " + response.get("text", "")
223
- elif msg_type == "response_text":
224
- response_text = response.get("text", "")
225
- elif msg_type == "audio":
226
- audio_chunks.append(response.get("data", ""))
227
- elif msg_type == "error":
228
- st.error(f"Error: {response.get('message')}")
229
-
230
- # Add messages to conversation
231
- if transcript_text.strip():
232
- st.session_state.messages.append({
233
- "role": "user",
234
- "content": transcript_text.strip()
235
- })
236
-
237
- if response_text:
238
- st.session_state.messages.append({
239
- "role": "assistant",
240
- "content": response_text
241
- })
242
-
243
- st.session_state.processing = False
244
- st.success("✅ Processing complete!")
245
- st.rerun()
246
-
247
- # Processing indicator
248
- if st.session_state.processing:
249
- st.info("⏳ Processing your message...")
250
-
251
- # Instructions
252
- with st.expander("📖 How to Use"):
253
- st.markdown(
254
- """
255
- 1. **Test connection** first to ensure server is running
256
- 2. **Click** the audio input to start recording
257
- 3. **Speak** your message clearly
258
- 4. **Click again** to stop recording
259
- 5. **Review** the recorded audio (optional)
260
- 6. **Send** the audio for processing
261
- 7. Wait for the AI response (text will appear in chat)
262
-
263
- **Tips:**
264
- - Speak clearly and at a normal pace
265
- - Wait for the response before recording again
266
- - Keep messages concise for better results
267
- - Use headphones to avoid echo
268
- """
269
- )
270
-
271
- # System info
272
- with st.expander("ℹ️ System Info"):
273
- st.markdown(f"""
274
- **Voice Provider:** NVIDIA API
275
- **WebSocket:** `{ws_url}`
276
- **Messages:** {len(st.session_state.messages)}
277
- """)
278
-
279
- # Footer
280
- st.markdown("---")
281
- st.markdown(
282
- """
283
- <div style='text-align: center'>
284
- <small>Powered by NVIDIA API (Voice) + LangGraph (Conversations) + Streamlit (UI)</small>
285
- </div>
286
- """,
287
- unsafe_allow_html=True,
288
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
testing/asr_moonshine.py DELETED
@@ -1,48 +0,0 @@
1
- import io
2
- import math
3
- import numpy as np
4
- import soundfile as sf
5
- from scipy.signal import resample_poly
6
- import torch
7
- from transformers import AutoProcessor, MoonshineStreamingForConditionalGeneration
8
-
9
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
10
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
11
-
12
- model_id = "usefulsensors/moonshine-streaming-small"
13
-
14
- model = MoonshineStreamingForConditionalGeneration.from_pretrained(model_id).to(
15
- device, torch_dtype
16
- )
17
- processor = AutoProcessor.from_pretrained(model_id)
18
-
19
- # Read audio file
20
- with open("dev/kokoro_tts.wav", "rb") as f:
21
- audio_bytes = f.read()
22
-
23
- # Load audio using soundfile
24
- audio_np, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32")
25
-
26
- if audio_np.ndim > 1:
27
- audio_np = np.mean(audio_np, axis=1)
28
-
29
- if sr != 16000:
30
- ratio_gcd = math.gcd(sr, 16000)
31
- up = 16000 // ratio_gcd
32
- down = sr // ratio_gcd
33
- print(f"Resampling from {sr}Hz to 16000Hz")
34
- audio_np = resample_poly(audio_np, up=up, down=down)
35
-
36
- inputs = processor(
37
- audio_np,
38
- return_tensors="pt",
39
- sampling_rate=16000,
40
- ).to(device, torch_dtype)
41
-
42
- token_limit_factor = 6.5 / 16000
43
- max_length = int((inputs.attention_mask.sum() * token_limit_factor).max().item())
44
-
45
- generated_ids = model.generate(**inputs, max_length=max_length)
46
- transcription = processor.decode(generated_ids[0], skip_special_tokens=True)
47
-
48
- print(f"Transcription: {transcription}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
testing/nvidia_.py DELETED
@@ -1,4 +0,0 @@
1
- import nemo.collections.asr as nemo_asr
2
- asr_model = nemo_asr.models.ASRModel.from_pretrained("nvidia/parakeet-tdt-0.6b-v2")
3
-
4
- transcriptions = asr_model.transcribe(["dev/kokoro_tts.wav"])
 
 
 
 
 
testing/pocket_tts_test.py DELETED
@@ -1,13 +0,0 @@
1
- from pocket_tts import TTSModel
2
- import scipy.io.wavfile
3
-
4
- tts_model = TTSModel.load_model()
5
- voice_state = tts_model.get_state_for_audio_prompt(
6
- "alba" # One of the pre-made voices, see above
7
- # You can also use any voice file you have locally or from Hugging Face:
8
- # "./some_audio.wav"
9
- # or "hf://kyutai/tts-voices/expresso/ex01-ex02_default_001_channel2_198s.wav"
10
- )
11
- audio = tts_model.generate_audio(voice_state, "Hello world, this is a test.")
12
- # Audio is a 1D torch tensor containing PCM data.
13
- scipy.io.wavfile.write("dev/pocket_tts.wav", tts_model.sample_rate, audio.numpy())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
uv.lock CHANGED
@@ -2081,7 +2081,7 @@ wheels = [
2081
  [[package]]
2082
  name = "open-voice-agent"
2083
  version = "0.1.0"
2084
- source = { virtual = "." }
2085
  dependencies = [
2086
  { name = "langgraph" },
2087
  { name = "lhotse" },
 
2081
  [[package]]
2082
  name = "open-voice-agent"
2083
  version = "0.1.0"
2084
+ source = { editable = "." }
2085
  dependencies = [
2086
  { name = "langgraph" },
2087
  { name = "lhotse" },