mgokg commited on
Commit
cb83081
Β·
verified Β·
1 Parent(s): 7f8466d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +298 -51
app.py CHANGED
@@ -2,64 +2,311 @@ import base64
2
  import gradio as gr
3
  import os
4
  import json
 
 
 
5
  from google import genai
6
  from google.genai import types
7
- from gradio_client import Client
 
 
 
8
 
 
 
9
 
10
- def generate(input_text):
11
- try:
12
- client = genai.Client(
13
- api_key=os.environ.get("GEMINI_API_KEY"),
14
- )
15
- except Exception as e:
16
- return f"Error initializing client: {e}. Make sure GEMINI_API_KEY is set."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- model = "gemini-flash-latest"
19
- contents = [
20
- types.Content(
21
- role="user",
22
- parts=[
23
- types.Part.from_text(text=f"{input_text}"),
24
- ],
25
- ),
26
- ]
27
- tools = [
28
- types.Tool(google_search=types.GoogleSearch()),
29
- ]
30
- generate_content_config = types.GenerateContentConfig(
31
- temperature=0.4,
32
- thinking_config = types.ThinkingConfig(
33
- thinking_budget=0,
34
- ),
35
- tools=tools,
36
- response_mime_type="text/plain",
37
- )
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- response_text = ""
41
- try:
42
- for chunk in client.models.generate_content_stream(
43
- model=model,
44
- contents=contents,
45
- config=generate_content_config,
46
- ):
47
- response_text += chunk.text
48
- except Exception as e:
49
- return f"Error during generation: {e}"
50
- data = response_text
51
- #data = clean_json_string(response_text)
52
- data = data[:-1]
53
- return response_text, ""
54
-
55
 
56
- if __name__ == '__main__':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- with gr.Blocks() as demo:
59
- title=gr.Markdown("# Gemini 2.0 Flash + Websearch")
60
- output_textbox = gr.Markdown()
61
- input_textbox = gr.Textbox(lines=3, label="", placeholder="Enter message here...")
62
- submit_button = gr.Button("send")
63
- submit_button.click(fn=generate,inputs=input_textbox,outputs=[output_textbox, input_textbox])
64
- demo.launch(show_error=True)
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import gradio as gr
3
  import os
4
  import json
5
+ import asyncio
6
+ import aiohttp
7
+ from typing import Dict, Any, Optional, List
8
  from google import genai
9
  from google.genai import types
10
+ from contextlib import AsyncExitStack
11
+ from mcp import ClientSession
12
+ from mcp.client.streamable_http import streamablehttp_client
13
+ import nest_asyncio
14
 
15
+ # Allow nested event loops for Jupyter notebooks
16
+ nest_asyncio.apply()
17
 
18
+ class MCPTrainClient:
19
+ """MCP Client for train timetable API using Streamable HTTP transport"""
20
+
21
+ def __init__(self):
22
+ self.session = None
23
+ self.exit_stack = None
24
+ self.tools = []
25
+ self.mcp_server_url = "https://mgokg-db-timetable-api.hf.space/gradio_api/mcp/"
26
+
27
+ async def connect(self) -> str:
28
+ """Connect to the MCP server using Streamable HTTP transport"""
29
+ try:
30
+ if self.exit_stack:
31
+ await self.exit_stack.aclose()
32
+
33
+ self.exit_stack = AsyncExitStack()
34
+
35
+ # Connect using Streamable HTTP transport
36
+ transport = await self.exit_stack.enter_async_context(
37
+ streamablehttp_client(self.mcp_server_url)
38
+ )
39
+
40
+ self.read_stream, self.write_stream = transport
41
+ self.session = await self.exit_stack.enter_async_context(
42
+ ClientSession(self.read_stream, self.write_stream)
43
+ )
44
+
45
+ await self.session.initialize()
46
+
47
+ # List available tools
48
+ response = await self.session.list_tools()
49
+ self.tools = response.tools
50
+
51
+ tool_names = [tool.name for tool in self.tools]
52
+ return f"βœ… Connected to MCP server. Available tools: {', '.join(tool_names)}"
53
+
54
+ except Exception as e:
55
+ return f"❌ Failed to connect to MCP server: {str(e)}"
56
+
57
+ async def get_train_connections(self, departure: str, destination: str) -> str:
58
+ """Get train connections using the MCP tool"""
59
+ if not self.session:
60
+ return "❌ Not connected to MCP server. Please connect first."
61
+
62
+ try:
63
+ # Call the db_timetable_api_ui_wrapper tool
64
+ result = await self.session.call_tool(
65
+ "db_timetable_api_ui_wrapper",
66
+ arguments={"dep": departure, "dest": destination}
67
+ )
68
+
69
+ if result.content and len(result.content) > 0:
70
+ return result.content[0].text
71
+ else:
72
+ return "❌ No results returned from train API"
73
+
74
+ except Exception as e:
75
+ return f"❌ Error calling train API: {str(e)}"
76
+
77
+ async def disconnect(self):
78
+ """Disconnect from the MCP server"""
79
+ if self.exit_stack:
80
+ await self.exit_stack.aclose()
81
+ self.exit_stack = None
82
+ self.session = None
83
+ self.tools = []
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ class GeminiMCPIntegration:
87
+ """Integration class that combines Gemini with MCP tools"""
88
+
89
+ def __init__(self):
90
+ self.mcp_client = MCPTrainClient()
91
+ self.gemini_client = None
92
+ self.loop = asyncio.new_event_loop()
93
+ asyncio.set_event_loop(self.loop)
94
+
95
+ def initialize_gemini(self) -> Optional[str]:
96
+ """Initialize Gemini client"""
97
+ try:
98
+ api_key = os.environ.get("GEMINI_API_KEY")
99
+ if not api_key:
100
+ return "GEMINI_API_KEY not found in environment variables"
101
+
102
+ self.gemini_client = genai.Client(api_key=api_key)
103
+ return None
104
+ except Exception as e:
105
+ return f"Error initializing Gemini client: {str(e)}"
106
+
107
+ def connect_to_mcp(self) -> str:
108
+ """Connect to MCP server (synchronous wrapper)"""
109
+ return self.loop.run_until_complete(self.mcp_client.connect())
110
+
111
+ def extract_train_stations(self, text: str) -> Optional[Dict[str, str]]:
112
+ """Extract departure and destination stations from text"""
113
+ text_lower = text.lower()
114
+
115
+ # Common patterns for train queries
116
+ patterns = [
117
+ r'(?:from|von)\s+(.+?)\s+(?:to|nach)\s+(.+)',
118
+ r'(.+?)\s+(?:to|nach)\s+(.+)',
119
+ r'(?:train|zug|bahn)(?:\s+from|\s+von)?\s+(.+?)\s+(?:to|nach)\s+(.+)',
120
+ ]
121
+
122
+ import re
123
+ for pattern in patterns:
124
+ match = re.search(pattern, text_lower)
125
+ if match:
126
+ dep = match.group(1).strip()
127
+ dest = match.group(2).strip()
128
+
129
+ # Clean up common suffixes
130
+ for suffix in ['?', '.', '!', 'please', 'bitte']:
131
+ if dep.endswith(suffix):
132
+ dep = dep[:-len(suffix)].strip()
133
+ if dest.endswith(suffix):
134
+ dest = dest[:-len(suffix)].strip()
135
+
136
+ return {"departure": dep, "destination": dest}
137
+
138
+ return None
139
+
140
+ def process_query(self, query: str, mode: str = "auto") -> str:
141
+ """Process user query with appropriate tools"""
142
+
143
+ # Initialize Gemini if not already done
144
+ if not self.gemini_client:
145
+ error = self.initialize_gemini()
146
+ if error:
147
+ return f"❌ {error}"
148
+
149
+ # Check if this is a train query
150
+ station_info = self.extract_train_stations(query)
151
+
152
+ if station_info and (mode == "auto" or mode == "train"):
153
+ # Use MCP for train connections
154
+ if not self.mcp_client.session:
155
+ connect_result = self.connect_to_mcp()
156
+ if "❌" in connect_result:
157
+ return connect_result
158
+
159
+ train_result = self.loop.run_until_complete(
160
+ self.mcp_client.get_train_connections(
161
+ station_info["departure"],
162
+ station_info["destination"]
163
+ )
164
+ )
165
+
166
+ # Also get Gemini's interpretation
167
+ gemini_prompt = f"User asked: '{query}'. I found these train connections: {train_result}. Please provide a helpful summary and any additional context."
168
+
169
+ try:
170
+ contents = [types.Content(role="user", parts=[types.Part.from_text(text=gemini_prompt)])]
171
+ config = types.GenerateContentConfig(
172
+ temperature=0.4,
173
+ response_mime_type="text/plain"
174
+ )
175
+
176
+ gemini_response = ""
177
+ for chunk in self.gemini_client.models.generate_content_stream(
178
+ model="gemini-flash-latest",
179
+ contents=contents,
180
+ config=config
181
+ ):
182
+ if hasattr(chunk, 'text'):
183
+ gemini_response += chunk.text
184
+
185
+ return f"πŸš‚ **Train Connections:**\n\n{train_result}\n\nπŸ€– **AI Assistant:**\n\n{gemini_response}"
186
+
187
+ except Exception as e:
188
+ return f"πŸš‚ **Train Connections:**\n\n{train_result}\n\n⚠️ Note: Could not generate additional AI insights: {str(e)}"
189
+
190
+ elif mode == "train" and not station_info:
191
+ return "❌ Could not extract station names from your query. Please use format like 'Train from [station] to [station]'"
192
+
193
+ else:
194
+ # Use original Gemini + websearch functionality
195
+ return self.generate_with_websearch(query)
196
+
197
+ def generate_with_websearch(self, query: str) -> str:
198
+ """Original Gemini + websearch functionality"""
199
+ try:
200
+ contents = [types.Content(role="user", parts=[types.Part.from_text(text=query)])]
201
+ tools = [types.Tool(google_search=types.GoogleSearch())]
202
+
203
+ config = types.GenerateContentConfig(
204
+ temperature=0.4,
205
+ thinking_config=types.ThinkingConfig(thinking_budget=0),
206
+ tools=tools,
207
+ response_mime_type="text/plain"
208
+ )
209
+
210
+ response_text = ""
211
+ for chunk in self.gemini_client.models.generate_content_stream(
212
+ model="gemini-flash-latest",
213
+ contents=contents,
214
+ config=config
215
+ ):
216
+ if hasattr(chunk, 'text'):
217
+ response_text += chunk.text
218
+
219
+ return f"πŸ” **Web Search Results:**\n\n{response_text}"
220
+
221
+ except Exception as e:
222
+ return f"❌ Error during web search: {str(e)}"
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
+ # Global instance
226
+ mcp_integration = GeminiMCPIntegration()
227
+
228
+
229
+ def process_user_input(input_text: str, mode: str = "auto") -> str:
230
+ """Process user input"""
231
+ if not input_text.strip():
232
+ return "Please enter a message."
233
+
234
+ return mcp_integration.process_query(input_text, mode)
235
+
236
+
237
+ def connect_to_train_api() -> str:
238
+ """Connect to train API MCP server"""
239
+ return mcp_integration.connect_to_mcp()
240
 
 
 
 
 
 
 
 
241
 
242
+ if __name__ == '__main__':
243
+ with gr.Blocks(title="Gemini 2.0 Flash + Websearch + Train Connections (MCP)") as demo:
244
+ gr.Markdown("# πŸ€– Gemini 2.0 Flash + Websearch + Train Connections (MCP)")
245
+ gr.Markdown("Advanced AI assistant with web search and real-time train connections via MCP.")
246
+
247
+ with gr.Row():
248
+ with gr.Column(scale=3):
249
+ input_textbox = gr.Textbox(
250
+ lines=3,
251
+ label="Your Message",
252
+ placeholder="Enter your message here...\nExamples:\n- 'Train from Schweinfurt HBF to Oerlenbach'\n- 'What's the weather in Berlin?'\n- 'Latest news about AI'"
253
+ )
254
+ with gr.Column(scale=1):
255
+ mode_dropdown = gr.Dropdown(
256
+ choices=["auto", "websearch", "train", "both"],
257
+ value="auto",
258
+ label="Mode",
259
+ info="Choose processing mode"
260
+ )
261
+
262
+ with gr.Row():
263
+ submit_button = gr.Button("Send", variant="primary", scale=3)
264
+ connect_button = gr.Button("Connect to Train API", scale=1)
265
+
266
+ output_textbox = gr.Markdown(
267
+ label="Response",
268
+ show_copy_button=True
269
+ )
270
+
271
+ status_textbox = gr.Textbox(
272
+ label="Connection Status",
273
+ interactive=False,
274
+ max_lines=2
275
+ )
276
+
277
+ # Examples
278
+ gr.Markdown("### 🎯 Quick Examples")
279
+ with gr.Row():
280
+ example1 = gr.Button("πŸš‚ Train: Schweinfurt β†’ Oerlenbach")
281
+ example2 = gr.Button("🌀️ Weather in Munich")
282
+ example3 = gr.Button("πŸ“° Latest AI news")
283
+ example4 = gr.Button("πŸ™οΈ Restaurants in Berlin")
284
+
285
+ def set_example(example_text):
286
+ return example_text
287
+
288
+ example1.click(fn=lambda: set_example("Train connections from Schweinfurt HBF to Oerlenbach"), outputs=input_textbox)
289
+ example2.click(fn=lambda: set_example("What's the weather in Munich today?"), outputs=input_textbox)
290
+ example3.click(fn=lambda: set_example("Latest news about artificial intelligence developments"), outputs=input_textbox)
291
+ example4.click(fn=lambda: set_example("Find good restaurants in Berlin city center"), outputs=input_textbox)
292
+
293
+ # Main functionality
294
+ submit_button.click(
295
+ fn=process_user_input,
296
+ inputs=[input_textbox, mode_dropdown],
297
+ outputs=output_textbox
298
+ )
299
+
300
+ connect_button.click(
301
+ fn=connect_to_train_api,
302
+ outputs=status_textbox
303
+ )
304
+
305
+ # Allow Enter key submission
306
+ input_textbox.submit(
307
+ fn=process_user_input,
308
+ inputs=[input_textbox, mode_dropdown],
309
+ outputs=output_textbox
310
+ )
311
+
312
+ demo.launch(show_error=True)