abtsousa commited on
Commit
8a46dc1
·
1 Parent(s): 1f19061

Add AI model integration and system prompt handling

Browse files
Files changed (1) hide show
  1. agent/nodes.py +73 -0
agent/nodes.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from getpass import getpass
2
+ import os
3
+ from typing import Literal, cast
4
+ from langchain_core.tools import StructuredTool
5
+ from langchain_core.language_models.chat_models import BaseChatModel
6
+ from langchain_core.runnables import Runnable
7
+ from langchain_core.messages import BaseMessage
8
+ from langchain_core.language_models.base import LanguageModelInput
9
+ from langchain_google_genai import ChatGoogleGenerativeAI
10
+ from langchain_openai import ChatOpenAI
11
+ from pydantic import BaseModel, Field, SecretStr
12
+ from agent.prompts import get_system_prompt
13
+ from agent.state import State
14
+ from langchain_core.messages import SystemMessage, HumanMessage
15
+ from langgraph.prebuilt import ToolNode
16
+
17
+ API_BASE_URL = "https://api.openrouter.ai/v1"
18
+ MODEL_NAME = "qwen/qwen3-235b-a22b:free"
19
+ API_KEY_ENV_VAR = "OPENROUTER_API_KEY"
20
+ if API_KEY_ENV_VAR not in os.environ:
21
+ print(f"Please set the environment variable {API_KEY_ENV_VAR}.")
22
+ os.environ[API_KEY_ENV_VAR] = getpass(f"Enter your {API_KEY_ENV_VAR} (will not be echoed): ")
23
+
24
+ ### Helper functions ###
25
+
26
+ def _get_model() -> BaseChatModel:
27
+
28
+ # api_key = os.getenv("GOOGLE_API_KEY")
29
+ # return ChatGoogleGenerativeAI(
30
+ # api_key=SecretStr(api_key) if api_key else None,
31
+ # model="gemini-2.5-pro"
32
+ # )
33
+
34
+ api_key = os.getenv(API_KEY_ENV_VAR)
35
+ return ChatOpenAI(
36
+ api_key=SecretStr(api_key) if api_key else None,
37
+ base_url=API_BASE_URL,
38
+ model=MODEL_NAME,
39
+ metadata={
40
+ "reasoning": {
41
+ "effort": "high" # Use high reasoning effort
42
+ }
43
+ }
44
+ )
45
+
46
+ def _get_tools() -> list[StructuredTool]:
47
+ from tools import get_all_tools
48
+ return get_all_tools()
49
+
50
+ def _bind_model(model: BaseChatModel) -> Runnable[LanguageModelInput, BaseMessage]:
51
+ return model.bind_tools(_get_tools())
52
+
53
+ ### NODES ###
54
+ # Call model node
55
+ def call_model(state: State, config) -> dict[str, list[BaseMessage]]:
56
+ messages = state["messages"]
57
+ app_name = config.get('configurable', {}).get("app_name", "OracleBot")
58
+
59
+ # Add system prompt if not already present
60
+ if not messages or messages[0].type != "system":
61
+ # Use dynamic system prompt if sports are mentioned
62
+ system_prompt = get_system_prompt()
63
+ system_message: BaseMessage = SystemMessage(content=system_prompt)
64
+ messages = [system_message] + list(messages)
65
+
66
+ model = _get_model()
67
+ model = _bind_model(model)
68
+ response = model.invoke(messages)
69
+
70
+ return {"messages": [response]}
71
+
72
+ # Tool node
73
+ tool_node = ToolNode(tools=_get_tools())