alisamak commited on
Commit
7db7bee
·
verified ·
1 Parent(s): 40bcb16

Update LG_agent.py

Browse files
Files changed (1) hide show
  1. LG_agent.py +51 -8
LG_agent.py CHANGED
@@ -7,6 +7,7 @@ from langchain_openai import ChatOpenAI
7
  from tools import all_tools
8
  import inspect
9
  import os
 
10
 
11
  # 1. Setup once
12
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
@@ -25,6 +26,35 @@ chat_with_tools = chat.bind_tools(all_tools)
25
  class AgentState(TypedDict):
26
  messages: Annotated[list[AnyMessage], add_messages]
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # 3. Assistant node
29
 
30
  def assistant(state: AgentState):
@@ -114,17 +144,30 @@ def build_graph(max_steps: int = 5):
114
 
115
  # 5. BasicAgent class
116
 
117
- class BasicAgent:
 
 
 
 
 
 
 
 
 
 
 
 
118
  def __init__(self, max_steps: int = 5):
119
  self.graph = build_graph(max_steps)
 
 
 
 
 
 
 
 
120
 
121
- def __call__(self, question: str) -> str:
122
- response = self.graph({"messages": [HumanMessage(content=question)]})
123
- if response.get("messages"):
124
- final_message = response["messages"][-1]
125
- return final_message.content if hasattr(final_message, "content") else "No final message."
126
- else:
127
- return "No response."
128
 
129
  if __name__ == "__main__":
130
  agent = BasicAgent()
 
7
  from tools import all_tools
8
  import inspect
9
  import os
10
+ import re
11
 
12
  # 1. Setup once
13
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
 
26
  class AgentState(TypedDict):
27
  messages: Annotated[list[AnyMessage], add_messages]
28
 
29
+ def extract_gaia_answer(text: str) -> str:
30
+ """
31
+ Extract only the final answer from the LLM output using common patterns like:
32
+ - 'The answer is: ...'
33
+ - 'Answer: ...'
34
+ - bullet points
35
+ - Or just return raw answer if short and plain
36
+
37
+ Strips everything else to conform with GAIA expectations.
38
+ """
39
+ # Common patterns GAIA outputs follow
40
+ patterns = [
41
+ r"The answer is: (.+)",
42
+ r"Answer: (.+)",
43
+ r"^([a-z0-9,\s]+)$", # raw answer line
44
+ ]
45
+ for pattern in patterns:
46
+ match = re.search(pattern, text.strip(), re.IGNORECASE | re.MULTILINE)
47
+ if match:
48
+ return match.group(1).strip()
49
+
50
+ # Fallback: if it's a single short line, return it
51
+ lines = text.strip().splitlines()
52
+ if len(lines) == 1 and len(lines[0]) < 100:
53
+ return lines[0].strip()
54
+
55
+ # Fallback: return full (stripped) content
56
+ return text.strip()
57
+
58
  # 3. Assistant node
59
 
60
  def assistant(state: AgentState):
 
144
 
145
  # 5. BasicAgent class
146
 
147
+ # class BasicAgent:
148
+ # def __init__(self, max_steps: int = 5):
149
+ # self.graph = build_graph(max_steps)
150
+
151
+ # def __call__(self, question: str) -> str:
152
+ # response = self.graph({"messages": [HumanMessage(content=question)]})
153
+ # if response.get("messages"):
154
+ # final_message = response["messages"][-1]
155
+ # return final_message.content if hasattr(final_message, "content") else "No final message."
156
+ # else:
157
+ # return "No response."
158
+
159
+ def __call__(self, question: str) -> str:
160
  def __init__(self, max_steps: int = 5):
161
  self.graph = build_graph(max_steps)
162
+
163
+ response = self.graph({"messages": [HumanMessage(content=question)]})
164
+ if response.get("messages"):
165
+ final_message = response["messages"][-1]
166
+ raw_content = final_message.content if hasattr(final_message, "content") else "No final message."
167
+ return extract_gaia_answer(raw_content)
168
+ else:
169
+ return "No response."
170
 
 
 
 
 
 
 
 
171
 
172
  if __name__ == "__main__":
173
  agent = BasicAgent()