Kackle commited on
Commit
3ffe515
·
verified ·
1 Parent(s): 7e06f4a

Update gemini_agent.py

Browse files
Files changed (1) hide show
  1. gemini_agent.py +60 -121
gemini_agent.py CHANGED
@@ -5,6 +5,8 @@ from excel_parser import ExcelParser
5
  import re
6
  import time
7
  import asyncio
 
 
8
  # Add LangChain tools for Wikipedia and DuckDuckGo
9
  from langchain_community.tools import DuckDuckGoSearchRun, WikipediaQueryRun
10
  from langchain_community.utilities import WikipediaAPIWrapper
@@ -15,10 +17,14 @@ class GeminiAgent:
15
  def __init__(self):
16
  print("GeminiAgent initialized.")
17
 
18
- # Get Google API key from environment variables
19
  api_key = os.getenv('GOOGLE_API_KEY')
20
  genai.configure(api_key=api_key)
21
 
 
 
 
 
22
  self.model = genai.GenerativeModel('gemini-2.0-flash')
23
  self.last_request_time = 0
24
  self.min_request_interval = 8.0 # 7 seconds between requests (10 per minute limit, with margin)
@@ -45,9 +51,9 @@ class GeminiAgent:
45
  if self._is_actor_or_show_question(question):
46
  return await self._handle_actor_show_question(question)
47
 
48
- # Check if question is about music, albums, or discography
49
- if self._is_music_question(question):
50
- return await self._handle_music_question(question)
51
 
52
  # Regular text-based question
53
  return await self._handle_text_question(question)
@@ -68,133 +74,66 @@ class GeminiAgent:
68
  ]
69
  return any(pattern in q for pattern in actor_show_patterns)
70
 
71
- def _is_music_question(self, question: str) -> bool:
72
- """Determine if a question is about music, albums, or discography"""
73
  q = question.lower()
74
  music_patterns = [
75
- "album", "albums", "discography", "song", "songs", "track", "tracks",
76
- "musician", "singer", "artist", "band", "composer", "recorded", "released",
77
- "studio album", "live album", "compilation", "single", "ep"
78
  ]
 
 
79
  # Check for music-related terms
80
  has_music_term = any(pattern in q for pattern in music_patterns)
 
 
81
  # Check for date ranges which are common in discography questions
82
- has_date_range = re.search(r'between\s+\d{4}\s+and\s+\d{4}', q) is not None
83
- return has_music_term or has_date_range
84
-
85
- async def _handle_music_question(self, question: str) -> str:
86
- """Handle questions about music, albums, and discographies with enhanced search"""
87
- print(f"Processing music/discography question: {question[:50]}...")
88
-
89
- # Always try both Wikipedia and DuckDuckGo for these questions
90
- wiki_context = ""
91
- ddg_context = ""
92
-
93
- # Extract date range if present
94
- date_range = re.search(r'between\s+(\d{4})\s+and\s+(\d{4})', question.lower())
95
- start_year = None
96
- end_year = None
97
- if date_range:
98
- start_year = int(date_range.group(1))
99
- end_year = int(date_range.group(2))
100
- else:
101
- # Check for other date range formats
102
- date_range = re.search(r'from\s+(\d{4})\s+to\s+(\d{4})', question.lower())
103
- if date_range:
104
- start_year = int(date_range.group(1))
105
- end_year = int(date_range.group(2))
106
- else:
107
- # Check for single year with inclusion indicator
108
- included_year = re.search(r'(\d{4})\s*\(included\)', question.lower())
109
- if included_year:
110
- end_year = int(included_year.group(1))
111
-
112
- # Specifically mention Wikipedia version if specified
113
- wiki_version = None
114
- if "version of" in question.lower() and "wikipedia" in question.lower():
115
- version_match = re.search(r'(\d{4})\s+version\s+of\s+\w+\s+wikipedia', question.lower())
116
- if version_match:
117
- wiki_version = version_match.group(1)
118
- print(f"Using Wikipedia version from {wiki_version}")
119
-
120
- # Construct a more specific search query for discography questions
121
- search_query = question
122
- if "how many" in question.lower() and "album" in question.lower():
123
- # Extract artist name - look for patterns like "by [Artist Name]" or just the name before "albums"
124
- artist_match = re.search(r'by\s+([\w\s]+)\s+between', question)
125
- if not artist_match:
126
- artist_match = re.search(r'([\w\s]+)\s+albums', question)
127
 
128
- if artist_match:
129
- artist_name = artist_match.group(1).strip()
130
- search_query = f"{artist_name} complete discography studio albums {start_year if start_year else ''} to {end_year if end_year else ''}"
131
- print(f"Enhanced search query: {search_query}")
132
-
133
  try:
134
- wiki_context = self.wiki_tool.run(search_query)
135
- print("Wikipedia search completed")
136
- except Exception as e:
137
- print(f"Wikipedia tool failed: {e}")
138
-
139
- try:
140
- ddg_context = self.ddg_tool.run(search_query)
141
- print("DuckDuckGo search completed")
142
- except Exception as e:
143
- print(f"DuckDuckGo tool failed: {e}")
144
-
145
- # Combine contexts if available
146
- combined_context = ""
147
- if wiki_context and not any(x in wiki_context.lower() for x in ["not found", "no results", "does not contain"]):
148
- combined_context += f"Wikipedia context: {wiki_context}\n\n"
149
- if ddg_context and not any(x in ddg_context.lower() for x in ["not found", "no results", "does not contain"]):
150
- combined_context += f"Web search context: {ddg_context}\n\n"
151
-
152
- # Create a specialized prompt for music/discography questions
153
- prompt = f"""Based on the following context, answer this question about music or discography:
154
-
155
- Question: {question}
156
-
157
- {combined_context}
158
-
159
- """
160
-
161
- # Add specific instructions for counting albums in a date range
162
- if "how many" in question.lower() and "album" in question.lower() and start_year and end_year:
163
- prompt += f"""Count ONLY the studio albums released between {start_year} and {end_year}, inclusive of both years.
164
 
165
- Provide ONLY the numeric count as your answer, with no additional text.
166
-
167
- Make sure to count each album only once, and only count studio albums (not compilations, live albums, or EPs) unless specifically asked for those."""
168
- else:
169
- prompt += "Provide ONLY the specific information requested. No explanations or additional context."
170
-
171
- await self._rate_limit()
172
- response = self.model.generate_content(
173
- prompt,
174
- generation_config=genai.types.GenerationConfig(
175
- max_output_tokens=50,
176
- temperature=0.0
177
- )
178
- )
179
- answer = response.text.strip()
180
-
181
- # Clean up the answer to extract just the number or information
182
- # Remove common prefixes
183
- prefixes = ['The answer is', 'Based on', 'According to', 'The number is', 'There were']
184
- for prefix in prefixes:
185
- if answer.lower().startswith(prefix.lower()):
186
- answer = answer[len(prefix):].strip()
187
- if answer.startswith(','):
188
- answer = answer[1:].strip()
189
-
190
- # If the question is asking for a count, extract just the number
191
- if "how many" in question.lower():
192
- # Try to extract just the number
193
- number_match = re.search(r'\b(\d+)\b', answer)
194
- if number_match:
195
- answer = number_match.group(1)
196
 
197
- return answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
  async def _handle_actor_show_question(self, question: str) -> str:
200
  """Handle questions about actors, TV shows, and movies with enhanced search"""
 
5
  import re
6
  import time
7
  import asyncio
8
+ import requests
9
+ import json
10
  # Add LangChain tools for Wikipedia and DuckDuckGo
11
  from langchain_community.tools import DuckDuckGoSearchRun, WikipediaQueryRun
12
  from langchain_community.utilities import WikipediaAPIWrapper
 
17
  def __init__(self):
18
  print("GeminiAgent initialized.")
19
 
20
+ # Get API keys from environment variables
21
  api_key = os.getenv('GOOGLE_API_KEY')
22
  genai.configure(api_key=api_key)
23
 
24
+ # Google Custom Search API keys
25
+ self.google_search_api_key = os.getenv('GOOGLE_SEARCH_API_KEY')
26
+ self.google_search_cx = os.getenv('GOOGLE_SEARCH_CX')
27
+
28
  self.model = genai.GenerativeModel('gemini-2.0-flash')
29
  self.last_request_time = 0
30
  self.min_request_interval = 8.0 # 7 seconds between requests (10 per minute limit, with margin)
 
51
  if self._is_actor_or_show_question(question):
52
  return await self._handle_actor_show_question(question)
53
 
54
+ # Check if question is about music discography or albums
55
+ if self._is_discography_question(question):
56
+ return await self._handle_discography_question(question)
57
 
58
  # Regular text-based question
59
  return await self._handle_text_question(question)
 
74
  ]
75
  return any(pattern in q for pattern in actor_show_patterns)
76
 
77
+ def _is_discography_question(self, question: str) -> bool:
78
+ """Determine if a question is about music discography or albums"""
79
  q = question.lower()
80
  music_patterns = [
81
+ "album", "albums", "discography", "studio album", "published", "released",
82
+ "recorded", "track", "tracks", "song", "songs", "single", "singles"
 
83
  ]
84
+ artist_patterns = ["musician", "singer", "artist", "band", "composer"]
85
+
86
  # Check for music-related terms
87
  has_music_term = any(pattern in q for pattern in music_patterns)
88
+ # Check for artist-related terms
89
+ has_artist_term = any(pattern in q for pattern in artist_patterns)
90
  # Check for date ranges which are common in discography questions
91
+ has_date_range = re.search(r'between\s+\d{4}\s+and\s+\d{4}', q) is not None or \
92
+ re.search(r'from\s+\d{4}\s+to\s+\d{4}', q) is not None or \
93
+ re.search(r'\d{4}\s*[-–]\s*\d{4}', q) is not None or \
94
+ re.search(r'\d{4}\s+to\s+\d{4}', q) is not None
95
+
96
+ # If it has a music term and either an artist term or a date range, it's likely a discography question
97
+ return has_music_term and (has_artist_term or has_date_range)
98
+
99
+ async def _google_search(self, query: str, num_results: int = 5) -> str:
100
+ """Perform a Google search using the Custom Search API"""
101
+ if not self.google_search_api_key or not self.google_search_cx:
102
+ print("Google Search API key or CX not configured, falling back to DuckDuckGo")
103
+ return self.ddg_tool.run(query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
 
 
 
 
 
105
  try:
106
+ url = "https://www.googleapis.com/customsearch/v1"
107
+ params = {
108
+ 'key': self.google_search_api_key,
109
+ 'cx': self.google_search_cx,
110
+ 'q': query,
111
+ 'num': num_results
112
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
+ response = requests.get(url, params=params)
115
+ if response.status_code != 200:
116
+ print(f"Google Search API error: {response.status_code}")
117
+ return self.ddg_tool.run(query) # Fall back to DuckDuckGo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
+ results = response.json()
120
+ if 'items' not in results:
121
+ print("No search results found")
122
+ return self.ddg_tool.run(query) # Fall back to DuckDuckGo
123
+
124
+ # Extract and format search results
125
+ formatted_results = ""
126
+ for item in results['items']:
127
+ title = item.get('title', 'No title')
128
+ snippet = item.get('snippet', 'No description')
129
+ link = item.get('link', 'No link')
130
+ formatted_results += f"Title: {title}\nDescription: {snippet}\nURL: {link}\n\n"
131
+
132
+ return formatted_results
133
+
134
+ except Exception as e:
135
+ print(f"Google Search API error: {str(e)}")
136
+ return self.ddg_tool.run(query) # Fall back to DuckDuckGo
137
 
138
  async def _handle_actor_show_question(self, question: str) -> str:
139
  """Handle questions about actors, TV shows, and movies with enhanced search"""