Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -15,7 +15,7 @@ from langchain_community.utilities.arxiv import ArxivAPIWrapper
|
|
| 15 |
|
| 16 |
from langgraph.graph import StateGraph, END
|
| 17 |
|
| 18 |
-
from langchain_core.messages import BaseMessage, FunctionMessage, HumanMessage, AIMessage
|
| 19 |
from langchain_openai import ChatOpenAI
|
| 20 |
|
| 21 |
# --- Constants ---
|
|
@@ -49,16 +49,31 @@ class AgentState(TypedDict):
|
|
| 49 |
next_action: Optional[str] # To decide if we need to call tools or respond
|
| 50 |
|
| 51 |
class LangGraphAgent:
|
| 52 |
-
def __init__(self):
|
| 53 |
-
print("LangGraphAgent initializing...")
|
| 54 |
if not OPENROUTER_API_KEY:
|
| 55 |
raise ValueError("OPENROUTER_API_KEY is not set. Cannot initialize LLM.")
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
self.tools_map = {tool.name: tool for tool in tools}
|
| 63 |
self.graph = self._build_graph()
|
| 64 |
print("LangGraphAgent initialized.")
|
|
@@ -127,41 +142,48 @@ class LangGraphAgent:
|
|
| 127 |
return {"messages": tool_messages}
|
| 128 |
|
| 129 |
def __call__(self, question: str) -> str:
|
| 130 |
-
print(f"Agent received question (first
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
-
# The GAIA prompt example suggests not including "FINAL ANSWER" and just replying with the answer.
|
| 134 |
-
# We need to ensure the LLM is prompted to provide a direct answer after tool use.
|
| 135 |
-
# For simplicity in this template, we will take the last AI message content as the answer.
|
| 136 |
-
# A more robust solution might involve a specific "final answer" node or prompt engineering.
|
| 137 |
-
|
| 138 |
final_graph_state = None
|
| 139 |
try:
|
| 140 |
for event in self.graph.stream(initial_state, {"recursion_limit": 100}): # Added recursion limit
|
| 141 |
-
# print(f"Graph event: {event}") # For debugging stream
|
| 142 |
if END in event:
|
| 143 |
final_graph_state = event[END]
|
| 144 |
break
|
| 145 |
-
# Update final_graph_state with the latest state from any node
|
| 146 |
-
# This ensures we have the latest messages even if END is not directly reached by llm
|
| 147 |
-
# (e.g. if recursion limit is hit)
|
| 148 |
for key in event:
|
| 149 |
if key != END:
|
| 150 |
final_graph_state = event[key]
|
| 151 |
|
| 152 |
if final_graph_state and final_graph_state["messages"]:
|
| 153 |
-
# Get the last AI message as the answer
|
| 154 |
for msg in reversed(final_graph_state["messages"]):
|
| 155 |
if isinstance(msg, AIMessage) and not msg.tool_calls:
|
| 156 |
answer = msg.content.strip()
|
| 157 |
-
#
|
| 158 |
-
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
print(f"Agent returning answer: {answer}")
|
| 161 |
return answer
|
| 162 |
-
# Fallback if no suitable AI message is found
|
| 163 |
print("No suitable AI message found for final answer. Returning last message content.")
|
| 164 |
-
# This might be a tool call or an intermediate thought, not ideal.
|
| 165 |
return str(final_graph_state["messages"][-1].content) if final_graph_state["messages"] else "Error: No messages in final state."
|
| 166 |
else:
|
| 167 |
print("Error: Agent did not reach a final state or no messages found.")
|
|
@@ -199,7 +221,8 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
|
|
| 199 |
submit_url = f"{api_url}/submit"
|
| 200 |
|
| 201 |
try:
|
| 202 |
-
|
|
|
|
| 203 |
except Exception as e:
|
| 204 |
print(f"Error instantiating agent: {e}")
|
| 205 |
return f"Error initializing agent: {e}", None
|
|
|
|
| 15 |
|
| 16 |
from langgraph.graph import StateGraph, END
|
| 17 |
|
| 18 |
+
from langchain_core.messages import BaseMessage, FunctionMessage, HumanMessage, AIMessage, SystemMessage
|
| 19 |
from langchain_openai import ChatOpenAI
|
| 20 |
|
| 21 |
# --- Constants ---
|
|
|
|
| 49 |
next_action: Optional[str] # To decide if we need to call tools or respond
|
| 50 |
|
| 51 |
class LangGraphAgent:
|
| 52 |
+
def __init__(self, llm_choice: str = "gemini"):
|
| 53 |
+
print(f"LangGraphAgent initializing with {llm_choice}...")
|
| 54 |
if not OPENROUTER_API_KEY:
|
| 55 |
raise ValueError("OPENROUTER_API_KEY is not set. Cannot initialize LLM.")
|
| 56 |
|
| 57 |
+
if llm_choice == "llama":
|
| 58 |
+
self.llm = ChatOpenAI(
|
| 59 |
+
model="meta-llama/llama-3.1-8b-instruct:free",
|
| 60 |
+
api_key=OPENROUTER_API_KEY,
|
| 61 |
+
base_url="https://openrouter.ai/api/v1",
|
| 62 |
+
temperature=0.1, # Llama models can be sensitive to temperature
|
| 63 |
+
# max_tokens=150 # Llama 8B might benefit from a smaller max_token for concise answers
|
| 64 |
+
)
|
| 65 |
+
print("Initialized Llama 3.1 8B Instruct.")
|
| 66 |
+
elif llm_choice == "gemini":
|
| 67 |
+
self.llm = ChatOpenAI(
|
| 68 |
+
model="google/gemini-2.0-flash-001",
|
| 69 |
+
api_key=OPENROUTER_API_KEY,
|
| 70 |
+
base_url="https://openrouter.ai/api/v1",
|
| 71 |
+
temperature=0.1 # Adding temperature for consistency
|
| 72 |
+
)
|
| 73 |
+
print("Initialized Gemini 2.0 Flash.")
|
| 74 |
+
else:
|
| 75 |
+
raise ValueError(f"Unsupported LLM choice: {llm_choice}. Choose 'gemini' or 'llama'.")
|
| 76 |
+
|
| 77 |
self.tools_map = {tool.name: tool for tool in tools}
|
| 78 |
self.graph = self._build_graph()
|
| 79 |
print("LangGraphAgent initialized.")
|
|
|
|
| 142 |
return {"messages": tool_messages}
|
| 143 |
|
| 144 |
def __call__(self, question: str) -> str:
|
| 145 |
+
print(f"Agent received question (first 100 chars): {question[:100]}...")
|
| 146 |
+
|
| 147 |
+
system_prompt = (
|
| 148 |
+
"You are an AI assistant designed to answer questions concisely. "
|
| 149 |
+
"Your goal is to provide only the direct answer to the question, without any additional explanations, conversation, or prefixes like 'FINAL ANSWER:'. "
|
| 150 |
+
"For example, if the question is 'What is the capital of France?', you should respond with 'Paris'. "
|
| 151 |
+
"If the question asks for a list, provide it comma-separated, e.g., 'apple, banana, cherry'. "
|
| 152 |
+
"If the question asks for a number, provide only the number, e.g., '42'."
|
| 153 |
+
)
|
| 154 |
+
initial_state = {"messages": [SystemMessage(content=system_prompt), HumanMessage(content=question)]}
|
| 155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
final_graph_state = None
|
| 157 |
try:
|
| 158 |
for event in self.graph.stream(initial_state, {"recursion_limit": 100}): # Added recursion limit
|
|
|
|
| 159 |
if END in event:
|
| 160 |
final_graph_state = event[END]
|
| 161 |
break
|
|
|
|
|
|
|
|
|
|
| 162 |
for key in event:
|
| 163 |
if key != END:
|
| 164 |
final_graph_state = event[key]
|
| 165 |
|
| 166 |
if final_graph_state and final_graph_state["messages"]:
|
|
|
|
| 167 |
for msg in reversed(final_graph_state["messages"]):
|
| 168 |
if isinstance(msg, AIMessage) and not msg.tool_calls:
|
| 169 |
answer = msg.content.strip()
|
| 170 |
+
# Remove common prefixes that LLMs might add despite instructions
|
| 171 |
+
prefixes_to_remove = [
|
| 172 |
+
"FINAL ANSWER:", "The answer is", "Here is the answer:",
|
| 173 |
+
"The final answer is", "Answer:", "Solution:"
|
| 174 |
+
]
|
| 175 |
+
for prefix in prefixes_to_remove:
|
| 176 |
+
if answer.upper().startswith(prefix.upper()):
|
| 177 |
+
answer = answer[len(prefix):].strip()
|
| 178 |
+
|
| 179 |
+
# Remove potential quotation marks if the answer is a single word/phrase
|
| 180 |
+
if len(answer.split()) < 5: # Heuristic for short answers
|
| 181 |
+
if answer.startswith(('"', "'")) and answer.endswith(('"', "'")):
|
| 182 |
+
answer = answer[1:-1]
|
| 183 |
+
|
| 184 |
print(f"Agent returning answer: {answer}")
|
| 185 |
return answer
|
|
|
|
| 186 |
print("No suitable AI message found for final answer. Returning last message content.")
|
|
|
|
| 187 |
return str(final_graph_state["messages"][-1].content) if final_graph_state["messages"] else "Error: No messages in final state."
|
| 188 |
else:
|
| 189 |
print("Error: Agent did not reach a final state or no messages found.")
|
|
|
|
| 221 |
submit_url = f"{api_url}/submit"
|
| 222 |
|
| 223 |
try:
|
| 224 |
+
# Default to Llama for now, can be made configurable later (e.g., via Gradio input)
|
| 225 |
+
agent = LangGraphAgent(llm_choice="llama")
|
| 226 |
except Exception as e:
|
| 227 |
print(f"Error instantiating agent: {e}")
|
| 228 |
return f"Error initializing agent: {e}", None
|