parthsinha commited on
Commit
4796dc0
·
verified ·
1 Parent(s): 5af4193

Update chatbot_engine.py

Browse files
Files changed (1) hide show
  1. chatbot_engine.py +305 -278
chatbot_engine.py CHANGED
@@ -1,25 +1,46 @@
1
  import re
2
- from typing import Dict, List, Any, Tuple
 
 
3
  from data_processor import DataProcessor
4
  import utils
5
 
6
- class FetiiChatbot:
7
  """
8
- GPT-style chatbot that can answer questions about Fetii rideshare data.
 
9
  """
10
 
11
- def __init__(self, data_processor: DataProcessor):
12
- """Initialize the chatbot with a data processor."""
13
  self.data_processor = data_processor
14
  self.conversation_history = []
 
 
 
15
 
 
 
 
 
 
16
  self.query_patterns = {
 
 
 
 
 
 
 
 
 
 
 
17
  'location_stats': [
18
  r'how many.*(?:groups?|trips?).*(?:went to|to|from)\s+([^?]+?)(?:\s+(?:last|this|yesterday|today|week|month|year).*?)?[?.]?$',
19
  r'(?:trips?|groups?).*(?:to|from)\s+([^?]+?)(?:\s+(?:last|this|yesterday|today|week|month|year).*?)?[?.]?$',
20
  r'tell me about\s+([^?]+?)(?:\s+(?:last|this|yesterday|today|week|month|year).*?)?[?.]?$',
21
  r'stats for\s+([^?]+?)(?:\s+(?:last|this|yesterday|today|week|month|year).*?)?[?.]?$',
22
- r'(?:show me|find|search)\s+([^?]+?)(?:\s+(?:trips?|data|stats))?(?:\s+(?:last|this|yesterday|today|week|month|year).*?)?[?.]?$'
23
  ],
24
  'time_patterns': [
25
  r'when do.*groups?.*ride',
@@ -38,13 +59,6 @@ class FetiiChatbot:
38
  r'most popular.*locations?',
39
  r'busiest.*locations?',
40
  r'hottest spots?',
41
- r'show.*(?:pickup|drop-?off|locations?)',
42
- r'list.*locations?'
43
- ],
44
- 'demographics': [
45
- r'(\d+)[-–](\d+) year[- ]olds?',
46
- r'age group',
47
- r'demographics?'
48
  ],
49
  'general_stats': [
50
  r'how many total',
@@ -56,304 +70,340 @@ class FetiiChatbot:
56
  r'total trips'
57
  ]
58
  }
59
-
60
- self.time_patterns = [
61
- r'\s+(?:last|this|yesterday|today)\s+(?:week|month|year|night)',
62
- r'\s+(?:last|this)\s+(?:monday|tuesday|wednesday|thursday|friday|saturday|sunday)',
63
- r'\s+(?:in\s+)?(?:january|february|march|april|may|june|july|august|september|october|november|december)',
64
- r'\s+(?:last|this|next)\s+\w+',
65
- r'\s+(?:yesterday|today|tonight)',
66
- r'\s+\d{1,2}\/\d{1,2}\/\d{2,4}',
67
- r'\s+\d{1,2}-\d{1,2}-\d{2,4}'
68
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  def process_query(self, user_query: str) -> str:
71
  """Process a user query and return an appropriate response."""
72
- user_query = user_query.lower().strip()
73
 
74
  self.conversation_history.append({"role": "user", "content": user_query})
75
 
76
  try:
77
- query_type, params = self._parse_query(user_query)
78
- response = self._generate_response(query_type, params, user_query)
79
- self.conversation_history.append({"role": "assistant", "content": response})
 
 
 
 
 
 
80
 
 
 
 
81
  return response
82
 
83
  except Exception as e:
84
- error_response = ("I'm having trouble understanding that question. "
85
- "Try asking about specific locations, times, or group sizes. "
86
- "For example: 'How many groups went to The Aquarium on 6th?' or "
87
- "'What are the peak hours for large groups?'")
88
  return error_response
89
 
90
- def _clean_location_from_query(self, location_text: str) -> str:
91
- """Clean time references from location text."""
92
- cleaned = location_text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- for pattern in self.time_patterns:
95
- cleaned = re.sub(pattern, '', cleaned, flags=re.IGNORECASE)
96
 
97
- cleaned = re.sub(r'\s+', ' ', cleaned).strip()
 
 
98
 
99
- return cleaned
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  def _parse_query(self, query: str) -> Tuple[str, Dict[str, Any]]:
102
  """Parse the user query to determine intent and extract parameters."""
103
  params = {}
104
 
 
 
 
 
 
 
 
 
 
 
 
105
  for pattern in self.query_patterns['location_stats']:
106
  match = re.search(pattern, query, re.IGNORECASE)
107
  if match:
108
  location = match.group(1).strip()
109
- location = self._clean_location_from_query(location)
110
  if location:
111
  params['location'] = location
112
  return 'location_stats', params
113
 
 
114
  for pattern in self.query_patterns['time_patterns']:
115
  if re.search(pattern, query, re.IGNORECASE):
116
- group_match = re.search(r'(\d+)\+?', query)
117
- if group_match:
118
- params['min_group_size'] = int(group_match.group(1))
119
  return 'time_patterns', params
120
 
121
  for pattern in self.query_patterns['group_size']:
122
- match = re.search(pattern, query, re.IGNORECASE)
123
- if match:
124
- if match.groups():
125
- params['group_size'] = int(match.group(1))
126
  return 'group_size', params
127
 
128
  for pattern in self.query_patterns['top_locations']:
129
  if re.search(pattern, query, re.IGNORECASE):
130
- if 'pickup' in query or 'pick up' in query:
131
- params['location_type'] = 'pickup'
132
- elif 'drop' in query:
133
- params['location_type'] = 'dropoff'
134
- else:
135
- params['location_type'] = 'both'
136
  return 'top_locations', params
137
 
138
- for pattern in self.query_patterns['demographics']:
139
- match = re.search(pattern, query, re.IGNORECASE)
140
- if match and match.groups():
141
- if len(match.groups()) == 2:
142
- params['age_range'] = (int(match.group(1)), int(match.group(2)))
143
- return 'demographics', params
144
-
145
  for pattern in self.query_patterns['general_stats']:
146
  if re.search(pattern, query, re.IGNORECASE):
147
  return 'general_stats', params
148
 
149
  return 'general_stats', params
150
 
151
- def _fuzzy_search_location(self, query_location: str) -> List[Tuple[str, int]]:
152
- """Search for locations using fuzzy matching."""
153
- all_pickups = self.data_processor.df['pickup_main'].value_counts()
154
- all_dropoffs = self.data_processor.df['dropoff_main'].value_counts()
155
-
156
- all_locations = {}
157
- for location, count in all_pickups.items():
158
- all_locations[location] = all_locations.get(location, 0) + count
159
- for location, count in all_dropoffs.items():
160
- all_locations[location] = all_locations.get(location, 0) + count
161
-
162
- matches = []
163
- query_lower = query_location.lower()
164
-
165
- # Exact match
166
- for location, count in all_locations.items():
167
- if query_lower == location.lower():
168
- matches.append((location, count))
169
-
170
- # Partial match
171
- if not matches:
172
- for location, count in all_locations.items():
173
- if query_lower in location.lower() or location.lower() in query_lower:
174
- matches.append((location, count))
175
-
176
- # Word match
177
- if not matches:
178
- query_words = query_lower.split()
179
- for location, count in all_locations.items():
180
- location_lower = location.lower()
181
- if any(word in location_lower for word in query_words if len(word) > 2):
182
- matches.append((location, count))
183
-
184
- matches.sort(key=lambda x: x[1], reverse=True)
185
- return matches[:5]
186
 
187
- def _generate_response(self, query_type: str, params: Dict[str, Any], original_query: str) -> str:
188
- """Generate a response based on the query type and parameters."""
 
189
 
190
- if query_type == 'location_stats':
191
- return self._handle_location_stats(params, original_query)
192
- elif query_type == 'time_patterns':
193
- return self._handle_time_patterns(params)
194
- elif query_type == 'group_size':
195
- return self._handle_group_size(params)
196
- elif query_type == 'top_locations':
197
- return self._handle_top_locations(params)
198
- elif query_type == 'demographics':
199
- return self._handle_demographics(params)
200
- elif query_type == 'general_stats':
201
- return self._handle_general_stats()
202
- else:
203
- return self._handle_fallback(original_query)
204
 
205
- def _handle_location_stats(self, params: Dict[str, Any], original_query: str) -> str:
206
- """Handle location-specific statistics queries."""
207
  location = params.get('location', '')
208
-
209
  stats = self.data_processor.get_location_stats(location)
210
 
211
  if stats['pickup_count'] == 0 and stats['dropoff_count'] == 0:
212
- matches = self._fuzzy_search_location(location)
213
-
214
- if matches:
215
- best_match = matches[0][0]
216
- stats = self.data_processor.get_location_stats(best_match)
217
-
218
- if stats['pickup_count'] > 0 or stats['dropoff_count'] > 0:
219
- response = f"<strong>Found results for '{best_match}'</strong> (closest match to '{location}'):\n\n"
220
- else:
221
- response = f"I couldn't find exact data for '{location}'. Did you mean one of these?\n\n"
222
- for match_location, count in matches[:3]:
223
- response += f"• <strong>{match_location}</strong> ({count} total trips)\n"
224
- response += f"\nTry asking: 'Tell me about {matches[0][0]}'"
225
- return response
226
- else:
227
- return f"I couldn't find any trips associated with '{location}'. Try checking the spelling or asking about a different location like 'West Campus' or 'The Aquarium on 6th'."
228
- else:
229
- best_match = location.title()
230
- response = f"<strong>Stats for {best_match}:</strong>\n\n"
231
 
232
  if stats['pickup_count'] > 0:
233
- response += f"<strong>{stats['pickup_count']} pickup trips</strong> with an average group size of {stats['avg_group_size_pickup']:.1f}\n"
234
- if stats['peak_hours_pickup']:
235
- peak_hours = ', '.join([utils.format_time(h) for h in stats['peak_hours_pickup']])
236
- response += f"Most popular pickup times: {peak_hours}\n"
237
 
238
  if stats['dropoff_count'] > 0:
239
- response += f"<strong>{stats['dropoff_count']} drop-off trips</strong> with an average group size of {stats['avg_group_size_dropoff']:.1f}\n"
240
- if stats['peak_hours_dropoff']:
241
- peak_hours = ', '.join([utils.format_time(h) for h in stats['peak_hours_dropoff']])
242
- response += f"Most popular drop-off times: {peak_hours}\n"
243
-
244
- total_trips = stats['pickup_count'] + stats['dropoff_count']
245
- insights = self.data_processor.get_quick_insights()
246
- percentage = (total_trips / insights['total_trips']) * 100
247
-
248
- response += f"\n<strong>Insight:</strong> This location accounts for {percentage:.1f}% of all Austin trips!"
249
-
250
- if any(word in original_query for word in ['last', 'this', 'month', 'week', 'yesterday', 'today']):
251
- response += f"\n\n<strong>Note:</strong> This data covers our full Austin dataset. For specific time periods, the patterns shown represent typical activity for this location."
252
 
253
  return response
254
 
255
  def _handle_time_patterns(self, params: Dict[str, Any]) -> str:
256
  """Handle time pattern queries."""
257
- min_group_size = params.get('min_group_size', None)
258
-
259
- time_data = self.data_processor.get_time_patterns(min_group_size)
260
-
261
- response = "<strong>Peak Riding Times:</strong>\n\n"
262
-
263
- if min_group_size:
264
- response += f"<em>For groups of {min_group_size}+ riders:</em>\n\n"
265
-
266
  hourly_counts = time_data['hourly_counts']
267
- top_hours = sorted(hourly_counts.items(), key=lambda x: x[1], reverse=True)[:5]
268
 
269
- response += "<strong>Busiest Hours:</strong>\n"
270
  for i, (hour, count) in enumerate(top_hours, 1):
271
- time_label = utils.format_time(hour)
272
- response += f"{i}. <strong>{time_label}</strong> - {count} trips\n"
273
-
274
- time_categories = time_data['time_category_counts']
275
- response += "\n<strong>By Time Period:</strong>\n"
276
- for period, count in sorted(time_categories.items(), key=lambda x: x[1], reverse=True):
277
- response += f"• <strong>{period}:</strong> {count} trips\n"
278
-
279
- peak_hour = top_hours[0][0]
280
- peak_count = top_hours[0][1]
281
- response += f"\n<strong>Insight:</strong> {utils.format_time(peak_hour)} is the absolute peak with {peak_count} trips!"
282
 
283
  return response
284
 
285
  def _handle_group_size(self, params: Dict[str, Any]) -> str:
286
  """Handle group size queries."""
287
- target_size = params.get('group_size', 6)
288
-
289
  insights = self.data_processor.get_quick_insights()
290
- group_distribution = insights['group_size_distribution']
291
-
292
- response = f"<strong>Group Size Analysis ({target_size}+ passengers):</strong>\n\n"
293
-
294
- large_group_trips = sum(count for size, count in group_distribution.items() if size >= target_size)
295
- total_trips = insights['total_trips']
296
- percentage = (large_group_trips / total_trips) * 100
297
-
298
- response += f"• <strong>{large_group_trips} trips</strong> had {target_size}+ passengers ({percentage:.1f}% of all trips)\n"
299
-
300
- response += f"\n<strong>Breakdown of {target_size}+ passenger groups:</strong>\n"
301
- large_groups = {size: count for size, count in group_distribution.items() if size >= target_size}
302
- for size, count in sorted(large_groups.items(), key=lambda x: x[1], reverse=True)[:8]:
303
- group_pct = (count / large_group_trips) * 100 if large_group_trips > 0 else 0
304
- response += f"• <strong>{size} passengers:</strong> {count} trips ({group_pct:.1f}%)\n"
305
-
306
- avg_size = insights['avg_group_size']
307
- response += f"\n<strong>Insight:</strong> Average group size is {avg_size:.1f} passengers - most rides are group experiences!"
308
-
309
  return response
310
 
311
  def _handle_top_locations(self, params: Dict[str, Any]) -> str:
312
  """Handle top locations queries."""
313
- location_type = params.get('location_type', 'both')
314
  insights = self.data_processor.get_quick_insights()
 
315
 
316
- response = "<strong>Most Popular Locations:</strong>\n\n"
317
-
318
- if location_type in ['pickup', 'both']:
319
- response += "<strong>Top Pickup Spots:</strong>\n"
320
- for i, (location, count) in enumerate(list(insights['top_pickups'])[:8], 1):
321
- response += f"{i}. <strong>{location}</strong> - {count} pickups\n"
322
-
323
- if location_type in ['dropoff', 'both']:
324
- if location_type == 'both':
325
- response += "\n<strong>Top Drop-off Destinations:</strong>\n"
326
- else:
327
- response += "<strong>Top Drop-off Destinations:</strong>\n"
328
- for i, (location, count) in enumerate(list(insights['top_dropoffs'])[:8], 1):
329
- response += f"{i}. <strong>{location}</strong> - {count} drop-offs\n"
330
-
331
- if location_type in ['pickup', 'both']:
332
- top_pickup = list(insights['top_pickups'])[0]
333
- response += f"\n<strong>Insight:</strong> {top_pickup[0]} dominates pickups with {top_pickup[1]} trips!"
334
-
335
- return response
336
-
337
- def _handle_demographics(self, params: Dict[str, Any]) -> str:
338
- """Handle demographics queries."""
339
- age_range = params.get('age_range', (18, 24))
340
-
341
- response = f"<strong>Demographics Analysis ({age_range[0]}-{age_range[1]} year olds):</strong>\n\n"
342
- response += "I'd love to help with demographic analysis, but I don't currently have access to rider age data in this dataset. "
343
- response += "However, I can tell you about the locations and times that are popular with different group sizes!\n\n"
344
-
345
- insights = self.data_processor.get_quick_insights()
346
- response += "<strong>Popular spots that might appeal to younger riders:</strong>\n"
347
-
348
- entertainment_spots = ['The Aquarium on 6th', 'Wiggle Room', "Shakespeare's", 'LUNA Rooftop', 'Green Light Social']
349
-
350
- for spot in entertainment_spots[:5]:
351
- for location, count in insights['top_dropoffs']:
352
- if spot.lower() in location.lower():
353
- response += f"• <strong>{location}</strong> - {count} drop-offs\n"
354
- break
355
-
356
- response += "\n<strong>Insight:</strong> Late night hours (10 PM - 1 AM) see the highest activity, which often correlates with younger demographics!"
357
 
358
  return response
359
 
@@ -361,53 +411,22 @@ class FetiiChatbot:
361
  """Handle general statistics queries."""
362
  insights = self.data_processor.get_quick_insights()
363
 
364
- response = "<strong>Fetii Austin Overview:</strong>\n\n"
365
-
366
- response += f"<strong>Total Trips Analyzed:</strong> {insights['total_trips']:,}\n"
367
- response += f"<strong>Average Group Size:</strong> {insights['avg_group_size']:.1f} passengers\n"
368
- response += f"<strong>Peak Hour:</strong> {utils.format_time(insights['peak_hour'])}\n"
369
- response += f"<strong>Large Groups (6+):</strong> {insights['large_groups_count']} trips ({insights['large_groups_pct']:.1f}%)\n\n"
370
-
371
- response += "<strong>Top Hotspots:</strong>\n"
372
- top_pickup = list(insights['top_pickups'])[0]
373
- top_dropoff = list(insights['top_dropoffs'])[0]
374
- response += f"• Most popular pickup: <strong>{top_pickup[0]}</strong> ({top_pickup[1]} trips)\n"
375
- response += f"• Most popular destination: <strong>{top_dropoff[0]}</strong> ({top_dropoff[1]} trips)\n\n"
376
-
377
- group_dist = insights['group_size_distribution']
378
- most_common_size = max(group_dist.items(), key=lambda x: x[1])
379
- response += f"<strong>Most Common Group Size:</strong> {most_common_size[0]} passengers ({most_common_size[1]} trips)\n\n"
380
-
381
- response += "<strong>Key Insights:</strong>\n"
382
- response += f"• {insights['large_groups_pct']:.0f}% of all rides are large groups (6+ people)\n"
383
- response += "• Peak activity happens late evening (10-11 PM)\n"
384
- response += "• West Campus dominates as the top pickup location\n"
385
- response += "• Entertainment venues are the most popular destinations"
386
 
387
  return response
388
 
389
  def _handle_fallback(self, query: str) -> str:
390
- """Handle queries that don't match any specific pattern."""
391
- response = "I'm not sure I understood that question perfectly. Here's what I can help you with:\n\n"
392
-
393
- response += "<strong>Location Questions:</strong>\n"
394
- response += "• 'How many groups went to [location]?'\n"
395
- response += "• 'Tell me about [location]'\n"
396
- response += " 'Top pickup/drop-off spots'\n\n"
397
-
398
- response += "<strong>Time Questions:</strong>\n"
399
- response += "• 'When do large groups typically ride?'\n"
400
- response += "• 'Peak hours for groups of 6+'\n"
401
- response += "• 'Busiest times'\n\n"
402
-
403
- response += "<strong>Group Size Questions:</strong>\n"
404
- response += "• 'How many trips had 10+ passengers?'\n"
405
- response += "• 'Large group patterns'\n"
406
- response += "• 'Average group size'\n\n"
407
-
408
- response += "Would you like to try asking one of these types of questions?"
409
-
410
- return response
411
 
412
  def get_conversation_history(self) -> List[Dict[str, str]]:
413
  """Get the conversation history."""
@@ -415,4 +434,12 @@ class FetiiChatbot:
415
 
416
  def clear_history(self):
417
  """Clear the conversation history."""
418
- self.conversation_history = []
 
 
 
 
 
 
 
 
 
1
  import re
2
+ import json
3
+ import requests
4
+ from typing import Dict, List, Any, Tuple, Optional
5
  from data_processor import DataProcessor
6
  import utils
7
 
8
+ class EnhancedFetiiChatbot:
9
  """
10
+ Enhanced conversational chatbot with Google Gemini AI integration for Fetii rideshare data analysis.
11
+ Falls back to pattern-based responses when AI is unavailable.
12
  """
13
 
14
+ def __init__(self, data_processor: DataProcessor, use_ai: bool = True, gemini_api_key: str = None):
15
+ """Initialize the enhanced chatbot with Gemini AI capabilities."""
16
  self.data_processor = data_processor
17
  self.conversation_history = []
18
+ self.use_ai = use_ai
19
+ self.gemini_api_key = gemini_api_key
20
+ self.ai_available = False
21
 
22
+ # Initialize Gemini AI if API key provided
23
+ if self.use_ai and self.gemini_api_key:
24
+ self._setup_gemini()
25
+
26
+ # Fallback pattern-based system
27
  self.query_patterns = {
28
+ 'greetings': [
29
+ r'^(?:hi|hello|hey|good morning|good afternoon|good evening|greetings?)(?:\s+.*)?$',
30
+ r'^(?:what\'?s up|how are you|how\'?s it going|sup)(?:\s+.*)?$',
31
+ r'^(?:thanks?|thank you|thx|appreciate it)(?:\s+.*)?$'
32
+ ],
33
+ 'casual_conversation': [
34
+ r'^(?:how are you|what are you|who are you|what can you do)(?:\s+.*)?$',
35
+ r'^(?:tell me about yourself|what\'?s your name|introduce yourself)(?:\s+.*)?$',
36
+ r'^(?:help|what can you help with|what do you do)(?:\s+.*)?$',
37
+ r'^(?:i\'?m (?:good|fine|okay|great|tired|busy))(?:\s+.*)?$'
38
+ ],
39
  'location_stats': [
40
  r'how many.*(?:groups?|trips?).*(?:went to|to|from)\s+([^?]+?)(?:\s+(?:last|this|yesterday|today|week|month|year).*?)?[?.]?$',
41
  r'(?:trips?|groups?).*(?:to|from)\s+([^?]+?)(?:\s+(?:last|this|yesterday|today|week|month|year).*?)?[?.]?$',
42
  r'tell me about\s+([^?]+?)(?:\s+(?:last|this|yesterday|today|week|month|year).*?)?[?.]?$',
43
  r'stats for\s+([^?]+?)(?:\s+(?:last|this|yesterday|today|week|month|year).*?)?[?.]?$',
 
44
  ],
45
  'time_patterns': [
46
  r'when do.*groups?.*ride',
 
59
  r'most popular.*locations?',
60
  r'busiest.*locations?',
61
  r'hottest spots?',
 
 
 
 
 
 
 
62
  ],
63
  'general_stats': [
64
  r'how many total',
 
70
  r'total trips'
71
  ]
72
  }
73
+
74
+ def _setup_gemini(self):
75
+ """Setup Gemini AI connection."""
76
+ try:
77
+ # Test Gemini API connection with minimal request
78
+ test_payload = {
79
+ "contents": [
80
+ {
81
+ "parts": [
82
+ {"text": "Hi"}
83
+ ]
84
+ }
85
+ ],
86
+ "generationConfig": {
87
+ "temperature": 0.7,
88
+ "maxOutputTokens": 10
89
+ }
90
+ }
91
+
92
+ response = requests.post(
93
+ f'https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent?key={self.gemini_api_key}',
94
+ headers={'Content-Type': 'application/json'},
95
+ json=test_payload,
96
+ timeout=5
97
+ )
98
+
99
+ if response.status_code == 200:
100
+ self.ai_available = True
101
+ print("✅ Gemini AI connected successfully")
102
+ elif response.status_code == 429:
103
+ print("⚠️ Gemini API rate limit reached - falling back to pattern-based responses")
104
+ self.ai_available = False
105
+ elif response.status_code == 400:
106
+ print("⚠️ Invalid Gemini API key or request")
107
+ self.ai_available = False
108
+ else:
109
+ print(f"⚠️ Gemini AI connection failed: {response.status_code}")
110
+ self.ai_available = False
111
+
112
+ except Exception as e:
113
+ print(f"⚠️ Failed to connect to Gemini AI: {str(e)}")
114
+ self.ai_available = False
115
 
116
  def process_query(self, user_query: str) -> str:
117
  """Process a user query and return an appropriate response."""
118
+ user_query = user_query.strip()
119
 
120
  self.conversation_history.append({"role": "user", "content": user_query})
121
 
122
  try:
123
+ # Get relevant data context
124
+ context = self._get_data_context(user_query)
125
+
126
+ # Try AI response first if available
127
+ if self.ai_available:
128
+ ai_response = self._get_gemini_response(user_query, context)
129
+ if ai_response:
130
+ self.conversation_history.append({"role": "assistant", "content": ai_response})
131
+ return ai_response
132
 
133
+ # Fallback to pattern-based response
134
+ response = self._pattern_based_response(user_query.lower())
135
+ self.conversation_history.append({"role": "assistant", "content": response})
136
  return response
137
 
138
  except Exception as e:
139
+ error_response = ("I'm having a bit of trouble processing that request. "
140
+ "Let me help you explore Austin rideshare data - try asking about specific locations, "
141
+ "time patterns, or group sizes. What would you like to discover?")
 
142
  return error_response
143
 
144
+ def _get_data_context(self, query: str) -> str:
145
+ """Extract relevant data context based on the query."""
146
+ insights = self.data_processor.get_quick_insights()
147
+ query_lower = query.lower()
148
+
149
+ # Base context always included
150
+ context_parts = [
151
+ f"Total Austin rideshare trips analyzed: {insights['total_trips']:,}",
152
+ f"Average group size: {insights['avg_group_size']:.1f} passengers",
153
+ f"Peak activity hour: {utils.format_time(insights['peak_hour'])}",
154
+ f"Large groups (6+): {insights['large_groups_pct']:.1f}% of all trips"
155
+ ]
156
+
157
+ # Add query-specific context
158
+ if any(word in query_lower for word in ['location', 'place', 'pickup', 'dropoff', 'where', 'destination']):
159
+ top_pickups = dict(list(insights['top_pickups'])[:5])
160
+ top_dropoffs = dict(list(insights['top_dropoffs'])[:5])
161
+ context_parts.extend([
162
+ f"Top pickup locations: {top_pickups}",
163
+ f"Top destinations: {top_dropoffs}"
164
+ ])
165
+
166
+ if any(word in query_lower for word in ['time', 'hour', 'peak', 'busy', 'when']):
167
+ time_data = self.data_processor.get_time_patterns()
168
+ hourly_top = dict(sorted(time_data['hourly_counts'].items(), key=lambda x: x[1], reverse=True)[:5])
169
+ context_parts.append(f"Hourly trip distribution: {hourly_top}")
170
+
171
+ if any(word in query_lower for word in ['group', 'size', 'passenger', 'people']):
172
+ group_dist = dict(list(insights['group_size_distribution'].items())[:8])
173
+ context_parts.append(f"Group size distribution: {group_dist}")
174
+
175
+ # Extract specific location if mentioned
176
+ potential_locations = self._extract_locations_from_query(query)
177
+ if potential_locations:
178
+ for location in potential_locations[:2]: # Limit to 2 locations
179
+ stats = self.data_processor.get_location_stats(location)
180
+ if stats['pickup_count'] > 0 or stats['dropoff_count'] > 0:
181
+ context_parts.append(
182
+ f"'{location}' stats: {stats['pickup_count']} pickups, "
183
+ f"{stats['dropoff_count']} dropoffs"
184
+ )
185
+
186
+ return "\n".join(context_parts)
187
+
188
+ def _extract_locations_from_query(self, query: str) -> List[str]:
189
+ """Extract potential location names from the query."""
190
+ # Get all known locations
191
+ all_pickups = self.data_processor.df['pickup_main'].unique()
192
+ all_dropoffs = self.data_processor.df['dropoff_main'].unique()
193
+ all_locations = set(list(all_pickups) + list(all_dropoffs))
194
 
195
+ query_lower = query.lower()
196
+ found_locations = []
197
 
198
+ for location in all_locations:
199
+ if location.lower() in query_lower:
200
+ found_locations.append(location)
201
 
202
+ return found_locations
203
+
204
+ def _get_gemini_response(self, query: str, context: str) -> Optional[str]:
205
+ """Get response from Gemini AI with improved error handling."""
206
+ try:
207
+ # Create system prompt with data context
208
+ system_prompt = f"""You are Fetii AI, a friendly and knowledgeable assistant specializing in Austin rideshare analytics.
209
+
210
+ Your personality:
211
+ - Conversational and helpful
212
+ - Provide specific data-driven insights
213
+ - Use the actual data provided in context
214
+ - Format responses clearly with key numbers highlighted
215
+ - Be enthusiastic about patterns and trends
216
+ - Keep responses concise but informative (under 150 words)
217
+
218
+ Current Austin rideshare data context:
219
+ {context}
220
+
221
+ Important: Always use the specific numbers and data from the context above. Don't make up statistics.
222
+
223
+ User query: {query}
224
+
225
+ Response:"""
226
+
227
+ payload = {
228
+ "contents": [
229
+ {
230
+ "parts": [
231
+ {"text": system_prompt}
232
+ ]
233
+ }
234
+ ],
235
+ "generationConfig": {
236
+ "temperature": 0.7,
237
+ "maxOutputTokens": 200,
238
+ "topP": 0.8,
239
+ "topK": 40
240
+ }
241
+ }
242
+
243
+ response = requests.post(
244
+ f'https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent?key={self.gemini_api_key}',
245
+ headers={'Content-Type': 'application/json'},
246
+ json=payload,
247
+ timeout=15
248
+ )
249
+
250
+ if response.status_code == 200:
251
+ result = response.json()
252
+ if 'candidates' in result and len(result['candidates']) > 0:
253
+ content = result['candidates'][0]['content']['parts'][0]['text']
254
+ return content.strip()
255
+ elif response.status_code == 429:
256
+ print("⚠️ Gemini API rate limit reached - falling back to pattern-based response")
257
+ self.ai_available = False
258
+ return None
259
+ elif response.status_code == 400:
260
+ print("⚠️ Invalid Gemini API request")
261
+ self.ai_available = False
262
+ return None
263
+ else:
264
+ print(f"Gemini API error: {response.status_code} - {response.text}")
265
+
266
+ except requests.exceptions.Timeout:
267
+ print("⚠️ Gemini API timeout - falling back to pattern-based response")
268
+ return None
269
+ except Exception as e:
270
+ print(f"Error calling Gemini API: {str(e)}")
271
+
272
+ return None
273
+
274
+ def _pattern_based_response(self, query: str) -> str:
275
+ """Fallback pattern-based response system."""
276
+ query_type, params = self._parse_query(query)
277
+
278
+ if query_type == 'greetings':
279
+ return self._handle_greetings(query)
280
+ elif query_type == 'casual_conversation':
281
+ return self._handle_casual_conversation(query)
282
+ elif query_type == 'location_stats':
283
+ return self._handle_location_stats(params, query)
284
+ elif query_type == 'time_patterns':
285
+ return self._handle_time_patterns(params)
286
+ elif query_type == 'group_size':
287
+ return self._handle_group_size(params)
288
+ elif query_type == 'top_locations':
289
+ return self._handle_top_locations(params)
290
+ elif query_type == 'general_stats':
291
+ return self._handle_general_stats()
292
+ else:
293
+ return self._handle_fallback(query)
294
 
295
  def _parse_query(self, query: str) -> Tuple[str, Dict[str, Any]]:
296
  """Parse the user query to determine intent and extract parameters."""
297
  params = {}
298
 
299
+ # Check for greetings first
300
+ for pattern in self.query_patterns['greetings']:
301
+ if re.search(pattern, query, re.IGNORECASE):
302
+ return 'greetings', params
303
+
304
+ # Check for casual conversation
305
+ for pattern in self.query_patterns['casual_conversation']:
306
+ if re.search(pattern, query, re.IGNORECASE):
307
+ return 'casual_conversation', params
308
+
309
+ # Check for location stats
310
  for pattern in self.query_patterns['location_stats']:
311
  match = re.search(pattern, query, re.IGNORECASE)
312
  if match:
313
  location = match.group(1).strip()
 
314
  if location:
315
  params['location'] = location
316
  return 'location_stats', params
317
 
318
+ # Check other patterns
319
  for pattern in self.query_patterns['time_patterns']:
320
  if re.search(pattern, query, re.IGNORECASE):
 
 
 
321
  return 'time_patterns', params
322
 
323
  for pattern in self.query_patterns['group_size']:
324
+ if re.search(pattern, query, re.IGNORECASE):
 
 
 
325
  return 'group_size', params
326
 
327
  for pattern in self.query_patterns['top_locations']:
328
  if re.search(pattern, query, re.IGNORECASE):
 
 
 
 
 
 
329
  return 'top_locations', params
330
 
 
 
 
 
 
 
 
331
  for pattern in self.query_patterns['general_stats']:
332
  if re.search(pattern, query, re.IGNORECASE):
333
  return 'general_stats', params
334
 
335
  return 'general_stats', params
336
 
337
+ def _handle_greetings(self, query: str) -> str:
338
+ """Handle greeting messages."""
339
+ if any(word in query.lower() for word in ['thanks', 'thank you']):
340
+ return "You're welcome! Happy to help you explore Austin rideshare patterns."
341
+
342
+ return ("Hello! I'm Fetii AI, your Austin rideshare analytics assistant. "
343
+ "I can help you understand trip patterns, popular locations, peak hours, and group behaviors. "
344
+ "What would you like to explore?")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
+ def _handle_casual_conversation(self, query: str) -> str:
347
+ """Handle casual conversation."""
348
+ query_lower = query.lower()
349
 
350
+ if any(phrase in query_lower for phrase in ['how are you', 'how\'s it going']):
351
+ return ("I'm doing great, thanks for asking! I'm excited to help you explore Austin rideshare data. "
352
+ "What aspect of the data interests you most?")
353
+
354
+ if any(phrase in query_lower for phrase in ['who are you', 'what are you']):
355
+ return ("I'm Fetii AI, your specialized assistant for Austin rideshare analytics! "
356
+ "I analyze real Austin rideshare data to provide insights about trip patterns, "
357
+ "popular destinations, peak hours, and group behaviors. What would you like to explore?")
358
+
359
+ return ("I'm here to help you explore Austin rideshare data! "
360
+ "Ask me about trip patterns, locations, or any trends you're curious about.")
 
 
 
361
 
362
+ def _handle_location_stats(self, params: Dict[str, Any], query: str) -> str:
363
+ """Handle location-specific queries."""
364
  location = params.get('location', '')
 
365
  stats = self.data_processor.get_location_stats(location)
366
 
367
  if stats['pickup_count'] == 0 and stats['dropoff_count'] == 0:
368
+ return f"I couldn't find trips for '{location}'. Try a different location like 'West Campus' or 'Downtown'."
369
+
370
+ response = f"**Stats for {location.title()}:**\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
 
372
  if stats['pickup_count'] > 0:
373
+ response += f"**{stats['pickup_count']} pickup trips** with average group size {stats['avg_group_size_pickup']:.1f}\n"
 
 
 
374
 
375
  if stats['dropoff_count'] > 0:
376
+ response += f"**{stats['dropoff_count']} drop-off trips** with average group size {stats['avg_group_size_dropoff']:.1f}\n"
 
 
 
 
 
 
 
 
 
 
 
 
377
 
378
  return response
379
 
380
  def _handle_time_patterns(self, params: Dict[str, Any]) -> str:
381
  """Handle time pattern queries."""
382
+ time_data = self.data_processor.get_time_patterns()
 
 
 
 
 
 
 
 
383
  hourly_counts = time_data['hourly_counts']
384
+ top_hours = sorted(hourly_counts.items(), key=lambda x: x[1], reverse=True)[:3]
385
 
386
+ response = "**Peak Hours Analysis:**\n\n"
387
  for i, (hour, count) in enumerate(top_hours, 1):
388
+ response += f"{i}. **{utils.format_time(hour)}** - {count} trips\n"
 
 
 
 
 
 
 
 
 
 
389
 
390
  return response
391
 
392
  def _handle_group_size(self, params: Dict[str, Any]) -> str:
393
  """Handle group size queries."""
 
 
394
  insights = self.data_processor.get_quick_insights()
395
+ response = f"**Group Size Analysis:**\n\n"
396
+ response += f"Average group size: **{insights['avg_group_size']:.1f} passengers**\n"
397
+ response += f"Large groups (6+): **{insights['large_groups_pct']:.1f}%** of all trips"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
  return response
399
 
400
  def _handle_top_locations(self, params: Dict[str, Any]) -> str:
401
  """Handle top locations queries."""
 
402
  insights = self.data_processor.get_quick_insights()
403
+ response = "**Top Pickup Locations:**\n\n"
404
 
405
+ for i, (location, count) in enumerate(list(insights['top_pickups'])[:5], 1):
406
+ response += f"{i}. **{location}** - {count} trips\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
 
408
  return response
409
 
 
411
  """Handle general statistics queries."""
412
  insights = self.data_processor.get_quick_insights()
413
 
414
+ response = "**Austin Rideshare Overview:**\n\n"
415
+ response += f"**Total Trips:** {insights['total_trips']:,}\n"
416
+ response += f"**Average Group Size:** {insights['avg_group_size']:.1f} passengers\n"
417
+ response += f"**Peak Hour:** {utils.format_time(insights['peak_hour'])}\n"
418
+ response += f"**Large Groups:** {insights['large_groups_pct']:.1f}% (6+ passengers)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
 
420
  return response
421
 
422
  def _handle_fallback(self, query: str) -> str:
423
+ """Handle unrecognized queries."""
424
+ return ("I can help you explore Austin rideshare data! Try asking about:\n\n"
425
+ "• Specific locations: 'Tell me about West Campus'\n"
426
+ " Time patterns: 'What are the peak hours?'\n"
427
+ "• Group sizes: 'How many large groups ride?'\n"
428
+ "• General stats: 'Give me an overview'\n\n"
429
+ "What interests you most?")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
 
431
  def get_conversation_history(self) -> List[Dict[str, str]]:
432
  """Get the conversation history."""
 
434
 
435
  def clear_history(self):
436
  """Clear the conversation history."""
437
+ self.conversation_history = []
438
+
439
+ def set_gemini_api_key(self, api_key: str):
440
+ """Update Gemini API key and reinitialize connection."""
441
+ self.gemini_api_key = api_key
442
+ if api_key:
443
+ self._setup_gemini()
444
+ else:
445
+ self.ai_available = False