Commit ·
0c0c279
1
Parent(s): 148a917
Added support for streaming gemini responses
Browse files- main.py +5 -4
- src/manager/manager.py +46 -34
main.py
CHANGED
|
@@ -151,10 +151,11 @@ css = """
|
|
| 151 |
|
| 152 |
def run_model(message, history):
|
| 153 |
if 'text' in message:
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
|
|
|
| 158 |
if 'files' in message:
|
| 159 |
for file in message['files']:
|
| 160 |
history.append({
|
|
|
|
| 151 |
|
| 152 |
def run_model(message, history):
|
| 153 |
if 'text' in message:
|
| 154 |
+
if message['text'].strip() != "":
|
| 155 |
+
history.append({
|
| 156 |
+
"role": "user",
|
| 157 |
+
"content": message['text']
|
| 158 |
+
})
|
| 159 |
if 'files' in message:
|
| 160 |
for file in message['files']:
|
| 161 |
history.append({
|
src/manager/manager.py
CHANGED
|
@@ -80,12 +80,12 @@ class GeminiManager:
|
|
| 80 |
return mode in self.modes
|
| 81 |
|
| 82 |
@backoff.on_exception(backoff.expo,
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
def generate_response(self, messages):
|
| 87 |
tools = self.toolsLoader.getTools()
|
| 88 |
-
return self.client.models.
|
| 89 |
model=self.model_name,
|
| 90 |
contents=messages,
|
| 91 |
config=types.GenerateContentConfig(
|
|
@@ -95,10 +95,10 @@ class GeminiManager:
|
|
| 95 |
),
|
| 96 |
)
|
| 97 |
|
| 98 |
-
def handle_tool_calls(self,
|
| 99 |
parts = []
|
| 100 |
i = 0
|
| 101 |
-
for function_call in
|
| 102 |
title = ""
|
| 103 |
thinking = ""
|
| 104 |
toolResponse = None
|
|
@@ -199,9 +199,9 @@ class GeminiManager:
|
|
| 199 |
parts = [types.Part.from_text(
|
| 200 |
text=message.get("content", ""))]
|
| 201 |
case "memories":
|
| 202 |
-
role = "
|
| 203 |
parts = [types.Part.from_text(
|
| 204 |
-
text="
|
| 205 |
case "tool":
|
| 206 |
role = "tool"
|
| 207 |
formatted_history.append(
|
|
@@ -273,8 +273,39 @@ class GeminiManager:
|
|
| 273 |
def invoke_manager(self, messages):
|
| 274 |
chat_history = self.format_chat_history(messages)
|
| 275 |
logger.debug(f"Chat history: {chat_history}")
|
|
|
|
| 276 |
try:
|
| 277 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
except Exception as e:
|
| 279 |
messages.append({
|
| 280 |
"role": "assistant",
|
|
@@ -284,41 +315,22 @@ class GeminiManager:
|
|
| 284 |
logger.error(f"Error generating response{e}")
|
| 285 |
yield messages
|
| 286 |
return messages
|
| 287 |
-
logger.debug(f"Response: {response}")
|
| 288 |
|
| 289 |
-
if
|
|
|
|
| 290 |
messages.append({
|
| 291 |
"role": "assistant",
|
| 292 |
"content": "No response from the model.",
|
| 293 |
"metadata": {"title": "No response from the model."}
|
| 294 |
})
|
| 295 |
-
print(response)
|
| 296 |
-
yield messages
|
| 297 |
-
return messages
|
| 298 |
-
|
| 299 |
-
# Attach the llm response to the messages
|
| 300 |
-
if response.text is not None and response.text != "":
|
| 301 |
-
messages.append({
|
| 302 |
-
"role": "assistant",
|
| 303 |
-
"content": response.text
|
| 304 |
-
})
|
| 305 |
yield messages
|
| 306 |
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
if candidate.content and candidate.content.parts:
|
| 310 |
-
# messages.append(response.candidates[0].content)
|
| 311 |
-
messages.append({
|
| 312 |
-
"role": "function_call",
|
| 313 |
-
"content": repr(candidate.content),
|
| 314 |
-
})
|
| 315 |
-
|
| 316 |
-
# Invoke the function calls if any and attach the response to the messages
|
| 317 |
-
if response.function_calls:
|
| 318 |
-
for call in self.handle_tool_calls(response):
|
| 319 |
yield messages + [call]
|
| 320 |
if (call.get("role") == "tool"
|
| 321 |
or (call.get("role") == "assistant" and call.get("metadata", {}).get("status") == "done")):
|
| 322 |
messages.append(call)
|
| 323 |
yield from self.invoke_manager(messages)
|
| 324 |
-
|
|
|
|
|
|
| 80 |
return mode in self.modes
|
| 81 |
|
| 82 |
@backoff.on_exception(backoff.expo,
|
| 83 |
+
APIError,
|
| 84 |
+
max_tries=3,
|
| 85 |
+
jitter=None)
|
| 86 |
def generate_response(self, messages):
|
| 87 |
tools = self.toolsLoader.getTools()
|
| 88 |
+
return self.client.models.generate_content_stream(
|
| 89 |
model=self.model_name,
|
| 90 |
contents=messages,
|
| 91 |
config=types.GenerateContentConfig(
|
|
|
|
| 95 |
),
|
| 96 |
)
|
| 97 |
|
| 98 |
+
def handle_tool_calls(self, function_calls):
|
| 99 |
parts = []
|
| 100 |
i = 0
|
| 101 |
+
for function_call in function_calls:
|
| 102 |
title = ""
|
| 103 |
thinking = ""
|
| 104 |
toolResponse = None
|
|
|
|
| 199 |
parts = [types.Part.from_text(
|
| 200 |
text=message.get("content", ""))]
|
| 201 |
case "memories":
|
| 202 |
+
role = "user"
|
| 203 |
parts = [types.Part.from_text(
|
| 204 |
+
text="Here are the relevant memories for the user's query: "+message.get("content", ""))]
|
| 205 |
case "tool":
|
| 206 |
role = "tool"
|
| 207 |
formatted_history.append(
|
|
|
|
| 273 |
def invoke_manager(self, messages):
|
| 274 |
chat_history = self.format_chat_history(messages)
|
| 275 |
logger.debug(f"Chat history: {chat_history}")
|
| 276 |
+
print(f"Chat history: {chat_history}")
|
| 277 |
try:
|
| 278 |
+
response_stream = suppress_output(
|
| 279 |
+
self.generate_response)(chat_history)
|
| 280 |
+
full_text = "" # Accumulate the text from the stream
|
| 281 |
+
function_calls = []
|
| 282 |
+
function_call_requests = []
|
| 283 |
+
for chunk in response_stream:
|
| 284 |
+
print(chunk)
|
| 285 |
+
if chunk.text:
|
| 286 |
+
full_text += chunk.text
|
| 287 |
+
yield messages + [{
|
| 288 |
+
"role": "assistant",
|
| 289 |
+
"content": chunk.text
|
| 290 |
+
}]
|
| 291 |
+
for candidate in chunk.candidates:
|
| 292 |
+
if candidate.content and candidate.content.parts:
|
| 293 |
+
# messages.append(response.candidates[0].content)
|
| 294 |
+
function_call_requests.append({
|
| 295 |
+
"role": "function_call",
|
| 296 |
+
"content": repr(candidate.content),
|
| 297 |
+
})
|
| 298 |
+
for part in candidate.content.parts:
|
| 299 |
+
if part.function_call:
|
| 300 |
+
function_calls.append(part.function_call)
|
| 301 |
+
if full_text.strip() != "":
|
| 302 |
+
messages.append({
|
| 303 |
+
"role": "assistant",
|
| 304 |
+
"content": full_text,
|
| 305 |
+
})
|
| 306 |
+
if function_call_requests:
|
| 307 |
+
messages = messages + function_call_requests
|
| 308 |
+
yield messages
|
| 309 |
except Exception as e:
|
| 310 |
messages.append({
|
| 311 |
"role": "assistant",
|
|
|
|
| 315 |
logger.error(f"Error generating response{e}")
|
| 316 |
yield messages
|
| 317 |
return messages
|
|
|
|
| 318 |
|
| 319 |
+
# Check if any text was received
|
| 320 |
+
if not full_text and len(function_calls) == 0:
|
| 321 |
messages.append({
|
| 322 |
"role": "assistant",
|
| 323 |
"content": "No response from the model.",
|
| 324 |
"metadata": {"title": "No response from the model."}
|
| 325 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
yield messages
|
| 327 |
|
| 328 |
+
if function_calls and len(function_calls) > 0:
|
| 329 |
+
for call in self.handle_tool_calls(function_calls):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
yield messages + [call]
|
| 331 |
if (call.get("role") == "tool"
|
| 332 |
or (call.get("role") == "assistant" and call.get("metadata", {}).get("status") == "done")):
|
| 333 |
messages.append(call)
|
| 334 |
yield from self.invoke_manager(messages)
|
| 335 |
+
else:
|
| 336 |
+
yield messages
|