File size: 3,838 Bytes
8a46dc1
 
 
8ce1b44
8a46dc1
 
 
 
 
 
60d1fd6
0242ef6
8a46dc1
 
 
 
 
2f66bef
 
 
0242ef6
8a46dc1
0242ef6
603a029
2f66bef
 
 
8a46dc1
 
 
 
 
 
 
 
 
 
 
 
 
 
2f66bef
603a029
0242ef6
 
 
 
 
 
 
 
 
 
2f66bef
 
 
 
2b9cce2
2f66bef
 
 
 
 
8a46dc1
 
2f66bef
 
 
 
 
 
8ce1b44
8a46dc1
 
 
 
 
 
 
 
2f66bef
 
0242ef6
2f66bef
0242ef6
2f66bef
0242ef6
8a46dc1
0242ef6
 
 
 
 
 
 
 
 
 
 
 
 
8a46dc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from getpass import getpass
import os
from typing import Literal, cast
from langchain_core.tools import BaseTool
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.runnables import Runnable
from langchain_core.messages import BaseMessage
from langchain_core.language_models.base import LanguageModelInput
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from langchain_deepseek import ChatDeepSeek
from langchain_ollama import ChatOllama
from pydantic import BaseModel, Field, SecretStr
from agent.prompts import get_system_prompt
from agent.state import State
from langchain_core.messages import SystemMessage, HumanMessage
from langgraph.prebuilt import ToolNode
import backoff
import openai
import re
from langchain_core.messages.utils import trim_messages, count_tokens_approximately

from agent.config import API_BASE_URL, MAX_TOKENS, MODEL_NAME, API_KEY_ENV_VAR, MODEL_TEMPERATURE

from dotenv import load_dotenv
load_dotenv()

if API_KEY_ENV_VAR not in os.environ:
    print(f"Please set the environment variable {API_KEY_ENV_VAR}.")
    os.environ[API_KEY_ENV_VAR] = getpass(f"Enter your {API_KEY_ENV_VAR} (will not be echoed): ")

### Helper functions ###

def _get_model() -> BaseChatModel:

    # api_key = os.getenv("GOOGLE_API_KEY")
    # return ChatGoogleGenerativeAI(
    #     api_key=SecretStr(api_key) if api_key else None,
    #     model="gemini-2.5-pro"
    # )

    api_key = os.getenv(API_KEY_ENV_VAR)

    # return ChatOllama(
    #     model=MODEL_NAME,
    #     temperature=MODEL_TEMPERATURE if MODEL_TEMPERATURE else 0.0,
    #     metadata={
    #         "reasoning": {
    #             "effort": "high"  # Use high reasoning effort
    #         }
    #     }
    # )

    return ChatOpenAI(
        api_key=SecretStr(api_key) if api_key else None,
        base_url=API_BASE_URL,
        model=MODEL_NAME,
        temperature=MODEL_TEMPERATURE if MODEL_TEMPERATURE else 0.0,
        metadata={
            "reasoning": {
                "effort": "high"  # Use high reasoning effort
            }
        }
    )

    # return ChatDeepSeek(
    #     model="deepseek-chat",
    #     temperature=MODEL_TEMPERATURE if MODEL_TEMPERATURE else 0.0,
    #     max_retries=2
    # )

def _get_tools() -> list[BaseTool]:
    from tools import get_all_tools
    return get_all_tools()

def _bind_model(model: BaseChatModel) -> Runnable[LanguageModelInput, BaseMessage]:
    return model.bind_tools(_get_tools())

### NODES ###
# Call model node
@backoff.on_exception(
    backoff.runtime,
    (openai.RateLimitError, openai.InternalServerError),
    value=lambda e: float(match.group(1)) if (match := re.search(r'try again in (\d+(?:\.\d+)?)s', str(e))) else 10.0,
    max_tries=200,
)

def call_model(state: State, config) -> dict[str, list[BaseMessage]]:
    if MAX_TOKENS:
        messages = trim_messages(
            state["messages"],
            strategy="last",
            token_counter=count_tokens_approximately,
            allow_partial=True,
            max_tokens=MAX_TOKENS,
            start_on="human",
            end_on=("human", "tool"),
        )
    else:
        messages = state["messages"]
    
    app_name = config.get('configurable', {}).get("app_name", "OracleBot")
    
    # Add system prompt if not already present
    if not messages or messages[0].type != "system":
        # Use dynamic system prompt if sports are mentioned
        system_prompt = get_system_prompt()
        system_message: BaseMessage = SystemMessage(content=system_prompt)
        messages = [system_message] + list(messages)
    
    model = _get_model()
    model = _bind_model(model)
    response = model.invoke(messages)
    
    return {"messages": [response]}

# Tool node
tool_node = ToolNode(tools=_get_tools())