jtan4albany commited on
Commit
8482b8b
·
verified ·
1 Parent(s): b01424b

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +309 -187
agent.py CHANGED
@@ -1,231 +1,353 @@
1
  import requests
2
  import json
3
- from transformers import Tool
4
- from huggingface_hub import login
5
- import os
6
  import re
7
- from typing import Dict, Any
 
 
 
 
 
8
 
9
- class WikipediaSearchTool(Tool):
10
- name = "wikipedia_search"
11
- description = "Search Wikipedia for information about a specific topic"
12
  inputs = {
13
- "query": {
14
- "type": "text",
15
- "description": "The search query for Wikipedia"
16
- }
17
  }
18
  output_type = "text"
19
 
20
- def forward(self, query: str) -> str:
21
- """Search Wikipedia using the API"""
 
 
 
 
 
 
 
 
22
  try:
23
- # Use Wikipedia API to search
24
- search_url = "https://en.wikipedia.org/api/rest_v1/page/summary/"
25
- # Clean the query
26
  clean_query = query.replace(" ", "_")
 
27
 
28
- response = requests.get(f"{search_url}{clean_query}")
29
  if response.status_code == 200:
30
  data = response.json()
31
- return f"Title: {data.get('title', '')}\nSummary: {data.get('extract', '')}"
32
- else:
33
- # Try search API if direct lookup fails
34
- search_api_url = "https://en.wikipedia.org/w/api.php"
35
- params = {
36
- 'action': 'query',
37
- 'format': 'json',
38
- 'list': 'search',
39
- 'srsearch': query,
40
- 'srlimit': 3
41
- }
 
 
 
 
 
 
 
42
 
43
- search_response = requests.get(search_api_url, params=params)
44
- if search_response.status_code == 200:
45
- search_data = search_response.json()
46
- results = []
47
- for result in search_data['query']['search'][:2]:
48
- title = result['title']
49
- snippet = result['snippet']
50
- results.append(f"Title: {title}\nSnippet: {snippet}")
51
- return "\n\n".join(results)
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- return f"No information found for query: {query}"
 
 
 
54
  except Exception as e:
55
- return f"Error searching Wikipedia: {str(e)}"
56
 
57
- class WebSearchTool(Tool):
58
- name = "web_search"
59
- description = "Search the web for current information"
60
- inputs = {
61
- "query": {
62
- "type": "text",
63
- "description": "The search query"
64
- }
65
- }
66
- output_type = "text"
67
-
68
- def forward(self, query: str) -> str:
69
- """Search the web using a search API or fallback method"""
70
  try:
71
- # For this implementation, we'll focus on Wikipedia since the question specifically mentions it
72
- wiki_tool = WikipediaSearchTool()
73
- return wiki_tool.forward(query)
74
- except Exception as e:
75
- return f"Error in web search: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- class MusicAgent:
78
  def __init__(self):
79
- self.tools = {
80
- "wikipedia_search": WikipediaSearchTool(),
81
- "web_search": WebSearchTool()
82
- }
83
-
84
- def extract_years_from_text(self, text: str, start_year: int, end_year: int) -> list:
85
- """Extract years within the specified range from text"""
86
- year_pattern = r'\b(19|20)\d{2}\b'
87
- years = re.findall(year_pattern, text)
88
- valid_years = []
89
- for match in re.finditer(year_pattern, text):
 
 
 
 
 
 
 
 
 
 
 
90
  year = int(match.group())
91
  if start_year <= year <= end_year:
92
- valid_years.append(year)
93
- return valid_years
94
-
95
- def extract_albums_from_text(self, text: str) -> list:
96
- """Extract album information from text"""
97
- albums = []
98
- # Look for common album indicators
99
- album_patterns = [
100
- r'album[s]?\s*[":]\s*([^,\n\.]+)',
101
- r'released\s+([^,\n\.]+?)\s+in\s+(\d{4})',
102
- r'(\d{4})[:\s]+([^,\n\.]+)',
103
- r'"([^"]+)"\s*\((\d{4})\)',
104
  ]
105
 
106
- for pattern in album_patterns:
107
- matches = re.findall(pattern, text, re.IGNORECASE)
108
- albums.extend(matches)
 
 
 
109
 
110
- return albums
111
-
112
- def count_studio_albums(self, artist_name: str, start_year: int, end_year: int) -> int:
113
- """Count studio albums for an artist within a year range"""
114
- try:
115
- # Search for the artist's discography
116
- discography_queries = [
117
- f"{artist_name} discography",
118
- f"{artist_name} studio albums",
119
- f"{artist_name} albums {start_year}-{end_year}",
120
- f"{artist_name} complete discography"
121
- ]
122
-
123
- all_text = ""
124
- for query in discography_queries:
125
- try:
126
- result = self.tools["wikipedia_search"].forward(query)
127
- all_text += result + "\n"
128
- except:
129
- continue
130
-
131
- if not all_text.strip():
132
- return 0
 
 
 
 
 
 
 
 
 
 
133
 
134
- # Count albums within the year range
135
- # Look for year patterns and album mentions
136
- year_pattern = r'\b(19|20)\d{2}\b'
137
- years_in_text = re.findall(year_pattern, all_text)
138
 
139
- # Simple heuristic: count unique years in range that likely represent album releases
140
- valid_years = set()
141
- for year_match in re.finditer(year_pattern, all_text):
142
- year = int(year_match.group())
143
- if start_year <= year <= end_year:
144
- # Check if this year is associated with album context
145
- context_start = max(0, year_match.start() - 100)
146
- context_end = min(len(all_text), year_match.end() + 100)
147
- context = all_text[context_start:context_end].lower()
 
148
 
149
- album_keywords = ['album', 'studio', 'released', 'record', 'disc']
150
- if any(keyword in context for keyword in album_keywords):
151
- valid_years.add(year)
152
-
153
- return len(valid_years)
154
 
155
- except Exception as e:
156
- print(f"Error counting albums: {str(e)}")
157
- return 0
158
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  def answer_question(self, question: str) -> str:
160
- """Answer a question using available tools"""
161
  try:
 
162
  question_lower = question.lower()
163
 
164
- # Check if this is about Mercedes Sosa albums
165
- if "mercedes sosa" in question_lower and "studio albums" in question_lower:
166
- # Extract year range from question
167
- year_matches = re.findall(r'\b(19|20)\d{2}\b', question)
168
- if len(year_matches) >= 2:
169
- start_year = int(year_matches[0])
170
- end_year = int(year_matches[1])
171
- else:
172
- start_year = 2000
173
- end_year = 2009
174
-
175
- count = self.count_studio_albums("Mercedes Sosa", start_year, end_year)
176
-
177
- # If we got 0, try alternative searches
178
- if count == 0:
179
- # Try more specific searches
180
- specific_queries = [
181
- "Mercedes Sosa 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 albums",
182
- "Mercedes Sosa studio albums 2000s"
183
- ]
184
-
185
- for query in specific_queries:
186
- try:
187
- result = self.tools["wikipedia_search"].forward(query)
188
- # Manual check for known albums in that period
189
- if any(year in result for year in ["2000", "2001", "2002", "2003", "2004", "2005", "2006", "2007", "2008", "2009"]):
190
- # Mercedes Sosa had limited studio album releases in 2000-2009
191
- # Based on typical discography patterns, estimate
192
- return "2"
193
- except:
194
- continue
195
-
196
- return str(max(count, 1)) # Ensure at least 1 if we found some evidence
197
 
198
- # For other questions, use general search
199
- search_result = self.tools["wikipedia_search"].forward(question)
200
 
201
- # Try to extract a simple answer
202
- if "how many" in question_lower:
203
- numbers = re.findall(r'\b\d+\b', search_result)
204
- if numbers:
205
- return numbers[0]
206
 
207
- # Return first meaningful sentence
208
- sentences = search_result.split('.')
209
- for sentence in sentences[:3]:
210
- if len(sentence.strip()) > 10:
211
- return sentence.strip()
212
-
213
- return "Unable to determine answer from available information"
214
 
 
 
 
 
215
  except Exception as e:
216
- print(f"Error answering question: {str(e)}")
217
- return "Error processing question"
218
 
219
- # Initialize the agent
220
- agent = MusicAgent()
221
 
222
  def answer_question(question: str) -> str:
223
  """Main function to answer questions"""
224
- return agent.answer_question(question)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
- # Test the specific question
227
  if __name__ == "__main__":
228
- test_question = "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia."
229
- result = answer_question(test_question)
230
- print(f"Question: {test_question}")
231
- print(f"Answer: {result}")
 
 
 
 
 
 
 
 
 
 
1
  import requests
2
  import json
 
 
 
3
  import re
4
+ import os
5
+ import math
6
+ from typing import Dict, Any, List, Union
7
+ from datetime import datetime, timedelta
8
+ import urllib.parse
9
+ from transformers import Tool
10
 
11
+ class AdvancedSearchTool(Tool):
12
+ name = "advanced_search"
13
+ description = "Advanced search tool for Wikipedia and web content"
14
  inputs = {
15
+ "query": {"type": "text", "description": "The search query"},
16
+ "search_type": {"type": "text", "description": "Type of search: 'wikipedia', 'general'"}
 
 
17
  }
18
  output_type = "text"
19
 
20
+ def forward(self, query: str, search_type: str = "wikipedia") -> str:
21
+ try:
22
+ if search_type == "wikipedia":
23
+ return self._search_wikipedia(query)
24
+ else:
25
+ return self._search_wikipedia(query) # Fallback to Wikipedia
26
+ except Exception as e:
27
+ return f"Search error: {str(e)}"
28
+
29
+ def _search_wikipedia(self, query: str) -> str:
30
  try:
31
+ # Try direct page lookup first
 
 
32
  clean_query = query.replace(" ", "_")
33
+ summary_url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{clean_query}"
34
 
35
+ response = requests.get(summary_url, timeout=10)
36
  if response.status_code == 200:
37
  data = response.json()
38
+ extract = data.get('extract', '')
39
+ if extract and len(extract) > 50:
40
+ return f"Title: {data.get('title', '')}\nContent: {extract}"
41
+
42
+ # Search API if direct lookup fails
43
+ search_url = "https://en.wikipedia.org/w/api.php"
44
+ search_params = {
45
+ 'action': 'query',
46
+ 'format': 'json',
47
+ 'list': 'search',
48
+ 'srsearch': query,
49
+ 'srlimit': 5
50
+ }
51
+
52
+ search_response = requests.get(search_url, params=search_params, timeout=10)
53
+ if search_response.status_code == 200:
54
+ search_data = search_response.json()
55
+ results = []
56
 
57
+ for result in search_data['query']['search'][:3]:
58
+ title = result['title']
59
+ # Get page content
60
+ page_params = {
61
+ 'action': 'query',
62
+ 'format': 'json',
63
+ 'titles': title,
64
+ 'prop': 'extracts',
65
+ 'exintro': True,
66
+ 'explaintext': True,
67
+ 'exsectionformat': 'plain'
68
+ }
69
+
70
+ page_response = requests.get(search_url, params=page_params, timeout=10)
71
+ if page_response.status_code == 200:
72
+ page_data = page_response.json()
73
+ pages = page_data.get('query', {}).get('pages', {})
74
+ for page_id, page_info in pages.items():
75
+ extract = page_info.get('extract', '')
76
+ if extract:
77
+ results.append(f"Title: {title}\nContent: {extract[:1000]}")
78
+ break
79
 
80
+ return "\n\n".join(results) if results else f"No detailed results found for: {query}"
81
+
82
+ return f"No Wikipedia results found for: {query}"
83
+
84
  except Exception as e:
85
+ return f"Wikipedia search error: {str(e)}"
86
 
87
+ class MathCalculator:
88
+ @staticmethod
89
+ def evaluate_expression(expression: str) -> Union[float, int, str]:
90
+ """Safely evaluate mathematical expressions"""
 
 
 
 
 
 
 
 
 
91
  try:
92
+ # Clean the expression
93
+ expression = re.sub(r'[^\d\+\-\*\/\.\(\)\s]', '', expression)
94
+ if not expression.strip():
95
+ return "Invalid expression"
96
+
97
+ # Use eval cautiously with limited scope
98
+ result = eval(expression, {"__builtins__": {}}, {
99
+ "abs": abs, "round": round, "min": min, "max": max,
100
+ "sum": sum, "len": len, "pow": pow, "sqrt": math.sqrt,
101
+ "sin": math.sin, "cos": math.cos, "tan": math.tan,
102
+ "log": math.log, "exp": math.exp, "pi": math.pi, "e": math.e
103
+ })
104
+
105
+ # Return integer if it's a whole number
106
+ if isinstance(result, float) and result.is_integer():
107
+ return int(result)
108
+ return result
109
+ except:
110
+ return "Calculation error"
111
 
112
+ class ComprehensiveAgent:
113
  def __init__(self):
114
+ self.search_tool = AdvancedSearchTool()
115
+ self.calculator = MathCalculator()
116
+
117
+ def extract_numbers(self, text: str) -> List[Union[int, float]]:
118
+ """Extract numbers from text"""
119
+ numbers = []
120
+ # Find integers and floats
121
+ for match in re.finditer(r'\b\d+(?:\.\d+)?\b', text):
122
+ try:
123
+ num_str = match.group()
124
+ if '.' in num_str:
125
+ numbers.append(float(num_str))
126
+ else:
127
+ numbers.append(int(num_str))
128
+ except:
129
+ continue
130
+ return numbers
131
+
132
+ def extract_years(self, text: str, start_year: int = 1900, end_year: int = 2025) -> List[int]:
133
+ """Extract years within a reasonable range"""
134
+ years = []
135
+ for match in re.finditer(r'\b(19|20)\d{2}\b', text):
136
  year = int(match.group())
137
  if start_year <= year <= end_year:
138
+ years.append(year)
139
+ return list(set(years)) # Remove duplicates
140
+
141
+ def answer_counting_question(self, question: str, context: str) -> str:
142
+ """Handle questions that ask 'how many'"""
143
+ question_lower = question.lower()
144
+
145
+ # Extract what we're counting
146
+ counting_patterns = [
147
+ r'how many (.*?) (?:were|are|did|have|has)',
148
+ r'how many (.*?)(?:\?|$)',
149
+ r'number of (.*?)(?:\?|$)'
150
  ]
151
 
152
+ counting_target = ""
153
+ for pattern in counting_patterns:
154
+ match = re.search(pattern, question_lower)
155
+ if match:
156
+ counting_target = match.group(1).strip()
157
+ break
158
 
159
+ if not counting_target:
160
+ return "Could not identify counting target"
161
+
162
+ # Look for numbers in context that might be the answer
163
+ numbers = self.extract_numbers(context)
164
+
165
+ # Special handling for common counting scenarios
166
+ if "albums" in counting_target:
167
+ return self._count_albums(question, context)
168
+ elif "years" in counting_target or "year" in counting_target:
169
+ years = self.extract_years(context)
170
+ if years:
171
+ return str(len(years))
172
+ elif "countries" in counting_target or "states" in counting_target:
173
+ # Look for country/state names or numbers
174
+ if numbers:
175
+ return str(numbers[0])
176
+
177
+ # Default: return first reasonable number found
178
+ reasonable_numbers = [n for n in numbers if 0 <= n <= 1000]
179
+ if reasonable_numbers:
180
+ return str(reasonable_numbers[0])
181
+
182
+ return "Unable to determine count"
183
+
184
+ def _count_albums(self, question: str, context: str) -> str:
185
+ """Specifically handle album counting questions"""
186
+ # Extract years from question
187
+ years_in_question = self.extract_years(question)
188
+
189
+ if len(years_in_question) >= 2:
190
+ start_year = min(years_in_question)
191
+ end_year = max(years_in_question)
192
 
193
+ # Count years in the context that fall within range
194
+ context_years = self.extract_years(context, start_year, end_year)
 
 
195
 
196
+ # Look for album-related keywords near years
197
+ album_count = 0
198
+ for year in context_years:
199
+ year_str = str(year)
200
+ year_pos = context.lower().find(year_str)
201
+ if year_pos != -1:
202
+ # Check surrounding context for album keywords
203
+ start_context = max(0, year_pos - 200)
204
+ end_context = min(len(context), year_pos + 200)
205
+ surrounding = context[start_context:end_context].lower()
206
 
207
+ if any(word in surrounding for word in ['album', 'studio', 'released', 'record']):
208
+ album_count += 1
 
 
 
209
 
210
+ return str(album_count) if album_count > 0 else "1"
211
+
212
+ # Fallback: look for explicit numbers
213
+ numbers = self.extract_numbers(context)
214
+ small_numbers = [n for n in numbers if 0 <= n <= 50]
215
+ return str(small_numbers[0]) if small_numbers else "0"
216
+
217
+ def answer_calculation_question(self, question: str) -> str:
218
+ """Handle mathematical calculation questions"""
219
+ # Extract mathematical expressions
220
+ math_patterns = [
221
+ r'(\d+(?:\.\d+)?)\s*[\+\-\*\/]\s*(\d+(?:\.\d+)?)',
222
+ r'what is\s+(.+?)(?:\?|$)',
223
+ r'calculate\s+(.+?)(?:\?|$)'
224
+ ]
225
+
226
+ for pattern in math_patterns:
227
+ match = re.search(pattern, question.lower())
228
+ if match:
229
+ expression = match.group(1) if len(match.groups()) == 1 else f"{match.group(1)} {match.group(2)}"
230
+ result = self.calculator.evaluate_expression(expression)
231
+ if result != "Calculation error":
232
+ return str(result)
233
+
234
+ return "Could not parse mathematical expression"
235
+
236
+ def answer_factual_question(self, question: str) -> str:
237
+ """Handle general factual questions"""
238
+ # Search for information
239
+ search_result = self.search_tool.forward(question, "wikipedia")
240
+
241
+ if "error" in search_result.lower():
242
+ return "Information not available"
243
+
244
+ # Extract potential answers based on question type
245
+ question_lower = question.lower()
246
+
247
+ if question_lower.startswith("when"):
248
+ # Look for years or dates
249
+ years = self.extract_years(search_result)
250
+ if years:
251
+ return str(years[0])
252
+
253
+ elif question_lower.startswith("where"):
254
+ # Look for place names (simplified)
255
+ sentences = search_result.split('.')
256
+ for sentence in sentences[:3]:
257
+ if any(word in sentence.lower() for word in ['located', 'in', 'at', 'city', 'country']):
258
+ return sentence.strip()[:100]
259
+
260
+ elif question_lower.startswith("who"):
261
+ # Return first meaningful sentence
262
+ sentences = search_result.split('.')
263
+ for sentence in sentences[:2]:
264
+ if len(sentence.strip()) > 20:
265
+ return sentence.strip()[:100]
266
+
267
+ elif question_lower.startswith("what"):
268
+ # Return definition or explanation
269
+ sentences = search_result.split('.')
270
+ for sentence in sentences[:2]:
271
+ if len(sentence.strip()) > 30:
272
+ return sentence.strip()[:150]
273
+
274
+ # Default: return first substantial sentence
275
+ sentences = search_result.split('.')
276
+ for sentence in sentences[:3]:
277
+ if len(sentence.strip()) > 20:
278
+ return sentence.strip()[:100]
279
+
280
+ return "Answer not found"
281
+
282
  def answer_question(self, question: str) -> str:
283
+ """Main method to answer various types of questions"""
284
  try:
285
+ question = question.strip()
286
  question_lower = question.lower()
287
 
288
+ # Handle different question types
289
+ if question_lower.startswith("how many"):
290
+ # Get relevant context first
291
+ search_context = self.search_tool.forward(question, "wikipedia")
292
+ return self.answer_counting_question(question, search_context)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
 
294
+ elif any(op in question for op in ['+', '-', '*', '/', 'calculate', 'what is']):
295
+ return self.answer_calculation_question(question)
296
 
297
+ elif question_lower.startswith(("when", "where", "who", "what", "which")):
298
+ return self.answer_factual_question(question)
 
 
 
299
 
300
+ elif "year" in question_lower and "born" in question_lower:
301
+ search_result = self.search_tool.forward(question, "wikipedia")
302
+ years = self.extract_years(search_result)
303
+ return str(years[0]) if years else "Year not found"
 
 
 
304
 
305
+ else:
306
+ # General question handling
307
+ return self.answer_factual_question(question)
308
+
309
  except Exception as e:
310
+ print(f"Error processing question: {str(e)}")
311
+ return "Processing error"
312
 
313
+ # Initialize the comprehensive agent
314
+ agent = ComprehensiveAgent()
315
 
316
  def answer_question(question: str) -> str:
317
  """Main function to answer questions"""
318
+ try:
319
+ result = agent.answer_question(question)
320
+ # Ensure result is clean and concise
321
+ if isinstance(result, str):
322
+ result = result.strip()
323
+ # Remove common prefixes that might interfere with exact matching
324
+ prefixes_to_remove = [
325
+ "the answer is ", "answer: ", "result: ", "final answer: ",
326
+ "title: ", "content: "
327
+ ]
328
+ result_lower = result.lower()
329
+ for prefix in prefixes_to_remove:
330
+ if result_lower.startswith(prefix):
331
+ result = result[len(prefix):].strip()
332
+ break
333
+
334
+ return result
335
+ except Exception as e:
336
+ print(f"Error in answer_question: {str(e)}")
337
+ return "Error processing question"
338
 
339
+ # Test with various question types
340
  if __name__ == "__main__":
341
+ test_questions = [
342
+ "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia.",
343
+ "What is 15 + 27?",
344
+ "When was Albert Einstein born?",
345
+ "Where is the Eiffel Tower located?",
346
+ "How many continents are there?"
347
+ ]
348
+
349
+ for question in test_questions:
350
+ result = answer_question(question)
351
+ print(f"Q: {question}")
352
+ print(f"A: {result}")
353
+ print("-" * 50)