EtienneB commited on
Commit
2cf9218
·
1 Parent(s): 7398226

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +51 -57
agent.py CHANGED
@@ -76,32 +76,63 @@ def build_graph():
76
  # Bind tools to LLM
77
  llm_with_tools = llm.bind_tools(tools)
78
 
79
- def clean_answer(text):
80
- """Extract clean answer from LLM response"""
81
- if not text:
82
- return ""
83
 
84
- # Remove common prefixes and suffixes
85
- text = text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- # Remove common response patterns
88
- patterns_to_remove = [
89
- r'^(The answer is:?\s*)',
90
- r'^(Answer:?\s*)',
91
- r'^(Final answer:?\s*)',
92
- r'^(Result:?\s*)',
93
- r'(\s*is the answer\.?)$',
94
- r'(\s*\.)$'
 
95
  ]
96
 
97
- for pattern in patterns_to_remove:
98
- text = re.sub(pattern, '', text, flags=re.IGNORECASE)
 
99
 
100
- # Take only the first line if multiple lines
101
- first_line = text.split('\n')[0].strip()
 
102
 
103
- return first_line
104
-
105
  def assistant(state: MessagesState):
106
  messages_with_system_prompt = [sys_msg] + state["messages"]
107
  llm_response = llm_with_tools.invoke(messages_with_system_prompt)
@@ -127,43 +158,6 @@ def build_graph():
127
  # Compile graph
128
  return builder.compile()
129
 
130
-
131
- def is_valid_agent_output(output):
132
- """
133
- Checks if the output matches the required format:
134
- [{"task_id": ..., "submitted_answer": ...}]
135
- """
136
- try:
137
- parsed = json.loads(output.strip())
138
- if not isinstance(parsed, list):
139
- return False
140
-
141
- for item in parsed:
142
- if not isinstance(item, dict):
143
- return False
144
- if "task_id" not in item or "submitted_answer" not in item:
145
- return False
146
- return True
147
- except:
148
- return False
149
-
150
-
151
- def extract_flat_answer(output):
152
- """Extract properly formatted answer from output"""
153
- try:
154
- # Try to parse as JSON first
155
- parsed = json.loads(output.strip())
156
- if isinstance(parsed, list) and len(parsed) > 0:
157
- first_item = parsed[0]
158
- if isinstance(first_item, dict) and "task_id" in first_item and "submitted_answer" in first_item:
159
- return output # Already properly formatted
160
- except:
161
- pass
162
-
163
- # If not properly formatted, return as-is (fallback)
164
- return output
165
-
166
-
167
  # test
168
  if __name__ == "__main__":
169
  question = "What is 2 + 2?"
 
76
  # Bind tools to LLM
77
  llm_with_tools = llm.bind_tools(tools)
78
 
79
+ def clean_answer(answer: any) -> str:
80
+ """
81
+ Clean up the answer to remove common prefixes and formatting
82
+ that models often add but that can cause exact match failures.
83
 
84
+ Args:
85
+ answer: The raw answer from the model
86
+
87
+ Returns:
88
+ The cleaned answer as a string
89
+ """
90
+ # Convert non-string types to strings
91
+ if not isinstance(answer, str):
92
+ # Handle numeric types (float, int)
93
+ if isinstance(answer, float):
94
+ # Format floating point numbers properly
95
+ # Check if it's an integer value in float form (e.g., 12.0)
96
+ if answer.is_integer():
97
+ formatted_answer = str(int(answer))
98
+ else:
99
+ # For currency values that might need formatting
100
+ if abs(answer) >= 1000:
101
+ formatted_answer = f"${answer:,.2f}"
102
+ else:
103
+ formatted_answer = str(answer)
104
+ return formatted_answer
105
+ elif isinstance(answer, int):
106
+ return str(answer)
107
+ else:
108
+ # For any other type
109
+ return str(answer)
110
+
111
+ # Now we know answer is a string, so we can safely use string methods
112
+ # Normalize whitespace
113
+ answer = answer.strip()
114
 
115
+ # Remove common prefixes and formatting that models add
116
+ prefixes_to_remove = [
117
+ "The answer is ",
118
+ "Answer: ",
119
+ "Final answer: ",
120
+ "The result is ",
121
+ "To answer this question: ",
122
+ "Based on the information provided, ",
123
+ "According to the information: ",
124
  ]
125
 
126
+ for prefix in prefixes_to_remove:
127
+ if answer.startswith(prefix):
128
+ answer = answer[len(prefix):].strip()
129
 
130
+ # Remove quotes if they wrap the entire answer
131
+ if (answer.startswith('"') and answer.endswith('"')) or (answer.startswith("'") and answer.endswith("'")):
132
+ answer = answer[1:-1].strip()
133
 
134
+ return answer
135
+
136
  def assistant(state: MessagesState):
137
  messages_with_system_prompt = [sys_msg] + state["messages"]
138
  llm_response = llm_with_tools.invoke(messages_with_system_prompt)
 
158
  # Compile graph
159
  return builder.compile()
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  # test
162
  if __name__ == "__main__":
163
  question = "What is 2 + 2?"