Spaces:
Running
Running
Refactor Open Voice Agent: Transition to Hatch for build system, restructure agent components, and streamline conversation graph
Browse files- .gitignore +3 -1
- main.py +0 -84
- pyproject.toml +7 -0
- src/agent/__init__.py +4 -4
- testing/livekit_custom.py → src/agent/agent.py +20 -66
- src/agent/graph.py +18 -112
- src/agent/llm_factory.py +0 -163
- src/agent/prompts.py +0 -41
- src/agent/state.py +0 -10
- src/api/__init__.py +0 -5
- src/api/main.py +0 -69
- src/core/__init__.py +4 -0
- src/core/settings.py +13 -45
- src/models/__init__.py +0 -1
- src/models/voice/__init__.py +0 -18
- src/models/voice/base.py +0 -53
- src/models/voice/types.py +0 -43
- src/plugins/pocket_tts/tts.py +0 -5
- src/streamlit_app.py +0 -288
- testing/asr_moonshine.py +0 -48
- testing/nvidia_.py +0 -4
- testing/pocket_tts_test.py +0 -13
- uv.lock +1 -1
.gitignore
CHANGED
|
@@ -119,7 +119,9 @@ venv.bak/
|
|
| 119 |
dev/
|
| 120 |
nvidia_services/cache/asr/
|
| 121 |
nvidia_services/cache/tts/
|
| 122 |
-
|
|
|
|
|
|
|
| 123 |
# Spyder project settings
|
| 124 |
.spyderproject
|
| 125 |
.spyproject
|
|
|
|
| 119 |
dev/
|
| 120 |
nvidia_services/cache/asr/
|
| 121 |
nvidia_services/cache/tts/
|
| 122 |
+
.claude/
|
| 123 |
+
.cursor/
|
| 124 |
+
.pytest_cache/
|
| 125 |
# Spyder project settings
|
| 126 |
.spyderproject
|
| 127 |
.spyproject
|
main.py
DELETED
|
@@ -1,84 +0,0 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import multiprocessing
|
| 3 |
-
import subprocess
|
| 4 |
-
import sys
|
| 5 |
-
|
| 6 |
-
from src.core.logger import logger
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def run_api():
|
| 10 |
-
logger.info("Starting FastAPI server...")
|
| 11 |
-
import uvicorn
|
| 12 |
-
from src.core.settings import settings
|
| 13 |
-
|
| 14 |
-
uvicorn.run(
|
| 15 |
-
"src.api.main:app",
|
| 16 |
-
host=settings.api.API_HOST,
|
| 17 |
-
port=settings.api.API_PORT,
|
| 18 |
-
workers=settings.api.API_WORKERS,
|
| 19 |
-
reload=True,
|
| 20 |
-
log_level="info",
|
| 21 |
-
)
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def run_streamlit():
|
| 25 |
-
logger.info("Starting Streamlit UI...")
|
| 26 |
-
subprocess.run([
|
| 27 |
-
sys.executable,
|
| 28 |
-
"-m",
|
| 29 |
-
"streamlit",
|
| 30 |
-
"run",
|
| 31 |
-
"src/streamlit_app.py",
|
| 32 |
-
"--server.port=8501",
|
| 33 |
-
"--server.address=localhost",
|
| 34 |
-
])
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def run_both():
|
| 38 |
-
logger.info("Starting both FastAPI server and Streamlit UI...")
|
| 39 |
-
|
| 40 |
-
api_process = multiprocessing.Process(target=run_api, name="FastAPI")
|
| 41 |
-
streamlit_process = multiprocessing.Process(target=run_streamlit, name="Streamlit")
|
| 42 |
-
|
| 43 |
-
try:
|
| 44 |
-
api_process.start()
|
| 45 |
-
streamlit_process.start()
|
| 46 |
-
|
| 47 |
-
api_process.join()
|
| 48 |
-
streamlit_process.join()
|
| 49 |
-
|
| 50 |
-
except KeyboardInterrupt:
|
| 51 |
-
logger.info("Shutting down...")
|
| 52 |
-
api_process.terminate()
|
| 53 |
-
streamlit_process.terminate()
|
| 54 |
-
api_process.join()
|
| 55 |
-
streamlit_process.join()
|
| 56 |
-
logger.info("Shutdown complete")
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def main():
|
| 60 |
-
parser = argparse.ArgumentParser(
|
| 61 |
-
description="Open Voice Agent - Real-time AI voice conversations"
|
| 62 |
-
)
|
| 63 |
-
parser.add_argument(
|
| 64 |
-
"mode",
|
| 65 |
-
choices=["api", "streamlit", "both"],
|
| 66 |
-
default="both",
|
| 67 |
-
nargs="?",
|
| 68 |
-
help="Run mode: 'api' (FastAPI server), 'streamlit' (UI), or 'both' (default)",
|
| 69 |
-
)
|
| 70 |
-
|
| 71 |
-
args = parser.parse_args()
|
| 72 |
-
|
| 73 |
-
logger.info(f"Starting Open Voice Agent in '{args.mode}' mode...")
|
| 74 |
-
|
| 75 |
-
if args.mode == "api":
|
| 76 |
-
run_api()
|
| 77 |
-
elif args.mode == "streamlit":
|
| 78 |
-
run_streamlit()
|
| 79 |
-
else:
|
| 80 |
-
run_both()
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
if __name__ == "__main__":
|
| 84 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pyproject.toml
CHANGED
|
@@ -1,3 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
[project]
|
| 2 |
name = "open-voice-agent"
|
| 3 |
version = "0.1.0"
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["hatchling"]
|
| 3 |
+
build-backend = "hatchling.build"
|
| 4 |
+
|
| 5 |
+
[tool.hatch.build.targets.wheel]
|
| 6 |
+
packages = ["src"]
|
| 7 |
+
|
| 8 |
[project]
|
| 9 |
name = "open-voice-agent"
|
| 10 |
version = "0.1.0"
|
src/agent/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
-
"""
|
| 2 |
|
| 3 |
-
from src.agent.graph import
|
| 4 |
-
from src.agent.
|
| 5 |
|
| 6 |
-
__all__ = ["
|
|
|
|
| 1 |
+
"""LiveKit voice agent using LangGraph."""
|
| 2 |
|
| 3 |
+
from src.agent.graph import create_graph
|
| 4 |
+
from src.agent.agent import Assistant
|
| 5 |
|
| 6 |
+
__all__ = ["create_graph", "Assistant"]
|
testing/livekit_custom.py → src/agent/agent.py
RENAMED
|
@@ -1,64 +1,14 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
import sys
|
| 5 |
-
|
| 6 |
-
project_root = Path(__file__).resolve().parents[1]
|
| 7 |
-
if str(project_root) not in sys.path:
|
| 8 |
-
sys.path.insert(0, str(project_root))
|
| 9 |
-
|
| 10 |
-
from dotenv import load_dotenv
|
| 11 |
-
|
| 12 |
from livekit import agents, rtc
|
| 13 |
-
from livekit.agents import AgentServer,AgentSession, Agent, room_io
|
| 14 |
from livekit.plugins import noise_cancellation, silero
|
| 15 |
from livekit.plugins.turn_detector.multilingual import MultilingualModel
|
| 16 |
-
|
| 17 |
-
import os
|
| 18 |
-
from langgraph.graph import StateGraph, MessagesState, START, END
|
| 19 |
-
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
| 20 |
from livekit.plugins import langchain
|
| 21 |
|
| 22 |
-
from
|
| 23 |
-
from huggingface_hub import InferenceClient
|
| 24 |
-
import io
|
| 25 |
-
import wave
|
| 26 |
from src.plugins.moonshine_stt import MoonshineSTT
|
| 27 |
-
from src.
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
# Simple LangGraph workflow with NVIDIA LLM
|
| 32 |
-
def create_nvidia_workflow():
|
| 33 |
-
"""Create a simple LangGraph workflow with NVIDIA ChatNVIDIA"""
|
| 34 |
-
|
| 35 |
-
# Initialize NVIDIA LLM
|
| 36 |
-
nvidia_llm = ChatNVIDIA(
|
| 37 |
-
model="meta/llama-3.1-8b-instruct",
|
| 38 |
-
api_key=os.getenv("NVIDIA_API_KEY"),
|
| 39 |
-
temperature=0.7,
|
| 40 |
-
max_tokens=150
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
# Define the conversation node
|
| 44 |
-
def call_model(state: MessagesState):
|
| 45 |
-
"""Simple node that calls the NVIDIA LLM"""
|
| 46 |
-
response = nvidia_llm.invoke(state["messages"])
|
| 47 |
-
return {"messages": [response]}
|
| 48 |
-
|
| 49 |
-
# Build the graph
|
| 50 |
-
workflow = StateGraph(MessagesState)
|
| 51 |
-
|
| 52 |
-
# Add the single node
|
| 53 |
-
workflow.add_node("agent", call_model)
|
| 54 |
-
|
| 55 |
-
# Define the flow: START -> agent -> END
|
| 56 |
-
workflow.add_edge(START, "agent")
|
| 57 |
-
workflow.add_edge("agent", END)
|
| 58 |
-
|
| 59 |
-
# Compile and return
|
| 60 |
-
return workflow.compile()
|
| 61 |
-
|
| 62 |
|
| 63 |
|
| 64 |
class Assistant(Agent):
|
|
@@ -70,32 +20,36 @@ class Assistant(Agent):
|
|
| 70 |
You are curious, friendly, and have a sense of humor.""",
|
| 71 |
)
|
| 72 |
|
|
|
|
| 73 |
server = AgentServer()
|
| 74 |
|
|
|
|
| 75 |
@server.rtc_session()
|
| 76 |
-
async def
|
| 77 |
session = AgentSession(
|
| 78 |
-
stt=MoonshineSTT(model_id=
|
| 79 |
-
llm=langchain.LLMAdapter(
|
| 80 |
-
tts=
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
vad=silero.VAD.load(),
|
| 82 |
turn_detection=MultilingualModel(),
|
| 83 |
)
|
| 84 |
-
|
| 85 |
await session.start(
|
| 86 |
room=ctx.room,
|
| 87 |
agent=Assistant(),
|
| 88 |
room_options=room_io.RoomOptions(
|
| 89 |
audio_input=room_io.AudioInputOptions(
|
| 90 |
-
noise_cancellation=lambda params: noise_cancellation.BVCTelephony()
|
|
|
|
|
|
|
| 91 |
),
|
| 92 |
),
|
| 93 |
)
|
| 94 |
-
|
| 95 |
-
await session.generate_reply(
|
| 96 |
-
instructions="Greet the user and offer your assistance."
|
| 97 |
-
)
|
| 98 |
|
| 99 |
|
| 100 |
if __name__ == "__main__":
|
| 101 |
-
agents.cli.run_app(server)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from livekit import agents, rtc
|
| 2 |
+
from livekit.agents import AgentServer, AgentSession, Agent, room_io
|
| 3 |
from livekit.plugins import noise_cancellation, silero
|
| 4 |
from livekit.plugins.turn_detector.multilingual import MultilingualModel
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from livekit.plugins import langchain
|
| 6 |
|
| 7 |
+
from src.agent.graph import create_graph
|
|
|
|
|
|
|
|
|
|
| 8 |
from src.plugins.moonshine_stt import MoonshineSTT
|
| 9 |
+
from src.plugins.pocket_tts import PocketTTS
|
| 10 |
+
from src.core.settings import settings
|
| 11 |
+
from src.core.logger import logger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
class Assistant(Agent):
|
|
|
|
| 20 |
You are curious, friendly, and have a sense of humor.""",
|
| 21 |
)
|
| 22 |
|
| 23 |
+
|
| 24 |
server = AgentServer()
|
| 25 |
|
| 26 |
+
|
| 27 |
@server.rtc_session()
|
| 28 |
+
async def session_handler(ctx: agents.JobContext) -> None:
|
| 29 |
session = AgentSession(
|
| 30 |
+
stt=MoonshineSTT(model_id=settings.voice.MOONSHINE_MODEL_ID),
|
| 31 |
+
llm=langchain.LLMAdapter(create_graph()),
|
| 32 |
+
tts=PocketTTS(
|
| 33 |
+
voice=settings.voice.POCKET_TTS_VOICE,
|
| 34 |
+
temperature=settings.voice.POCKET_TTS_TEMPERATURE,
|
| 35 |
+
lsd_decode_steps=settings.voice.POCKET_TTS_LSD_DECODE_STEPS,
|
| 36 |
+
),
|
| 37 |
vad=silero.VAD.load(),
|
| 38 |
turn_detection=MultilingualModel(),
|
| 39 |
)
|
|
|
|
| 40 |
await session.start(
|
| 41 |
room=ctx.room,
|
| 42 |
agent=Assistant(),
|
| 43 |
room_options=room_io.RoomOptions(
|
| 44 |
audio_input=room_io.AudioInputOptions(
|
| 45 |
+
noise_cancellation=lambda params: noise_cancellation.BVCTelephony()
|
| 46 |
+
if params.participant.kind == rtc.ParticipantKind.PARTICIPANT_KIND_SIP
|
| 47 |
+
else noise_cancellation.BVC(),
|
| 48 |
),
|
| 49 |
),
|
| 50 |
)
|
| 51 |
+
await session.generate_reply(instructions="Greet the user and offer your assistance.")
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
|
| 54 |
if __name__ == "__main__":
|
| 55 |
+
agents.cli.run_app(server)
|
src/agent/graph.py
CHANGED
|
@@ -1,117 +1,23 @@
|
|
| 1 |
-
from
|
|
|
|
| 2 |
|
| 3 |
-
from
|
| 4 |
-
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
| 5 |
-
from langgraph.graph import StateGraph, END
|
| 6 |
-
from langgraph.checkpoint.memory import MemorySaver
|
| 7 |
|
| 8 |
-
from src.agent.llm_factory import LLMFactory
|
| 9 |
-
from src.agent.prompts import get_system_prompt
|
| 10 |
-
from src.agent.state import ConversationState
|
| 11 |
-
from src.core.logger import logger
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
-
def
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
if not transcript:
|
| 18 |
-
logger.debug("No transcript to process")
|
| 19 |
-
return state
|
| 20 |
-
|
| 21 |
-
logger.info(f"Processing user input: {transcript}")
|
| 22 |
-
|
| 23 |
-
messages = state.get("messages", [])
|
| 24 |
-
|
| 25 |
-
if not messages:
|
| 26 |
-
messages.append(SystemMessage(content=get_system_prompt()))
|
| 27 |
-
|
| 28 |
-
messages.append(HumanMessage(content=transcript))
|
| 29 |
-
|
| 30 |
-
return {
|
| 31 |
-
**state,
|
| 32 |
-
"messages": messages,
|
| 33 |
-
"current_transcript": "",
|
| 34 |
-
}
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def generate_response(state: ConversationState, llm: Optional[BaseLanguageModel] = None) -> ConversationState:
|
| 38 |
-
if llm is None:
|
| 39 |
-
llm = LLMFactory.create_llm()
|
| 40 |
-
|
| 41 |
-
messages = state.get("messages", [])
|
| 42 |
-
|
| 43 |
-
if not messages:
|
| 44 |
-
logger.warning("No messages to generate response from")
|
| 45 |
-
return state
|
| 46 |
-
|
| 47 |
-
logger.info("Generating AI response...")
|
| 48 |
-
|
| 49 |
-
try:
|
| 50 |
-
response = llm.invoke(messages)
|
| 51 |
-
|
| 52 |
-
if hasattr(response, "content"):
|
| 53 |
-
content = response.content
|
| 54 |
-
else:
|
| 55 |
-
content = str(response)
|
| 56 |
-
|
| 57 |
-
logger.info(f"Generated response: {content[:100]}...")
|
| 58 |
-
|
| 59 |
-
messages.append(AIMessage(content=content))
|
| 60 |
-
|
| 61 |
-
return {
|
| 62 |
-
**state,
|
| 63 |
-
"messages": messages,
|
| 64 |
-
}
|
| 65 |
-
|
| 66 |
-
except Exception as e:
|
| 67 |
-
logger.error(f"Error generating response: {e}")
|
| 68 |
-
fallback = "I'm sorry, I encountered an error. Could you please repeat that?"
|
| 69 |
-
messages.append(AIMessage(content=fallback))
|
| 70 |
-
return {
|
| 71 |
-
**state,
|
| 72 |
-
"messages": messages,
|
| 73 |
-
}
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
def should_respond(state: ConversationState) -> Literal["generate", "wait"]:
|
| 77 |
-
turn_active = state.get("turn_active", False)
|
| 78 |
-
current_transcript = state.get("current_transcript", "").strip()
|
| 79 |
-
|
| 80 |
-
if turn_active:
|
| 81 |
-
logger.debug("Turn still active, waiting...")
|
| 82 |
-
return "wait"
|
| 83 |
-
|
| 84 |
-
if current_transcript:
|
| 85 |
-
logger.debug("Turn complete with transcript, generating response")
|
| 86 |
-
return "generate"
|
| 87 |
-
|
| 88 |
-
logger.debug("No action needed, waiting...")
|
| 89 |
-
return "wait"
|
| 90 |
-
|
| 91 |
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
workflow
|
| 96 |
-
|
| 97 |
-
workflow.add_node("process_input", process_user_input)
|
| 98 |
-
workflow.add_node("generate_response", generate_response)
|
| 99 |
-
|
| 100 |
-
workflow.set_entry_point("process_input")
|
| 101 |
-
|
| 102 |
-
workflow.add_conditional_edges(
|
| 103 |
-
"process_input",
|
| 104 |
-
should_respond,
|
| 105 |
-
{
|
| 106 |
-
"generate": "generate_response",
|
| 107 |
-
"wait": END,
|
| 108 |
-
},
|
| 109 |
-
)
|
| 110 |
-
|
| 111 |
-
workflow.add_edge("generate_response", END)
|
| 112 |
-
|
| 113 |
-
memory = MemorySaver()
|
| 114 |
-
graph = workflow.compile(checkpointer=memory)
|
| 115 |
-
|
| 116 |
-
logger.info("Conversation graph created successfully")
|
| 117 |
-
return graph
|
|
|
|
| 1 |
+
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
| 2 |
+
from langgraph.graph import StateGraph, MessagesState, START, END
|
| 3 |
|
| 4 |
+
from src.core.settings import settings
|
|
|
|
|
|
|
|
|
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
def create_graph():
|
| 8 |
+
"""Create a single-node LangGraph workflow using NVIDIA ChatNVIDIA."""
|
| 9 |
+
llm = ChatNVIDIA(
|
| 10 |
+
model=settings.llm.NVIDIA_MODEL,
|
| 11 |
+
api_key=settings.llm.NVIDIA_API_KEY,
|
| 12 |
+
temperature=settings.llm.LLM_TEMPERATURE,
|
| 13 |
+
max_tokens=settings.llm.LLM_MAX_TOKENS,
|
| 14 |
+
)
|
| 15 |
|
| 16 |
+
def call_model(state: MessagesState) -> dict:
|
| 17 |
+
return {"messages": [llm.invoke(state["messages"])]}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
+
workflow = StateGraph(MessagesState)
|
| 20 |
+
workflow.add_node("agent", call_model)
|
| 21 |
+
workflow.add_edge(START, "agent")
|
| 22 |
+
workflow.add_edge("agent", END)
|
| 23 |
+
return workflow.compile()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/agent/llm_factory.py
DELETED
|
@@ -1,163 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
from typing import Any, Union
|
| 3 |
-
|
| 4 |
-
from huggingface_hub import InferenceClient
|
| 5 |
-
from transformers import pipeline
|
| 6 |
-
#from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFacePipeline
|
| 7 |
-
|
| 8 |
-
#from kokoro import KPipeline
|
| 9 |
-
import torch
|
| 10 |
-
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
| 11 |
-
|
| 12 |
-
from src.core.logger import logger
|
| 13 |
-
from src.core.settings import settings
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class LLMFactory:
|
| 17 |
-
@staticmethod
|
| 18 |
-
def create_nvidia_llm(
|
| 19 |
-
model: str = settings.llm.NVIDIA_MODEL,
|
| 20 |
-
temperature: float = settings.llm.LLM_TEMPERATURE,
|
| 21 |
-
max_tokens: int = settings.llm.LLM_MAX_TOKENS,
|
| 22 |
-
) -> ChatNVIDIA:
|
| 23 |
-
logger.info(f"Initializing NVIDIA LLM: {model}")
|
| 24 |
-
|
| 25 |
-
if not settings.llm.NVIDIA_API_KEY:
|
| 26 |
-
raise ValueError("NVIDIA_API_KEY must be set to use the NVIDIA LLM provider.")
|
| 27 |
-
|
| 28 |
-
return ChatNVIDIA(
|
| 29 |
-
model=model,
|
| 30 |
-
api_key=settings.llm.NVIDIA_API_KEY,
|
| 31 |
-
temperature=temperature,
|
| 32 |
-
max_completion_tokens=max_tokens,
|
| 33 |
-
)
|
| 34 |
-
|
| 35 |
-
# @staticmethod
|
| 36 |
-
# def create_huggingface_llm(
|
| 37 |
-
# model_id: str,
|
| 38 |
-
# provider: str = "auto",
|
| 39 |
-
# temperature: float = settings.llm.LLM_TEMPERATURE,
|
| 40 |
-
# max_tokens: int = settings.llm.LLM_MAX_TOKENS,
|
| 41 |
-
# run_local: bool = False,
|
| 42 |
-
# ) -> ChatHuggingFace:
|
| 43 |
-
# if run_local:
|
| 44 |
-
# logger.info(f"Initializing local HuggingFace LLM: {model_id}")
|
| 45 |
-
# llm = HuggingFacePipeline.from_model_id(
|
| 46 |
-
# model_id=model_id,
|
| 47 |
-
# task="text-generation",
|
| 48 |
-
# pipeline_kwargs={
|
| 49 |
-
# "temperature": temperature,
|
| 50 |
-
# "max_new_tokens": max_tokens,
|
| 51 |
-
# },
|
| 52 |
-
# )
|
| 53 |
-
# return ChatHuggingFace(llm=llm)
|
| 54 |
-
|
| 55 |
-
# token = (settings.llm.HF_TOKEN or "").strip()
|
| 56 |
-
# if not token:
|
| 57 |
-
# raise ValueError("HF_TOKEN must be set to use the HuggingFace LLM provider.")
|
| 58 |
-
|
| 59 |
-
# logger.info(f"Initializing HuggingFace LLM: {model_id} via provider={provider}")
|
| 60 |
-
|
| 61 |
-
# llm = HuggingFaceEndpoint(
|
| 62 |
-
# repo_id=model_id,
|
| 63 |
-
# provider=provider,
|
| 64 |
-
# huggingfacehub_api_token=token,
|
| 65 |
-
# temperature=temperature,
|
| 66 |
-
# max_new_tokens=max_tokens,
|
| 67 |
-
# )
|
| 68 |
-
# return ChatHuggingFace(llm=llm)
|
| 69 |
-
|
| 70 |
-
@staticmethod
|
| 71 |
-
def create_huggingface_stt(
|
| 72 |
-
model_id: str | None = None, run_local: bool = False
|
| 73 |
-
) -> Union[InferenceClient, Any]:
|
| 74 |
-
if run_local:
|
| 75 |
-
logger.info(f"Initializing local HuggingFace STT: {model_id or 'default'}")
|
| 76 |
-
return pipeline("automatic-speech-recognition", model=model_id)
|
| 77 |
-
|
| 78 |
-
token = (settings.llm.HF_TOKEN or "").strip()
|
| 79 |
-
if not token:
|
| 80 |
-
raise ValueError("HF_TOKEN must be set to use the HuggingFace STT provider.")
|
| 81 |
-
|
| 82 |
-
logger.info(f"Initializing HuggingFace STT: {model_id or 'default'}")
|
| 83 |
-
|
| 84 |
-
return InferenceClient(model=model_id, token=token)
|
| 85 |
-
|
| 86 |
-
@staticmethod
|
| 87 |
-
def create_huggingface_tts(
|
| 88 |
-
model_id: str | None = None, run_local: bool = False
|
| 89 |
-
) -> Union[InferenceClient, Any]:
|
| 90 |
-
if run_local:
|
| 91 |
-
logger.info(f"Initializing local HuggingFace TTS: {model_id or 'default'}")
|
| 92 |
-
return pipeline("text-to-speech", model=model_id)
|
| 93 |
-
|
| 94 |
-
token = (settings.llm.HF_TOKEN or "").strip()
|
| 95 |
-
if not token:
|
| 96 |
-
raise ValueError("HF_TOKEN must be set to use the HuggingFace TTS provider.")
|
| 97 |
-
|
| 98 |
-
logger.info(f"Initializing HuggingFace TTS: {model_id or 'default'}")
|
| 99 |
-
|
| 100 |
-
return InferenceClient(model=model_id, token=token)
|
| 101 |
-
|
| 102 |
-
@staticmethod
|
| 103 |
-
def create_kokoro_tts(lang_code: str = "a") -> Any:
|
| 104 |
-
if KPipeline is None:
|
| 105 |
-
raise ImportError(
|
| 106 |
-
"kokoro library not found. Please install it (pip install kokoro>=0.9.4) to use Kokoro TTS."
|
| 107 |
-
)
|
| 108 |
-
|
| 109 |
-
logger.info(f"Initializing Kokoro TTS Pipeline with lang_code: {lang_code}")
|
| 110 |
-
return KPipeline(lang_code=lang_code, repo_id="hexgrad/Kokoro-82M")
|
| 111 |
-
|
| 112 |
-
@staticmethod
|
| 113 |
-
def create_moonshine_stt(
|
| 114 |
-
model_size: str = "base",
|
| 115 |
-
language: str = "en",
|
| 116 |
-
) -> "MoonshineSTT":
|
| 117 |
-
"""Initialize Moonshine ONNX STT plugin.
|
| 118 |
-
|
| 119 |
-
Args:
|
| 120 |
-
model_size: "tiny" (26MB) or "base" (57MB), or language variants (e.g., "base-es", "tiny-ar")
|
| 121 |
-
language: Currently only "en" supported
|
| 122 |
-
|
| 123 |
-
Returns:
|
| 124 |
-
MoonshineSTT plugin instance
|
| 125 |
-
"""
|
| 126 |
-
logger.info(f"Initializing Moonshine ONNX STT: {model_size}")
|
| 127 |
-
from src.plugins.moonshine_stt import MoonshineSTT
|
| 128 |
-
return MoonshineSTT(model_size=model_size, language=language)
|
| 129 |
-
|
| 130 |
-
@staticmethod
|
| 131 |
-
def create_pocket_tts(
|
| 132 |
-
voice: str | None = None,
|
| 133 |
-
temperature: float | None = None,
|
| 134 |
-
lsd_decode_steps: int | None = None,
|
| 135 |
-
) -> "PocketTTS":
|
| 136 |
-
"""Initialize Pocket TTS plugin.
|
| 137 |
-
|
| 138 |
-
Args:
|
| 139 |
-
voice: Voice name (alba, marius, etc.) or path to audio file.
|
| 140 |
-
If None, uses settings.voice.POCKET_TTS_VOICE
|
| 141 |
-
temperature: Sampling temperature (0.0-2.0).
|
| 142 |
-
If None, uses settings.voice.POCKET_TTS_TEMPERATURE
|
| 143 |
-
lsd_decode_steps: LSD decoding steps for quality.
|
| 144 |
-
If None, uses settings.voice.POCKET_TTS_LSD_DECODE_STEPS
|
| 145 |
-
|
| 146 |
-
Returns:
|
| 147 |
-
PocketTTS plugin instance
|
| 148 |
-
"""
|
| 149 |
-
from src.plugins.pocket_tts import PocketTTS
|
| 150 |
-
|
| 151 |
-
if voice is None:
|
| 152 |
-
voice = settings.voice.POCKET_TTS_VOICE
|
| 153 |
-
if temperature is None:
|
| 154 |
-
temperature = settings.voice.POCKET_TTS_TEMPERATURE
|
| 155 |
-
if lsd_decode_steps is None:
|
| 156 |
-
lsd_decode_steps = settings.voice.POCKET_TTS_LSD_DECODE_STEPS
|
| 157 |
-
|
| 158 |
-
logger.info(f"Initializing Pocket TTS: voice={voice}, temp={temperature}, lsd_steps={lsd_decode_steps}")
|
| 159 |
-
return PocketTTS(
|
| 160 |
-
voice=voice,
|
| 161 |
-
temperature=temperature,
|
| 162 |
-
lsd_decode_steps=lsd_decode_steps,
|
| 163 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/agent/prompts.py
DELETED
|
@@ -1,41 +0,0 @@
|
|
| 1 |
-
from typing import Any, Optional
|
| 2 |
-
from enum import Enum
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
class PromptVersion(str, Enum):
|
| 6 |
-
V1 = "v1"
|
| 7 |
-
DEFAULT = "v1"
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class PromptTemplate:
|
| 11 |
-
def __init__(self, template: str, version: PromptVersion = PromptVersion.DEFAULT):
|
| 12 |
-
self.template = template
|
| 13 |
-
self.version = version
|
| 14 |
-
|
| 15 |
-
def render(self, **kwargs: Any) -> str:
|
| 16 |
-
return self.template.format(**kwargs)
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
SYSTEM_PROMPT_V1 = """You are a helpful AI voice assistant. You engage in natural, conversational dialogue with users.
|
| 20 |
-
|
| 21 |
-
Guidelines:
|
| 22 |
-
- Keep responses concise and natural for voice interaction
|
| 23 |
-
- Be friendly and engaging
|
| 24 |
-
- Ask clarifying questions when needed
|
| 25 |
-
- Acknowledge what the user says before responding
|
| 26 |
-
- Keep your responses focused and to the point (2-3 sentences typically)
|
| 27 |
-
"""
|
| 28 |
-
|
| 29 |
-
SYSTEM_PROMPTS = {
|
| 30 |
-
PromptVersion.V1: PromptTemplate(SYSTEM_PROMPT_V1, PromptVersion.V1),
|
| 31 |
-
}
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def get_system_prompt(version: Optional[PromptVersion] = None) -> str:
|
| 35 |
-
version = version or PromptVersion.DEFAULT
|
| 36 |
-
return SYSTEM_PROMPTS[version].render()
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
def get_custom_prompt(template: str, **context: Any) -> str:
|
| 40 |
-
prompt = PromptTemplate(template)
|
| 41 |
-
return prompt.render(**context)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/agent/state.py
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
from typing import Any, TypedDict
|
| 2 |
-
|
| 3 |
-
from langchain_core.messages import BaseMessage
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
class ConversationState(TypedDict):
|
| 7 |
-
messages: list[BaseMessage]
|
| 8 |
-
current_transcript: str
|
| 9 |
-
context: dict[str, Any]
|
| 10 |
-
turn_active: bool
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/api/__init__.py
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
"""FastAPI application and WebSocket handlers."""
|
| 2 |
-
|
| 3 |
-
from src.api.main import app
|
| 4 |
-
|
| 5 |
-
__all__ = ["app"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/api/main.py
DELETED
|
@@ -1,69 +0,0 @@
|
|
| 1 |
-
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
| 2 |
-
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
-
|
| 4 |
-
from src.api.websocket import VoiceWebSocketHandler
|
| 5 |
-
from src.core.logger import logger
|
| 6 |
-
from src.core.settings import settings
|
| 7 |
-
|
| 8 |
-
app = FastAPI(
|
| 9 |
-
title="Open Voice Agent API",
|
| 10 |
-
description="Real-time voice conversation agent with WebSocket support",
|
| 11 |
-
version="0.1.0",
|
| 12 |
-
)
|
| 13 |
-
|
| 14 |
-
app.add_middleware(
|
| 15 |
-
CORSMiddleware,
|
| 16 |
-
allow_origins=settings.api.API_CORS_ORIGINS,
|
| 17 |
-
allow_credentials=True,
|
| 18 |
-
allow_methods=["*"],
|
| 19 |
-
allow_headers=["*"],
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
@app.get("/health")
|
| 24 |
-
async def health_check():
|
| 25 |
-
return {
|
| 26 |
-
"status": "healthy",
|
| 27 |
-
"service": "open-voice-agent",
|
| 28 |
-
"version": "0.1.0",
|
| 29 |
-
}
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
@app.websocket("/ws/voice")
|
| 33 |
-
async def websocket_voice_endpoint(websocket: WebSocket):
|
| 34 |
-
handler = VoiceWebSocketHandler(websocket)
|
| 35 |
-
|
| 36 |
-
try:
|
| 37 |
-
await handler.connect()
|
| 38 |
-
await handler.handle_conversation()
|
| 39 |
-
except WebSocketDisconnect:
|
| 40 |
-
logger.info("Client disconnected")
|
| 41 |
-
except Exception as e:
|
| 42 |
-
logger.error(f"WebSocket error: {e}", exc_info=True)
|
| 43 |
-
await handler.send_error(str(e))
|
| 44 |
-
finally:
|
| 45 |
-
await handler.disconnect()
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
@app.on_event("startup")
|
| 49 |
-
async def startup_event():
|
| 50 |
-
logger.info("Starting Open Voice Agent API...")
|
| 51 |
-
logger.info(f"Voice provider: {settings.voice.VOICE_PROVIDER}")
|
| 52 |
-
logger.info(f"LLM provider: {settings.llm.LLM_PROVIDER}")
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
@app.on_event("shutdown")
|
| 56 |
-
async def shutdown_event():
|
| 57 |
-
logger.info("Shutting down Open Voice Agent API...")
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
if __name__ == "__main__":
|
| 61 |
-
import uvicorn
|
| 62 |
-
|
| 63 |
-
uvicorn.run(
|
| 64 |
-
"src.api.main:app",
|
| 65 |
-
host=settings.api.API_HOST,
|
| 66 |
-
port=settings.api.API_PORT,
|
| 67 |
-
workers=settings.api.API_WORKERS,
|
| 68 |
-
reload=True,
|
| 69 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/core/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.core.settings import settings
|
| 2 |
+
from src.core.logger import logger
|
| 3 |
+
|
| 4 |
+
__all__ = ["settings", "logger"]
|
src/core/settings.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import json
|
| 2 |
from pathlib import Path
|
| 3 |
-
from typing import
|
| 4 |
|
| 5 |
from pydantic import Field, ValidationError
|
| 6 |
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
@@ -15,10 +15,10 @@ load_dotenv(ENV_FILE, override=True)
|
|
| 15 |
logger.info(f"Loaded environment from: {ENV_FILE}")
|
| 16 |
|
| 17 |
|
| 18 |
-
def mask_sensitive_data(data: dict
|
| 19 |
masked = {}
|
| 20 |
sensitive_keys = ["key", "token", "secret", "password"]
|
| 21 |
-
|
| 22 |
for key, value in data.items():
|
| 23 |
if isinstance(value, dict):
|
| 24 |
masked[key] = mask_sensitive_data(value)
|
|
@@ -31,7 +31,7 @@ def mask_sensitive_data(data: dict[str, Any]) -> dict[str, Any]:
|
|
| 31 |
masked[key] = f"{value[:4]}...{value[-4:]}"
|
| 32 |
else:
|
| 33 |
masked[key] = value
|
| 34 |
-
|
| 35 |
return masked
|
| 36 |
|
| 37 |
|
|
@@ -46,80 +46,48 @@ class CoreSettings(BaseSettings):
|
|
| 46 |
|
| 47 |
|
| 48 |
class VoiceSettings(CoreSettings):
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
NVIDIA_VOICE_NAME: str = Field(default="Magpie-Multilingual.EN-US.Aria")
|
| 53 |
-
NVIDIA_TTS_MODEL: str = Field(default="magpie-tts-multilingual")
|
| 54 |
-
NVIDIA_TTS_ENDPOINT: str = Field(default="")
|
| 55 |
-
|
| 56 |
-
SAMPLE_RATE_OUTPUT: int = Field(default=48000, gt=0)
|
| 57 |
-
CHUNK_DURATION_MS: int = Field(default=80, gt=0)
|
| 58 |
-
|
| 59 |
-
VAD_THRESHOLD: float = Field(default=0.5, ge=0.0, le=1.0)
|
| 60 |
-
VAD_HORIZON_INDEX: int = Field(default=2, ge=0)
|
| 61 |
-
|
| 62 |
-
# STT (Speech-to-Text) Settings
|
| 63 |
-
STT_PROVIDER: str = Field(
|
| 64 |
-
default="moonshine",
|
| 65 |
-
description="STT provider (moonshine, assemblyai, etc)"
|
| 66 |
-
)
|
| 67 |
-
MOONSHINE_MODEL_SIZE: str = Field(
|
| 68 |
-
default="small",
|
| 69 |
-
description="Moonshine model size: tiny, base, or small"
|
| 70 |
)
|
| 71 |
-
|
| 72 |
-
# TTS (Text-to-Speech) Settings - Pocket TTS
|
| 73 |
POCKET_TTS_VOICE: str = Field(
|
| 74 |
default="alba",
|
| 75 |
-
description="Default voice (alba, marius, javert, jean, fantine, cosette, eponine, azelma) or path to audio file"
|
| 76 |
)
|
|
|
|
| 77 |
POCKET_TTS_TEMPERATURE: float = Field(
|
| 78 |
default=0.7,
|
| 79 |
ge=0.0,
|
| 80 |
le=2.0,
|
| 81 |
-
description="Sampling temperature for generation"
|
| 82 |
)
|
| 83 |
POCKET_TTS_LSD_DECODE_STEPS: int = Field(
|
| 84 |
default=1,
|
| 85 |
ge=1,
|
| 86 |
-
description="LSD decoding steps (higher = better quality, slower)"
|
| 87 |
)
|
| 88 |
|
| 89 |
|
| 90 |
class LLMSettings(CoreSettings):
|
| 91 |
NVIDIA_API_KEY: Optional[str] = Field(default=None)
|
| 92 |
NVIDIA_MODEL: str = Field(default="meta/llama-3.1-8b-instruct")
|
| 93 |
-
NVIDIA_BASE_URL: str = Field(default="https://integrate.api.nvidia.com/v1")
|
| 94 |
-
|
| 95 |
-
HF_TOKEN: Optional[str] = Field(default=None)
|
| 96 |
|
| 97 |
LLM_TEMPERATURE: float = Field(default=0.7, ge=0.0, le=2.0)
|
| 98 |
LLM_MAX_TOKENS: int = Field(default=1024, gt=0)
|
| 99 |
|
| 100 |
|
| 101 |
-
class APISettings(CoreSettings):
|
| 102 |
-
API_HOST: str = Field(default="0.0.0.0")
|
| 103 |
-
API_PORT: int = Field(default=8000, gt=0, lt=65536)
|
| 104 |
-
API_WORKERS: int = Field(default=1, gt=0)
|
| 105 |
-
API_CORS_ORIGINS: list[str] = Field(
|
| 106 |
-
default=["http://localhost:8501", "http://localhost:3000"]
|
| 107 |
-
)
|
| 108 |
-
|
| 109 |
-
|
| 110 |
class Settings(CoreSettings):
|
| 111 |
voice: VoiceSettings = Field(default_factory=VoiceSettings)
|
| 112 |
llm: LLMSettings = Field(default_factory=LLMSettings)
|
| 113 |
-
api: APISettings = Field(default_factory=APISettings)
|
| 114 |
|
| 115 |
|
| 116 |
try:
|
| 117 |
settings = Settings()
|
| 118 |
-
|
| 119 |
settings_dict = settings.model_dump()
|
| 120 |
masked_settings = mask_sensitive_data(settings_dict)
|
| 121 |
logger.info(f"Settings loaded: {json.dumps(masked_settings, indent=2)}")
|
| 122 |
-
|
| 123 |
except ValidationError as e:
|
| 124 |
logger.exception(f"Error validating settings: {e.json()}")
|
| 125 |
raise
|
|
|
|
| 1 |
import json
|
| 2 |
from pathlib import Path
|
| 3 |
+
from typing import Optional
|
| 4 |
|
| 5 |
from pydantic import Field, ValidationError
|
| 6 |
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
|
|
| 15 |
logger.info(f"Loaded environment from: {ENV_FILE}")
|
| 16 |
|
| 17 |
|
| 18 |
+
def mask_sensitive_data(data: dict) -> dict:
|
| 19 |
masked = {}
|
| 20 |
sensitive_keys = ["key", "token", "secret", "password"]
|
| 21 |
+
|
| 22 |
for key, value in data.items():
|
| 23 |
if isinstance(value, dict):
|
| 24 |
masked[key] = mask_sensitive_data(value)
|
|
|
|
| 31 |
masked[key] = f"{value[:4]}...{value[-4:]}"
|
| 32 |
else:
|
| 33 |
masked[key] = value
|
| 34 |
+
|
| 35 |
return masked
|
| 36 |
|
| 37 |
|
|
|
|
| 46 |
|
| 47 |
|
| 48 |
class VoiceSettings(CoreSettings):
|
| 49 |
+
MOONSHINE_MODEL_ID: str = Field(
|
| 50 |
+
default="usefulsensors/moonshine-streaming-medium",
|
| 51 |
+
description="Moonshine model size: tiny, base, or small",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
)
|
|
|
|
|
|
|
| 53 |
POCKET_TTS_VOICE: str = Field(
|
| 54 |
default="alba",
|
| 55 |
+
description="Default voice (alba, marius, javert, jean, fantine, cosette, eponine, azelma) or path to audio file",
|
| 56 |
)
|
| 57 |
+
SAMPLE_RATE_OUTPUT: int = Field(default=48000, gt=0)
|
| 58 |
POCKET_TTS_TEMPERATURE: float = Field(
|
| 59 |
default=0.7,
|
| 60 |
ge=0.0,
|
| 61 |
le=2.0,
|
| 62 |
+
description="Sampling temperature for generation",
|
| 63 |
)
|
| 64 |
POCKET_TTS_LSD_DECODE_STEPS: int = Field(
|
| 65 |
default=1,
|
| 66 |
ge=1,
|
| 67 |
+
description="LSD decoding steps (higher = better quality, slower)",
|
| 68 |
)
|
| 69 |
|
| 70 |
|
| 71 |
class LLMSettings(CoreSettings):
|
| 72 |
NVIDIA_API_KEY: Optional[str] = Field(default=None)
|
| 73 |
NVIDIA_MODEL: str = Field(default="meta/llama-3.1-8b-instruct")
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
LLM_TEMPERATURE: float = Field(default=0.7, ge=0.0, le=2.0)
|
| 76 |
LLM_MAX_TOKENS: int = Field(default=1024, gt=0)
|
| 77 |
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
class Settings(CoreSettings):
|
| 80 |
voice: VoiceSettings = Field(default_factory=VoiceSettings)
|
| 81 |
llm: LLMSettings = Field(default_factory=LLMSettings)
|
|
|
|
| 82 |
|
| 83 |
|
| 84 |
try:
|
| 85 |
settings = Settings()
|
| 86 |
+
|
| 87 |
settings_dict = settings.model_dump()
|
| 88 |
masked_settings = mask_sensitive_data(settings_dict)
|
| 89 |
logger.info(f"Settings loaded: {json.dumps(masked_settings, indent=2)}")
|
| 90 |
+
|
| 91 |
except ValidationError as e:
|
| 92 |
logger.exception(f"Error validating settings: {e.json()}")
|
| 93 |
raise
|
src/models/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
"""Models package for voice agents and data structures."""
|
|
|
|
|
|
src/models/voice/__init__.py
DELETED
|
@@ -1,18 +0,0 @@
|
|
| 1 |
-
"""Voice provider interfaces and implementations."""
|
| 2 |
-
|
| 3 |
-
from src.models.voice.base import BaseVoiceProvider, VoiceProviderConfig
|
| 4 |
-
from src.models.voice.types import (
|
| 5 |
-
AudioFormat,
|
| 6 |
-
VADInfo,
|
| 7 |
-
VoiceMessage,
|
| 8 |
-
TranscriptionResult,
|
| 9 |
-
)
|
| 10 |
-
|
| 11 |
-
__all__ = [
|
| 12 |
-
"BaseVoiceProvider",
|
| 13 |
-
"VoiceProviderConfig",
|
| 14 |
-
"AudioFormat",
|
| 15 |
-
"VADInfo",
|
| 16 |
-
"VoiceMessage",
|
| 17 |
-
"TranscriptionResult",
|
| 18 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/voice/base.py
DELETED
|
@@ -1,53 +0,0 @@
|
|
| 1 |
-
from abc import ABC, abstractmethod
|
| 2 |
-
from typing import AsyncIterator, Optional
|
| 3 |
-
|
| 4 |
-
from pydantic import BaseModel
|
| 5 |
-
|
| 6 |
-
from src.models.voice.types import TranscriptionResult, VADInfo
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class VoiceProviderConfig(BaseModel):
|
| 10 |
-
provider_name: str
|
| 11 |
-
sample_rate_input: int = 24000
|
| 12 |
-
sample_rate_output: int = 48000
|
| 13 |
-
chunk_duration_ms: int = 80
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class BaseVoiceProvider(ABC):
|
| 17 |
-
def __init__(self, config: VoiceProviderConfig):
|
| 18 |
-
self.config = config
|
| 19 |
-
self._connected = False
|
| 20 |
-
|
| 21 |
-
@abstractmethod
|
| 22 |
-
async def connect(self) -> None:
|
| 23 |
-
pass
|
| 24 |
-
|
| 25 |
-
@abstractmethod
|
| 26 |
-
async def disconnect(self) -> None:
|
| 27 |
-
pass
|
| 28 |
-
|
| 29 |
-
@abstractmethod
|
| 30 |
-
async def text_to_speech(
|
| 31 |
-
self, text: str, stream: bool = True
|
| 32 |
-
) -> AsyncIterator[bytes]:
|
| 33 |
-
pass
|
| 34 |
-
|
| 35 |
-
async def speech_to_text(
|
| 36 |
-
self, audio_stream: AsyncIterator[bytes]
|
| 37 |
-
) -> AsyncIterator[TranscriptionResult]:
|
| 38 |
-
raise NotImplementedError("Speech-to-text not supported by this provider")
|
| 39 |
-
|
| 40 |
-
@abstractmethod
|
| 41 |
-
async def get_vad_info(self) -> Optional[VADInfo]:
|
| 42 |
-
pass
|
| 43 |
-
|
| 44 |
-
@property
|
| 45 |
-
def is_connected(self) -> bool:
|
| 46 |
-
return self._connected
|
| 47 |
-
|
| 48 |
-
async def __aenter__(self):
|
| 49 |
-
await self.connect()
|
| 50 |
-
return self
|
| 51 |
-
|
| 52 |
-
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
| 53 |
-
await self.disconnect()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/voice/types.py
DELETED
|
@@ -1,43 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
from enum import Enum
|
| 3 |
-
from typing import Optional
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
class AudioFormat(str, Enum):
|
| 7 |
-
PCM = "pcm"
|
| 8 |
-
WAV = "wav"
|
| 9 |
-
OPUS = "opus"
|
| 10 |
-
ULAW_8000 = "ulaw_8000"
|
| 11 |
-
ALAW_8000 = "alaw_8000"
|
| 12 |
-
PCM_16000 = "pcm_16000"
|
| 13 |
-
PCM_24000 = "pcm_24000"
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
@dataclass
|
| 17 |
-
class VoiceMessage:
|
| 18 |
-
type: str
|
| 19 |
-
content: str | bytes
|
| 20 |
-
timestamp: Optional[float] = None
|
| 21 |
-
metadata: Optional[dict] = None
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
@dataclass
|
| 25 |
-
class VADInfo:
|
| 26 |
-
inactivity_prob: float
|
| 27 |
-
horizon_s: float
|
| 28 |
-
step_idx: int
|
| 29 |
-
total_duration_s: float
|
| 30 |
-
|
| 31 |
-
@property
|
| 32 |
-
def is_turn_complete(self, threshold: float = 0.5) -> bool:
|
| 33 |
-
return self.inactivity_prob > threshold
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
@dataclass
|
| 37 |
-
class TranscriptionResult:
|
| 38 |
-
text: str
|
| 39 |
-
start_s: float
|
| 40 |
-
stop_s: Optional[float] = None
|
| 41 |
-
is_final: bool = False
|
| 42 |
-
confidence: Optional[float] = None
|
| 43 |
-
stream_id: Optional[int] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/plugins/pocket_tts/tts.py
CHANGED
|
@@ -199,7 +199,6 @@ class PocketSynthesizeStream(tts.SynthesizeStream):
|
|
| 199 |
):
|
| 200 |
audio_bytes = self._tensor_to_pcm_bytes(audio_chunk)
|
| 201 |
chunks.append(audio_bytes)
|
| 202 |
-
logger.debug(f"Generated chunk: {len(audio_bytes)} bytes")
|
| 203 |
logger.info(f"Total chunks generated: {len(chunks)}")
|
| 204 |
return chunks
|
| 205 |
|
|
@@ -210,10 +209,6 @@ class PocketSynthesizeStream(tts.SynthesizeStream):
|
|
| 210 |
|
| 211 |
# Push raw PCM bytes to the emitter
|
| 212 |
for i, chunk in enumerate(audio_chunks):
|
| 213 |
-
num_samples = len(chunk) // 2 # int16 = 2 bytes per sample
|
| 214 |
-
logger.debug(
|
| 215 |
-
f"Pushing chunk {i+1}/{len(audio_chunks)}: {len(chunk)} bytes ({num_samples} samples @ {self._tts._output_sample_rate}Hz)"
|
| 216 |
-
)
|
| 217 |
output_emitter.push(chunk)
|
| 218 |
|
| 219 |
logger.info(f"Successfully pushed all {len(audio_chunks)} chunks")
|
|
|
|
| 199 |
):
|
| 200 |
audio_bytes = self._tensor_to_pcm_bytes(audio_chunk)
|
| 201 |
chunks.append(audio_bytes)
|
|
|
|
| 202 |
logger.info(f"Total chunks generated: {len(chunks)}")
|
| 203 |
return chunks
|
| 204 |
|
|
|
|
| 209 |
|
| 210 |
# Push raw PCM bytes to the emitter
|
| 211 |
for i, chunk in enumerate(audio_chunks):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
output_emitter.push(chunk)
|
| 213 |
|
| 214 |
logger.info(f"Successfully pushed all {len(audio_chunks)} chunks")
|
src/streamlit_app.py
DELETED
|
@@ -1,288 +0,0 @@
|
|
| 1 |
-
"""Streamlit UI for voice agent with audio I/O."""
|
| 2 |
-
|
| 3 |
-
import asyncio
|
| 4 |
-
import base64
|
| 5 |
-
import json
|
| 6 |
-
from threading import Thread
|
| 7 |
-
from queue import Queue
|
| 8 |
-
from typing import Optional
|
| 9 |
-
|
| 10 |
-
import streamlit as st
|
| 11 |
-
import websockets
|
| 12 |
-
|
| 13 |
-
from src.core.logger import logger
|
| 14 |
-
|
| 15 |
-
# Page configuration
|
| 16 |
-
st.set_page_config(
|
| 17 |
-
page_title="Open Voice Agent",
|
| 18 |
-
page_icon="🎙️",
|
| 19 |
-
layout="wide",
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
-
# Initialize session state
|
| 23 |
-
if "messages" not in st.session_state:
|
| 24 |
-
st.session_state.messages = []
|
| 25 |
-
if "ws_connected" not in st.session_state:
|
| 26 |
-
st.session_state.ws_connected = False
|
| 27 |
-
if "current_transcript" not in st.session_state:
|
| 28 |
-
st.session_state.current_transcript = ""
|
| 29 |
-
if "processing" not in st.session_state:
|
| 30 |
-
st.session_state.processing = False
|
| 31 |
-
if "response_queue" not in st.session_state:
|
| 32 |
-
st.session_state.response_queue = Queue()
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def send_audio_to_websocket(ws_url: str, audio_data: str, response_queue: Queue):
|
| 36 |
-
"""Send audio to WebSocket and receive responses in background thread.
|
| 37 |
-
|
| 38 |
-
Args:
|
| 39 |
-
ws_url: WebSocket URL
|
| 40 |
-
audio_data: Base64 encoded audio data
|
| 41 |
-
response_queue: Queue to put responses
|
| 42 |
-
"""
|
| 43 |
-
async def communicate():
|
| 44 |
-
try:
|
| 45 |
-
async with websockets.connect(ws_url) as websocket:
|
| 46 |
-
# Send audio message
|
| 47 |
-
await websocket.send(json.dumps({
|
| 48 |
-
"type": "audio",
|
| 49 |
-
"data": audio_data
|
| 50 |
-
}))
|
| 51 |
-
logger.info("Audio sent to WebSocket")
|
| 52 |
-
|
| 53 |
-
# Send end turn signal
|
| 54 |
-
await websocket.send(json.dumps({"type": "end_turn"}))
|
| 55 |
-
logger.info("End turn signal sent")
|
| 56 |
-
|
| 57 |
-
# Receive responses
|
| 58 |
-
while True:
|
| 59 |
-
try:
|
| 60 |
-
message = await asyncio.wait_for(websocket.recv(), timeout=30.0)
|
| 61 |
-
response = json.loads(message)
|
| 62 |
-
response_queue.put(response)
|
| 63 |
-
|
| 64 |
-
# Stop if response is complete
|
| 65 |
-
if response.get("type") == "response_complete":
|
| 66 |
-
logger.info("Response complete received")
|
| 67 |
-
break
|
| 68 |
-
|
| 69 |
-
except asyncio.TimeoutError:
|
| 70 |
-
logger.warning("WebSocket receive timeout")
|
| 71 |
-
break
|
| 72 |
-
except Exception as e:
|
| 73 |
-
logger.error(f"Error receiving message: {e}")
|
| 74 |
-
response_queue.put({"type": "error", "message": str(e)})
|
| 75 |
-
break
|
| 76 |
-
|
| 77 |
-
except Exception as e:
|
| 78 |
-
logger.error(f"WebSocket error: {e}")
|
| 79 |
-
response_queue.put({"type": "error", "message": str(e)})
|
| 80 |
-
|
| 81 |
-
# Run async function
|
| 82 |
-
asyncio.run(communicate())
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
# Title and description
|
| 86 |
-
st.title("🎙️ Open Voice Agent")
|
| 87 |
-
st.markdown(
|
| 88 |
-
"""
|
| 89 |
-
Voice conversation with AI using NVIDIA API for speech synthesis
|
| 90 |
-
and LangGraph for conversation management.
|
| 91 |
-
"""
|
| 92 |
-
)
|
| 93 |
-
|
| 94 |
-
# Sidebar configuration
|
| 95 |
-
with st.sidebar:
|
| 96 |
-
st.header("⚙️ Settings")
|
| 97 |
-
|
| 98 |
-
# WebSocket connection settings
|
| 99 |
-
ws_url = st.text_input(
|
| 100 |
-
"WebSocket URL",
|
| 101 |
-
value="ws://localhost:8000/ws/voice",
|
| 102 |
-
help="URL of the FastAPI WebSocket server",
|
| 103 |
-
)
|
| 104 |
-
|
| 105 |
-
# Connection status
|
| 106 |
-
status_color = "🟢" if st.session_state.ws_connected else "🔴"
|
| 107 |
-
st.markdown(f"**Status:** {status_color} {'Connected' if st.session_state.ws_connected else 'Disconnected'}")
|
| 108 |
-
st.markdown("**Voice Provider:** NVIDIA API (configured on the server)")
|
| 109 |
-
|
| 110 |
-
# Test connection button
|
| 111 |
-
if st.button("🔍 Test Connection"):
|
| 112 |
-
try:
|
| 113 |
-
import requests
|
| 114 |
-
response = requests.get("http://localhost:8000/health", timeout=2)
|
| 115 |
-
if response.status_code == 200:
|
| 116 |
-
st.success("✅ Server is running!")
|
| 117 |
-
st.session_state.ws_connected = True
|
| 118 |
-
else:
|
| 119 |
-
st.error("❌ Server not responding correctly")
|
| 120 |
-
st.session_state.ws_connected = False
|
| 121 |
-
except Exception as e:
|
| 122 |
-
st.error(f"❌ Cannot connect to server: {e}")
|
| 123 |
-
st.session_state.ws_connected = False
|
| 124 |
-
|
| 125 |
-
st.divider()
|
| 126 |
-
|
| 127 |
-
if st.button("🗑️ Clear Conversation"):
|
| 128 |
-
st.session_state.messages = []
|
| 129 |
-
st.session_state.current_transcript = ""
|
| 130 |
-
st.rerun()
|
| 131 |
-
|
| 132 |
-
# Download conversation
|
| 133 |
-
if st.session_state.messages:
|
| 134 |
-
transcript = "\n\n".join(
|
| 135 |
-
[f"{msg['role'].upper()}: {msg['content']}" for msg in st.session_state.messages]
|
| 136 |
-
)
|
| 137 |
-
st.download_button(
|
| 138 |
-
label="📥 Download Transcript",
|
| 139 |
-
data=transcript,
|
| 140 |
-
file_name="conversation_transcript.txt",
|
| 141 |
-
mime="text/plain",
|
| 142 |
-
)
|
| 143 |
-
|
| 144 |
-
# Main content area
|
| 145 |
-
col1, col2 = st.columns([2, 1])
|
| 146 |
-
|
| 147 |
-
with col1:
|
| 148 |
-
st.subheader("💬 Conversation")
|
| 149 |
-
|
| 150 |
-
# Chat container
|
| 151 |
-
chat_container = st.container(height=500)
|
| 152 |
-
|
| 153 |
-
with chat_container:
|
| 154 |
-
# Display messages
|
| 155 |
-
for message in st.session_state.messages:
|
| 156 |
-
with st.chat_message(message["role"]):
|
| 157 |
-
st.write(message["content"])
|
| 158 |
-
|
| 159 |
-
# Display current transcript being captured
|
| 160 |
-
if st.session_state.current_transcript:
|
| 161 |
-
with st.chat_message("user"):
|
| 162 |
-
st.write(f"*{st.session_state.current_transcript}...*")
|
| 163 |
-
|
| 164 |
-
with col2:
|
| 165 |
-
st.subheader("🎤 Audio Controls")
|
| 166 |
-
|
| 167 |
-
# Check if server is running
|
| 168 |
-
if not st.session_state.ws_connected:
|
| 169 |
-
st.warning("⚠️ Please test connection first")
|
| 170 |
-
|
| 171 |
-
# Audio recorder using built-in st.audio_input
|
| 172 |
-
st.markdown("**Record your message:**")
|
| 173 |
-
|
| 174 |
-
audio_data = st.audio_input(
|
| 175 |
-
"Click to start/stop recording",
|
| 176 |
-
key="audio_input",
|
| 177 |
-
help="Click to start recording, click again to stop"
|
| 178 |
-
)
|
| 179 |
-
|
| 180 |
-
if audio_data is not None:
|
| 181 |
-
st.success("✓ Audio recorded!")
|
| 182 |
-
|
| 183 |
-
# Show audio player
|
| 184 |
-
st.audio(audio_data)
|
| 185 |
-
|
| 186 |
-
# Send button
|
| 187 |
-
if st.button("📤 Send Audio", disabled=st.session_state.processing):
|
| 188 |
-
if not st.session_state.ws_connected:
|
| 189 |
-
st.error("❌ Not connected to server")
|
| 190 |
-
else:
|
| 191 |
-
st.session_state.processing = True
|
| 192 |
-
|
| 193 |
-
# Get audio bytes
|
| 194 |
-
audio_bytes = audio_data.getvalue()
|
| 195 |
-
|
| 196 |
-
# Encode to base64
|
| 197 |
-
encoded_audio = base64.b64encode(audio_bytes).decode('utf-8')
|
| 198 |
-
|
| 199 |
-
# Show processing indicator
|
| 200 |
-
with st.spinner("🔊 Processing audio..."):
|
| 201 |
-
# Start WebSocket communication in background thread
|
| 202 |
-
thread = Thread(
|
| 203 |
-
target=send_audio_to_websocket,
|
| 204 |
-
args=(ws_url, encoded_audio, st.session_state.response_queue),
|
| 205 |
-
daemon=True
|
| 206 |
-
)
|
| 207 |
-
thread.start()
|
| 208 |
-
|
| 209 |
-
# Wait for thread to complete (with timeout)
|
| 210 |
-
thread.join(timeout=30)
|
| 211 |
-
|
| 212 |
-
# Process responses from queue
|
| 213 |
-
transcript_text = ""
|
| 214 |
-
response_text = ""
|
| 215 |
-
audio_chunks = []
|
| 216 |
-
|
| 217 |
-
while not st.session_state.response_queue.empty():
|
| 218 |
-
response = st.session_state.response_queue.get()
|
| 219 |
-
msg_type = response.get("type")
|
| 220 |
-
|
| 221 |
-
if msg_type == "transcript":
|
| 222 |
-
transcript_text += " " + response.get("text", "")
|
| 223 |
-
elif msg_type == "response_text":
|
| 224 |
-
response_text = response.get("text", "")
|
| 225 |
-
elif msg_type == "audio":
|
| 226 |
-
audio_chunks.append(response.get("data", ""))
|
| 227 |
-
elif msg_type == "error":
|
| 228 |
-
st.error(f"Error: {response.get('message')}")
|
| 229 |
-
|
| 230 |
-
# Add messages to conversation
|
| 231 |
-
if transcript_text.strip():
|
| 232 |
-
st.session_state.messages.append({
|
| 233 |
-
"role": "user",
|
| 234 |
-
"content": transcript_text.strip()
|
| 235 |
-
})
|
| 236 |
-
|
| 237 |
-
if response_text:
|
| 238 |
-
st.session_state.messages.append({
|
| 239 |
-
"role": "assistant",
|
| 240 |
-
"content": response_text
|
| 241 |
-
})
|
| 242 |
-
|
| 243 |
-
st.session_state.processing = False
|
| 244 |
-
st.success("✅ Processing complete!")
|
| 245 |
-
st.rerun()
|
| 246 |
-
|
| 247 |
-
# Processing indicator
|
| 248 |
-
if st.session_state.processing:
|
| 249 |
-
st.info("⏳ Processing your message...")
|
| 250 |
-
|
| 251 |
-
# Instructions
|
| 252 |
-
with st.expander("📖 How to Use"):
|
| 253 |
-
st.markdown(
|
| 254 |
-
"""
|
| 255 |
-
1. **Test connection** first to ensure server is running
|
| 256 |
-
2. **Click** the audio input to start recording
|
| 257 |
-
3. **Speak** your message clearly
|
| 258 |
-
4. **Click again** to stop recording
|
| 259 |
-
5. **Review** the recorded audio (optional)
|
| 260 |
-
6. **Send** the audio for processing
|
| 261 |
-
7. Wait for the AI response (text will appear in chat)
|
| 262 |
-
|
| 263 |
-
**Tips:**
|
| 264 |
-
- Speak clearly and at a normal pace
|
| 265 |
-
- Wait for the response before recording again
|
| 266 |
-
- Keep messages concise for better results
|
| 267 |
-
- Use headphones to avoid echo
|
| 268 |
-
"""
|
| 269 |
-
)
|
| 270 |
-
|
| 271 |
-
# System info
|
| 272 |
-
with st.expander("ℹ️ System Info"):
|
| 273 |
-
st.markdown(f"""
|
| 274 |
-
**Voice Provider:** NVIDIA API
|
| 275 |
-
**WebSocket:** `{ws_url}`
|
| 276 |
-
**Messages:** {len(st.session_state.messages)}
|
| 277 |
-
""")
|
| 278 |
-
|
| 279 |
-
# Footer
|
| 280 |
-
st.markdown("---")
|
| 281 |
-
st.markdown(
|
| 282 |
-
"""
|
| 283 |
-
<div style='text-align: center'>
|
| 284 |
-
<small>Powered by NVIDIA API (Voice) + LangGraph (Conversations) + Streamlit (UI)</small>
|
| 285 |
-
</div>
|
| 286 |
-
""",
|
| 287 |
-
unsafe_allow_html=True,
|
| 288 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
testing/asr_moonshine.py
DELETED
|
@@ -1,48 +0,0 @@
|
|
| 1 |
-
import io
|
| 2 |
-
import math
|
| 3 |
-
import numpy as np
|
| 4 |
-
import soundfile as sf
|
| 5 |
-
from scipy.signal import resample_poly
|
| 6 |
-
import torch
|
| 7 |
-
from transformers import AutoProcessor, MoonshineStreamingForConditionalGeneration
|
| 8 |
-
|
| 9 |
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 10 |
-
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 11 |
-
|
| 12 |
-
model_id = "usefulsensors/moonshine-streaming-small"
|
| 13 |
-
|
| 14 |
-
model = MoonshineStreamingForConditionalGeneration.from_pretrained(model_id).to(
|
| 15 |
-
device, torch_dtype
|
| 16 |
-
)
|
| 17 |
-
processor = AutoProcessor.from_pretrained(model_id)
|
| 18 |
-
|
| 19 |
-
# Read audio file
|
| 20 |
-
with open("dev/kokoro_tts.wav", "rb") as f:
|
| 21 |
-
audio_bytes = f.read()
|
| 22 |
-
|
| 23 |
-
# Load audio using soundfile
|
| 24 |
-
audio_np, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32")
|
| 25 |
-
|
| 26 |
-
if audio_np.ndim > 1:
|
| 27 |
-
audio_np = np.mean(audio_np, axis=1)
|
| 28 |
-
|
| 29 |
-
if sr != 16000:
|
| 30 |
-
ratio_gcd = math.gcd(sr, 16000)
|
| 31 |
-
up = 16000 // ratio_gcd
|
| 32 |
-
down = sr // ratio_gcd
|
| 33 |
-
print(f"Resampling from {sr}Hz to 16000Hz")
|
| 34 |
-
audio_np = resample_poly(audio_np, up=up, down=down)
|
| 35 |
-
|
| 36 |
-
inputs = processor(
|
| 37 |
-
audio_np,
|
| 38 |
-
return_tensors="pt",
|
| 39 |
-
sampling_rate=16000,
|
| 40 |
-
).to(device, torch_dtype)
|
| 41 |
-
|
| 42 |
-
token_limit_factor = 6.5 / 16000
|
| 43 |
-
max_length = int((inputs.attention_mask.sum() * token_limit_factor).max().item())
|
| 44 |
-
|
| 45 |
-
generated_ids = model.generate(**inputs, max_length=max_length)
|
| 46 |
-
transcription = processor.decode(generated_ids[0], skip_special_tokens=True)
|
| 47 |
-
|
| 48 |
-
print(f"Transcription: {transcription}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
testing/nvidia_.py
DELETED
|
@@ -1,4 +0,0 @@
|
|
| 1 |
-
import nemo.collections.asr as nemo_asr
|
| 2 |
-
asr_model = nemo_asr.models.ASRModel.from_pretrained("nvidia/parakeet-tdt-0.6b-v2")
|
| 3 |
-
|
| 4 |
-
transcriptions = asr_model.transcribe(["dev/kokoro_tts.wav"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
testing/pocket_tts_test.py
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
from pocket_tts import TTSModel
|
| 2 |
-
import scipy.io.wavfile
|
| 3 |
-
|
| 4 |
-
tts_model = TTSModel.load_model()
|
| 5 |
-
voice_state = tts_model.get_state_for_audio_prompt(
|
| 6 |
-
"alba" # One of the pre-made voices, see above
|
| 7 |
-
# You can also use any voice file you have locally or from Hugging Face:
|
| 8 |
-
# "./some_audio.wav"
|
| 9 |
-
# or "hf://kyutai/tts-voices/expresso/ex01-ex02_default_001_channel2_198s.wav"
|
| 10 |
-
)
|
| 11 |
-
audio = tts_model.generate_audio(voice_state, "Hello world, this is a test.")
|
| 12 |
-
# Audio is a 1D torch tensor containing PCM data.
|
| 13 |
-
scipy.io.wavfile.write("dev/pocket_tts.wav", tts_model.sample_rate, audio.numpy())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
uv.lock
CHANGED
|
@@ -2081,7 +2081,7 @@ wheels = [
|
|
| 2081 |
[[package]]
|
| 2082 |
name = "open-voice-agent"
|
| 2083 |
version = "0.1.0"
|
| 2084 |
-
source = {
|
| 2085 |
dependencies = [
|
| 2086 |
{ name = "langgraph" },
|
| 2087 |
{ name = "lhotse" },
|
|
|
|
| 2081 |
[[package]]
|
| 2082 |
name = "open-voice-agent"
|
| 2083 |
version = "0.1.0"
|
| 2084 |
+
source = { editable = "." }
|
| 2085 |
dependencies = [
|
| 2086 |
{ name = "langgraph" },
|
| 2087 |
{ name = "lhotse" },
|