from getpass import getpass import os from typing import Literal, cast from langchain_core.tools import BaseTool from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.runnables import Runnable from langchain_core.messages import BaseMessage from langchain_core.language_models.base import LanguageModelInput from langchain_google_genai import ChatGoogleGenerativeAI from langchain_openai import ChatOpenAI from langchain_deepseek import ChatDeepSeek from langchain_ollama import ChatOllama from pydantic import BaseModel, Field, SecretStr from agent.prompts import get_system_prompt from agent.state import State from langchain_core.messages import SystemMessage, HumanMessage from langgraph.prebuilt import ToolNode import backoff import openai import re from langchain_core.messages.utils import trim_messages, count_tokens_approximately from agent.config import API_BASE_URL, MAX_TOKENS, MODEL_NAME, API_KEY_ENV_VAR, MODEL_TEMPERATURE from dotenv import load_dotenv load_dotenv() if API_KEY_ENV_VAR not in os.environ: print(f"Please set the environment variable {API_KEY_ENV_VAR}.") os.environ[API_KEY_ENV_VAR] = getpass(f"Enter your {API_KEY_ENV_VAR} (will not be echoed): ") ### Helper functions ### def _get_model() -> BaseChatModel: # api_key = os.getenv("GOOGLE_API_KEY") # return ChatGoogleGenerativeAI( # api_key=SecretStr(api_key) if api_key else None, # model="gemini-2.5-pro" # ) api_key = os.getenv(API_KEY_ENV_VAR) # return ChatOllama( # model=MODEL_NAME, # temperature=MODEL_TEMPERATURE if MODEL_TEMPERATURE else 0.0, # metadata={ # "reasoning": { # "effort": "high" # Use high reasoning effort # } # } # ) return ChatOpenAI( api_key=SecretStr(api_key) if api_key else None, base_url=API_BASE_URL, model=MODEL_NAME, temperature=MODEL_TEMPERATURE if MODEL_TEMPERATURE else 0.0, metadata={ "reasoning": { "effort": "high" # Use high reasoning effort } } ) # return ChatDeepSeek( # model="deepseek-chat", # temperature=MODEL_TEMPERATURE if MODEL_TEMPERATURE else 0.0, # max_retries=2 # ) def _get_tools() -> list[BaseTool]: from tools import get_all_tools return get_all_tools() def _bind_model(model: BaseChatModel) -> Runnable[LanguageModelInput, BaseMessage]: return model.bind_tools(_get_tools()) ### NODES ### # Call model node @backoff.on_exception( backoff.runtime, (openai.RateLimitError, openai.InternalServerError), value=lambda e: float(match.group(1)) if (match := re.search(r'try again in (\d+(?:\.\d+)?)s', str(e))) else 10.0, max_tries=200, ) def call_model(state: State, config) -> dict[str, list[BaseMessage]]: if MAX_TOKENS: messages = trim_messages( state["messages"], strategy="last", token_counter=count_tokens_approximately, allow_partial=True, max_tokens=MAX_TOKENS, start_on="human", end_on=("human", "tool"), ) else: messages = state["messages"] app_name = config.get('configurable', {}).get("app_name", "OracleBot") # Add system prompt if not already present if not messages or messages[0].type != "system": # Use dynamic system prompt if sports are mentioned system_prompt = get_system_prompt() system_message: BaseMessage = SystemMessage(content=system_prompt) messages = [system_message] + list(messages) model = _get_model() model = _bind_model(model) response = model.invoke(messages) return {"messages": [response]} # Tool node tool_node = ToolNode(tools=_get_tools())