Spaces:
Runtime error
Runtime error
| from getpass import getpass | |
| import os | |
| from typing import Literal, cast | |
| from langchain_core.tools import BaseTool | |
| from langchain_core.language_models.chat_models import BaseChatModel | |
| from langchain_core.runnables import Runnable | |
| from langchain_core.messages import BaseMessage | |
| from langchain_core.language_models.base import LanguageModelInput | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_openai import ChatOpenAI | |
| from langchain_deepseek import ChatDeepSeek | |
| from langchain_ollama import ChatOllama | |
| from pydantic import BaseModel, Field, SecretStr | |
| from agent.prompts import get_system_prompt | |
| from agent.state import State | |
| from langchain_core.messages import SystemMessage, HumanMessage | |
| from langgraph.prebuilt import ToolNode | |
| import backoff | |
| import openai | |
| import re | |
| from langchain_core.messages.utils import trim_messages, count_tokens_approximately | |
| from agent.config import API_BASE_URL, MAX_TOKENS, MODEL_NAME, API_KEY_ENV_VAR, MODEL_TEMPERATURE | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| if API_KEY_ENV_VAR not in os.environ: | |
| print(f"Please set the environment variable {API_KEY_ENV_VAR}.") | |
| os.environ[API_KEY_ENV_VAR] = getpass(f"Enter your {API_KEY_ENV_VAR} (will not be echoed): ") | |
| ### Helper functions ### | |
| def _get_model() -> BaseChatModel: | |
| # api_key = os.getenv("GOOGLE_API_KEY") | |
| # return ChatGoogleGenerativeAI( | |
| # api_key=SecretStr(api_key) if api_key else None, | |
| # model="gemini-2.5-pro" | |
| # ) | |
| api_key = os.getenv(API_KEY_ENV_VAR) | |
| # return ChatOllama( | |
| # model=MODEL_NAME, | |
| # temperature=MODEL_TEMPERATURE if MODEL_TEMPERATURE else 0.0, | |
| # metadata={ | |
| # "reasoning": { | |
| # "effort": "high" # Use high reasoning effort | |
| # } | |
| # } | |
| # ) | |
| return ChatOpenAI( | |
| api_key=SecretStr(api_key) if api_key else None, | |
| base_url=API_BASE_URL, | |
| model=MODEL_NAME, | |
| temperature=MODEL_TEMPERATURE if MODEL_TEMPERATURE else 0.0, | |
| metadata={ | |
| "reasoning": { | |
| "effort": "high" # Use high reasoning effort | |
| } | |
| } | |
| ) | |
| # return ChatDeepSeek( | |
| # model="deepseek-chat", | |
| # temperature=MODEL_TEMPERATURE if MODEL_TEMPERATURE else 0.0, | |
| # max_retries=2 | |
| # ) | |
| def _get_tools() -> list[BaseTool]: | |
| from tools import get_all_tools | |
| return get_all_tools() | |
| def _bind_model(model: BaseChatModel) -> Runnable[LanguageModelInput, BaseMessage]: | |
| return model.bind_tools(_get_tools()) | |
| ### NODES ### | |
| # Call model node | |
| def call_model(state: State, config) -> dict[str, list[BaseMessage]]: | |
| if MAX_TOKENS: | |
| messages = trim_messages( | |
| state["messages"], | |
| strategy="last", | |
| token_counter=count_tokens_approximately, | |
| allow_partial=True, | |
| max_tokens=MAX_TOKENS, | |
| start_on="human", | |
| end_on=("human", "tool"), | |
| ) | |
| else: | |
| messages = state["messages"] | |
| app_name = config.get('configurable', {}).get("app_name", "OracleBot") | |
| # Add system prompt if not already present | |
| if not messages or messages[0].type != "system": | |
| # Use dynamic system prompt if sports are mentioned | |
| system_prompt = get_system_prompt() | |
| system_message: BaseMessage = SystemMessage(content=system_prompt) | |
| messages = [system_message] + list(messages) | |
| model = _get_model() | |
| model = _bind_model(model) | |
| response = model.invoke(messages) | |
| return {"messages": [response]} | |
| # Tool node | |
| tool_node = ToolNode(tools=_get_tools()) | |