File size: 5,160 Bytes
e6315a0
 
713244e
e6315a0
 
c3c0da2
af4f5ea
 
e6315a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af4f5ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6315a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11cd630
e6315a0
 
 
075ffbc
 
 
713244e
ab1278d
713244e
 
7ebc204
af4f5ea
e6315a0
 
 
 
 
af4f5ea
 
e6315a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189d69a
 
e6315a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a37c30b
e6315a0
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
from typing import Optional
from langchain_together import ChatTogether
from langgraph.graph import StateGraph, START, END, MessagesState
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langchain_google_genai import ChatGoogleGenerativeAI
from langgraph.prebuilt import ToolNode
from pydantic import Field, SecretStr  # add

# Try to import tools from tools.py
try:
    from .tools import get_tools as _get_tools  # package-style
except Exception:
    try:
        from tools import get_tools as _get_tools  # script-style
    except Exception:
        def _get_tools(): return []  # fallback

try:
    # Optional, used when OPENAI_API_KEY is available
    from langchain_openai import ChatOpenAI
except Exception:  # pragma: no cover - optional dependency resolution
    ChatOpenAI = None  # type: ignore


class ChatOpenRouter(ChatOpenAI):
    openai_api_key: Optional[SecretStr] = Field(
        alias="api_key",
        default_factory=os.getenv("OPENROUTER_API_KEY", None),
    )

    @property
    def lc_secrets(self) -> dict[str, str]:
        return {"openai_api_key": "OPENROUTER_API_KEY"}

    def __init__(self,
                 openai_api_key: Optional[str] = None,
                 **kwargs):
        openai_api_key = (
            openai_api_key or os.getenv("OPENROUTER_API_KEY")
        )
        super().__init__(
            base_url="https://openrouter.ai/api/v1",
            openai_api_key=openai_api_key,
            **kwargs
        )


class _EchoModel:
    """Simple stub model used when no API key / model is configured."""

    def __init__(self, prefix: str = "[stub]"):
        self.prefix = prefix

    def invoke(self, messages):
        last = messages[-1]
        content = getattr(last, "content", str(last))
        # Ensure the contract: always emit FINAL ANSWER:
        return AIMessage(content=f"{self.prefix} FINAL ANSWER: You asked: {content}")


class LangGraphAgent:
    """
    Minimal LangGraph agent template.

    Usage:
        agent = LangGraphAgent()
        answer = agent("What is the capital of France?")
    """

    def __init__(self, *, model: Optional[object] = None, system_prompt: Optional[str] = None):
        # Guide the model to use tools and to output a clear final answer.
        self.system_prompt = system_prompt or "You are a helpful assistant. Keep answers concise."

        # Choose an LLM if not provided
        if model is None:
            # model = ChatGoogleGenerativeAI(
            #     model="gemma-3-27b-it",
            # )
            model = ChatTogether(
                model="meta-llama/Llama-3.3-70B-Instruct-Turbo",
                api_key=os.getenv("TOGETHER_API_KEY"),

            )
            if model is None and ChatOpenAI is not None:
                model = ChatOpenAI(
                    api_key=os.getenv("OPENROUTER_API_KEY"),
                    base_url=os.getenv("OPENROUTER_BASE_URL"),
                    model="openai/gpt-oss-20b:free",
                )
        if model is None:
            model = _EchoModel()
        self.model = model

        # Load tools and bind to the model if supported
        self.tools = _get_tools()
        self.llm = getattr(self.model, "bind_tools",
                           lambda _: self.model)(self.tools)

        # Build a tool-using LangGraph: agent -> (maybe) tools -> agent -> ... -> END
        def call_agent(state: MessagesState):
            msgs = [SystemMessage(content=self.system_prompt)
                    ] + list(state["messages"])
            ai = self.llm.invoke(msgs)
            return {"messages": [ai]}

        def should_call_tools(state: MessagesState):
            # If the last AI message includes tool calls, route to tools; else end.
            last = state["messages"][-1]
            if isinstance(last, AIMessage) and getattr(last, "tool_calls", None):
                print(
                    f"Detected tool calls in last AI message: {last.tool_calls}")
                return "tools"
            return "end"

        builder = StateGraph(MessagesState)
        builder.add_node("agent", call_agent)
        builder.add_node("tools", ToolNode(self.tools))
        builder.add_edge(START, "agent")
        builder.add_edge("tools", "agent")
        builder.add_conditional_edges("agent", should_call_tools, {
                                      "tools": "tools", "end": END})
        self.graph = builder.compile()

    @staticmethod
    def _extract_final_answer(text: str) -> str:
        key = "FINAL ANSWER:"
        idx = text.rfind(key)
        return text[idx + len(key):].strip() if idx != -1 else text.strip()

    def __call__(self, question: str) -> str:
        state = {"messages": [HumanMessage(content=question)]}
        result = self.graph.invoke(state, {'recursion_limit': 10})
        messages = result.get("messages", [])
        # Return only the content after "FINAL ANSWER:"
        for msg in reversed(messages):
            if isinstance(msg, AIMessage):
                return self._extract_final_answer(msg.content)
        return self._extract_final_answer(messages[-1].content) if messages else ""