Aakash jammula commited on
Commit
88edcbb
·
1 Parent(s): 241379e
Files changed (1) hide show
  1. app.py +149 -55
app.py CHANGED
@@ -1,14 +1,15 @@
1
  import os
2
  import torch
3
  from kokoro import KPipeline
4
- from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage # Ensure AIMessage and ToolMessage are available if needed directly
5
  from langchain_core.tools import Tool
6
  from langgraph.graph import MessagesState, StateGraph, START
7
  from langgraph.prebuilt import tools_condition, ToolNode
8
  from langgraph.checkpoint.memory import MemorySaver
9
  from langchain_google_genai import ChatGoogleGenerativeAI
10
  from tavily import TavilyClient
11
- import warnings, logging
 
12
  from fastapi import FastAPI, HTTPException
13
  from fastapi.responses import StreamingResponse
14
  from pydantic import BaseModel
@@ -33,61 +34,32 @@ assert GOOGLE_API_KEY and TAVILY_API_KEY, "Missing API keys in environment."
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
34
  print(f"Using device: {device}")
35
 
36
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash",
37
- temperature=0.6, max_tokens=60,
38
- google_api_key=GOOGLE_API_KEY)
39
-
40
- tavily = TavilyClient(api_key=TAVILY_API_KEY)
41
- search_tool = Tool.from_function(name="TavilySearch",
42
- func=lambda q: tavily.search(q, max_results=3),
43
- description="Fetch factual current information from the web. Input should be a search query.")
44
-
45
- llm_tools = llm.bind_tools([search_tool])
46
-
47
- print("Loading Kokoro TTS model...")
48
- tts_pipeline = KPipeline(lang_code="a", device=device, repo_id="hexgrad/Kokoro-82M")
49
- print("Kokoro TTS model loaded.")
50
 
 
 
51
 
52
- def assistant_node(state: MessagesState) -> MessagesState:
53
- msgs = state.get("messages", [])
54
- if not msgs:
55
- return state
56
- sys = SystemMessage(
57
- "You are Jarvis, a helpful and concise AI assistant. "
58
- "Your responses should be brief, informative, and directly answer the user's query. "
59
- "Aim for responses around 1 to 3 sentences. Maximum 60 tokens output. " # Increased slightly for tool summarization
60
- "If you use tools, search web for any factual information needed, or any information which is time sensitive, like today's news and latest model releases. "
61
- "After receiving tool results, you MUST summarize them or state you couldn't find the information based on the tool output."
62
- )
63
- resp = llm_tools.invoke([sys] + msgs[-5:])
64
- state["messages"] = msgs + [resp]
65
- return state
66
 
67
- builder = StateGraph(MessagesState)
68
- builder.add_node("assistant", assistant_node)
69
- builder.add_node("tools", ToolNode([search_tool]))
70
- builder.add_edge(START, "assistant")
71
- builder.add_conditional_edges("assistant", tools_condition)
72
- builder.add_edge("tools", "assistant")
73
- agent_graph = builder.compile(MemorySaver())
74
-
75
- def query_agent(prompt: str, thread_id: str) -> str:
76
- payload = {"messages": [HumanMessage(content=prompt)]}
77
- cfg = {"configurable": {"thread_id": thread_id}}
78
- out = agent_graph.invoke(payload, cfg)
79
- return out["messages"][-1].content.strip()
80
 
81
- KOKORO_SAMPLE_RATE = 24000
 
 
 
82
 
83
- def generate_speech_audio(text: str) -> io.BytesIO:
84
- if not text or text.isspace():
85
- print("Empty text received for TTS, generating silent audio.")
86
  duration_ms = 100
87
  num_samples = int(KOKORO_SAMPLE_RATE * (duration_ms / 1000.0))
88
  audio_data_int16 = np.zeros(num_samples, dtype=np.int16)
89
  else:
90
- print(f"Generating speech for: '{text}'")
91
  audio_chunks = []
92
  try:
93
  for _, _, audio_segment in tts_pipeline(text, voice="af_heart", speed=1.3):
@@ -98,7 +70,7 @@ def generate_speech_audio(text: str) -> io.BytesIO:
98
  print(f"Warning: TTS produced an empty or non-1D audio segment. Shape: {audio_segment.shape if hasattr(audio_segment, 'shape') else 'N/A'}")
99
  else:
100
  print(f"Warning: TTS produced None audio segment for part of text: '{text}'")
101
-
102
  if not audio_chunks:
103
  print(f"Warning: TTS produced no valid audio for text: '{text}'. Generating silence.")
104
  duration_ms = 100
@@ -121,12 +93,129 @@ def generate_speech_audio(text: str) -> io.BytesIO:
121
  wf.setsampwidth(2)
122
  wf.setframerate(KOKORO_SAMPLE_RATE)
123
  wf.writeframes(audio_data_int16.tobytes())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  wav_buffer.seek(0)
125
- print(f"Generated WAV audio of size: {len(wav_buffer.getvalue())} bytes")
126
  return wav_buffer
127
 
128
- app = FastAPI()
 
 
 
 
 
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  app.add_middleware(
132
  CORSMiddleware,
@@ -146,7 +235,7 @@ async def process_command(payload: TextInput):
146
 
147
  if not payload.text or payload.text.isspace():
148
  print("Empty text received, responding with silent audio.")
149
- silent_audio_wav = await asyncio.to_thread(generate_speech_audio, "")
150
  return StreamingResponse(silent_audio_wav, media_type="audio/wav")
151
 
152
  try:
@@ -154,10 +243,10 @@ async def process_command(payload: TextInput):
154
  print(f"LLM response: '{llm_response_text}'")
155
 
156
  if not llm_response_text or llm_response_text.isspace():
157
- print("LLM returned empty response, generating silent audio.")
158
  llm_response_text = "I don't have a response for that."
159
 
160
- audio_wav_buffer = await asyncio.to_thread(generate_speech_audio, llm_response_text)
161
  return StreamingResponse(audio_wav_buffer, media_type="audio/wav")
162
 
163
  except Exception as e:
@@ -165,9 +254,14 @@ async def process_command(payload: TextInput):
165
  traceback.print_exc()
166
  error_message_text = "Sorry, I encountered an error."
167
  try:
168
- error_audio_buffer = await asyncio.to_thread(generate_speech_audio, error_message_text)
169
  return StreamingResponse(error_audio_buffer, media_type="audio/wav", status_code=500)
170
  except Exception as audio_err:
171
  print(f"Critical error generating error audio: {audio_err}")
 
172
  raise HTTPException(status_code=500, detail="Internal server error during audio generation")
173
 
 
 
 
 
 
1
  import os
2
  import torch
3
  from kokoro import KPipeline
4
+ from langchain_core.messages import SystemMessage, HumanMessage
5
  from langchain_core.tools import Tool
6
  from langgraph.graph import MessagesState, StateGraph, START
7
  from langgraph.prebuilt import tools_condition, ToolNode
8
  from langgraph.checkpoint.memory import MemorySaver
9
  from langchain_google_genai import ChatGoogleGenerativeAI
10
  from tavily import TavilyClient
11
+ import warnings
12
+ import logging
13
  from fastapi import FastAPI, HTTPException
14
  from fastapi.responses import StreamingResponse
15
  from pydantic import BaseModel
 
34
  device = "cuda" if torch.cuda.is_available() else "cpu"
35
  print(f"Using device: {device}")
36
 
37
+ TTS_CACHE = {}
38
+ KOKORO_SAMPLE_RATE = 24000
39
+ tts_pipeline: KPipeline # To be initialized later
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ def generate_speech_audio_optimized(text: str) -> io.BytesIO:
42
+ global TTS_CACHE, KOKORO_SAMPLE_RATE, tts_pipeline
43
 
44
+ cache_key_to_check = text
45
+ is_silent_request = False
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ if not text or text.isspace():
48
+ cache_key_to_check = "SILENT_AUDIO_OPTIMIZED"
49
+ is_silent_request = True
 
 
 
 
 
 
 
 
 
 
50
 
51
+ if cache_key_to_check in TTS_CACHE:
52
+ cached_buffer = io.BytesIO(TTS_CACHE[cache_key_to_check])
53
+ cached_buffer.seek(0)
54
+ return cached_buffer
55
 
56
+ audio_data_int16: np.ndarray
57
+
58
+ if is_silent_request:
59
  duration_ms = 100
60
  num_samples = int(KOKORO_SAMPLE_RATE * (duration_ms / 1000.0))
61
  audio_data_int16 = np.zeros(num_samples, dtype=np.int16)
62
  else:
 
63
  audio_chunks = []
64
  try:
65
  for _, _, audio_segment in tts_pipeline(text, voice="af_heart", speed=1.3):
 
70
  print(f"Warning: TTS produced an empty or non-1D audio segment. Shape: {audio_segment.shape if hasattr(audio_segment, 'shape') else 'N/A'}")
71
  else:
72
  print(f"Warning: TTS produced None audio segment for part of text: '{text}'")
73
+
74
  if not audio_chunks:
75
  print(f"Warning: TTS produced no valid audio for text: '{text}'. Generating silence.")
76
  duration_ms = 100
 
93
  wf.setsampwidth(2)
94
  wf.setframerate(KOKORO_SAMPLE_RATE)
95
  wf.writeframes(audio_data_int16.tobytes())
96
+
97
+ wav_data_bytes = wav_buffer.getvalue()
98
+
99
+ texts_to_always_cache = [
100
+ "Sorry, I encountered an error.",
101
+ "I don't have a response for that."
102
+ ]
103
+ current_cache_target_key = "SILENT_AUDIO_OPTIMIZED" if is_silent_request else text
104
+
105
+ if current_cache_target_key == "SILENT_AUDIO_OPTIMIZED" and current_cache_target_key not in TTS_CACHE:
106
+ TTS_CACHE[current_cache_target_key] = wav_data_bytes
107
+ elif current_cache_target_key in texts_to_always_cache and current_cache_target_key not in TTS_CACHE:
108
+ TTS_CACHE[current_cache_target_key] = wav_data_bytes
109
+
110
  wav_buffer.seek(0)
 
111
  return wav_buffer
112
 
113
+ def pre_populate_tts_cache_on_startup():
114
+ global TTS_CACHE, KOKORO_SAMPLE_RATE, tts_pipeline
115
+
116
+ if not hasattr(tts_pipeline, '__call__'):
117
+ print("Error: TTS pipeline not available for pre-caching.")
118
+ return
119
 
120
+ print("Pre-populating TTS cache...")
121
+
122
+ common_texts_to_pre_cache = [
123
+ "Sorry, I encountered an error.",
124
+ "I don't have a response for that."
125
+ ]
126
+
127
+ for text_to_cache in common_texts_to_pre_cache:
128
+ if text_to_cache not in TTS_CACHE:
129
+ try:
130
+ audio_chunks = []
131
+ for _, _, audio_segment in tts_pipeline(text_to_cache, voice="af_heart", speed=1.3):
132
+ if audio_segment is not None and audio_segment.ndim == 1 and audio_segment.numel() > 0:
133
+ audio_chunks.append(audio_segment)
134
+
135
+ if not audio_chunks:
136
+ duration_ms = 100
137
+ num_samples = int(KOKORO_SAMPLE_RATE * (duration_ms / 1000.0))
138
+ audio_data_int16_precached = np.zeros(num_samples, dtype=np.int16)
139
+ else:
140
+ full_audio_data_tensor = torch.cat(audio_chunks)
141
+ audio_data_int16_precached = (full_audio_data_tensor.cpu().numpy() * 32767).astype(np.int16)
142
+
143
+ wav_buffer_temp = io.BytesIO()
144
+ with wave.open(wav_buffer_temp, 'wb') as wf:
145
+ wf.setnchannels(1)
146
+ wf.setsampwidth(2)
147
+ wf.setframerate(KOKORO_SAMPLE_RATE)
148
+ wf.writeframes(audio_data_int16_precached.tobytes())
149
+ TTS_CACHE[text_to_cache] = wav_buffer_temp.getvalue()
150
+ except Exception as e:
151
+ print(f"Error pre-caching TTS for '{text_to_cache}': {e}")
152
+ traceback.print_exc()
153
+
154
+ if "SILENT_AUDIO_OPTIMIZED" not in TTS_CACHE:
155
+ try:
156
+ duration_ms = 100
157
+ num_samples = int(KOKORO_SAMPLE_RATE * (duration_ms / 1000.0))
158
+ audio_data_int16_silent = np.zeros(num_samples, dtype=np.int16)
159
+ wav_buffer_silent = io.BytesIO()
160
+ with wave.open(wav_buffer_silent, 'wb') as wf:
161
+ wf.setnchannels(1)
162
+ wf.setsampwidth(2)
163
+ wf.setframerate(KOKORO_SAMPLE_RATE)
164
+ wf.writeframes(audio_data_int16_silent.tobytes())
165
+ TTS_CACHE["SILENT_AUDIO_OPTIMIZED"] = wav_buffer_silent.getvalue()
166
+ except Exception as e:
167
+ print(f"Error pre-caching silent audio: {e}")
168
+ traceback.print_exc()
169
+ print("TTS cache pre-population finished.")
170
+
171
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash",
172
+ temperature=0.6, max_tokens=60,
173
+ google_api_key=GOOGLE_API_KEY)
174
+
175
+ tavily = TavilyClient(api_key=TAVILY_API_KEY)
176
+ search_tool = Tool.from_function(name="TavilySearch",
177
+ func=lambda q: tavily.search(q, max_results=3),
178
+ description="Fetch factual current information from the web. Input should be a search query.")
179
+
180
+ llm_tools = llm.bind_tools([search_tool])
181
+
182
+ def assistant_node(state: MessagesState) -> MessagesState:
183
+ msgs = state.get("messages", [])
184
+ if not msgs:
185
+ return state
186
+ sys = SystemMessage(
187
+ "You are Jarvis, a helpful and concise AI assistant. "
188
+ "Your responses should be brief, informative, and directly answer the user's query. "
189
+ "Aim for responses around 1 to 3 sentences. Maximum 60 tokens output. "
190
+ "If you use tools, search web for any factual information needed, or any information which is time sensitive, like today's news and latest model releases. "
191
+ "After receiving tool results, you MUST summarize them or state you couldn't find the information based on the tool output."
192
+ )
193
+ resp = llm_tools.invoke([sys] + msgs[-5:])
194
+ current_messages = state.get("messages", []) # Re-fetch in case state was an empty dict initially
195
+ state["messages"] = current_messages + [resp]
196
+ return state
197
+
198
+ builder = StateGraph(MessagesState)
199
+ builder.add_node("assistant", assistant_node)
200
+ builder.add_node("tools", ToolNode([search_tool]))
201
+ builder.add_edge(START, "assistant")
202
+ builder.add_conditional_edges("assistant", tools_condition)
203
+ builder.add_edge("tools", "assistant")
204
+ agent_graph = builder.compile(checkpointer=MemorySaver())
205
+
206
+ def query_agent(prompt: str, thread_id: str) -> str:
207
+ payload = {"messages": [HumanMessage(content=prompt)]}
208
+ cfg = {"configurable": {"thread_id": thread_id}}
209
+ out = agent_graph.invoke(payload, cfg)
210
+ return out["messages"][-1].content.strip()
211
+
212
+ print("Loading Kokoro TTS model...")
213
+ tts_pipeline = KPipeline(lang_code="a", device=device, repo_id="hexgrad/Kokoro-82M")
214
+ print("Kokoro TTS model loaded.")
215
+
216
+ pre_populate_tts_cache_on_startup()
217
+
218
+ app = FastAPI()
219
 
220
  app.add_middleware(
221
  CORSMiddleware,
 
235
 
236
  if not payload.text or payload.text.isspace():
237
  print("Empty text received, responding with silent audio.")
238
+ silent_audio_wav = await asyncio.to_thread(generate_speech_audio_optimized, "")
239
  return StreamingResponse(silent_audio_wav, media_type="audio/wav")
240
 
241
  try:
 
243
  print(f"LLM response: '{llm_response_text}'")
244
 
245
  if not llm_response_text or llm_response_text.isspace():
246
+ print("LLM returned empty response, generating canned audio.")
247
  llm_response_text = "I don't have a response for that."
248
 
249
+ audio_wav_buffer = await asyncio.to_thread(generate_speech_audio_optimized, llm_response_text)
250
  return StreamingResponse(audio_wav_buffer, media_type="audio/wav")
251
 
252
  except Exception as e:
 
254
  traceback.print_exc()
255
  error_message_text = "Sorry, I encountered an error."
256
  try:
257
+ error_audio_buffer = await asyncio.to_thread(generate_speech_audio_optimized, error_message_text)
258
  return StreamingResponse(error_audio_buffer, media_type="audio/wav", status_code=500)
259
  except Exception as audio_err:
260
  print(f"Critical error generating error audio: {audio_err}")
261
+ traceback.print_exc()
262
  raise HTTPException(status_code=500, detail="Internal server error during audio generation")
263
 
264
+ if __name__ == "__main__":
265
+ import uvicorn
266
+ print("Starting FastAPI server with Uvicorn...")
267
+ uvicorn.run(app, host="0.0.0.0", port=8000)