Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Optional | |
| from langchain_together import ChatTogether | |
| from langgraph.graph import StateGraph, START, END, MessagesState | |
| from langchain_core.messages import HumanMessage, AIMessage, SystemMessage | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langgraph.prebuilt import ToolNode | |
| from pydantic import Field, SecretStr # add | |
| # Try to import tools from tools.py | |
| try: | |
| from .tools import get_tools as _get_tools # package-style | |
| except Exception: | |
| try: | |
| from tools import get_tools as _get_tools # script-style | |
| except Exception: | |
| def _get_tools(): return [] # fallback | |
| try: | |
| # Optional, used when OPENAI_API_KEY is available | |
| from langchain_openai import ChatOpenAI | |
| except Exception: # pragma: no cover - optional dependency resolution | |
| ChatOpenAI = None # type: ignore | |
| class ChatOpenRouter(ChatOpenAI): | |
| openai_api_key: Optional[SecretStr] = Field( | |
| alias="api_key", | |
| default_factory=os.getenv("OPENROUTER_API_KEY", None), | |
| ) | |
| def lc_secrets(self) -> dict[str, str]: | |
| return {"openai_api_key": "OPENROUTER_API_KEY"} | |
| def __init__(self, | |
| openai_api_key: Optional[str] = None, | |
| **kwargs): | |
| openai_api_key = ( | |
| openai_api_key or os.getenv("OPENROUTER_API_KEY") | |
| ) | |
| super().__init__( | |
| base_url="https://openrouter.ai/api/v1", | |
| openai_api_key=openai_api_key, | |
| **kwargs | |
| ) | |
| class _EchoModel: | |
| """Simple stub model used when no API key / model is configured.""" | |
| def __init__(self, prefix: str = "[stub]"): | |
| self.prefix = prefix | |
| def invoke(self, messages): | |
| last = messages[-1] | |
| content = getattr(last, "content", str(last)) | |
| # Ensure the contract: always emit FINAL ANSWER: | |
| return AIMessage(content=f"{self.prefix} FINAL ANSWER: You asked: {content}") | |
| class LangGraphAgent: | |
| """ | |
| Minimal LangGraph agent template. | |
| Usage: | |
| agent = LangGraphAgent() | |
| answer = agent("What is the capital of France?") | |
| """ | |
| def __init__(self, *, model: Optional[object] = None, system_prompt: Optional[str] = None): | |
| # Guide the model to use tools and to output a clear final answer. | |
| self.system_prompt = system_prompt or "You are a helpful assistant. Keep answers concise." | |
| # Choose an LLM if not provided | |
| if model is None: | |
| # model = ChatGoogleGenerativeAI( | |
| # model="gemma-3-27b-it", | |
| # ) | |
| model = ChatTogether( | |
| model="meta-llama/Llama-3.3-70B-Instruct-Turbo", | |
| api_key=os.getenv("TOGETHER_API_KEY"), | |
| ) | |
| if model is None and ChatOpenAI is not None: | |
| model = ChatOpenAI( | |
| api_key=os.getenv("OPENROUTER_API_KEY"), | |
| base_url=os.getenv("OPENROUTER_BASE_URL"), | |
| model="openai/gpt-oss-20b:free", | |
| ) | |
| if model is None: | |
| model = _EchoModel() | |
| self.model = model | |
| # Load tools and bind to the model if supported | |
| self.tools = _get_tools() | |
| self.llm = getattr(self.model, "bind_tools", | |
| lambda _: self.model)(self.tools) | |
| # Build a tool-using LangGraph: agent -> (maybe) tools -> agent -> ... -> END | |
| def call_agent(state: MessagesState): | |
| msgs = [SystemMessage(content=self.system_prompt) | |
| ] + list(state["messages"]) | |
| ai = self.llm.invoke(msgs) | |
| return {"messages": [ai]} | |
| def should_call_tools(state: MessagesState): | |
| # If the last AI message includes tool calls, route to tools; else end. | |
| last = state["messages"][-1] | |
| if isinstance(last, AIMessage) and getattr(last, "tool_calls", None): | |
| print( | |
| f"Detected tool calls in last AI message: {last.tool_calls}") | |
| return "tools" | |
| return "end" | |
| builder = StateGraph(MessagesState) | |
| builder.add_node("agent", call_agent) | |
| builder.add_node("tools", ToolNode(self.tools)) | |
| builder.add_edge(START, "agent") | |
| builder.add_edge("tools", "agent") | |
| builder.add_conditional_edges("agent", should_call_tools, { | |
| "tools": "tools", "end": END}) | |
| self.graph = builder.compile() | |
| def _extract_final_answer(text: str) -> str: | |
| key = "FINAL ANSWER:" | |
| idx = text.rfind(key) | |
| return text[idx + len(key):].strip() if idx != -1 else text.strip() | |
| def __call__(self, question: str) -> str: | |
| state = {"messages": [HumanMessage(content=question)]} | |
| result = self.graph.invoke(state, {'recursion_limit': 10}) | |
| messages = result.get("messages", []) | |
| # Return only the content after "FINAL ANSWER:" | |
| for msg in reversed(messages): | |
| if isinstance(msg, AIMessage): | |
| return self._extract_final_answer(msg.content) | |
| return self._extract_final_answer(messages[-1].content) if messages else "" | |