i-dhilip commited on
Commit
44d380b
·
verified ·
1 Parent(s): 689cba9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -26
app.py CHANGED
@@ -15,7 +15,7 @@ from langchain_community.utilities.arxiv import ArxivAPIWrapper
15
 
16
  from langgraph.graph import StateGraph, END
17
 
18
- from langchain_core.messages import BaseMessage, FunctionMessage, HumanMessage, AIMessage
19
  from langchain_openai import ChatOpenAI
20
 
21
  # --- Constants ---
@@ -49,16 +49,31 @@ class AgentState(TypedDict):
49
  next_action: Optional[str] # To decide if we need to call tools or respond
50
 
51
  class LangGraphAgent:
52
- def __init__(self):
53
- print("LangGraphAgent initializing...")
54
  if not OPENROUTER_API_KEY:
55
  raise ValueError("OPENROUTER_API_KEY is not set. Cannot initialize LLM.")
56
 
57
- self.llm = ChatOpenAI(
58
- model="google/gemini-2.0-flash-001",
59
- api_key=OPENROUTER_API_KEY,
60
- base_url="https://openrouter.ai/api/v1"
61
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  self.tools_map = {tool.name: tool for tool in tools}
63
  self.graph = self._build_graph()
64
  print("LangGraphAgent initialized.")
@@ -127,41 +142,48 @@ class LangGraphAgent:
127
  return {"messages": tool_messages}
128
 
129
  def __call__(self, question: str) -> str:
130
- print(f"Agent received question (first 50 chars): {question[:100]}...")
131
- initial_state = {"messages": [HumanMessage(content=question)]}
 
 
 
 
 
 
 
 
132
 
133
- # The GAIA prompt example suggests not including "FINAL ANSWER" and just replying with the answer.
134
- # We need to ensure the LLM is prompted to provide a direct answer after tool use.
135
- # For simplicity in this template, we will take the last AI message content as the answer.
136
- # A more robust solution might involve a specific "final answer" node or prompt engineering.
137
-
138
  final_graph_state = None
139
  try:
140
  for event in self.graph.stream(initial_state, {"recursion_limit": 100}): # Added recursion limit
141
- # print(f"Graph event: {event}") # For debugging stream
142
  if END in event:
143
  final_graph_state = event[END]
144
  break
145
- # Update final_graph_state with the latest state from any node
146
- # This ensures we have the latest messages even if END is not directly reached by llm
147
- # (e.g. if recursion limit is hit)
148
  for key in event:
149
  if key != END:
150
  final_graph_state = event[key]
151
 
152
  if final_graph_state and final_graph_state["messages"]:
153
- # Get the last AI message as the answer
154
  for msg in reversed(final_graph_state["messages"]):
155
  if isinstance(msg, AIMessage) and not msg.tool_calls:
156
  answer = msg.content.strip()
157
- # Ensure no "FINAL ANSWER:" prefix as per GAIA instructions
158
- if answer.upper().startswith("FINAL ANSWER:"):
159
- answer = answer[len("FINAL ANSWER:"):].strip()
 
 
 
 
 
 
 
 
 
 
 
160
  print(f"Agent returning answer: {answer}")
161
  return answer
162
- # Fallback if no suitable AI message is found
163
  print("No suitable AI message found for final answer. Returning last message content.")
164
- # This might be a tool call or an intermediate thought, not ideal.
165
  return str(final_graph_state["messages"][-1].content) if final_graph_state["messages"] else "Error: No messages in final state."
166
  else:
167
  print("Error: Agent did not reach a final state or no messages found.")
@@ -199,7 +221,8 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
199
  submit_url = f"{api_url}/submit"
200
 
201
  try:
202
- agent = LangGraphAgent()
 
203
  except Exception as e:
204
  print(f"Error instantiating agent: {e}")
205
  return f"Error initializing agent: {e}", None
 
15
 
16
  from langgraph.graph import StateGraph, END
17
 
18
+ from langchain_core.messages import BaseMessage, FunctionMessage, HumanMessage, AIMessage, SystemMessage
19
  from langchain_openai import ChatOpenAI
20
 
21
  # --- Constants ---
 
49
  next_action: Optional[str] # To decide if we need to call tools or respond
50
 
51
  class LangGraphAgent:
52
+ def __init__(self, llm_choice: str = "gemini"):
53
+ print(f"LangGraphAgent initializing with {llm_choice}...")
54
  if not OPENROUTER_API_KEY:
55
  raise ValueError("OPENROUTER_API_KEY is not set. Cannot initialize LLM.")
56
 
57
+ if llm_choice == "llama":
58
+ self.llm = ChatOpenAI(
59
+ model="meta-llama/llama-3.1-8b-instruct:free",
60
+ api_key=OPENROUTER_API_KEY,
61
+ base_url="https://openrouter.ai/api/v1",
62
+ temperature=0.1, # Llama models can be sensitive to temperature
63
+ # max_tokens=150 # Llama 8B might benefit from a smaller max_token for concise answers
64
+ )
65
+ print("Initialized Llama 3.1 8B Instruct.")
66
+ elif llm_choice == "gemini":
67
+ self.llm = ChatOpenAI(
68
+ model="google/gemini-2.0-flash-001",
69
+ api_key=OPENROUTER_API_KEY,
70
+ base_url="https://openrouter.ai/api/v1",
71
+ temperature=0.1 # Adding temperature for consistency
72
+ )
73
+ print("Initialized Gemini 2.0 Flash.")
74
+ else:
75
+ raise ValueError(f"Unsupported LLM choice: {llm_choice}. Choose 'gemini' or 'llama'.")
76
+
77
  self.tools_map = {tool.name: tool for tool in tools}
78
  self.graph = self._build_graph()
79
  print("LangGraphAgent initialized.")
 
142
  return {"messages": tool_messages}
143
 
144
  def __call__(self, question: str) -> str:
145
+ print(f"Agent received question (first 100 chars): {question[:100]}...")
146
+
147
+ system_prompt = (
148
+ "You are an AI assistant designed to answer questions concisely. "
149
+ "Your goal is to provide only the direct answer to the question, without any additional explanations, conversation, or prefixes like 'FINAL ANSWER:'. "
150
+ "For example, if the question is 'What is the capital of France?', you should respond with 'Paris'. "
151
+ "If the question asks for a list, provide it comma-separated, e.g., 'apple, banana, cherry'. "
152
+ "If the question asks for a number, provide only the number, e.g., '42'."
153
+ )
154
+ initial_state = {"messages": [SystemMessage(content=system_prompt), HumanMessage(content=question)]}
155
 
 
 
 
 
 
156
  final_graph_state = None
157
  try:
158
  for event in self.graph.stream(initial_state, {"recursion_limit": 100}): # Added recursion limit
 
159
  if END in event:
160
  final_graph_state = event[END]
161
  break
 
 
 
162
  for key in event:
163
  if key != END:
164
  final_graph_state = event[key]
165
 
166
  if final_graph_state and final_graph_state["messages"]:
 
167
  for msg in reversed(final_graph_state["messages"]):
168
  if isinstance(msg, AIMessage) and not msg.tool_calls:
169
  answer = msg.content.strip()
170
+ # Remove common prefixes that LLMs might add despite instructions
171
+ prefixes_to_remove = [
172
+ "FINAL ANSWER:", "The answer is", "Here is the answer:",
173
+ "The final answer is", "Answer:", "Solution:"
174
+ ]
175
+ for prefix in prefixes_to_remove:
176
+ if answer.upper().startswith(prefix.upper()):
177
+ answer = answer[len(prefix):].strip()
178
+
179
+ # Remove potential quotation marks if the answer is a single word/phrase
180
+ if len(answer.split()) < 5: # Heuristic for short answers
181
+ if answer.startswith(('"', "'")) and answer.endswith(('"', "'")):
182
+ answer = answer[1:-1]
183
+
184
  print(f"Agent returning answer: {answer}")
185
  return answer
 
186
  print("No suitable AI message found for final answer. Returning last message content.")
 
187
  return str(final_graph_state["messages"][-1].content) if final_graph_state["messages"] else "Error: No messages in final state."
188
  else:
189
  print("Error: Agent did not reach a final state or no messages found.")
 
221
  submit_url = f"{api_url}/submit"
222
 
223
  try:
224
+ # Default to Llama for now, can be made configurable later (e.g., via Gradio input)
225
+ agent = LangGraphAgent(llm_choice="llama")
226
  except Exception as e:
227
  print(f"Error instantiating agent: {e}")
228
  return f"Error initializing agent: {e}", None