Spaces:
Runtime error
Runtime error
File size: 3,838 Bytes
8a46dc1 8ce1b44 8a46dc1 60d1fd6 0242ef6 8a46dc1 2f66bef 0242ef6 8a46dc1 0242ef6 603a029 2f66bef 8a46dc1 2f66bef 603a029 0242ef6 2f66bef 2b9cce2 2f66bef 8a46dc1 2f66bef 8ce1b44 8a46dc1 2f66bef 0242ef6 2f66bef 0242ef6 2f66bef 0242ef6 8a46dc1 0242ef6 8a46dc1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
from getpass import getpass
import os
from typing import Literal, cast
from langchain_core.tools import BaseTool
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.runnables import Runnable
from langchain_core.messages import BaseMessage
from langchain_core.language_models.base import LanguageModelInput
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from langchain_deepseek import ChatDeepSeek
from langchain_ollama import ChatOllama
from pydantic import BaseModel, Field, SecretStr
from agent.prompts import get_system_prompt
from agent.state import State
from langchain_core.messages import SystemMessage, HumanMessage
from langgraph.prebuilt import ToolNode
import backoff
import openai
import re
from langchain_core.messages.utils import trim_messages, count_tokens_approximately
from agent.config import API_BASE_URL, MAX_TOKENS, MODEL_NAME, API_KEY_ENV_VAR, MODEL_TEMPERATURE
from dotenv import load_dotenv
load_dotenv()
if API_KEY_ENV_VAR not in os.environ:
print(f"Please set the environment variable {API_KEY_ENV_VAR}.")
os.environ[API_KEY_ENV_VAR] = getpass(f"Enter your {API_KEY_ENV_VAR} (will not be echoed): ")
### Helper functions ###
def _get_model() -> BaseChatModel:
# api_key = os.getenv("GOOGLE_API_KEY")
# return ChatGoogleGenerativeAI(
# api_key=SecretStr(api_key) if api_key else None,
# model="gemini-2.5-pro"
# )
api_key = os.getenv(API_KEY_ENV_VAR)
# return ChatOllama(
# model=MODEL_NAME,
# temperature=MODEL_TEMPERATURE if MODEL_TEMPERATURE else 0.0,
# metadata={
# "reasoning": {
# "effort": "high" # Use high reasoning effort
# }
# }
# )
return ChatOpenAI(
api_key=SecretStr(api_key) if api_key else None,
base_url=API_BASE_URL,
model=MODEL_NAME,
temperature=MODEL_TEMPERATURE if MODEL_TEMPERATURE else 0.0,
metadata={
"reasoning": {
"effort": "high" # Use high reasoning effort
}
}
)
# return ChatDeepSeek(
# model="deepseek-chat",
# temperature=MODEL_TEMPERATURE if MODEL_TEMPERATURE else 0.0,
# max_retries=2
# )
def _get_tools() -> list[BaseTool]:
from tools import get_all_tools
return get_all_tools()
def _bind_model(model: BaseChatModel) -> Runnable[LanguageModelInput, BaseMessage]:
return model.bind_tools(_get_tools())
### NODES ###
# Call model node
@backoff.on_exception(
backoff.runtime,
(openai.RateLimitError, openai.InternalServerError),
value=lambda e: float(match.group(1)) if (match := re.search(r'try again in (\d+(?:\.\d+)?)s', str(e))) else 10.0,
max_tries=200,
)
def call_model(state: State, config) -> dict[str, list[BaseMessage]]:
if MAX_TOKENS:
messages = trim_messages(
state["messages"],
strategy="last",
token_counter=count_tokens_approximately,
allow_partial=True,
max_tokens=MAX_TOKENS,
start_on="human",
end_on=("human", "tool"),
)
else:
messages = state["messages"]
app_name = config.get('configurable', {}).get("app_name", "OracleBot")
# Add system prompt if not already present
if not messages or messages[0].type != "system":
# Use dynamic system prompt if sports are mentioned
system_prompt = get_system_prompt()
system_message: BaseMessage = SystemMessage(content=system_prompt)
messages = [system_message] + list(messages)
model = _get_model()
model = _bind_model(model)
response = model.invoke(messages)
return {"messages": [response]}
# Tool node
tool_node = ToolNode(tools=_get_tools())
|