Spaces:
Runtime error
Runtime error
Aakash jammula commited on
Commit ·
88edcbb
1
Parent(s): 241379e
init
Browse files
app.py
CHANGED
|
@@ -1,14 +1,15 @@
|
|
| 1 |
import os
|
| 2 |
import torch
|
| 3 |
from kokoro import KPipeline
|
| 4 |
-
from langchain_core.messages import SystemMessage, HumanMessage
|
| 5 |
from langchain_core.tools import Tool
|
| 6 |
from langgraph.graph import MessagesState, StateGraph, START
|
| 7 |
from langgraph.prebuilt import tools_condition, ToolNode
|
| 8 |
from langgraph.checkpoint.memory import MemorySaver
|
| 9 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 10 |
from tavily import TavilyClient
|
| 11 |
-
import warnings
|
|
|
|
| 12 |
from fastapi import FastAPI, HTTPException
|
| 13 |
from fastapi.responses import StreamingResponse
|
| 14 |
from pydantic import BaseModel
|
|
@@ -33,61 +34,32 @@ assert GOOGLE_API_KEY and TAVILY_API_KEY, "Missing API keys in environment."
|
|
| 33 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 34 |
print(f"Using device: {device}")
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
tavily = TavilyClient(api_key=TAVILY_API_KEY)
|
| 41 |
-
search_tool = Tool.from_function(name="TavilySearch",
|
| 42 |
-
func=lambda q: tavily.search(q, max_results=3),
|
| 43 |
-
description="Fetch factual current information from the web. Input should be a search query.")
|
| 44 |
-
|
| 45 |
-
llm_tools = llm.bind_tools([search_tool])
|
| 46 |
-
|
| 47 |
-
print("Loading Kokoro TTS model...")
|
| 48 |
-
tts_pipeline = KPipeline(lang_code="a", device=device, repo_id="hexgrad/Kokoro-82M")
|
| 49 |
-
print("Kokoro TTS model loaded.")
|
| 50 |
|
|
|
|
|
|
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
if not msgs:
|
| 55 |
-
return state
|
| 56 |
-
sys = SystemMessage(
|
| 57 |
-
"You are Jarvis, a helpful and concise AI assistant. "
|
| 58 |
-
"Your responses should be brief, informative, and directly answer the user's query. "
|
| 59 |
-
"Aim for responses around 1 to 3 sentences. Maximum 60 tokens output. " # Increased slightly for tool summarization
|
| 60 |
-
"If you use tools, search web for any factual information needed, or any information which is time sensitive, like today's news and latest model releases. "
|
| 61 |
-
"After receiving tool results, you MUST summarize them or state you couldn't find the information based on the tool output."
|
| 62 |
-
)
|
| 63 |
-
resp = llm_tools.invoke([sys] + msgs[-5:])
|
| 64 |
-
state["messages"] = msgs + [resp]
|
| 65 |
-
return state
|
| 66 |
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
builder.add_edge(START, "assistant")
|
| 71 |
-
builder.add_conditional_edges("assistant", tools_condition)
|
| 72 |
-
builder.add_edge("tools", "assistant")
|
| 73 |
-
agent_graph = builder.compile(MemorySaver())
|
| 74 |
-
|
| 75 |
-
def query_agent(prompt: str, thread_id: str) -> str:
|
| 76 |
-
payload = {"messages": [HumanMessage(content=prompt)]}
|
| 77 |
-
cfg = {"configurable": {"thread_id": thread_id}}
|
| 78 |
-
out = agent_graph.invoke(payload, cfg)
|
| 79 |
-
return out["messages"][-1].content.strip()
|
| 80 |
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
duration_ms = 100
|
| 87 |
num_samples = int(KOKORO_SAMPLE_RATE * (duration_ms / 1000.0))
|
| 88 |
audio_data_int16 = np.zeros(num_samples, dtype=np.int16)
|
| 89 |
else:
|
| 90 |
-
print(f"Generating speech for: '{text}'")
|
| 91 |
audio_chunks = []
|
| 92 |
try:
|
| 93 |
for _, _, audio_segment in tts_pipeline(text, voice="af_heart", speed=1.3):
|
|
@@ -98,7 +70,7 @@ def generate_speech_audio(text: str) -> io.BytesIO:
|
|
| 98 |
print(f"Warning: TTS produced an empty or non-1D audio segment. Shape: {audio_segment.shape if hasattr(audio_segment, 'shape') else 'N/A'}")
|
| 99 |
else:
|
| 100 |
print(f"Warning: TTS produced None audio segment for part of text: '{text}'")
|
| 101 |
-
|
| 102 |
if not audio_chunks:
|
| 103 |
print(f"Warning: TTS produced no valid audio for text: '{text}'. Generating silence.")
|
| 104 |
duration_ms = 100
|
|
@@ -121,12 +93,129 @@ def generate_speech_audio(text: str) -> io.BytesIO:
|
|
| 121 |
wf.setsampwidth(2)
|
| 122 |
wf.setframerate(KOKORO_SAMPLE_RATE)
|
| 123 |
wf.writeframes(audio_data_int16.tobytes())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
wav_buffer.seek(0)
|
| 125 |
-
print(f"Generated WAV audio of size: {len(wav_buffer.getvalue())} bytes")
|
| 126 |
return wav_buffer
|
| 127 |
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
app.add_middleware(
|
| 132 |
CORSMiddleware,
|
|
@@ -146,7 +235,7 @@ async def process_command(payload: TextInput):
|
|
| 146 |
|
| 147 |
if not payload.text or payload.text.isspace():
|
| 148 |
print("Empty text received, responding with silent audio.")
|
| 149 |
-
silent_audio_wav = await asyncio.to_thread(
|
| 150 |
return StreamingResponse(silent_audio_wav, media_type="audio/wav")
|
| 151 |
|
| 152 |
try:
|
|
@@ -154,10 +243,10 @@ async def process_command(payload: TextInput):
|
|
| 154 |
print(f"LLM response: '{llm_response_text}'")
|
| 155 |
|
| 156 |
if not llm_response_text or llm_response_text.isspace():
|
| 157 |
-
print("LLM returned empty response, generating
|
| 158 |
llm_response_text = "I don't have a response for that."
|
| 159 |
|
| 160 |
-
audio_wav_buffer = await asyncio.to_thread(
|
| 161 |
return StreamingResponse(audio_wav_buffer, media_type="audio/wav")
|
| 162 |
|
| 163 |
except Exception as e:
|
|
@@ -165,9 +254,14 @@ async def process_command(payload: TextInput):
|
|
| 165 |
traceback.print_exc()
|
| 166 |
error_message_text = "Sorry, I encountered an error."
|
| 167 |
try:
|
| 168 |
-
error_audio_buffer = await asyncio.to_thread(
|
| 169 |
return StreamingResponse(error_audio_buffer, media_type="audio/wav", status_code=500)
|
| 170 |
except Exception as audio_err:
|
| 171 |
print(f"Critical error generating error audio: {audio_err}")
|
|
|
|
| 172 |
raise HTTPException(status_code=500, detail="Internal server error during audio generation")
|
| 173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import torch
|
| 3 |
from kokoro import KPipeline
|
| 4 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
| 5 |
from langchain_core.tools import Tool
|
| 6 |
from langgraph.graph import MessagesState, StateGraph, START
|
| 7 |
from langgraph.prebuilt import tools_condition, ToolNode
|
| 8 |
from langgraph.checkpoint.memory import MemorySaver
|
| 9 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 10 |
from tavily import TavilyClient
|
| 11 |
+
import warnings
|
| 12 |
+
import logging
|
| 13 |
from fastapi import FastAPI, HTTPException
|
| 14 |
from fastapi.responses import StreamingResponse
|
| 15 |
from pydantic import BaseModel
|
|
|
|
| 34 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 35 |
print(f"Using device: {device}")
|
| 36 |
|
| 37 |
+
TTS_CACHE = {}
|
| 38 |
+
KOKORO_SAMPLE_RATE = 24000
|
| 39 |
+
tts_pipeline: KPipeline # To be initialized later
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
+
def generate_speech_audio_optimized(text: str) -> io.BytesIO:
|
| 42 |
+
global TTS_CACHE, KOKORO_SAMPLE_RATE, tts_pipeline
|
| 43 |
|
| 44 |
+
cache_key_to_check = text
|
| 45 |
+
is_silent_request = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
+
if not text or text.isspace():
|
| 48 |
+
cache_key_to_check = "SILENT_AUDIO_OPTIMIZED"
|
| 49 |
+
is_silent_request = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
+
if cache_key_to_check in TTS_CACHE:
|
| 52 |
+
cached_buffer = io.BytesIO(TTS_CACHE[cache_key_to_check])
|
| 53 |
+
cached_buffer.seek(0)
|
| 54 |
+
return cached_buffer
|
| 55 |
|
| 56 |
+
audio_data_int16: np.ndarray
|
| 57 |
+
|
| 58 |
+
if is_silent_request:
|
| 59 |
duration_ms = 100
|
| 60 |
num_samples = int(KOKORO_SAMPLE_RATE * (duration_ms / 1000.0))
|
| 61 |
audio_data_int16 = np.zeros(num_samples, dtype=np.int16)
|
| 62 |
else:
|
|
|
|
| 63 |
audio_chunks = []
|
| 64 |
try:
|
| 65 |
for _, _, audio_segment in tts_pipeline(text, voice="af_heart", speed=1.3):
|
|
|
|
| 70 |
print(f"Warning: TTS produced an empty or non-1D audio segment. Shape: {audio_segment.shape if hasattr(audio_segment, 'shape') else 'N/A'}")
|
| 71 |
else:
|
| 72 |
print(f"Warning: TTS produced None audio segment for part of text: '{text}'")
|
| 73 |
+
|
| 74 |
if not audio_chunks:
|
| 75 |
print(f"Warning: TTS produced no valid audio for text: '{text}'. Generating silence.")
|
| 76 |
duration_ms = 100
|
|
|
|
| 93 |
wf.setsampwidth(2)
|
| 94 |
wf.setframerate(KOKORO_SAMPLE_RATE)
|
| 95 |
wf.writeframes(audio_data_int16.tobytes())
|
| 96 |
+
|
| 97 |
+
wav_data_bytes = wav_buffer.getvalue()
|
| 98 |
+
|
| 99 |
+
texts_to_always_cache = [
|
| 100 |
+
"Sorry, I encountered an error.",
|
| 101 |
+
"I don't have a response for that."
|
| 102 |
+
]
|
| 103 |
+
current_cache_target_key = "SILENT_AUDIO_OPTIMIZED" if is_silent_request else text
|
| 104 |
+
|
| 105 |
+
if current_cache_target_key == "SILENT_AUDIO_OPTIMIZED" and current_cache_target_key not in TTS_CACHE:
|
| 106 |
+
TTS_CACHE[current_cache_target_key] = wav_data_bytes
|
| 107 |
+
elif current_cache_target_key in texts_to_always_cache and current_cache_target_key not in TTS_CACHE:
|
| 108 |
+
TTS_CACHE[current_cache_target_key] = wav_data_bytes
|
| 109 |
+
|
| 110 |
wav_buffer.seek(0)
|
|
|
|
| 111 |
return wav_buffer
|
| 112 |
|
| 113 |
+
def pre_populate_tts_cache_on_startup():
|
| 114 |
+
global TTS_CACHE, KOKORO_SAMPLE_RATE, tts_pipeline
|
| 115 |
+
|
| 116 |
+
if not hasattr(tts_pipeline, '__call__'):
|
| 117 |
+
print("Error: TTS pipeline not available for pre-caching.")
|
| 118 |
+
return
|
| 119 |
|
| 120 |
+
print("Pre-populating TTS cache...")
|
| 121 |
+
|
| 122 |
+
common_texts_to_pre_cache = [
|
| 123 |
+
"Sorry, I encountered an error.",
|
| 124 |
+
"I don't have a response for that."
|
| 125 |
+
]
|
| 126 |
+
|
| 127 |
+
for text_to_cache in common_texts_to_pre_cache:
|
| 128 |
+
if text_to_cache not in TTS_CACHE:
|
| 129 |
+
try:
|
| 130 |
+
audio_chunks = []
|
| 131 |
+
for _, _, audio_segment in tts_pipeline(text_to_cache, voice="af_heart", speed=1.3):
|
| 132 |
+
if audio_segment is not None and audio_segment.ndim == 1 and audio_segment.numel() > 0:
|
| 133 |
+
audio_chunks.append(audio_segment)
|
| 134 |
+
|
| 135 |
+
if not audio_chunks:
|
| 136 |
+
duration_ms = 100
|
| 137 |
+
num_samples = int(KOKORO_SAMPLE_RATE * (duration_ms / 1000.0))
|
| 138 |
+
audio_data_int16_precached = np.zeros(num_samples, dtype=np.int16)
|
| 139 |
+
else:
|
| 140 |
+
full_audio_data_tensor = torch.cat(audio_chunks)
|
| 141 |
+
audio_data_int16_precached = (full_audio_data_tensor.cpu().numpy() * 32767).astype(np.int16)
|
| 142 |
+
|
| 143 |
+
wav_buffer_temp = io.BytesIO()
|
| 144 |
+
with wave.open(wav_buffer_temp, 'wb') as wf:
|
| 145 |
+
wf.setnchannels(1)
|
| 146 |
+
wf.setsampwidth(2)
|
| 147 |
+
wf.setframerate(KOKORO_SAMPLE_RATE)
|
| 148 |
+
wf.writeframes(audio_data_int16_precached.tobytes())
|
| 149 |
+
TTS_CACHE[text_to_cache] = wav_buffer_temp.getvalue()
|
| 150 |
+
except Exception as e:
|
| 151 |
+
print(f"Error pre-caching TTS for '{text_to_cache}': {e}")
|
| 152 |
+
traceback.print_exc()
|
| 153 |
+
|
| 154 |
+
if "SILENT_AUDIO_OPTIMIZED" not in TTS_CACHE:
|
| 155 |
+
try:
|
| 156 |
+
duration_ms = 100
|
| 157 |
+
num_samples = int(KOKORO_SAMPLE_RATE * (duration_ms / 1000.0))
|
| 158 |
+
audio_data_int16_silent = np.zeros(num_samples, dtype=np.int16)
|
| 159 |
+
wav_buffer_silent = io.BytesIO()
|
| 160 |
+
with wave.open(wav_buffer_silent, 'wb') as wf:
|
| 161 |
+
wf.setnchannels(1)
|
| 162 |
+
wf.setsampwidth(2)
|
| 163 |
+
wf.setframerate(KOKORO_SAMPLE_RATE)
|
| 164 |
+
wf.writeframes(audio_data_int16_silent.tobytes())
|
| 165 |
+
TTS_CACHE["SILENT_AUDIO_OPTIMIZED"] = wav_buffer_silent.getvalue()
|
| 166 |
+
except Exception as e:
|
| 167 |
+
print(f"Error pre-caching silent audio: {e}")
|
| 168 |
+
traceback.print_exc()
|
| 169 |
+
print("TTS cache pre-population finished.")
|
| 170 |
+
|
| 171 |
+
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash",
|
| 172 |
+
temperature=0.6, max_tokens=60,
|
| 173 |
+
google_api_key=GOOGLE_API_KEY)
|
| 174 |
+
|
| 175 |
+
tavily = TavilyClient(api_key=TAVILY_API_KEY)
|
| 176 |
+
search_tool = Tool.from_function(name="TavilySearch",
|
| 177 |
+
func=lambda q: tavily.search(q, max_results=3),
|
| 178 |
+
description="Fetch factual current information from the web. Input should be a search query.")
|
| 179 |
+
|
| 180 |
+
llm_tools = llm.bind_tools([search_tool])
|
| 181 |
+
|
| 182 |
+
def assistant_node(state: MessagesState) -> MessagesState:
|
| 183 |
+
msgs = state.get("messages", [])
|
| 184 |
+
if not msgs:
|
| 185 |
+
return state
|
| 186 |
+
sys = SystemMessage(
|
| 187 |
+
"You are Jarvis, a helpful and concise AI assistant. "
|
| 188 |
+
"Your responses should be brief, informative, and directly answer the user's query. "
|
| 189 |
+
"Aim for responses around 1 to 3 sentences. Maximum 60 tokens output. "
|
| 190 |
+
"If you use tools, search web for any factual information needed, or any information which is time sensitive, like today's news and latest model releases. "
|
| 191 |
+
"After receiving tool results, you MUST summarize them or state you couldn't find the information based on the tool output."
|
| 192 |
+
)
|
| 193 |
+
resp = llm_tools.invoke([sys] + msgs[-5:])
|
| 194 |
+
current_messages = state.get("messages", []) # Re-fetch in case state was an empty dict initially
|
| 195 |
+
state["messages"] = current_messages + [resp]
|
| 196 |
+
return state
|
| 197 |
+
|
| 198 |
+
builder = StateGraph(MessagesState)
|
| 199 |
+
builder.add_node("assistant", assistant_node)
|
| 200 |
+
builder.add_node("tools", ToolNode([search_tool]))
|
| 201 |
+
builder.add_edge(START, "assistant")
|
| 202 |
+
builder.add_conditional_edges("assistant", tools_condition)
|
| 203 |
+
builder.add_edge("tools", "assistant")
|
| 204 |
+
agent_graph = builder.compile(checkpointer=MemorySaver())
|
| 205 |
+
|
| 206 |
+
def query_agent(prompt: str, thread_id: str) -> str:
|
| 207 |
+
payload = {"messages": [HumanMessage(content=prompt)]}
|
| 208 |
+
cfg = {"configurable": {"thread_id": thread_id}}
|
| 209 |
+
out = agent_graph.invoke(payload, cfg)
|
| 210 |
+
return out["messages"][-1].content.strip()
|
| 211 |
+
|
| 212 |
+
print("Loading Kokoro TTS model...")
|
| 213 |
+
tts_pipeline = KPipeline(lang_code="a", device=device, repo_id="hexgrad/Kokoro-82M")
|
| 214 |
+
print("Kokoro TTS model loaded.")
|
| 215 |
+
|
| 216 |
+
pre_populate_tts_cache_on_startup()
|
| 217 |
+
|
| 218 |
+
app = FastAPI()
|
| 219 |
|
| 220 |
app.add_middleware(
|
| 221 |
CORSMiddleware,
|
|
|
|
| 235 |
|
| 236 |
if not payload.text or payload.text.isspace():
|
| 237 |
print("Empty text received, responding with silent audio.")
|
| 238 |
+
silent_audio_wav = await asyncio.to_thread(generate_speech_audio_optimized, "")
|
| 239 |
return StreamingResponse(silent_audio_wav, media_type="audio/wav")
|
| 240 |
|
| 241 |
try:
|
|
|
|
| 243 |
print(f"LLM response: '{llm_response_text}'")
|
| 244 |
|
| 245 |
if not llm_response_text or llm_response_text.isspace():
|
| 246 |
+
print("LLM returned empty response, generating canned audio.")
|
| 247 |
llm_response_text = "I don't have a response for that."
|
| 248 |
|
| 249 |
+
audio_wav_buffer = await asyncio.to_thread(generate_speech_audio_optimized, llm_response_text)
|
| 250 |
return StreamingResponse(audio_wav_buffer, media_type="audio/wav")
|
| 251 |
|
| 252 |
except Exception as e:
|
|
|
|
| 254 |
traceback.print_exc()
|
| 255 |
error_message_text = "Sorry, I encountered an error."
|
| 256 |
try:
|
| 257 |
+
error_audio_buffer = await asyncio.to_thread(generate_speech_audio_optimized, error_message_text)
|
| 258 |
return StreamingResponse(error_audio_buffer, media_type="audio/wav", status_code=500)
|
| 259 |
except Exception as audio_err:
|
| 260 |
print(f"Critical error generating error audio: {audio_err}")
|
| 261 |
+
traceback.print_exc()
|
| 262 |
raise HTTPException(status_code=500, detail="Internal server error during audio generation")
|
| 263 |
|
| 264 |
+
if __name__ == "__main__":
|
| 265 |
+
import uvicorn
|
| 266 |
+
print("Starting FastAPI server with Uvicorn...")
|
| 267 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|