abtsousa
Update configuration and enhance tool functionality
0242ef6
from getpass import getpass
import os
from typing import Literal, cast
from langchain_core.tools import BaseTool
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.runnables import Runnable
from langchain_core.messages import BaseMessage
from langchain_core.language_models.base import LanguageModelInput
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from langchain_deepseek import ChatDeepSeek
from langchain_ollama import ChatOllama
from pydantic import BaseModel, Field, SecretStr
from agent.prompts import get_system_prompt
from agent.state import State
from langchain_core.messages import SystemMessage, HumanMessage
from langgraph.prebuilt import ToolNode
import backoff
import openai
import re
from langchain_core.messages.utils import trim_messages, count_tokens_approximately
from agent.config import API_BASE_URL, MAX_TOKENS, MODEL_NAME, API_KEY_ENV_VAR, MODEL_TEMPERATURE
from dotenv import load_dotenv
load_dotenv()
if API_KEY_ENV_VAR not in os.environ:
print(f"Please set the environment variable {API_KEY_ENV_VAR}.")
os.environ[API_KEY_ENV_VAR] = getpass(f"Enter your {API_KEY_ENV_VAR} (will not be echoed): ")
### Helper functions ###
def _get_model() -> BaseChatModel:
# api_key = os.getenv("GOOGLE_API_KEY")
# return ChatGoogleGenerativeAI(
# api_key=SecretStr(api_key) if api_key else None,
# model="gemini-2.5-pro"
# )
api_key = os.getenv(API_KEY_ENV_VAR)
# return ChatOllama(
# model=MODEL_NAME,
# temperature=MODEL_TEMPERATURE if MODEL_TEMPERATURE else 0.0,
# metadata={
# "reasoning": {
# "effort": "high" # Use high reasoning effort
# }
# }
# )
return ChatOpenAI(
api_key=SecretStr(api_key) if api_key else None,
base_url=API_BASE_URL,
model=MODEL_NAME,
temperature=MODEL_TEMPERATURE if MODEL_TEMPERATURE else 0.0,
metadata={
"reasoning": {
"effort": "high" # Use high reasoning effort
}
}
)
# return ChatDeepSeek(
# model="deepseek-chat",
# temperature=MODEL_TEMPERATURE if MODEL_TEMPERATURE else 0.0,
# max_retries=2
# )
def _get_tools() -> list[BaseTool]:
from tools import get_all_tools
return get_all_tools()
def _bind_model(model: BaseChatModel) -> Runnable[LanguageModelInput, BaseMessage]:
return model.bind_tools(_get_tools())
### NODES ###
# Call model node
@backoff.on_exception(
backoff.runtime,
(openai.RateLimitError, openai.InternalServerError),
value=lambda e: float(match.group(1)) if (match := re.search(r'try again in (\d+(?:\.\d+)?)s', str(e))) else 10.0,
max_tries=200,
)
def call_model(state: State, config) -> dict[str, list[BaseMessage]]:
if MAX_TOKENS:
messages = trim_messages(
state["messages"],
strategy="last",
token_counter=count_tokens_approximately,
allow_partial=True,
max_tokens=MAX_TOKENS,
start_on="human",
end_on=("human", "tool"),
)
else:
messages = state["messages"]
app_name = config.get('configurable', {}).get("app_name", "OracleBot")
# Add system prompt if not already present
if not messages or messages[0].type != "system":
# Use dynamic system prompt if sports are mentioned
system_prompt = get_system_prompt()
system_message: BaseMessage = SystemMessage(content=system_prompt)
messages = [system_message] + list(messages)
model = _get_model()
model = _bind_model(model)
response = model.invoke(messages)
return {"messages": [response]}
# Tool node
tool_node = ToolNode(tools=_get_tools())