Spaces:
Paused
Paused
frdel commited on
Commit ·
f3f8ca5
1
Parent(s): 551c95a
Merge branch 'pr/491' into development
Browse files- agent.py +75 -63
- initialize.py +22 -4
- models.py +435 -352
- preload.py +5 -1
- prompts/agent0/agent.system.tool.response.md +1 -0
- prompts/default/agent.system.tool.call_sub.md +10 -47
- python/extensions/reasoning_stream/.gitkeep +0 -0
- python/extensions/reasoning_stream/_10_log_from_stream.py +29 -0
- python/extensions/response_stream/_10_log_from_stream.py +25 -6
- python/helpers/document_query.py +22 -13
- python/helpers/history.py +11 -4
- python/helpers/memory.py +1 -2
- python/tools/browser_agent.py +19 -15
- requirements.txt +3 -9
- test.py +74 -0
- webui/index.css +5 -0
- webui/js/messages.js +7 -5
agent.py
CHANGED
|
@@ -343,19 +343,28 @@ class Agent:
|
|
| 343 |
# create log message right away, more responsive
|
| 344 |
self.loop_data.params_temporary["log_item_generating"] = (
|
| 345 |
self.context.log.log(
|
| 346 |
-
type="agent", heading=f"{self.agent_name}:
|
| 347 |
)
|
| 348 |
)
|
| 349 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
async def stream_callback(chunk: str, full: str):
|
| 351 |
# output the agent response stream
|
| 352 |
-
if chunk:
|
| 353 |
-
printer.
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
|
|
|
|
|
|
|
|
|
| 359 |
|
| 360 |
await self.handle_intervention(agent_response)
|
| 361 |
|
|
@@ -409,7 +418,7 @@ class Agent:
|
|
| 409 |
# call monologue_end extensions
|
| 410 |
await self.call_extensions("monologue_end", loop_data=self.loop_data) # type: ignore
|
| 411 |
|
| 412 |
-
async def prepare_prompt(self, loop_data: LoopData) ->
|
| 413 |
self.context.log.set_progress("Building prompt")
|
| 414 |
|
| 415 |
# call extensions before setting prompts
|
|
@@ -422,12 +431,10 @@ class Agent:
|
|
| 422 |
# and allow extensions to edit them
|
| 423 |
await self.call_extensions("message_loop_prompts_after", loop_data=loop_data)
|
| 424 |
|
| 425 |
-
#
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
#
|
| 429 |
-
# for extra in loop_data.extras_temporary.values():
|
| 430 |
-
# extras += history.Message(False, content=extra).output()
|
| 431 |
extras = history.Message(
|
| 432 |
False,
|
| 433 |
content=self.read_prompt(
|
|
@@ -444,28 +451,23 @@ class Agent:
|
|
| 444 |
loop_data.history_output + extras
|
| 445 |
)
|
| 446 |
|
| 447 |
-
# build
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
# AIMessage(content="JSON:"), # force the LLM to start with json
|
| 454 |
-
]
|
| 455 |
-
)
|
| 456 |
|
| 457 |
# store as last context window content
|
| 458 |
self.set_data(
|
| 459 |
Agent.DATA_NAME_CTX_WINDOW,
|
| 460 |
{
|
| 461 |
-
"text":
|
| 462 |
-
"tokens":
|
| 463 |
-
+ tokens.approximate_tokens(system_text)
|
| 464 |
-
+ tokens.approximate_tokens(history.output_text(extras)),
|
| 465 |
},
|
| 466 |
)
|
| 467 |
|
| 468 |
-
return
|
| 469 |
|
| 470 |
def handle_critical_exception(self, exception: Exception):
|
| 471 |
if isinstance(exception, HandledException):
|
|
@@ -586,24 +588,21 @@ class Agent:
|
|
| 586 |
return self.history.output_text(human_label="user", ai_label="assistant")
|
| 587 |
|
| 588 |
def get_chat_model(self):
|
| 589 |
-
return models.
|
| 590 |
-
models.ModelType.CHAT,
|
| 591 |
self.config.chat_model.provider,
|
| 592 |
self.config.chat_model.name,
|
| 593 |
**self.config.chat_model.kwargs,
|
| 594 |
)
|
| 595 |
|
| 596 |
def get_utility_model(self):
|
| 597 |
-
return models.
|
| 598 |
-
models.ModelType.CHAT,
|
| 599 |
self.config.utility_model.provider,
|
| 600 |
self.config.utility_model.name,
|
| 601 |
**self.config.utility_model.kwargs,
|
| 602 |
)
|
| 603 |
|
| 604 |
def get_embedding_model(self):
|
| 605 |
-
return models.
|
| 606 |
-
models.ModelType.EMBEDDING,
|
| 607 |
self.config.embeddings_model.provider,
|
| 608 |
self.config.embeddings_model.name,
|
| 609 |
**self.config.embeddings_model.kwargs,
|
|
@@ -616,36 +615,37 @@ class Agent:
|
|
| 616 |
callback: Callable[[str], Awaitable[None]] | None = None,
|
| 617 |
background: bool = False,
|
| 618 |
):
|
| 619 |
-
prompt = ChatPromptTemplate.from_messages(
|
| 620 |
-
[SystemMessage(content=system), HumanMessage(content=message)]
|
| 621 |
-
)
|
| 622 |
-
|
| 623 |
-
response = ""
|
| 624 |
-
|
| 625 |
-
# model class
|
| 626 |
model = self.get_utility_model()
|
| 627 |
|
| 628 |
# rate limiter
|
| 629 |
limiter = await self.rate_limiter(
|
| 630 |
-
self.config.utility_model,
|
| 631 |
)
|
| 632 |
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
limiter.add(output=tokens.approximate_tokens(content))
|
| 638 |
-
response += content
|
| 639 |
|
|
|
|
|
|
|
| 640 |
if callback:
|
| 641 |
-
await callback(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 642 |
|
| 643 |
return response
|
| 644 |
|
| 645 |
async def call_chat_model(
|
| 646 |
self,
|
| 647 |
-
|
| 648 |
-
|
|
|
|
| 649 |
):
|
| 650 |
response = ""
|
| 651 |
|
|
@@ -653,19 +653,24 @@ class Agent:
|
|
| 653 |
model = self.get_chat_model()
|
| 654 |
|
| 655 |
# rate limiter
|
| 656 |
-
limiter = await self.rate_limiter(
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
await self.handle_intervention() # wait for intervention and handle it, if paused
|
| 660 |
-
|
| 661 |
-
content = models.parse_chunk(chunk)
|
| 662 |
-
limiter.add(output=tokens.approximate_tokens(content))
|
| 663 |
-
response += content
|
| 664 |
|
| 665 |
-
|
| 666 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 667 |
|
| 668 |
-
return response
|
| 669 |
|
| 670 |
async def rate_limiter(
|
| 671 |
self, model_config: ModelConfig, input: str, background: bool = False
|
|
@@ -786,6 +791,13 @@ class Agent:
|
|
| 786 |
content=f"{self.agent_name}: Message misformat, no valid tool request found.",
|
| 787 |
)
|
| 788 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 789 |
async def handle_response_stream(self, stream: str):
|
| 790 |
try:
|
| 791 |
if len(stream) < 25:
|
|
|
|
| 343 |
# create log message right away, more responsive
|
| 344 |
self.loop_data.params_temporary["log_item_generating"] = (
|
| 345 |
self.context.log.log(
|
| 346 |
+
type="agent", heading=f"{self.agent_name}: Generating..."
|
| 347 |
)
|
| 348 |
)
|
| 349 |
|
| 350 |
+
async def reasoning_callback(chunk: str, full: str):
|
| 351 |
+
if chunk == full:
|
| 352 |
+
printer.print("Reasoning: ") # start of reasoning
|
| 353 |
+
printer.stream(chunk)
|
| 354 |
+
await self.handle_reasoning_stream(full)
|
| 355 |
+
|
| 356 |
async def stream_callback(chunk: str, full: str):
|
| 357 |
# output the agent response stream
|
| 358 |
+
if chunk == full:
|
| 359 |
+
printer.print("Response: ") # start of response
|
| 360 |
+
printer.stream(chunk)
|
| 361 |
+
await self.handle_response_stream(full)
|
| 362 |
+
|
| 363 |
+
agent_response, _reasoning = await self.call_chat_model(
|
| 364 |
+
messages=prompt,
|
| 365 |
+
response_callback=stream_callback,
|
| 366 |
+
reasoning_callback=reasoning_callback,
|
| 367 |
+
)
|
| 368 |
|
| 369 |
await self.handle_intervention(agent_response)
|
| 370 |
|
|
|
|
| 418 |
# call monologue_end extensions
|
| 419 |
await self.call_extensions("monologue_end", loop_data=self.loop_data) # type: ignore
|
| 420 |
|
| 421 |
+
async def prepare_prompt(self, loop_data: LoopData) -> list[BaseMessage]:
|
| 422 |
self.context.log.set_progress("Building prompt")
|
| 423 |
|
| 424 |
# call extensions before setting prompts
|
|
|
|
| 431 |
# and allow extensions to edit them
|
| 432 |
await self.call_extensions("message_loop_prompts_after", loop_data=loop_data)
|
| 433 |
|
| 434 |
+
# concatenate system prompt
|
| 435 |
+
system_text = "\n\n".join(loop_data.system)
|
| 436 |
+
|
| 437 |
+
# join extras
|
|
|
|
|
|
|
| 438 |
extras = history.Message(
|
| 439 |
False,
|
| 440 |
content=self.read_prompt(
|
|
|
|
| 451 |
loop_data.history_output + extras
|
| 452 |
)
|
| 453 |
|
| 454 |
+
# build full prompt from system prompt, message history and extrS
|
| 455 |
+
full_prompt: list[BaseMessage] = [
|
| 456 |
+
SystemMessage(content=system_text),
|
| 457 |
+
*history_langchain,
|
| 458 |
+
]
|
| 459 |
+
full_text = ChatPromptTemplate.from_messages(full_prompt).format()
|
|
|
|
|
|
|
|
|
|
| 460 |
|
| 461 |
# store as last context window content
|
| 462 |
self.set_data(
|
| 463 |
Agent.DATA_NAME_CTX_WINDOW,
|
| 464 |
{
|
| 465 |
+
"text": full_text,
|
| 466 |
+
"tokens": tokens.approximate_tokens(full_text),
|
|
|
|
|
|
|
| 467 |
},
|
| 468 |
)
|
| 469 |
|
| 470 |
+
return full_prompt
|
| 471 |
|
| 472 |
def handle_critical_exception(self, exception: Exception):
|
| 473 |
if isinstance(exception, HandledException):
|
|
|
|
| 588 |
return self.history.output_text(human_label="user", ai_label="assistant")
|
| 589 |
|
| 590 |
def get_chat_model(self):
|
| 591 |
+
return models.get_chat_model(
|
|
|
|
| 592 |
self.config.chat_model.provider,
|
| 593 |
self.config.chat_model.name,
|
| 594 |
**self.config.chat_model.kwargs,
|
| 595 |
)
|
| 596 |
|
| 597 |
def get_utility_model(self):
|
| 598 |
+
return models.get_chat_model(
|
|
|
|
| 599 |
self.config.utility_model.provider,
|
| 600 |
self.config.utility_model.name,
|
| 601 |
**self.config.utility_model.kwargs,
|
| 602 |
)
|
| 603 |
|
| 604 |
def get_embedding_model(self):
|
| 605 |
+
return models.get_embedding_model(
|
|
|
|
| 606 |
self.config.embeddings_model.provider,
|
| 607 |
self.config.embeddings_model.name,
|
| 608 |
**self.config.embeddings_model.kwargs,
|
|
|
|
| 615 |
callback: Callable[[str], Awaitable[None]] | None = None,
|
| 616 |
background: bool = False,
|
| 617 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 618 |
model = self.get_utility_model()
|
| 619 |
|
| 620 |
# rate limiter
|
| 621 |
limiter = await self.rate_limiter(
|
| 622 |
+
self.config.utility_model, f"SYSTEM: {system}\nUSER: {message}", background
|
| 623 |
)
|
| 624 |
|
| 625 |
+
# add output tokens to rate limiter in tokens callback
|
| 626 |
+
async def tokens_callback(delta: str, tokens: int):
|
| 627 |
+
await self.handle_intervention()
|
| 628 |
+
limiter.add(output=tokens)
|
|
|
|
|
|
|
| 629 |
|
| 630 |
+
# propagate stream to callback if set
|
| 631 |
+
async def stream_callback(chunk: str, total: str):
|
| 632 |
if callback:
|
| 633 |
+
await callback(chunk)
|
| 634 |
+
|
| 635 |
+
response, _reasoning = await model.unified_call(
|
| 636 |
+
system_message=system,
|
| 637 |
+
user_message=message,
|
| 638 |
+
response_callback=stream_callback,
|
| 639 |
+
tokens_callback=tokens_callback,
|
| 640 |
+
)
|
| 641 |
|
| 642 |
return response
|
| 643 |
|
| 644 |
async def call_chat_model(
|
| 645 |
self,
|
| 646 |
+
messages: list[BaseMessage],
|
| 647 |
+
response_callback: Callable[[str, str], Awaitable[None]] | None = None,
|
| 648 |
+
reasoning_callback: Callable[[str, str], Awaitable[None]] | None = None,
|
| 649 |
):
|
| 650 |
response = ""
|
| 651 |
|
|
|
|
| 653 |
model = self.get_chat_model()
|
| 654 |
|
| 655 |
# rate limiter
|
| 656 |
+
limiter = await self.rate_limiter(
|
| 657 |
+
self.config.chat_model, ChatPromptTemplate.from_messages(messages).format()
|
| 658 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 659 |
|
| 660 |
+
# add output tokens to rate limiter in tokens callback
|
| 661 |
+
async def tokens_callback(delta: str, tokens: int):
|
| 662 |
+
await self.handle_intervention()
|
| 663 |
+
limiter.add(output=tokens)
|
| 664 |
+
|
| 665 |
+
# call model
|
| 666 |
+
response, reasoning = await model.unified_call(
|
| 667 |
+
messages=messages,
|
| 668 |
+
reasoning_callback=reasoning_callback,
|
| 669 |
+
response_callback=response_callback,
|
| 670 |
+
tokens_callback=tokens_callback,
|
| 671 |
+
)
|
| 672 |
|
| 673 |
+
return response, reasoning
|
| 674 |
|
| 675 |
async def rate_limiter(
|
| 676 |
self, model_config: ModelConfig, input: str, background: bool = False
|
|
|
|
| 791 |
content=f"{self.agent_name}: Message misformat, no valid tool request found.",
|
| 792 |
)
|
| 793 |
|
| 794 |
+
async def handle_reasoning_stream(self, stream: str):
|
| 795 |
+
await self.call_extensions(
|
| 796 |
+
"reasoning_stream",
|
| 797 |
+
loop_data=self.loop_data,
|
| 798 |
+
text=stream,
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
async def handle_response_stream(self, stream: str):
|
| 802 |
try:
|
| 803 |
if len(stream) < 25:
|
initialize.py
CHANGED
|
@@ -7,6 +7,24 @@ from python.helpers.print_style import PrintStyle
|
|
| 7 |
def initialize_agent():
|
| 8 |
current_settings = settings.get_settings()
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
# chat model from user settings
|
| 11 |
chat_llm = ModelConfig(
|
| 12 |
provider=models.ModelProvider[current_settings["chat_model_provider"]],
|
|
@@ -16,7 +34,7 @@ def initialize_agent():
|
|
| 16 |
limit_requests=current_settings["chat_model_rl_requests"],
|
| 17 |
limit_input=current_settings["chat_model_rl_input"],
|
| 18 |
limit_output=current_settings["chat_model_rl_output"],
|
| 19 |
-
kwargs=current_settings["chat_model_kwargs"],
|
| 20 |
)
|
| 21 |
|
| 22 |
# utility model from user settings
|
|
@@ -27,21 +45,21 @@ def initialize_agent():
|
|
| 27 |
limit_requests=current_settings["util_model_rl_requests"],
|
| 28 |
limit_input=current_settings["util_model_rl_input"],
|
| 29 |
limit_output=current_settings["util_model_rl_output"],
|
| 30 |
-
kwargs=current_settings["util_model_kwargs"],
|
| 31 |
)
|
| 32 |
# embedding model from user settings
|
| 33 |
embedding_llm = ModelConfig(
|
| 34 |
provider=models.ModelProvider[current_settings["embed_model_provider"]],
|
| 35 |
name=current_settings["embed_model_name"],
|
| 36 |
limit_requests=current_settings["embed_model_rl_requests"],
|
| 37 |
-
kwargs=current_settings["embed_model_kwargs"],
|
| 38 |
)
|
| 39 |
# browser model from user settings
|
| 40 |
browser_llm = ModelConfig(
|
| 41 |
provider=models.ModelProvider[current_settings["browser_model_provider"]],
|
| 42 |
name=current_settings["browser_model_name"],
|
| 43 |
vision=current_settings["browser_model_vision"],
|
| 44 |
-
kwargs=current_settings["browser_model_kwargs"],
|
| 45 |
)
|
| 46 |
# agent configuration
|
| 47 |
config = AgentConfig(
|
|
|
|
| 7 |
def initialize_agent():
|
| 8 |
current_settings = settings.get_settings()
|
| 9 |
|
| 10 |
+
def _normalize_model_kwargs(kwargs: dict) -> dict:
|
| 11 |
+
# convert string values that represent valid Python numbers to numeric types
|
| 12 |
+
result = {}
|
| 13 |
+
for key, value in kwargs.items():
|
| 14 |
+
if isinstance(value, str):
|
| 15 |
+
# try to convert string to number if it's a valid Python number
|
| 16 |
+
try:
|
| 17 |
+
# try int first, then float
|
| 18 |
+
result[key] = int(value)
|
| 19 |
+
except ValueError:
|
| 20 |
+
try:
|
| 21 |
+
result[key] = float(value)
|
| 22 |
+
except ValueError:
|
| 23 |
+
result[key] = value
|
| 24 |
+
else:
|
| 25 |
+
result[key] = value
|
| 26 |
+
return result
|
| 27 |
+
|
| 28 |
# chat model from user settings
|
| 29 |
chat_llm = ModelConfig(
|
| 30 |
provider=models.ModelProvider[current_settings["chat_model_provider"]],
|
|
|
|
| 34 |
limit_requests=current_settings["chat_model_rl_requests"],
|
| 35 |
limit_input=current_settings["chat_model_rl_input"],
|
| 36 |
limit_output=current_settings["chat_model_rl_output"],
|
| 37 |
+
kwargs=_normalize_model_kwargs(current_settings["chat_model_kwargs"]),
|
| 38 |
)
|
| 39 |
|
| 40 |
# utility model from user settings
|
|
|
|
| 45 |
limit_requests=current_settings["util_model_rl_requests"],
|
| 46 |
limit_input=current_settings["util_model_rl_input"],
|
| 47 |
limit_output=current_settings["util_model_rl_output"],
|
| 48 |
+
kwargs=_normalize_model_kwargs(current_settings["util_model_kwargs"]),
|
| 49 |
)
|
| 50 |
# embedding model from user settings
|
| 51 |
embedding_llm = ModelConfig(
|
| 52 |
provider=models.ModelProvider[current_settings["embed_model_provider"]],
|
| 53 |
name=current_settings["embed_model_name"],
|
| 54 |
limit_requests=current_settings["embed_model_rl_requests"],
|
| 55 |
+
kwargs=_normalize_model_kwargs(current_settings["embed_model_kwargs"]),
|
| 56 |
)
|
| 57 |
# browser model from user settings
|
| 58 |
browser_llm = ModelConfig(
|
| 59 |
provider=models.ModelProvider[current_settings["browser_model_provider"]],
|
| 60 |
name=current_settings["browser_model_name"],
|
| 61 |
vision=current_settings["browser_model_vision"],
|
| 62 |
+
kwargs=_normalize_model_kwargs(current_settings["browser_model_kwargs"]),
|
| 63 |
)
|
| 64 |
# agent configuration
|
| 65 |
config = AgentConfig(
|
models.py
CHANGED
|
@@ -1,38 +1,37 @@
|
|
| 1 |
from enum import Enum
|
| 2 |
import os
|
| 3 |
-
from typing import
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
| 11 |
)
|
| 12 |
-
from langchain_community.llms.ollama import Ollama
|
| 13 |
-
from langchain_ollama import ChatOllama
|
| 14 |
-
from langchain_community.embeddings import OllamaEmbeddings
|
| 15 |
-
from langchain_anthropic import ChatAnthropic
|
| 16 |
-
from langchain_groq import ChatGroq
|
| 17 |
-
from langchain_huggingface import (
|
| 18 |
-
HuggingFaceEmbeddings,
|
| 19 |
-
ChatHuggingFace,
|
| 20 |
-
HuggingFaceEndpoint,
|
| 21 |
-
)
|
| 22 |
-
from langchain_google_genai import (
|
| 23 |
-
ChatGoogleGenerativeAI,
|
| 24 |
-
HarmBlockThreshold,
|
| 25 |
-
HarmCategory,
|
| 26 |
-
embeddings as google_embeddings,
|
| 27 |
-
)
|
| 28 |
-
from langchain_mistralai import ChatMistralAI
|
| 29 |
|
| 30 |
-
|
| 31 |
-
from python.helpers import dotenv
|
| 32 |
from python.helpers.dotenv import load_dotenv
|
| 33 |
from python.helpers.rate_limiter import RateLimiter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
-
# environment variables
|
| 36 |
load_dotenv()
|
| 37 |
|
| 38 |
|
|
@@ -52,40 +51,71 @@ class ModelProvider(Enum):
|
|
| 52 |
MISTRALAI = "Mistral AI"
|
| 53 |
OLLAMA = "Ollama"
|
| 54 |
OPENAI = "OpenAI"
|
| 55 |
-
|
| 56 |
OPENROUTER = "OpenRouter"
|
| 57 |
SAMBANOVA = "Sambanova"
|
| 58 |
OTHER = "Other"
|
| 59 |
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
rate_limiters: dict[str, RateLimiter] = {}
|
| 62 |
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
return (
|
| 67 |
dotenv.get_dotenv_value(f"API_KEY_{service.upper()}")
|
| 68 |
or dotenv.get_dotenv_value(f"{service.upper()}_API_KEY")
|
| 69 |
-
or dotenv.get_dotenv_value(
|
| 70 |
-
f"{service.upper()}_API_TOKEN"
|
| 71 |
-
) # Added for CHUTES_API_TOKEN
|
| 72 |
or "None"
|
| 73 |
)
|
| 74 |
|
| 75 |
|
| 76 |
-
def get_model(type: ModelType, provider: ModelProvider, name: str, **kwargs):
|
| 77 |
-
fnc_name = f"get_{provider.name.lower()}_{type.name.lower()}" # function name of model getter
|
| 78 |
-
model = globals()[fnc_name](name, **kwargs) # call function by name
|
| 79 |
-
return model
|
| 80 |
-
|
| 81 |
-
|
| 82 |
def get_rate_limiter(
|
| 83 |
provider: ModelProvider, name: str, requests: int, input: int, output: int
|
| 84 |
) -> RateLimiter:
|
| 85 |
-
# get or create
|
| 86 |
key = f"{provider.name}\\{name}"
|
| 87 |
rate_limiters[key] = limiter = rate_limiters.get(key, RateLimiter(seconds=60))
|
| 88 |
-
# always update
|
| 89 |
limiter.limits["requests"] = requests or 0
|
| 90 |
limiter.limits["input"] = input or 0
|
| 91 |
limiter.limits["output"] = output or 0
|
|
@@ -102,332 +132,385 @@ def parse_chunk(chunk: Any):
|
|
| 102 |
return content
|
| 103 |
|
| 104 |
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
num_ctx=8192,
|
| 117 |
-
**kwargs,
|
| 118 |
-
):
|
| 119 |
-
if not base_url:
|
| 120 |
-
base_url = get_ollama_base_url()
|
| 121 |
-
return ChatOllama(
|
| 122 |
-
model=model_name,
|
| 123 |
-
base_url=base_url,
|
| 124 |
-
num_ctx=num_ctx,
|
| 125 |
-
**kwargs,
|
| 126 |
)
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
base_url=None,
|
| 132 |
-
num_ctx=8192,
|
| 133 |
-
**kwargs,
|
| 134 |
-
):
|
| 135 |
-
if not base_url:
|
| 136 |
-
base_url = get_ollama_base_url()
|
| 137 |
-
return OllamaEmbeddings(
|
| 138 |
-
model=model_name, base_url=base_url, num_ctx=num_ctx, **kwargs
|
| 139 |
)
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
model_name: str
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
)
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
#
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
)
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
def
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
)
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
if not api_key:
|
| 283 |
-
api_key = get_api_key("google")
|
| 284 |
-
return google_embeddings.GoogleGenerativeAIEmbeddings(model=model_name, google_api_key=api_key, **kwargs) # type: ignore
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
# Mistral models
|
| 288 |
-
def get_mistralai_chat(
|
| 289 |
-
model_name: str,
|
| 290 |
-
api_key=None,
|
| 291 |
-
**kwargs,
|
| 292 |
-
):
|
| 293 |
-
if not api_key:
|
| 294 |
-
api_key = get_api_key("mistral")
|
| 295 |
-
return ChatMistralAI(model=model_name, api_key=api_key, **kwargs) # type: ignore
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
# Groq models
|
| 299 |
-
def get_groq_chat(
|
| 300 |
-
model_name: str,
|
| 301 |
-
api_key=None,
|
| 302 |
-
**kwargs,
|
| 303 |
-
):
|
| 304 |
-
if not api_key:
|
| 305 |
-
api_key = get_api_key("groq")
|
| 306 |
-
return ChatGroq(model_name=model_name, api_key=api_key, **kwargs) # type: ignore
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
# DeepSeek models
|
| 310 |
-
def get_deepseek_chat(
|
| 311 |
-
model_name: str,
|
| 312 |
-
api_key=None,
|
| 313 |
-
base_url=None,
|
| 314 |
-
**kwargs,
|
| 315 |
-
):
|
| 316 |
-
if not api_key:
|
| 317 |
-
api_key = get_api_key("deepseek")
|
| 318 |
-
if not base_url:
|
| 319 |
-
base_url = (
|
| 320 |
-
dotenv.get_dotenv_value("DEEPSEEK_BASE_URL") or "https://api.deepseek.com"
|
| 321 |
)
|
| 322 |
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
)
|
| 340 |
-
|
| 341 |
-
api_key=api_key, # type: ignore
|
| 342 |
-
model=model_name,
|
| 343 |
-
base_url=base_url,
|
| 344 |
-
stream_usage=True,
|
| 345 |
-
model_kwargs={
|
| 346 |
-
"extra_headers": {
|
| 347 |
-
"HTTP-Referer": "https://agent-zero.ai",
|
| 348 |
-
"X-Title": "Agent Zero",
|
| 349 |
-
}
|
| 350 |
-
},
|
| 351 |
-
**kwargs,
|
| 352 |
-
)
|
| 353 |
|
| 354 |
|
| 355 |
-
def
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
**kwargs,
|
| 360 |
):
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
|
| 370 |
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
):
|
| 379 |
-
if not api_key:
|
| 380 |
-
api_key = get_api_key("sambanova")
|
| 381 |
-
if not base_url:
|
| 382 |
-
base_url = (
|
| 383 |
-
dotenv.get_dotenv_value("SAMBANOVA_BASE_URL")
|
| 384 |
-
or "https://fast-api.snova.ai/v1"
|
| 385 |
-
)
|
| 386 |
-
return ChatOpenAI(api_key=api_key, model=model_name, base_url=base_url, max_tokens=max_tokens, **kwargs) # type: ignore
|
| 387 |
|
| 388 |
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
if not base_url:
|
| 399 |
-
base_url = (
|
| 400 |
-
dotenv.get_dotenv_value("SAMBANOVA_BASE_URL")
|
| 401 |
-
or "https://fast-api.snova.ai/v1"
|
| 402 |
-
)
|
| 403 |
-
return OpenAIEmbeddings(model=model_name, api_key=api_key, base_url=base_url, **kwargs) # type: ignore
|
| 404 |
|
| 405 |
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
**kwargs
|
| 412 |
-
|
| 413 |
-
return ChatOpenAI(api_key=api_key, model=model_name, base_url=base_url, **kwargs) # type: ignore
|
| 414 |
|
| 415 |
|
| 416 |
-
def
|
| 417 |
-
return
|
| 418 |
|
| 419 |
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
model_name: str,
|
| 423 |
-
api_key=None,
|
| 424 |
-
base_url=None,
|
| 425 |
-
**kwargs,
|
| 426 |
-
):
|
| 427 |
-
if not api_key:
|
| 428 |
-
api_key = get_api_key("chutes")
|
| 429 |
-
if not base_url:
|
| 430 |
-
base_url = (
|
| 431 |
-
dotenv.get_dotenv_value("CHUTES_BASE_URL") or "https://llm.chutes.ai/v1"
|
| 432 |
-
)
|
| 433 |
-
return ChatOpenAI(api_key=api_key, model=model_name, base_url=base_url, **kwargs) # type: ignore
|
|
|
|
| 1 |
from enum import Enum
|
| 2 |
import os
|
| 3 |
+
from typing import (
|
| 4 |
+
Any,
|
| 5 |
+
Awaitable,
|
| 6 |
+
Callable,
|
| 7 |
+
List,
|
| 8 |
+
Optional,
|
| 9 |
+
Iterator,
|
| 10 |
+
AsyncIterator,
|
| 11 |
+
Tuple,
|
| 12 |
+
TypedDict,
|
| 13 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
from litellm import completion, acompletion, embedding
|
| 16 |
+
from python.helpers import dotenv
|
| 17 |
from python.helpers.dotenv import load_dotenv
|
| 18 |
from python.helpers.rate_limiter import RateLimiter
|
| 19 |
+
from python.helpers.tokens import approximate_tokens
|
| 20 |
+
|
| 21 |
+
from langchain_core.language_models.chat_models import SimpleChatModel
|
| 22 |
+
from langchain_core.outputs.chat_generation import ChatGenerationChunk
|
| 23 |
+
from langchain_core.callbacks.manager import (
|
| 24 |
+
CallbackManagerForLLMRun,
|
| 25 |
+
AsyncCallbackManagerForLLMRun,
|
| 26 |
+
)
|
| 27 |
+
from langchain_core.messages import (
|
| 28 |
+
BaseMessage,
|
| 29 |
+
AIMessageChunk,
|
| 30 |
+
HumanMessage,
|
| 31 |
+
SystemMessage,
|
| 32 |
+
)
|
| 33 |
+
from langchain.embeddings.base import Embeddings
|
| 34 |
|
|
|
|
| 35 |
load_dotenv()
|
| 36 |
|
| 37 |
|
|
|
|
| 51 |
MISTRALAI = "Mistral AI"
|
| 52 |
OLLAMA = "Ollama"
|
| 53 |
OPENAI = "OpenAI"
|
| 54 |
+
AZURE = "OpenAI Azure"
|
| 55 |
OPENROUTER = "OpenRouter"
|
| 56 |
SAMBANOVA = "Sambanova"
|
| 57 |
OTHER = "Other"
|
| 58 |
|
| 59 |
|
| 60 |
+
class ChatChunk(TypedDict):
|
| 61 |
+
"""Simplified response chunk for chat models."""
|
| 62 |
+
|
| 63 |
+
response_delta: str
|
| 64 |
+
reasoning_delta: str
|
| 65 |
+
|
| 66 |
+
|
| 67 |
rate_limiters: dict[str, RateLimiter] = {}
|
| 68 |
|
| 69 |
|
| 70 |
+
def configure_litellm_environment():
|
| 71 |
+
env_mappings = {
|
| 72 |
+
"API_KEY_OPENAI": "OPENAI_API_KEY",
|
| 73 |
+
"API_KEY_ANTHROPIC": "ANTHROPIC_API_KEY",
|
| 74 |
+
"API_KEY_GROQ": "GROQ_API_KEY",
|
| 75 |
+
"API_KEY_GOOGLE": "GOOGLE_API_KEY",
|
| 76 |
+
"API_KEY_MISTRAL": "MISTRAL_API_KEY",
|
| 77 |
+
"API_KEY_OLLAMA": "OLLAMA_API_KEY",
|
| 78 |
+
"API_KEY_HUGGINGFACE": "HUGGINGFACE_API_KEY",
|
| 79 |
+
"API_KEY_OPENAI_AZURE": "AZURE_API_KEY",
|
| 80 |
+
"API_KEY_DEEPSEEK": "DEEPSEEK_API_KEY",
|
| 81 |
+
"API_KEY_SAMBANOVA": "SAMBANOVA_API_KEY",
|
| 82 |
+
}
|
| 83 |
+
base_url_mappings = {
|
| 84 |
+
"OPENAI_BASE_URL": "OPENAI_API_BASE",
|
| 85 |
+
"ANTHROPIC_BASE_URL": "ANTHROPIC_API_BASE",
|
| 86 |
+
"GROQ_BASE_URL": "GROQ_API_BASE",
|
| 87 |
+
"GOOGLE_BASE_URL": "GOOGLE_API_BASE",
|
| 88 |
+
"MISTRAL_BASE_URL": "MISTRAL_API_BASE",
|
| 89 |
+
"OLLAMA_BASE_URL": "OLLAMA_API_BASE",
|
| 90 |
+
"HUGGINGFACE_BASE_URL": "HUGGINGFACE_API_BASE",
|
| 91 |
+
"AZURE_BASE_URL": "AZURE_API_BASE",
|
| 92 |
+
"DEEPSEEK_BASE_URL": "DEEPSEEK_API_BASE",
|
| 93 |
+
"SAMBANOVA_BASE_URL": "SAMBANOVA_API_BASE",
|
| 94 |
+
}
|
| 95 |
+
for a0, llm in env_mappings.items():
|
| 96 |
+
val = dotenv.get_dotenv_value(a0)
|
| 97 |
+
if val and not os.getenv(llm):
|
| 98 |
+
os.environ[llm] = val
|
| 99 |
+
for a0_base, llm_base in base_url_mappings.items():
|
| 100 |
+
val = dotenv.get_dotenv_value(a0_base)
|
| 101 |
+
if val and not os.getenv(llm_base):
|
| 102 |
+
os.environ[llm_base] = val
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def get_api_key(service: str) -> str:
|
| 106 |
return (
|
| 107 |
dotenv.get_dotenv_value(f"API_KEY_{service.upper()}")
|
| 108 |
or dotenv.get_dotenv_value(f"{service.upper()}_API_KEY")
|
| 109 |
+
or dotenv.get_dotenv_value(f"{service.upper()}_API_TOKEN")
|
|
|
|
|
|
|
| 110 |
or "None"
|
| 111 |
)
|
| 112 |
|
| 113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
def get_rate_limiter(
|
| 115 |
provider: ModelProvider, name: str, requests: int, input: int, output: int
|
| 116 |
) -> RateLimiter:
|
|
|
|
| 117 |
key = f"{provider.name}\\{name}"
|
| 118 |
rate_limiters[key] = limiter = rate_limiters.get(key, RateLimiter(seconds=60))
|
|
|
|
| 119 |
limiter.limits["requests"] = requests or 0
|
| 120 |
limiter.limits["input"] = input or 0
|
| 121 |
limiter.limits["output"] = output or 0
|
|
|
|
| 132 |
return content
|
| 133 |
|
| 134 |
|
| 135 |
+
def _parse_chunk(chunk: Any) -> ChatChunk:
|
| 136 |
+
delta = chunk["choices"][0].get("delta", {})
|
| 137 |
+
message = chunk["choices"][0].get("model_extra", {}).get("message", {})
|
| 138 |
+
response_delta = (
|
| 139 |
+
delta.get("content", "")
|
| 140 |
+
if isinstance(delta, dict)
|
| 141 |
+
else getattr(delta, "content", "")
|
| 142 |
+
) or (
|
| 143 |
+
message.get("content", "")
|
| 144 |
+
if isinstance(message, dict)
|
| 145 |
+
else getattr(message, "content", "")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
)
|
| 147 |
+
reasoning_delta = (
|
| 148 |
+
delta.get("reasoning_content", "")
|
| 149 |
+
if isinstance(delta, dict)
|
| 150 |
+
else getattr(delta, "reasoning_content", "")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
)
|
| 152 |
+
return ChatChunk(reasoning_delta=reasoning_delta, response_delta=response_delta)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class LiteLLMChatWrapper(SimpleChatModel):
|
| 156 |
+
model_name: str
|
| 157 |
+
provider: str
|
| 158 |
+
kwargs: dict = {}
|
| 159 |
+
|
| 160 |
+
def __init__(self, model: str, provider: str, **kwargs: Any):
|
| 161 |
+
model_value = f"{provider}/{model}"
|
| 162 |
+
super().__init__(model_name=model_value, provider=provider, kwargs=kwargs) # type: ignore
|
| 163 |
+
|
| 164 |
+
@property
|
| 165 |
+
def _llm_type(self) -> str:
|
| 166 |
+
return "litellm-chat"
|
| 167 |
+
|
| 168 |
+
def _convert_messages(self, messages: List[BaseMessage]) -> List[dict]:
|
| 169 |
+
result = []
|
| 170 |
+
# Map LangChain message types to LiteLLM roles
|
| 171 |
+
role_mapping = {
|
| 172 |
+
"human": "user",
|
| 173 |
+
"ai": "assistant",
|
| 174 |
+
"system": "system",
|
| 175 |
+
"tool": "tool",
|
| 176 |
+
}
|
| 177 |
+
for m in messages:
|
| 178 |
+
role = role_mapping.get(m.type, m.type)
|
| 179 |
+
message_dict = {"role": role, "content": m.content}
|
| 180 |
+
|
| 181 |
+
# Handle tool calls for AI messages
|
| 182 |
+
tool_calls = getattr(m, "tool_calls", None)
|
| 183 |
+
if tool_calls:
|
| 184 |
+
# Convert LangChain tool calls to LiteLLM format
|
| 185 |
+
new_tool_calls = []
|
| 186 |
+
for tool_call in tool_calls:
|
| 187 |
+
# Ensure arguments is a JSON string
|
| 188 |
+
args = tool_call["args"]
|
| 189 |
+
if isinstance(args, dict):
|
| 190 |
+
import json
|
| 191 |
+
|
| 192 |
+
args_str = json.dumps(args)
|
| 193 |
+
else:
|
| 194 |
+
args_str = str(args)
|
| 195 |
+
|
| 196 |
+
new_tool_calls.append(
|
| 197 |
+
{
|
| 198 |
+
"id": tool_call.get("id", ""),
|
| 199 |
+
"type": "function",
|
| 200 |
+
"function": {
|
| 201 |
+
"name": tool_call["name"],
|
| 202 |
+
"arguments": args_str,
|
| 203 |
+
},
|
| 204 |
+
}
|
| 205 |
+
)
|
| 206 |
+
message_dict["tool_calls"] = new_tool_calls
|
| 207 |
+
|
| 208 |
+
# Handle tool call ID for ToolMessage
|
| 209 |
+
tool_call_id = getattr(m, "tool_call_id", None)
|
| 210 |
+
if tool_call_id:
|
| 211 |
+
message_dict["tool_call_id"] = tool_call_id
|
| 212 |
+
|
| 213 |
+
result.append(message_dict)
|
| 214 |
+
return result
|
| 215 |
+
|
| 216 |
+
def _call(
|
| 217 |
+
self,
|
| 218 |
+
messages: List[BaseMessage],
|
| 219 |
+
stop: Optional[List[str]] = None,
|
| 220 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
| 221 |
+
**kwargs: Any,
|
| 222 |
+
) -> str:
|
| 223 |
+
msgs = self._convert_messages(messages)
|
| 224 |
+
resp = completion(
|
| 225 |
+
model=self.model_name, messages=msgs, stop=stop, **{**self.kwargs, **kwargs}
|
| 226 |
)
|
| 227 |
+
parsed = _parse_chunk(resp)
|
| 228 |
+
return parsed["response_delta"]
|
| 229 |
+
|
| 230 |
+
def _stream(
|
| 231 |
+
self,
|
| 232 |
+
messages: List[BaseMessage],
|
| 233 |
+
stop: Optional[List[str]] = None,
|
| 234 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
| 235 |
+
**kwargs: Any,
|
| 236 |
+
) -> Iterator[ChatGenerationChunk]:
|
| 237 |
+
msgs = self._convert_messages(messages)
|
| 238 |
+
for chunk in completion(
|
| 239 |
+
model=self.model_name,
|
| 240 |
+
messages=msgs,
|
| 241 |
+
stream=True,
|
| 242 |
+
stop=stop,
|
| 243 |
+
**{**self.kwargs, **kwargs},
|
| 244 |
+
):
|
| 245 |
+
parsed = _parse_chunk(chunk)
|
| 246 |
+
# Only yield chunks with non-None content
|
| 247 |
+
if parsed["response_delta"]:
|
| 248 |
+
yield ChatGenerationChunk(
|
| 249 |
+
message=AIMessageChunk(content=parsed["response_delta"])
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
async def _astream(
|
| 253 |
+
self,
|
| 254 |
+
messages: List[BaseMessage],
|
| 255 |
+
stop: Optional[List[str]] = None,
|
| 256 |
+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
| 257 |
+
**kwargs: Any,
|
| 258 |
+
) -> AsyncIterator[ChatGenerationChunk]:
|
| 259 |
+
msgs = self._convert_messages(messages)
|
| 260 |
+
response = await acompletion(
|
| 261 |
+
model=self.model_name,
|
| 262 |
+
messages=msgs,
|
| 263 |
+
stream=True,
|
| 264 |
+
stop=stop,
|
| 265 |
+
**{**self.kwargs, **kwargs},
|
| 266 |
+
)
|
| 267 |
+
async for chunk in response: # type: ignore
|
| 268 |
+
parsed = _parse_chunk(chunk)
|
| 269 |
+
# Only yield chunks with non-None content
|
| 270 |
+
if parsed["response_delta"]:
|
| 271 |
+
yield ChatGenerationChunk(
|
| 272 |
+
message=AIMessageChunk(content=parsed["response_delta"])
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
async def unified_call(
|
| 276 |
+
self,
|
| 277 |
+
system_message="",
|
| 278 |
+
user_message="",
|
| 279 |
+
messages: List[BaseMessage] = [],
|
| 280 |
+
response_callback: Callable[[str, str], Awaitable[None]] | None = None,
|
| 281 |
+
reasoning_callback: Callable[[str, str], Awaitable[None]] | None = None,
|
| 282 |
+
tokens_callback: Callable[[str, int], Awaitable[None]] | None = None,
|
| 283 |
+
**kwargs: Any,
|
| 284 |
+
) -> Tuple[str, str]:
|
| 285 |
+
# construct messages
|
| 286 |
+
if system_message:
|
| 287 |
+
messages.insert(0, SystemMessage(content=system_message))
|
| 288 |
+
if user_message:
|
| 289 |
+
messages.append(HumanMessage(content=user_message))
|
| 290 |
+
|
| 291 |
+
# convert to litellm format
|
| 292 |
+
msgs_conv = self._convert_messages(messages)
|
| 293 |
+
|
| 294 |
+
# call model
|
| 295 |
+
_completion = await acompletion(
|
| 296 |
+
model=self.model_name,
|
| 297 |
+
messages=msgs_conv,
|
| 298 |
+
stream=True,
|
| 299 |
+
**{**self.kwargs, **kwargs},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
)
|
| 301 |
|
| 302 |
+
# results
|
| 303 |
+
reasoning = ""
|
| 304 |
+
response = ""
|
| 305 |
+
|
| 306 |
+
# iterate over chunks
|
| 307 |
+
async for chunk in _completion: # type: ignore
|
| 308 |
+
parsed = _parse_chunk(chunk)
|
| 309 |
+
# collect reasoning delta and call callbacks
|
| 310 |
+
if parsed["reasoning_delta"]:
|
| 311 |
+
reasoning += parsed["reasoning_delta"]
|
| 312 |
+
if reasoning_callback:
|
| 313 |
+
await reasoning_callback(parsed["reasoning_delta"], reasoning)
|
| 314 |
+
if tokens_callback:
|
| 315 |
+
await tokens_callback(
|
| 316 |
+
parsed["reasoning_delta"],
|
| 317 |
+
approximate_tokens(parsed["reasoning_delta"]),
|
| 318 |
+
)
|
| 319 |
+
# collect response delta and call callbacks
|
| 320 |
+
if parsed["response_delta"]:
|
| 321 |
+
response += parsed["response_delta"]
|
| 322 |
+
if response_callback:
|
| 323 |
+
await response_callback(parsed["response_delta"], response)
|
| 324 |
+
if tokens_callback:
|
| 325 |
+
await tokens_callback(
|
| 326 |
+
parsed["response_delta"],
|
| 327 |
+
approximate_tokens(parsed["response_delta"]),
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
# return complete results
|
| 331 |
+
return response, reasoning
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
class BrowserCompatibleChatWrapper(LiteLLMChatWrapper):
|
| 335 |
+
"""
|
| 336 |
+
A wrapper for browser agent that can filter/sanitize messages
|
| 337 |
+
before sending them to the LLM.
|
| 338 |
+
"""
|
| 339 |
+
|
| 340 |
+
def _call(
|
| 341 |
+
self,
|
| 342 |
+
messages: List[BaseMessage],
|
| 343 |
+
stop: Optional[List[str]] = None,
|
| 344 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
| 345 |
+
**kwargs: Any,
|
| 346 |
+
) -> str:
|
| 347 |
+
# In the future, message filtering logic can be added here.
|
| 348 |
+
result = super()._call(messages, stop, run_manager, **kwargs)
|
| 349 |
+
return result
|
| 350 |
+
|
| 351 |
+
async def _astream(
|
| 352 |
+
self,
|
| 353 |
+
messages: List[BaseMessage],
|
| 354 |
+
stop: Optional[List[str]] = None,
|
| 355 |
+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
| 356 |
+
**kwargs: Any,
|
| 357 |
+
) -> AsyncIterator[ChatGenerationChunk]:
|
| 358 |
+
# In the future, message filtering logic can be added here.
|
| 359 |
+
async for chunk in super()._astream(messages, stop, run_manager, **kwargs):
|
| 360 |
+
yield chunk
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
class LiteLLMEmbeddingWrapper(Embeddings):
|
| 364 |
+
model_name: str
|
| 365 |
+
kwargs: dict = {}
|
| 366 |
+
|
| 367 |
+
def __init__(self, model: str, provider: str, **kwargs: Any):
|
| 368 |
+
self.model_name = f"{provider}/{model}" if provider != "openai" else model
|
| 369 |
+
self.kwargs = kwargs
|
| 370 |
+
|
| 371 |
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
| 372 |
+
resp = embedding(model=self.model_name, input=texts, **self.kwargs)
|
| 373 |
+
return [
|
| 374 |
+
item.get("embedding") if isinstance(item, dict) else item.embedding
|
| 375 |
+
for item in resp.data
|
| 376 |
+
]
|
| 377 |
+
|
| 378 |
+
def embed_query(self, text: str) -> List[float]:
|
| 379 |
+
resp = embedding(model=self.model_name, input=[text], **self.kwargs)
|
| 380 |
+
item = resp.data[0]
|
| 381 |
+
return item.get("embedding") if isinstance(item, dict) else item.embedding
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
class LocalSentenceTransformerWrapper(Embeddings):
|
| 385 |
+
"""Local wrapper for sentence-transformers models to avoid HuggingFace API calls"""
|
| 386 |
+
|
| 387 |
+
def __init__(self, model_name: str, **kwargs: Any):
|
| 388 |
+
try:
|
| 389 |
+
from sentence_transformers import SentenceTransformer
|
| 390 |
+
except ImportError:
|
| 391 |
+
raise ImportError(
|
| 392 |
+
"sentence-transformers library is required for local embeddings. Install with: pip install sentence-transformers"
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
# Remove the "sentence-transformers/" prefix if present
|
| 396 |
+
if model_name.startswith("sentence-transformers/"):
|
| 397 |
+
model_name = model_name[len("sentence-transformers/") :]
|
| 398 |
+
|
| 399 |
+
self.model = SentenceTransformer(model_name, **kwargs)
|
| 400 |
+
self.model_name = model_name
|
| 401 |
+
|
| 402 |
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
| 403 |
+
embeddings = self.model.encode(texts, convert_to_tensor=False)
|
| 404 |
+
return embeddings.tolist() if hasattr(embeddings, "tolist") else embeddings
|
| 405 |
+
|
| 406 |
+
def embed_query(self, text: str) -> List[float]:
|
| 407 |
+
embedding = self.model.encode([text], convert_to_tensor=False)
|
| 408 |
+
result = (
|
| 409 |
+
embedding[0].tolist() if hasattr(embedding[0], "tolist") else embedding[0]
|
| 410 |
)
|
| 411 |
+
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
|
| 413 |
|
| 414 |
+
def _get_litellm_chat(
|
| 415 |
+
cls: type = LiteLLMChatWrapper,
|
| 416 |
+
model_name: str = "",
|
| 417 |
+
provider_name: str = "",
|
| 418 |
+
**kwargs: Any,
|
| 419 |
):
|
| 420 |
+
provider_name = provider_name.lower()
|
| 421 |
+
|
| 422 |
+
configure_litellm_environment()
|
| 423 |
+
# Use original provider name for API key lookup, fallback to mapped provider name
|
| 424 |
+
api_key = kwargs.pop("api_key", None) or get_api_key(provider_name)
|
| 425 |
+
|
| 426 |
+
# litellm will pick up base_url from env. We just need to control the api_key.
|
| 427 |
+
base_url = dotenv.get_dotenv_value(f"{provider_name.upper()}_BASE_URL")
|
| 428 |
+
|
| 429 |
+
# If a base_url is set, ensure api_key is not passed to litellm
|
| 430 |
+
if base_url:
|
| 431 |
+
if "api_key" in kwargs:
|
| 432 |
+
del kwargs["api_key"]
|
| 433 |
+
# Only pass API key if no base_url is set and key is not a placeholder
|
| 434 |
+
elif api_key and api_key not in ("None", "NA"):
|
| 435 |
+
kwargs["api_key"] = api_key
|
| 436 |
+
|
| 437 |
+
# for openrouter add app reference
|
| 438 |
+
if provider_name == "openrouter":
|
| 439 |
+
kwargs["extra_headers"] = {
|
| 440 |
+
"HTTP-Referer": "https://agent-zero.ai",
|
| 441 |
+
"X-Title": "Agent Zero",
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
return cls(model=model_name, provider=provider_name, **kwargs)
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def get_litellm_embedding(model_name: str, provider: str, **kwargs: Any):
|
| 448 |
+
# Check if this is a local sentence-transformers model
|
| 449 |
+
if provider == "huggingface" and model_name.startswith("sentence-transformers/"):
|
| 450 |
+
# Use local sentence-transformers instead of LiteLLM for local models
|
| 451 |
+
return LocalSentenceTransformerWrapper(model_name=model_name, **kwargs)
|
| 452 |
+
|
| 453 |
+
configure_litellm_environment()
|
| 454 |
+
# Use original provider name for API key lookup, fallback to mapped provider name
|
| 455 |
+
api_key = kwargs.pop("api_key", None) or get_api_key(provider)
|
| 456 |
+
|
| 457 |
+
# litellm will pick up base_url from env. We just need to control the api_key.
|
| 458 |
+
base_url = dotenv.get_dotenv_value(f"{provider.upper()}_BASE_URL")
|
| 459 |
+
|
| 460 |
+
# If a base_url is set, ensure api_key is not passed to litellm
|
| 461 |
+
if base_url:
|
| 462 |
+
if "api_key" in kwargs:
|
| 463 |
+
del kwargs["api_key"]
|
| 464 |
+
# Only pass API key if no base_url is set and key is not a placeholder
|
| 465 |
+
elif api_key and api_key not in ("None", "NA"):
|
| 466 |
+
kwargs["api_key"] = api_key
|
| 467 |
+
|
| 468 |
+
return LiteLLMEmbeddingWrapper(model=model_name, provider=provider, **kwargs)
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def get_model(type: ModelType, provider: ModelProvider, name: str, **kwargs: Any):
|
| 472 |
+
provider_name = provider.name.lower()
|
| 473 |
+
kwargs = _normalize_chat_kwargs(kwargs)
|
| 474 |
+
if type == ModelType.CHAT:
|
| 475 |
+
return _get_litellm_chat(LiteLLMChatWrapper, name, provider_name, **kwargs)
|
| 476 |
+
elif type == ModelType.EMBEDDING:
|
| 477 |
+
return get_litellm_embedding(name, provider_name, **kwargs)
|
| 478 |
+
else:
|
| 479 |
+
raise ValueError(f"Unsupported model type: {type}")
|
| 480 |
|
| 481 |
|
| 482 |
+
def get_chat_model(
|
| 483 |
+
provider: ModelProvider, name: str, **kwargs: Any
|
| 484 |
+
) -> LiteLLMChatWrapper:
|
| 485 |
+
provider_name = provider.name.lower()
|
| 486 |
+
kwargs = _normalize_chat_kwargs(kwargs)
|
| 487 |
+
model = _get_litellm_chat(LiteLLMChatWrapper, name, provider_name, **kwargs)
|
| 488 |
+
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 489 |
|
| 490 |
|
| 491 |
+
def get_browser_model(
|
| 492 |
+
provider: ModelProvider, name: str, **kwargs: Any
|
| 493 |
+
) -> BrowserCompatibleChatWrapper:
|
| 494 |
+
provider_name = provider.name.lower()
|
| 495 |
+
kwargs = _normalize_chat_kwargs(kwargs)
|
| 496 |
+
model = _get_litellm_chat(
|
| 497 |
+
BrowserCompatibleChatWrapper, name, provider_name, **kwargs
|
| 498 |
+
)
|
| 499 |
+
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 500 |
|
| 501 |
|
| 502 |
+
def get_embedding_model(
|
| 503 |
+
provider: ModelProvider, name: str, **kwargs: Any
|
| 504 |
+
) -> LiteLLMEmbeddingWrapper | LocalSentenceTransformerWrapper:
|
| 505 |
+
provider_name = provider.name.lower()
|
| 506 |
+
kwargs = _normalize_embedding_kwargs(kwargs)
|
| 507 |
+
model = get_litellm_embedding(name, provider_name, **kwargs)
|
| 508 |
+
return model
|
|
|
|
| 509 |
|
| 510 |
|
| 511 |
+
def _normalize_chat_kwargs(kwargs: Any) -> Any:
|
| 512 |
+
return kwargs
|
| 513 |
|
| 514 |
|
| 515 |
+
def _normalize_embedding_kwargs(kwargs: Any) -> Any:
|
| 516 |
+
return kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
preload.py
CHANGED
|
@@ -22,7 +22,11 @@ async def preload():
|
|
| 22 |
async def preload_embedding():
|
| 23 |
if set["embed_model_provider"] == models.ModelProvider.HUGGINGFACE.name:
|
| 24 |
try:
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
emb_txt = await emb_mod.aembed_query("test")
|
| 27 |
return emb_txt
|
| 28 |
except Exception as e:
|
|
|
|
| 22 |
async def preload_embedding():
|
| 23 |
if set["embed_model_provider"] == models.ModelProvider.HUGGINGFACE.name:
|
| 24 |
try:
|
| 25 |
+
# Use the new LiteLLM-based model system
|
| 26 |
+
emb_mod = models.get_embedding_model(
|
| 27 |
+
models.ModelProvider.HUGGINGFACE,
|
| 28 |
+
set["embed_model_name"]
|
| 29 |
+
)
|
| 30 |
emb_txt = await emb_mod.aembed_query("test")
|
| 31 |
return emb_txt
|
| 32 |
except Exception as e:
|
prompts/agent0/agent.system.tool.response.md
CHANGED
|
@@ -3,6 +3,7 @@ final answer to user
|
|
| 3 |
ends task processing use only when done or no task active
|
| 4 |
put result in text arg
|
| 5 |
always use markdown formatting headers bold text lists
|
|
|
|
| 6 |
use emojis as icons improve readability
|
| 7 |
prefer using tables
|
| 8 |
focus nice structured output key selling point
|
|
|
|
| 3 |
ends task processing use only when done or no task active
|
| 4 |
put result in text arg
|
| 5 |
always use markdown formatting headers bold text lists
|
| 6 |
+
full message is automatically markdown do not wrap ~~~markdown
|
| 7 |
use emojis as icons improve readability
|
| 8 |
prefer using tables
|
| 9 |
focus nice structured output key selling point
|
prompts/default/agent.system.tool.call_sub.md
CHANGED
|
@@ -1,63 +1,26 @@
|
|
| 1 |
### call_subordinate
|
| 2 |
|
| 3 |
you can use subordinates for subtasks
|
| 4 |
-
subordinates can be
|
| 5 |
-
message field: always describe task details goal overview
|
| 6 |
delegate specific subtasks not entire task
|
| 7 |
reset arg usage:
|
| 8 |
"true": spawn new subordinate
|
| 9 |
-
"false": continue
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
#### if you are superior
|
| 13 |
-
- identify new tasks which your main task's completion depends upon
|
| 14 |
-
- break down your main task into subtasks if possible. If the task can not be split execute it yourself
|
| 15 |
-
- only let saubtasks and new depended upon tasks of your main task be handled by subordinates
|
| 16 |
-
- never forward your entire task to a subordinate to avoid endless delegation loops
|
| 17 |
-
|
| 18 |
-
#### if you are subordinate:
|
| 19 |
-
- superior is {{agent_name}} minus 1
|
| 20 |
-
- execute the task you were assigned
|
| 21 |
-
- delegate further if asked
|
| 22 |
-
- break down tasks and delegate if necessary
|
| 23 |
-
- do not delegate tasks you can accomplish yourself without refining them
|
| 24 |
-
- only subtasks of your current main task are allowed to be delegated. Never delegate your entire task ro prevent endless loops.
|
| 25 |
-
|
| 26 |
-
#### Arguments:
|
| 27 |
-
- message (string): always describe task details goal overview important details for new subordinate
|
| 28 |
-
- reset (boolean): true: spawn new subordinate, false: continue current conversation
|
| 29 |
-
- prompt_profile (string): defines specialization, only available prompt profiles below, can omit when reset false
|
| 30 |
-
|
| 31 |
-
##### Prompt Profiles available
|
| 32 |
-
{{prompt_profiles}}
|
| 33 |
-
|
| 34 |
-
#### example usage
|
| 35 |
-
~~~json
|
| 36 |
-
{
|
| 37 |
-
"thoughts": [
|
| 38 |
-
"This task is challenging and requires a data analyst",
|
| 39 |
-
"The research_agent profile supports data analysis",
|
| 40 |
-
],
|
| 41 |
-
"headline": "Delegating coding fix to subordinate agent",
|
| 42 |
-
"tool_name": "call_subordinate",
|
| 43 |
-
"tool_args": {
|
| 44 |
-
"message": "...",
|
| 45 |
-
"reset": "true",
|
| 46 |
-
"prompt_profile": "research_agent",
|
| 47 |
-
}
|
| 48 |
-
}
|
| 49 |
-
~~~
|
| 50 |
|
|
|
|
| 51 |
~~~json
|
| 52 |
{
|
| 53 |
"thoughts": [
|
| 54 |
-
"The
|
| 55 |
-
"I will ask a subordinate to
|
| 56 |
],
|
| 57 |
"tool_name": "call_subordinate",
|
| 58 |
"tool_args": {
|
| 59 |
"message": "...",
|
| 60 |
-
"reset": "
|
| 61 |
}
|
| 62 |
}
|
| 63 |
-
~~~
|
|
|
|
| 1 |
### call_subordinate
|
| 2 |
|
| 3 |
you can use subordinates for subtasks
|
| 4 |
+
subordinates can be scientist coder engineer etc
|
| 5 |
+
message field: always describe role, task details goal overview for new subordinate
|
| 6 |
delegate specific subtasks not entire task
|
| 7 |
reset arg usage:
|
| 8 |
"true": spawn new subordinate
|
| 9 |
+
"false": continue existing subordinate
|
| 10 |
+
if superior, orchestrate
|
| 11 |
+
respond to existing subordinates using call_subordinate tool with reset false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
example usage
|
| 14 |
~~~json
|
| 15 |
{
|
| 16 |
"thoughts": [
|
| 17 |
+
"The result seems to be ok but...",
|
| 18 |
+
"I will ask a coder subordinate to fix...",
|
| 19 |
],
|
| 20 |
"tool_name": "call_subordinate",
|
| 21 |
"tool_args": {
|
| 22 |
"message": "...",
|
| 23 |
+
"reset": "true"
|
| 24 |
}
|
| 25 |
}
|
| 26 |
+
~~~
|
python/extensions/reasoning_stream/.gitkeep
ADDED
|
File without changes
|
python/extensions/reasoning_stream/_10_log_from_stream.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from python.helpers import persist_chat, tokens
|
| 2 |
+
from python.helpers.extension import Extension
|
| 3 |
+
from agent import LoopData
|
| 4 |
+
import asyncio
|
| 5 |
+
from python.helpers.log import LogItem
|
| 6 |
+
from python.helpers import log
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class LogFromStream(Extension):
|
| 11 |
+
|
| 12 |
+
async def execute(self, loop_data: LoopData = LoopData(), text: str = "", **kwargs):
|
| 13 |
+
|
| 14 |
+
# thought length indicator
|
| 15 |
+
length = math.ceil(len(text) / 10) * 10
|
| 16 |
+
heading = f"{self.agent.agent_name}: Reasoning ({length})..."
|
| 17 |
+
|
| 18 |
+
# create log message and store it in loop data temporary params
|
| 19 |
+
if "log_item_generating" not in loop_data.params_temporary:
|
| 20 |
+
loop_data.params_temporary["log_item_generating"] = (
|
| 21 |
+
self.agent.context.log.log(
|
| 22 |
+
type="agent",
|
| 23 |
+
heading=heading,
|
| 24 |
+
)
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# update log message
|
| 28 |
+
log_item = loop_data.params_temporary["log_item_generating"]
|
| 29 |
+
log_item.update(heading=heading, reasoning=text)
|
python/extensions/response_stream/_10_log_from_stream.py
CHANGED
|
@@ -4,6 +4,7 @@ from agent import LoopData
|
|
| 4 |
import asyncio
|
| 5 |
from python.helpers.log import LogItem
|
| 6 |
from python.helpers import log
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
class LogFromStream(Extension):
|
|
@@ -13,20 +14,38 @@ class LogFromStream(Extension):
|
|
| 13 |
loop_data: LoopData = LoopData(),
|
| 14 |
text: str = "",
|
| 15 |
parsed: dict = {},
|
| 16 |
-
**kwargs
|
| 17 |
):
|
| 18 |
|
| 19 |
-
heading = f"{self.agent.agent_name}:
|
| 20 |
if "headline" in parsed:
|
| 21 |
heading = f"{self.agent.agent_name}: {parsed['headline']}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
# create log message and store it in loop data temporary params
|
| 24 |
if "log_item_generating" not in loop_data.params_temporary:
|
| 25 |
-
loop_data.params_temporary["log_item_generating"] =
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
| 28 |
)
|
| 29 |
|
| 30 |
# update log message
|
| 31 |
log_item = loop_data.params_temporary["log_item_generating"]
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import asyncio
|
| 5 |
from python.helpers.log import LogItem
|
| 6 |
from python.helpers import log
|
| 7 |
+
import math
|
| 8 |
|
| 9 |
|
| 10 |
class LogFromStream(Extension):
|
|
|
|
| 14 |
loop_data: LoopData = LoopData(),
|
| 15 |
text: str = "",
|
| 16 |
parsed: dict = {},
|
| 17 |
+
**kwargs,
|
| 18 |
):
|
| 19 |
|
| 20 |
+
heading = f"{self.agent.agent_name}: Generating..."
|
| 21 |
if "headline" in parsed:
|
| 22 |
heading = f"{self.agent.agent_name}: {parsed['headline']}"
|
| 23 |
+
elif "thoughts" in parsed:
|
| 24 |
+
# thought length indicator
|
| 25 |
+
thoughts = "\n".join(parsed["thoughts"])
|
| 26 |
+
length = math.ceil(len(thoughts) / 10) * 10
|
| 27 |
+
heading = f"{self.agent.agent_name}: Thinking ({length})..."
|
| 28 |
+
|
| 29 |
+
if "tool_name" in parsed:
|
| 30 |
+
heading += f" ({parsed['tool_name']})"
|
| 31 |
|
| 32 |
# create log message and store it in loop data temporary params
|
| 33 |
if "log_item_generating" not in loop_data.params_temporary:
|
| 34 |
+
loop_data.params_temporary["log_item_generating"] = (
|
| 35 |
+
self.agent.context.log.log(
|
| 36 |
+
type="agent",
|
| 37 |
+
heading=heading,
|
| 38 |
+
)
|
| 39 |
)
|
| 40 |
|
| 41 |
# update log message
|
| 42 |
log_item = loop_data.params_temporary["log_item_generating"]
|
| 43 |
+
|
| 44 |
+
# keep reasoning from previous logs in kvps
|
| 45 |
+
kvps = {}
|
| 46 |
+
if log_item.kvps is not None and "reasoning" in log_item.kvps:
|
| 47 |
+
kvps["reasoning"] = log_item.kvps["reasoning"]
|
| 48 |
+
kvps.update(parsed)
|
| 49 |
+
|
| 50 |
+
# update the log item
|
| 51 |
+
log_item.update(heading=heading, content=text, kvps=kvps)
|
python/helpers/document_query.py
CHANGED
|
@@ -42,6 +42,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
| 42 |
|
| 43 |
DEFAULT_SEARCH_THRESHOLD = 0.5
|
| 44 |
|
|
|
|
| 45 |
class DocumentQueryStore:
|
| 46 |
"""
|
| 47 |
FAISS Store for document query results.
|
|
@@ -85,7 +86,7 @@ class DocumentQueryStore:
|
|
| 85 |
Normalized URI
|
| 86 |
"""
|
| 87 |
# Convert to lowercase
|
| 88 |
-
normalized = uri.strip()
|
| 89 |
|
| 90 |
# Parse the URL to get scheme
|
| 91 |
parsed = urlparse(normalized)
|
|
@@ -368,7 +369,9 @@ class DocumentQueryStore:
|
|
| 368 |
|
| 369 |
class DocumentQueryHelper:
|
| 370 |
|
| 371 |
-
def __init__(
|
|
|
|
|
|
|
| 372 |
self.agent = agent
|
| 373 |
self.store = DocumentQueryStore.get(agent)
|
| 374 |
self.progress_callback = progress_callback or (lambda x: None)
|
|
@@ -414,30 +417,34 @@ class DocumentQueryHelper:
|
|
| 414 |
content = f"!!! No content found for document: {document_uri} matching queries: {json.dumps(questions)}"
|
| 415 |
return False, content
|
| 416 |
|
| 417 |
-
self.progress_callback(
|
|
|
|
|
|
|
| 418 |
|
| 419 |
questions_str = "\n".join([f" * {question}" for question in questions])
|
| 420 |
-
content = "\n\n----\n\n".join(
|
|
|
|
|
|
|
| 421 |
|
| 422 |
qa_system_message = self.agent.parse_prompt(
|
| 423 |
"fw.document_query.system_prompt.md"
|
| 424 |
)
|
| 425 |
qa_user_message = f"# Document:\n{content}\n\n# Queries:\n{questions_str}"
|
| 426 |
|
| 427 |
-
ai_response = await self.agent.call_chat_model(
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
]
|
| 433 |
-
)
|
| 434 |
)
|
| 435 |
|
| 436 |
self.progress_callback(f"Q&A process completed")
|
| 437 |
|
| 438 |
return True, str(ai_response)
|
| 439 |
|
| 440 |
-
async def document_get_content(
|
|
|
|
|
|
|
| 441 |
self.progress_callback(f"Fetching document content")
|
| 442 |
url = urlparse(document_uri)
|
| 443 |
scheme = url.scheme or "file"
|
|
@@ -518,7 +525,9 @@ class DocumentQueryHelper:
|
|
| 518 |
)
|
| 519 |
if add_to_db:
|
| 520 |
self.progress_callback(f"Indexing document")
|
| 521 |
-
success, ids = await self.store.add_document(
|
|
|
|
|
|
|
| 522 |
if not success:
|
| 523 |
self.progress_callback(f"Failed to index document")
|
| 524 |
raise ValueError(
|
|
|
|
| 42 |
|
| 43 |
DEFAULT_SEARCH_THRESHOLD = 0.5
|
| 44 |
|
| 45 |
+
|
| 46 |
class DocumentQueryStore:
|
| 47 |
"""
|
| 48 |
FAISS Store for document query results.
|
|
|
|
| 86 |
Normalized URI
|
| 87 |
"""
|
| 88 |
# Convert to lowercase
|
| 89 |
+
normalized = uri.strip() # uri.lower()
|
| 90 |
|
| 91 |
# Parse the URL to get scheme
|
| 92 |
parsed = urlparse(normalized)
|
|
|
|
| 369 |
|
| 370 |
class DocumentQueryHelper:
|
| 371 |
|
| 372 |
+
def __init__(
|
| 373 |
+
self, agent: Agent, progress_callback: Callable[[str], None] | None = None
|
| 374 |
+
):
|
| 375 |
self.agent = agent
|
| 376 |
self.store = DocumentQueryStore.get(agent)
|
| 377 |
self.progress_callback = progress_callback or (lambda x: None)
|
|
|
|
| 417 |
content = f"!!! No content found for document: {document_uri} matching queries: {json.dumps(questions)}"
|
| 418 |
return False, content
|
| 419 |
|
| 420 |
+
self.progress_callback(
|
| 421 |
+
f"Processing {len(questions)} questions in context of {len(selected_chunks)} chunks"
|
| 422 |
+
)
|
| 423 |
|
| 424 |
questions_str = "\n".join([f" * {question}" for question in questions])
|
| 425 |
+
content = "\n\n----\n\n".join(
|
| 426 |
+
[chunk.page_content for chunk in selected_chunks.values()]
|
| 427 |
+
)
|
| 428 |
|
| 429 |
qa_system_message = self.agent.parse_prompt(
|
| 430 |
"fw.document_query.system_prompt.md"
|
| 431 |
)
|
| 432 |
qa_user_message = f"# Document:\n{content}\n\n# Queries:\n{questions_str}"
|
| 433 |
|
| 434 |
+
ai_response, _reasoning = await self.agent.call_chat_model(
|
| 435 |
+
messages=[
|
| 436 |
+
SystemMessage(content=qa_system_message),
|
| 437 |
+
HumanMessage(content=qa_user_message),
|
| 438 |
+
]
|
|
|
|
|
|
|
| 439 |
)
|
| 440 |
|
| 441 |
self.progress_callback(f"Q&A process completed")
|
| 442 |
|
| 443 |
return True, str(ai_response)
|
| 444 |
|
| 445 |
+
async def document_get_content(
|
| 446 |
+
self, document_uri: str, add_to_db: bool = False
|
| 447 |
+
) -> str:
|
| 448 |
self.progress_callback(f"Fetching document content")
|
| 449 |
url = urlparse(document_uri)
|
| 450 |
scheme = url.scheme or "file"
|
|
|
|
| 525 |
)
|
| 526 |
if add_to_db:
|
| 527 |
self.progress_callback(f"Indexing document")
|
| 528 |
+
success, ids = await self.store.add_document(
|
| 529 |
+
document_content, document_uri_norm
|
| 530 |
+
)
|
| 531 |
if not success:
|
| 532 |
self.progress_callback(f"Failed to index document")
|
| 533 |
raise ValueError(
|
python/helpers/history.py
CHANGED
|
@@ -534,10 +534,17 @@ def _merge_outputs(a: MessageContent, b: MessageContent) -> MessageContent:
|
|
| 534 |
if isinstance(a, str) and isinstance(b, str):
|
| 535 |
return a + "\n" + b
|
| 536 |
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 541 |
|
| 542 |
return cast(MessageContent, a + b)
|
| 543 |
|
|
|
|
| 534 |
if isinstance(a, str) and isinstance(b, str):
|
| 535 |
return a + "\n" + b
|
| 536 |
|
| 537 |
+
def make_list(obj: MessageContent) -> list[MessageContent]:
|
| 538 |
+
if isinstance(obj, list):
|
| 539 |
+
return obj # type: ignore
|
| 540 |
+
if isinstance(obj, dict):
|
| 541 |
+
return [obj]
|
| 542 |
+
if isinstance(obj, str):
|
| 543 |
+
return [{"type": "text", "text": obj}]
|
| 544 |
+
return [obj]
|
| 545 |
+
|
| 546 |
+
a = make_list(a)
|
| 547 |
+
b = make_list(b)
|
| 548 |
|
| 549 |
return cast(MessageContent, a + b)
|
| 550 |
|
python/helpers/memory.py
CHANGED
|
@@ -117,8 +117,7 @@ class Memory:
|
|
| 117 |
os.makedirs(em_dir, exist_ok=True)
|
| 118 |
store = LocalFileStore(em_dir)
|
| 119 |
|
| 120 |
-
embeddings_model = models.
|
| 121 |
-
models.ModelType.EMBEDDING,
|
| 122 |
model_config.provider,
|
| 123 |
model_config.name,
|
| 124 |
**model_config.kwargs,
|
|
|
|
| 117 |
os.makedirs(em_dir, exist_ok=True)
|
| 118 |
store = LocalFileStore(em_dir)
|
| 119 |
|
| 120 |
+
embeddings_model = models.get_embedding_model(
|
|
|
|
| 121 |
model_config.provider,
|
| 122 |
model_config.name,
|
| 123 |
**model_config.kwargs,
|
python/tools/browser_agent.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
import asyncio
|
| 2 |
-
import json
|
| 3 |
import time
|
| 4 |
from typing import Optional
|
| 5 |
from agent import Agent, InterventionException
|
|
@@ -57,6 +56,8 @@ class State:
|
|
| 57 |
screen={"width": 1024, "height": 2048},
|
| 58 |
viewport={"width": 1024, "height": 2048},
|
| 59 |
args=["--headless=new"],
|
|
|
|
|
|
|
| 60 |
)
|
| 61 |
)
|
| 62 |
|
|
@@ -118,25 +119,28 @@ class State:
|
|
| 118 |
)
|
| 119 |
return result
|
| 120 |
|
| 121 |
-
|
| 122 |
-
|
| 123 |
provider=self.agent.config.browser_model.provider,
|
| 124 |
name=self.agent.config.browser_model.name,
|
| 125 |
**self.agent.config.browser_model.kwargs,
|
| 126 |
)
|
| 127 |
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
self.iter_no = get_iter_no(self.agent)
|
| 142 |
|
|
|
|
| 1 |
import asyncio
|
|
|
|
| 2 |
import time
|
| 3 |
from typing import Optional
|
| 4 |
from agent import Agent, InterventionException
|
|
|
|
| 56 |
screen={"width": 1024, "height": 2048},
|
| 57 |
viewport={"width": 1024, "height": 2048},
|
| 58 |
args=["--headless=new"],
|
| 59 |
+
# Use a unique user data directory to avoid conflicts
|
| 60 |
+
user_data_dir=str(Path.home() / ".config" / "browseruse" / "profiles" / f"agent_{self.agent.context.id}"),
|
| 61 |
)
|
| 62 |
)
|
| 63 |
|
|
|
|
| 119 |
)
|
| 120 |
return result
|
| 121 |
|
| 122 |
+
|
| 123 |
+
model = models.get_browser_model(
|
| 124 |
provider=self.agent.config.browser_model.provider,
|
| 125 |
name=self.agent.config.browser_model.name,
|
| 126 |
**self.agent.config.browser_model.kwargs,
|
| 127 |
)
|
| 128 |
|
| 129 |
+
try:
|
| 130 |
+
self.use_agent = browser_use.Agent(
|
| 131 |
+
task=task,
|
| 132 |
+
browser_session=self.browser_session,
|
| 133 |
+
llm=model,
|
| 134 |
+
use_vision=self.agent.config.browser_model.vision,
|
| 135 |
+
extend_system_message=self.agent.read_prompt(
|
| 136 |
+
"prompts/browser_agent.system.md"
|
| 137 |
+
),
|
| 138 |
+
controller=controller,
|
| 139 |
+
enable_memory=False, # Disable memory to avoid state conflicts
|
| 140 |
+
# available_file_paths=[],
|
| 141 |
+
)
|
| 142 |
+
except Exception as e:
|
| 143 |
+
raise Exception(f"Browser agent initialization failed. This might be due to model compatibility issues. Error: {e}") from e
|
| 144 |
|
| 145 |
self.iter_no = get_iter_no(self.agent)
|
| 146 |
|
requirements.txt
CHANGED
|
@@ -10,15 +10,7 @@ flask-basicauth==0.2.0
|
|
| 10 |
flaredantic==0.1.4
|
| 11 |
GitPython==3.1.43
|
| 12 |
inputimeout==1.0.4
|
| 13 |
-
langchain-
|
| 14 |
-
langchain-community==0.3.19
|
| 15 |
-
langchain-google-genai==2.1.2
|
| 16 |
-
langchain-groq==0.2.2
|
| 17 |
-
langchain-huggingface==0.1.2
|
| 18 |
-
langchain-mistralai==0.2.4
|
| 19 |
-
langchain-ollama==0.3.0
|
| 20 |
-
langchain-openai==0.3.11
|
| 21 |
-
langchain-unstructured[all-docs]==0.1.6
|
| 22 |
openai-whisper==20240930
|
| 23 |
lxml_html_clean==0.3.1
|
| 24 |
markdown==3.7
|
|
@@ -35,6 +27,8 @@ unstructured[all-docs]==0.16.23
|
|
| 35 |
unstructured-client==0.31.0
|
| 36 |
webcolors==24.6.0
|
| 37 |
nest-asyncio==1.6.0
|
|
|
|
|
|
|
| 38 |
markdownify==1.1.0
|
| 39 |
pymupdf==1.25.3
|
| 40 |
pytesseract==0.3.13
|
|
|
|
| 10 |
flaredantic==0.1.4
|
| 11 |
GitPython==3.1.43
|
| 12 |
inputimeout==1.0.4
|
| 13 |
+
langchain-core==0.3.49
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
openai-whisper==20240930
|
| 15 |
lxml_html_clean==0.3.1
|
| 16 |
markdown==3.7
|
|
|
|
| 27 |
unstructured-client==0.31.0
|
| 28 |
webcolors==24.6.0
|
| 29 |
nest-asyncio==1.6.0
|
| 30 |
+
crontab==1.0.1
|
| 31 |
+
litellm==1.72.4
|
| 32 |
markdownify==1.1.0
|
| 33 |
pymupdf==1.25.3
|
| 34 |
pytesseract==0.3.13
|
test.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
from os import sep
|
| 3 |
+
from langchain_core.messages import HumanMessage, SystemMessage
|
| 4 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 5 |
+
import models
|
| 6 |
+
from python.helpers import dotenv
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
async def test():
|
| 10 |
+
|
| 11 |
+
dotenv.load_dotenv()
|
| 12 |
+
|
| 13 |
+
# model_name = "moonshotai/kimi-dev-72b:free"
|
| 14 |
+
# model_name = "qwen/qwq-32b"
|
| 15 |
+
# model_name = "qwen/qwen3-32b"
|
| 16 |
+
# model_name = "anthropic/claude-3.7-sonnet:thinking"
|
| 17 |
+
model_name = "openai/gpt-4.1-nano"
|
| 18 |
+
system = ""
|
| 19 |
+
message = "hello"
|
| 20 |
+
|
| 21 |
+
model = models.get_chat_model(models.ModelProvider.OPENROUTER, model_name)
|
| 22 |
+
|
| 23 |
+
async def response_callback(chunk: str, full: str):
|
| 24 |
+
if chunk == full:
|
| 25 |
+
print("\n")
|
| 26 |
+
print("Response:")
|
| 27 |
+
print(chunk, end="", flush=True)
|
| 28 |
+
|
| 29 |
+
async def reasoning_callback(chunk: str, full: str):
|
| 30 |
+
if chunk == full:
|
| 31 |
+
print("\n")
|
| 32 |
+
print("Reasoning:")
|
| 33 |
+
print(chunk, end="", flush=True)
|
| 34 |
+
|
| 35 |
+
response, reasoning = await model.unified_call(
|
| 36 |
+
system_message=system,
|
| 37 |
+
user_message=message,
|
| 38 |
+
response_callback=response_callback,
|
| 39 |
+
reasoning_callback=reasoning_callback,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
print("\n")
|
| 43 |
+
print("Final:")
|
| 44 |
+
print("Reasoning:", reasoning)
|
| 45 |
+
print("Response:", response)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
async def test2():
|
| 49 |
+
|
| 50 |
+
dotenv.load_dotenv()
|
| 51 |
+
|
| 52 |
+
import initialize
|
| 53 |
+
config = initialize.initialize_agent()
|
| 54 |
+
|
| 55 |
+
model = models.get_browser_model(
|
| 56 |
+
provider=config.browser_model.provider,
|
| 57 |
+
name=config.browser_model.name,
|
| 58 |
+
**config.browser_model.kwargs,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
response, reasoning = await model.unified_call(
|
| 62 |
+
system_message="",
|
| 63 |
+
user_message="hi",
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
print("\n")
|
| 67 |
+
print("Final:")
|
| 68 |
+
print("Reasoning:", reasoning)
|
| 69 |
+
print("Response:", response)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
if __name__ == "__main__":
|
| 73 |
+
# asyncio.run(test())
|
| 74 |
+
asyncio.run(test2())
|
webui/index.css
CHANGED
|
@@ -1572,6 +1572,11 @@ input:checked + .slider:before {
|
|
| 1572 |
display: auto;
|
| 1573 |
}
|
| 1574 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1575 |
.msg-content {
|
| 1576 |
margin-bottom: 0;
|
| 1577 |
}
|
|
|
|
| 1572 |
display: auto;
|
| 1573 |
}
|
| 1574 |
|
| 1575 |
+
.msg-thoughts .kvps-val {
|
| 1576 |
+
max-height: 20em;
|
| 1577 |
+
overflow: auto;
|
| 1578 |
+
}
|
| 1579 |
+
|
| 1580 |
.msg-content {
|
| 1581 |
margin-bottom: 0;
|
| 1582 |
}
|
webui/js/messages.js
CHANGED
|
@@ -537,7 +537,7 @@ function drawKvps(container, kvps, latex) {
|
|
| 537 |
for (let [key, value] of Object.entries(kvps)) {
|
| 538 |
const row = table.insertRow();
|
| 539 |
row.classList.add("kvps-row");
|
| 540 |
-
if (key === "thoughts" || key === "
|
| 541 |
row.classList.add("msg-thoughts");
|
| 542 |
|
| 543 |
const th = row.insertCell();
|
|
@@ -545,6 +545,9 @@ function drawKvps(container, kvps, latex) {
|
|
| 545 |
th.classList.add("kvps-key");
|
| 546 |
|
| 547 |
const td = row.insertCell();
|
|
|
|
|
|
|
|
|
|
| 548 |
|
| 549 |
if (Array.isArray(value)) {
|
| 550 |
for (const item of value) {
|
|
@@ -562,7 +565,7 @@ function drawKvps(container, kvps, latex) {
|
|
| 562 |
imgElement.classList.add("kvps-img");
|
| 563 |
imgElement.src = value.replace("img://", "/image_get?path=");
|
| 564 |
imgElement.alt = "Image Attachment";
|
| 565 |
-
|
| 566 |
|
| 567 |
// Add click handler and cursor change
|
| 568 |
imgElement.style.cursor = "pointer";
|
|
@@ -570,15 +573,14 @@ function drawKvps(container, kvps, latex) {
|
|
| 570 |
openImageModal(imgElement.src, 1000);
|
| 571 |
});
|
| 572 |
|
| 573 |
-
td.appendChild(imgElement);
|
| 574 |
} else {
|
| 575 |
const pre = document.createElement("pre");
|
| 576 |
-
pre.classList.add("kvps-val");
|
| 577 |
// if (row.classList.contains("msg-thoughts")) {
|
| 578 |
const span = document.createElement("span");
|
| 579 |
span.innerHTML = convertHTML(value);
|
| 580 |
pre.appendChild(span);
|
| 581 |
-
|
| 582 |
addCopyButtonToElement(row);
|
| 583 |
|
| 584 |
// Add click handler
|
|
|
|
| 537 |
for (let [key, value] of Object.entries(kvps)) {
|
| 538 |
const row = table.insertRow();
|
| 539 |
row.classList.add("kvps-row");
|
| 540 |
+
if (key === "thoughts" || key === "reasoning") // TODO: find a better way to determine special class assignment
|
| 541 |
row.classList.add("msg-thoughts");
|
| 542 |
|
| 543 |
const th = row.insertCell();
|
|
|
|
| 545 |
th.classList.add("kvps-key");
|
| 546 |
|
| 547 |
const td = row.insertCell();
|
| 548 |
+
const tdiv = document.createElement("div");
|
| 549 |
+
tdiv.classList.add("kvps-val");
|
| 550 |
+
td.appendChild(tdiv);
|
| 551 |
|
| 552 |
if (Array.isArray(value)) {
|
| 553 |
for (const item of value) {
|
|
|
|
| 565 |
imgElement.classList.add("kvps-img");
|
| 566 |
imgElement.src = value.replace("img://", "/image_get?path=");
|
| 567 |
imgElement.alt = "Image Attachment";
|
| 568 |
+
tdiv.appendChild(imgElement);
|
| 569 |
|
| 570 |
// Add click handler and cursor change
|
| 571 |
imgElement.style.cursor = "pointer";
|
|
|
|
| 573 |
openImageModal(imgElement.src, 1000);
|
| 574 |
});
|
| 575 |
|
|
|
|
| 576 |
} else {
|
| 577 |
const pre = document.createElement("pre");
|
| 578 |
+
// pre.classList.add("kvps-val");
|
| 579 |
// if (row.classList.contains("msg-thoughts")) {
|
| 580 |
const span = document.createElement("span");
|
| 581 |
span.innerHTML = convertHTML(value);
|
| 582 |
pre.appendChild(span);
|
| 583 |
+
tdiv.appendChild(pre);
|
| 584 |
addCopyButtonToElement(row);
|
| 585 |
|
| 586 |
// Add click handler
|