Kackle commited on
Commit
60f4659
·
verified ·
1 Parent(s): f152db2

multi search

Browse files
Files changed (1) hide show
  1. gemini_agent.py +245 -2
gemini_agent.py CHANGED
@@ -54,6 +54,10 @@ class GeminiAgent:
54
  # Check if question is about music discography or albums
55
  if self._is_discography_question(question):
56
  return await self._handle_discography_question(question)
 
 
 
 
57
 
58
  # Regular text-based question
59
  return await self._handle_text_question(question)
@@ -96,8 +100,30 @@ class GeminiAgent:
96
  # If it has a music term and either an artist term or a date range, it's likely a discography question
97
  return has_music_term and (has_artist_term or has_date_range)
98
 
99
- async def _google_search(self, query: str, num_results: int = 5) -> str:
100
- """Perform a Google search using the Custom Search API"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  if not self.google_search_api_key or not self.google_search_cx:
102
  print("Google Search API key or CX not configured, falling back to DuckDuckGo")
103
  return self.ddg_tool.run(query)
@@ -111,6 +137,14 @@ class GeminiAgent:
111
  'num': num_results
112
  }
113
 
 
 
 
 
 
 
 
 
114
  response = requests.get(url, params=params)
115
  if response.status_code != 200:
116
  print(f"Google Search API error: {response.status_code}")
@@ -127,6 +161,17 @@ class GeminiAgent:
127
  title = item.get('title', 'No title')
128
  snippet = item.get('snippet', 'No description')
129
  link = item.get('link', 'No link')
 
 
 
 
 
 
 
 
 
 
 
130
  formatted_results += f"Title: {title}\nDescription: {snippet}\nURL: {link}\n\n"
131
 
132
  return formatted_results
@@ -210,6 +255,204 @@ If the answer is a person's name, provide ONLY their first name as requested."""
210
 
211
  return answer
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  async def _handle_discography_question(self, question: str) -> str:
214
  """Handle questions about music discography with enhanced search capabilities"""
215
  print(f"Processing discography question: {question[:50]}...")
 
54
  # Check if question is about music discography or albums
55
  if self._is_discography_question(question):
56
  return await self._handle_discography_question(question)
57
+
58
+ # Check if question is about competitions, awards, or recipients
59
+ if self._is_competition_question(question):
60
+ return await self._handle_competition_question(question)
61
 
62
  # Regular text-based question
63
  return await self._handle_text_question(question)
 
100
  # If it has a music term and either an artist term or a date range, it's likely a discography question
101
  return has_music_term and (has_artist_term or has_date_range)
102
 
103
+ def _is_competition_question(self, question: str) -> bool:
104
+ """Determine if a question is about competitions, awards, or recipients"""
105
+ q = question.lower()
106
+ competition_patterns = [
107
+ "competition", "award", "prize", "medal", "recipient", "winner", "laureate",
108
+ "finalist", "champion", "trophy", "recognition", "honor", "honour", "nominee"
109
+ ]
110
+
111
+ # Check for competition-related terms
112
+ has_competition_term = any(pattern in q for pattern in competition_patterns)
113
+
114
+ # Check for specific patterns that indicate complex competition questions
115
+ complex_patterns = [
116
+ "first name", "last name", "nationality", "country", "no longer exists",
117
+ "century", "decade", "after\s+\d{4}", "before\s+\d{4}", "between\s+\d{4}",
118
+ "youngest", "oldest", "only", "ever", "never"
119
+ ]
120
+
121
+ has_complex_pattern = any(re.search(pattern, q) for pattern in complex_patterns)
122
+
123
+ return has_competition_term and has_complex_pattern
124
+
125
+ async def _google_search(self, query: str, num_results: int = 5, exact_terms: str = None, site_restrict: str = None) -> str:
126
+ """Perform a Google search using the Custom Search API with enhanced options"""
127
  if not self.google_search_api_key or not self.google_search_cx:
128
  print("Google Search API key or CX not configured, falling back to DuckDuckGo")
129
  return self.ddg_tool.run(query)
 
137
  'num': num_results
138
  }
139
 
140
+ # Add exact terms if provided
141
+ if exact_terms:
142
+ params['exactTerms'] = exact_terms
143
+
144
+ # Add site restriction if provided
145
+ if site_restrict:
146
+ params['siteSearch'] = site_restrict
147
+
148
  response = requests.get(url, params=params)
149
  if response.status_code != 200:
150
  print(f"Google Search API error: {response.status_code}")
 
161
  title = item.get('title', 'No title')
162
  snippet = item.get('snippet', 'No description')
163
  link = item.get('link', 'No link')
164
+
165
+ # Try to get more content if available
166
+ page_map = item.get('pagemap', {})
167
+ meta_desc = ""
168
+ if 'metatags' in page_map and page_map['metatags']:
169
+ meta_desc = page_map['metatags'][0].get('og:description', '')
170
+
171
+ # Add the meta description if it provides additional information
172
+ if meta_desc and meta_desc not in snippet:
173
+ snippet += " " + meta_desc
174
+
175
  formatted_results += f"Title: {title}\nDescription: {snippet}\nURL: {link}\n\n"
176
 
177
  return formatted_results
 
255
 
256
  return answer
257
 
258
+ async def _multi_search(self, queries: list, num_results: int = 5, include_sites: list = None) -> str:
259
+ """Perform multiple searches and combine the results with enhanced options"""
260
+ combined_results = ""
261
+
262
+ # Define authoritative sites for different domains
263
+ authoritative_sites = {
264
+ "music": ["grammy.org", "billboard.com", "allmusic.com", "musicbrainz.org"],
265
+ "competition": ["wikipedia.org", "britannica.com"],
266
+ "awards": ["nobelprize.org", "pulitzer.org", "oscars.org"],
267
+ "classical": ["classicalmusic.org", "gramophone.co.uk", "medici.tv"]
268
+ }
269
+
270
+ # Process each query
271
+ for i, query in enumerate(queries):
272
+ print(f"Searching for query {i+1}/{len(queries)}: {query[:50]}...")
273
+ try:
274
+ # Standard search
275
+ result = await self._google_search(query, num_results)
276
+ if result:
277
+ combined_results += f"=== Results for query: {query} ===\n{result}\n\n"
278
+
279
+ # If specific sites are provided, search those too
280
+ if include_sites:
281
+ for site in include_sites:
282
+ site_result = await self._google_search(query, num_results=3, site_restrict=site)
283
+ if site_result and "no results" not in site_result.lower():
284
+ combined_results += f"=== Results from {site} for: {query} ===\n{site_result}\n\n"
285
+
286
+ # For competition questions, try some authoritative sites
287
+ if "competition" in query.lower() or "award" in query.lower() or "prize" in query.lower():
288
+ for site in authoritative_sites["competition"] + authoritative_sites["awards"]:
289
+ site_result = await self._google_search(query, num_results=2, site_restrict=site)
290
+ if site_result and "no results" not in site_result.lower():
291
+ combined_results += f"=== Results from {site} for: {query} ===\n{site_result}\n\n"
292
+
293
+ # For classical music questions, try classical music sites
294
+ if "classical" in query.lower() or "conductor" in query.lower() or "orchestra" in query.lower():
295
+ for site in authoritative_sites["classical"]:
296
+ site_result = await self._google_search(query, num_results=2, site_restrict=site)
297
+ if site_result and "no results" not in site_result.lower():
298
+ combined_results += f"=== Results from {site} for: {query} ===\n{site_result}\n\n"
299
+
300
+ # Try exact term matching for key entities
301
+ key_terms = self._extract_key_terms(query)
302
+ if key_terms:
303
+ exact_result = await self._google_search(query, num_results=3, exact_terms=key_terms)
304
+ if exact_result and "no results" not in exact_result.lower():
305
+ combined_results += f"=== Results with exact match for '{key_terms}' ===\n{exact_result}\n\n"
306
+
307
+ except Exception as e:
308
+ print(f"Search failed for query {i+1}: {e}")
309
+
310
+ return combined_results
311
+
312
+ def _extract_key_terms(self, query: str) -> str:
313
+ """Extract key terms from a query for exact matching"""
314
+ # Extract competition names
315
+ competition_match = re.search(r'(\w+\s+Competition|\w+\s+Award|\w+\s+Prize)', query, re.IGNORECASE)
316
+ if competition_match:
317
+ return competition_match.group(1)
318
+
319
+ # Extract dates
320
+ date_match = re.search(r'(\d{4})', query)
321
+ if date_match:
322
+ return date_match.group(1)
323
+
324
+ # Extract countries
325
+ country_patterns = ["Soviet Union", "Yugoslavia", "Czechoslovakia", "East Germany"]
326
+ for country in country_patterns:
327
+ if country.lower() in query.lower():
328
+ return country
329
+
330
+ return ""
331
+
332
+ async def _handle_competition_question(self, question: str) -> str:
333
+ """Handle questions about competitions, awards, and recipients with advanced search"""
334
+ print(f"Processing competition question: {question[:50]}...")
335
+
336
+ # Extract key entities from the question
337
+ competition_name = ""
338
+ time_period = ""
339
+ nationality_info = ""
340
+
341
+ # Try to extract competition name
342
+ competition_patterns = [
343
+ r'(\w+\s+Competition)', # "Malko Competition"
344
+ r'(\w+\s+Award)', # "Nobel Award"
345
+ r'(\w+\s+Prize)' # "Pulitzer Prize"
346
+ ]
347
+
348
+ for pattern in competition_patterns:
349
+ match = re.search(pattern, question, re.IGNORECASE)
350
+ if match:
351
+ competition_name = match.group(1)
352
+ break
353
+
354
+ # Extract time period information
355
+ time_patterns = [
356
+ r'(\d{2}(?:st|nd|rd|th)\s+[Cc]entury)', # "20th Century"
357
+ r'(after\s+\d{4})', # "after 1977"
358
+ r'(before\s+\d{4})', # "before 1990"
359
+ r'(between\s+\d{4}\s+and\s+\d{4})' # "between 1977 and 2000"
360
+ ]
361
+
362
+ for pattern in time_patterns:
363
+ match = re.search(pattern, question, re.IGNORECASE)
364
+ if match:
365
+ time_period = match.group(1)
366
+ break
367
+
368
+ # Extract nationality information
369
+ if "nationality" in question.lower() or "country" in question.lower():
370
+ if "no longer exists" in question.lower():
371
+ nationality_info = "country that no longer exists"
372
+
373
+ # Construct specialized search queries
374
+ search_queries = []
375
+
376
+ # Generic competition queries
377
+ if competition_name:
378
+ base_query = f"{competition_name} winners list"
379
+ search_queries.append(base_query)
380
+
381
+ if time_period:
382
+ search_queries.append(f"{competition_name} winners {time_period}")
383
+
384
+ if nationality_info:
385
+ search_queries.append(f"{competition_name} winners {nationality_info}")
386
+
387
+ # For questions about countries that no longer exist, add general queries
388
+ if "no longer exists" in nationality_info:
389
+ # Add queries for common dissolved countries without hardcoding specific competitions
390
+ dissolved_countries = ["Soviet Union", "Yugoslavia", "Czechoslovakia", "East Germany"]
391
+ for country in dissolved_countries:
392
+ search_queries.append(f"{competition_name} winners from {country}")
393
+
394
+ # Add more specific queries
395
+ if time_period and nationality_info:
396
+ search_queries.append(f"{competition_name} winners {time_period} {nationality_info}")
397
+ else:
398
+ # If we couldn't extract competition name, use the original question
399
+ search_queries.append(question)
400
+
401
+ # Perform multiple searches with different queries
402
+ combined_context = await self._multi_search(search_queries)
403
+
404
+ # Also try Wikipedia for general information
405
+ wiki_context = ""
406
+ try:
407
+ if competition_name:
408
+ wiki_context = self.wiki_tool.run(competition_name)
409
+ print("Wikipedia search completed")
410
+ except Exception as e:
411
+ print(f"Wikipedia tool failed: {e}")
412
+
413
+ # Add Wikipedia context if available
414
+ if wiki_context and not any(x in wiki_context.lower() for x in ["not found", "no results", "does not contain"]):
415
+ combined_context += f"Wikipedia context: {wiki_context}\n\n"
416
+
417
+ # Create a specialized prompt for competition questions
418
+ prompt = f"""Based on the following search results, answer this question about a competition or award:
419
+
420
+ {combined_context}
421
+
422
+ Question: {question}
423
+
424
+ Analyze the search results carefully to find information about competition winners, their nationalities, and the time periods.
425
+ If the question asks about a country that no longer exists, look for winners from countries like the Soviet Union, Yugoslavia, Czechoslovakia, East Germany, etc.
426
+ If asked for a first name only, extract just the first name from the full name.
427
+
428
+ Provide ONLY the specific information requested with no explanations."""
429
+
430
+ await self._rate_limit()
431
+ response = self.model.generate_content(
432
+ prompt,
433
+ generation_config=genai.types.GenerationConfig(
434
+ max_output_tokens=100,
435
+ temperature=0.0
436
+ )
437
+ )
438
+ answer = response.text.strip()
439
+
440
+ # Clean up the answer
441
+ prefixes = ['The answer is', 'Based on', 'According to', 'The first name is', 'The recipient is']
442
+ for prefix in prefixes:
443
+ if answer.lower().startswith(prefix.lower()):
444
+ answer = answer[len(prefix):].strip()
445
+ if answer.startswith(','):
446
+ answer = answer[1:].strip()
447
+
448
+ # If the question asks for just a first name, extract it
449
+ if "first name" in question.lower():
450
+ name_parts = answer.split()
451
+ if name_parts:
452
+ answer = name_parts[0].rstrip(',.')
453
+
454
+ return answer
455
+
456
  async def _handle_discography_question(self, question: str) -> str:
457
  """Handle questions about music discography with enhanced search capabilities"""
458
  print(f"Processing discography question: {question[:50]}...")