mgokg commited on
Commit
1f1770e
·
verified ·
1 Parent(s): 193c4e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -281
app.py CHANGED
@@ -1,305 +1,176 @@
1
- import base64
2
  import gradio as gr
3
  import os
4
- import json
5
- import asyncio
6
- import aiohttp
7
- from typing import Dict, Any, Optional, List
8
  from google import genai
9
  from google.genai import types
10
- from mcp import ClientSession, StdioServerParameters
11
- from mcp.client.stdio import stdio_client
12
- from mcp.client.streamable_http import streamablehttp_client
13
- import nest_asyncio
14
 
15
- # Allow nested event loops for Jupyter notebooks
16
- nest_asyncio.apply()
17
 
18
- class MCPTrainClient:
19
- """MCP Client for train timetable API using Streamable HTTP transport"""
20
-
21
- def __init__(self):
22
- self.session = None
23
- self.exit_stack = None
24
- self.mcp_server_url = "https://mgokg-db-timetable-api.hf.space/gradio_api/mcp/"
25
-
26
- async def connect(self) -> str:
27
- """Connect to the MCP server using Streamable HTTP transport"""
28
- try:
29
- if self.exit_stack:
30
- await self.exit_stack.aclose()
31
-
32
- self.exit_stack = AsyncExitStack()
33
-
34
- # Connect using Streamable HTTP transport - FIXED VERSION
35
- self.session = await self.exit_stack.enter_async_context(
36
- streamablehttp_client(self.mcp_server_url)
37
- )
38
-
39
- # Initialize the session
40
- await self.session.initialize()
41
-
42
- # List available tools
43
- response = await self.session.list_tools()
44
- tool_names = [tool.name for tool in response.tools]
45
-
46
- return f"✅ Connected to MCP server. Available tools: {', '.join(tool_names)}"
47
-
48
- except Exception as e:
49
- return f"❌ Failed to connect to MCP server: {str(e)}"
50
-
51
- async def get_train_connections(self, departure: str, destination: str) -> str:
52
- """Get train connections using the MCP tool"""
53
- if not self.session:
54
- return "❌ Not connected to MCP server. Please connect first."
55
-
56
- try:
57
- # Call the db_timetable_api_ui_wrapper tool
58
- result = await self.session.call_tool(
59
- "db_timetable_api_ui_wrapper",
60
- arguments={"dep": departure, "dest": destination}
61
- )
62
-
63
- if result.content and len(result.content) > 0:
64
- return result.content[0].text
65
- else:
66
- return "❌ No results returned from train API"
67
-
68
- except Exception as e:
69
- return f"❌ Error calling train API: {str(e)}"
70
-
71
- async def disconnect(self):
72
- """Disconnect from the MCP server"""
73
- if self.exit_stack:
74
- await self.exit_stack.aclose()
75
- self.exit_stack = None
76
- self.session = None
77
 
 
 
78
 
79
- class GeminiMCPIntegration:
80
- """Integration class that combines Gemini with MCP tools"""
81
-
82
- def __init__(self):
83
- self.mcp_client = MCPTrainClient()
84
- self.gemini_client = None
85
- self.loop = asyncio.new_event_loop()
86
- asyncio.set_event_loop(self.loop)
87
-
88
- def initialize_gemini(self) -> Optional[str]:
89
- """Initialize Gemini client"""
90
- try:
91
- api_key = os.environ.get("GEMINI_API_KEY")
92
- if not api_key:
93
- return "GEMINI_API_KEY not found in environment variables"
94
-
95
- self.gemini_client = genai.Client(api_key=api_key)
96
- return None
97
- except Exception as e:
98
- return f"Error initializing Gemini client: {str(e)}"
99
-
100
- def connect_to_mcp(self) -> str:
101
- """Connect to MCP server (synchronous wrapper)"""
102
- return self.loop.run_until_complete(self.mcp_client.connect())
103
-
104
- def extract_train_stations(self, text: str) -> Optional[Dict[str, str]]:
105
- """Extract departure and destination stations from text"""
106
- text_lower = text.lower()
107
-
108
- # Common patterns for train queries
109
- patterns = [
110
- r'(?:from|von)\s+(.+?)\s+(?:to|nach)\s+(.+)',
111
- r'(.+?)\s+(?:to|nach)\s+(.+)',
112
- r'(?:train|zug|bahn)(?:\s+from|\s+von)?\s+(.+?)\s+(?:to|nach)\s+(.+)',
113
- ]
114
-
115
- import re
116
- for pattern in patterns:
117
- match = re.search(pattern, text_lower)
118
- if match:
119
- dep = match.group(1).strip()
120
- dest = match.group(2).strip()
121
-
122
- # Clean up common suffixes
123
- for suffix in ['?', '.', '!', 'please', 'bitte']:
124
- if dep.endswith(suffix):
125
- dep = dep[:-len(suffix)].strip()
126
- if dest.endswith(suffix):
127
- dest = dest[:-len(suffix)].strip()
128
-
129
- return {"departure": dep, "destination": dest}
130
-
131
- return None
132
-
133
- def process_query(self, query: str, mode: str = "auto") -> str:
134
- """Process user query with appropriate tools"""
135
-
136
- # Initialize Gemini if not already done
137
- if not self.gemini_client:
138
- error = self.initialize_gemini()
139
- if error:
140
- return f"❌ {error}"
141
-
142
- # Check if this is a train query
143
- station_info = self.extract_train_stations(query)
144
-
145
- if station_info and (mode == "auto" or mode == "train"):
146
- # Use MCP for train connections
147
- if not self.mcp_client.session:
148
- connect_result = self.connect_to_mcp()
149
- if "❌" in connect_result:
150
- return connect_result
151
-
152
- train_result = self.loop.run_until_complete(
153
- self.mcp_client.get_train_connections(
154
- station_info["departure"],
155
- station_info["destination"]
156
- )
157
- )
158
-
159
- # Also get Gemini's interpretation
160
- gemini_prompt = f"User asked: '{query}'. I found these train connections: {train_result}. Please provide a helpful summary and any additional context."
161
-
162
- try:
163
- contents = [types.Content(role="user", parts=[types.Part.from_text(text=gemini_prompt)])]
164
- config = types.GenerateContentConfig(
165
- temperature=0.4,
166
- response_mime_type="text/plain"
167
- )
168
-
169
- gemini_response = ""
170
- for chunk in self.gemini_client.models.generate_content_stream(
171
- model="gemini-flash-latest",
172
- contents=contents,
173
- config=config
174
- ):
175
- if hasattr(chunk, 'text'):
176
- gemini_response += chunk.text
177
-
178
- return f"🚂 **Train Connections:**\n\n{train_result}\n\n🤖 **AI Assistant:**\n\n{gemini_response}"
179
-
180
- except Exception as e:
181
- return f"🚂 **Train Connections:**\n\n{train_result}\n\n⚠️ Note: Could not generate additional AI insights: {str(e)}"
182
-
183
- elif mode == "train" and not station_info:
184
- return "❌ Could not extract station names from your query. Please use format like 'Train from [station] to [station]'"
185
-
186
- else:
187
- # Use original Gemini + websearch functionality
188
- return self.generate_with_websearch(query)
189
-
190
- def generate_with_websearch(self, query: str) -> str:
191
- """Original Gemini + websearch functionality"""
192
- try:
193
- contents = [types.Content(role="user", parts=[types.Part.from_text(text=query)])]
194
- tools = [types.Tool(google_search=types.GoogleSearch())]
195
-
196
- config = types.GenerateContentConfig(
197
- temperature=0.4,
198
- thinking_config=types.ThinkingConfig(thinking_budget=0),
199
- tools=tools,
200
- response_mime_type="text/plain"
201
- )
202
-
203
- response_text = ""
204
- for chunk in self.gemini_client.models.generate_content_stream(
205
- model="gemini-flash-latest",
206
- contents=contents,
207
- config=config
208
- ):
209
- if hasattr(chunk, 'text'):
210
- response_text += chunk.text
211
-
212
- return f"🔍 **Web Search Results:**\n\n{response_text}"
213
-
214
- except Exception as e:
215
- return f"❌ Error during web search: {str(e)}"
216
 
 
 
 
 
 
 
 
 
 
 
217
 
218
- # Global instance
219
- mcp_integration = GeminiMCPIntegration()
 
 
 
 
 
 
 
 
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
- def process_user_input(input_text: str, mode: str = "auto") -> str:
223
- """Process user input"""
224
- if not input_text.strip():
225
- return "Please enter a message."
226
-
227
- return mcp_integration.process_query(input_text, mode)
228
 
 
229
 
230
- def connect_to_train_api() -> str:
231
- """Connect to train API MCP server"""
232
- return mcp_integration.connect_to_mcp()
 
233
 
 
 
 
 
 
 
 
 
 
234
 
235
- if __name__ == '__main__':
236
- with gr.Blocks(title="Gemini 2.0 Flash + Websearch + Train Connections (MCP)") as demo:
237
- gr.Markdown("# 🤖 Gemini 2.0 Flash + Websearch + Train Connections (MCP)")
238
- gr.Markdown("Advanced AI assistant with web search and real-time train connections via MCP.")
239
-
240
- with gr.Row():
241
- with gr.Column(scale=3):
242
- input_textbox = gr.Textbox(
243
- lines=3,
244
- label="Your Message",
245
- placeholder="Enter your message here...\nExamples:\n- 'Train from Schweinfurt HBF to Oerlenbach'\n- 'What's the weather in Berlin?'\n- 'Latest news about AI'"
246
- )
247
- with gr.Column(scale=1):
248
- mode_dropdown = gr.Dropdown(
249
- choices=["auto", "websearch", "train", "both"],
250
- value="auto",
251
- label="Mode",
252
- info="Choose processing mode"
253
- )
254
-
255
- with gr.Row():
256
- submit_button = gr.Button("Send", variant="primary", scale=3)
257
- connect_button = gr.Button("Connect to Train API", scale=1)
258
-
259
- output_textbox = gr.Markdown(
260
- label="Response",
261
- show_copy_button=True
262
  )
263
-
264
- status_textbox = gr.Textbox(
265
- label="Connection Status",
266
- interactive=False,
267
- max_lines=2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
- # Examples
271
- gr.Markdown("### 🎯 Quick Examples")
272
- with gr.Row():
273
- example1 = gr.Button("🚂 Train: Schweinfurt → Oerlenbach")
274
- example2 = gr.Button("🌤️ Weather in Munich")
275
- example3 = gr.Button("📰 Latest AI news")
276
- example4 = gr.Button("🏙️ Restaurants in Berlin")
277
-
278
- def set_example(example_text):
279
- return example_text
280
-
281
- example1.click(fn=lambda: set_example("Train connections from Schweinfurt HBF to Oerlenbach"), outputs=input_textbox)
282
- example2.click(fn=lambda: set_example("What's the weather in Munich today?"), outputs=input_textbox)
283
- example3.click(fn=lambda: set_example("Latest news about artificial intelligence developments"), outputs=input_textbox)
284
- example4.click(fn=lambda: set_example("Find good restaurants in Berlin city center"), outputs=input_textbox)
285
 
286
- # Main functionality
287
  submit_button.click(
288
- fn=process_user_input,
289
- inputs=[input_textbox, mode_dropdown],
290
- outputs=output_textbox
291
  )
292
 
293
- connect_button.click(
294
- fn=connect_to_train_api,
295
- outputs=status_textbox
296
- )
297
-
298
- # Allow Enter key submission
299
- input_textbox.submit(
300
- fn=process_user_input,
301
- inputs=[input_textbox, mode_dropdown],
302
- outputs=output_textbox
303
- )
304
-
305
  demo.launch(show_error=True)
 
 
1
  import gradio as gr
2
  import os
 
 
 
 
3
  from google import genai
4
  from google.genai import types
5
+ from gradio_client import Client
 
 
 
6
 
7
+ # --- 1. Robust Client Initialization for the External Tool ---
8
+ db_client = None
9
 
10
+ def init_db_client():
11
+ global db_client
12
+ try:
13
+ # Use the full URL to avoid DNS resolution issues
14
+ print("Connecting to DB Timetable API...")
15
+ db_client = Client("https://mgokg-db-timetable-api.hf.space/")
16
+ print("Successfully connected.")
17
+ except Exception as e:
18
+ print(f"Warning: Could not connect to DB Timetable API: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ # Attempt connection on startup
21
+ init_db_client()
22
 
23
+ # --- 2. Tool Definition ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ def get_train_connection(dep: str, dest: str):
26
+ """
27
+ Fetches the train timetable between two cities using the external API.
28
+ """
29
+ global db_client
30
+ # Retry connection if it failed initially
31
+ if db_client is None:
32
+ init_db_client()
33
+ if db_client is None:
34
+ return "Error: The train database is currently unreachable."
35
 
36
+ try:
37
+ # Calling the specific endpoint mentioned in the MCP docs
38
+ result = db_client.predict(
39
+ dep=dep,
40
+ dest=dest,
41
+ api_name="/db_timetable_api_ui_wrapper"
42
+ )
43
+ return result
44
+ except Exception as e:
45
+ return f"Error fetching timetable: {str(e)}"
46
 
47
+ # Define the tool schema for Gemini
48
+ train_tool = types.FunctionDeclaration(
49
+ name="get_train_connection",
50
+ description="Find train connections and timetables between a start location (dep) and a destination (dest).",
51
+ parameters=types.Schema(
52
+ type=types.Type.OBJECT,
53
+ properties={
54
+ "dep": types.Schema(type=types.Type.STRING, description="Departure city or station"),
55
+ "dest": types.Schema(type=types.Type.STRING, description="Destination city or station"),
56
+ },
57
+ required=["dep", "dest"]
58
+ )
59
+ )
60
 
61
+ # Map string name to the actual function
62
+ tools_map = {
63
+ "get_train_connection": get_train_connection
64
+ }
 
 
65
 
66
+ # --- 3. Generation Logic (v1alpha) ---
67
 
68
+ def generate(input_text):
69
+ if not input_text:
70
+ yield "", ""
71
+ return
72
 
73
+ try:
74
+ # Initialize Client with v1alpha
75
+ client = genai.Client(
76
+ api_key=os.environ.get("GEMINI_API_KEY"),
77
+ #http_options={'api_version': 'v1alpha'} # <--- ENABLE v1alpha HERE
78
+ )
79
+ except Exception as e:
80
+ yield f"Error initializing client: {e}. Make sure GEMINI_API_KEY is set.", input_text
81
+ return
82
 
83
+ model = "gemini-flash-latest"
84
+
85
+ # Configure tools
86
+ tools = [
87
+ types.Tool(
88
+ google_search=types.GoogleSearch(),
89
+ function_declarations=[train_tool]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  )
91
+ ]
92
+
93
+ generate_content_config = types.GenerateContentConfig(
94
+ temperature=0.4,
95
+ tools=tools,
96
+ )
97
+
98
+ contents = [
99
+ types.Content(
100
+ role="user",
101
+ parts=[types.Part.from_text(text=input_text)],
102
+ ),
103
+ ]
104
+
105
+ response_text = ""
106
+
107
+ try:
108
+ # First API Call
109
+ response = client.models.generate_content(
110
+ model=model,
111
+ contents=contents,
112
+ config=generate_content_config,
113
  )
114
+
115
+ # Handle Function Calling
116
+ if response.candidates and response.candidates[0].content.parts:
117
+ part = response.candidates[0].content.parts[0]
118
+
119
+ if part.function_call:
120
+ fn_name = part.function_call.name
121
+ fn_args = part.function_call.args
122
+
123
+ if fn_name in tools_map:
124
+ # Execute tool
125
+ api_result = tools_map[fn_name](**fn_args)
126
+
127
+ # Append history
128
+ contents.append(response.candidates[0].content)
129
+ contents.append(
130
+ types.Content(
131
+ role="tool",
132
+ parts=[
133
+ types.Part.from_function_response(
134
+ name=fn_name,
135
+ response={"result": api_result}
136
+ )
137
+ ]
138
+ )
139
+ )
140
+
141
+ # Second API Call (Streamed)
142
+ stream = client.models.generate_content_stream(
143
+ model=model,
144
+ contents=contents,
145
+ config=generate_content_config
146
+ )
147
+
148
+ for chunk in stream:
149
+ response_text += chunk.text
150
+ yield response_text, ""
151
+ return
152
+
153
+ # Handle standard response
154
+ if response.text:
155
+ yield response.text, ""
156
+
157
+ except Exception as e:
158
+ yield f"Error during generation: {e}", input_text
159
+
160
+ # --- 4. UI Setup ---
161
+
162
+ if __name__ == '__main__':
163
+ with gr.Blocks() as demo:
164
+ gr.Markdown("# Gemini 2.0 Flash (v1alpha) + DB Timetable Tool")
165
 
166
+ output_textbox = gr.Markdown()
167
+ input_textbox = gr.Textbox(lines=3, label="", placeholder="Ask for a train connection (e.g., 'Train from Berlin to Frankfurt')...")
168
+ submit_button = gr.Button("Send")
 
 
 
 
 
 
 
 
 
 
 
 
169
 
 
170
  submit_button.click(
171
+ fn=generate,
172
+ inputs=input_textbox,
173
+ outputs=[output_textbox, input_textbox]
174
  )
175
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  demo.launch(show_error=True)