mgokg commited on
Commit
a1a18be
·
verified ·
1 Parent(s): ce99afe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -71
app.py CHANGED
@@ -1,86 +1,108 @@
1
- import requests
2
  import json
3
- from datetime import datetime
4
- from mcp.server.fastmcp import FastMCP
 
 
 
 
5
 
6
- # Initialize FastMCP Server
7
- mcp = FastMCP("DB-Train-Helper")
 
8
 
9
- BASE_URL = "https://v6.db.transport.rest"
 
 
 
 
 
 
 
 
 
 
10
 
11
- def get_station_id(query):
12
- try:
13
- response = requests.get(f"{BASE_URL}/locations", params={
14
- "poi": "false",
15
- "addresses": "false",
16
- "query": query
17
- })
18
- data = response.json()
19
- if data and len(data) > 0:
20
- return data[0]["id"]
21
- return None
22
- except:
23
- return None
 
 
 
 
 
24
 
25
- def format_duration(departure_str, arrival_str):
26
- fmt = "%Y-%m-%dT%H:%M:%S%z"
27
- dep = datetime.strptime(departure_str, fmt)
28
- arr = datetime.strptime(arrival_str, fmt)
29
- diff = arr - dep
30
- hours, remainder = divmod(diff.seconds, 3600)
31
- minutes = remainder // 60
32
- return f"{hours}h {minutes}min" if hours > 0 else f"{minutes}min"
33
 
34
- @mcp.tool()
35
- def get_train_connections(start_loc: str, dest_loc: str) -> str:
36
- """
37
- Fetches train connections between two German cities and returns
38
- a beautifully formatted HTML snippet with departure times and platforms.
39
- """
40
- start_id = get_station_id(start_loc)
41
- dest_id = get_station_id(dest_loc)
42
 
43
- if not start_id or not dest_id:
44
- return f"Error: Could not find station IDs for {start_loc} or {dest_loc}."
 
 
 
 
 
 
45
 
46
  try:
47
- response = requests.get(f"{BASE_URL}/journeys", params={
48
- "from": start_id,
49
- "to": dest_id,
50
- "results": 3
51
- })
52
- journey_data = response.json()
53
 
54
- connections_list = []
55
- for j in journey_data.get("journeys", []):
56
- legs = j.get("legs", [])
57
- if not legs: continue
58
-
59
- first, last = legs[0], legs[-1]
60
 
61
- conn_obj = {
62
- "departure": datetime.strptime(first["departure"], "%Y-%m-%dT%H:%M:%S%z").strftime("%H:%M"),
63
- "arrival": datetime.strptime(last["arrival"], "%Y-%m-%dT%H:%M:%S%z").strftime("%H:%M"),
64
- "startLocation": first["origin"]["name"],
65
- "destination": last["destination"]["name"],
66
- "duration": format_duration(first["departure"], last["arrival"]),
67
- "platform": f"Gl. {first.get('departurePlatform', '-')}"
68
- }
69
- connections_list.append(conn_obj)
 
 
 
 
 
 
 
70
 
71
- # We return the HTML just as your original code did
72
- connections_json = json.dumps(connections_list)
73
-
74
- return f"""
75
- <div style="font-family: sans-serif; background: #1a1a2e; padding: 20px; border-radius: 10px;">
76
- <h2 style="color: white;">Connections: {start_loc} to {dest_loc}</h2>
77
- {connections_json}
78
- <p style="color: #a0a0a0;">(Code generated for HTML rendering)</p>
79
- </div>
80
- """
81
 
82
  except Exception as e:
83
- return f"Error fetching journeys: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- if __name__ == "__main__":
86
- mcp.run()
 
1
+ import os
2
  import json
3
+ import asyncio
4
+ import gradio as gr
5
+ from google import genai
6
+ from google.genai import types
7
+ from mcp import ClientSession
8
+ from mcp.client.sse import sse_client
9
 
10
+ # --- CONFIGURATION ---
11
+ MCP_SERVER_URL = "https://mgokg/DB_API_MCP.hf.space/sse"
12
+ GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
13
 
14
+ # 1. Helper to talk to your MCP Server
15
+ async def call_mcp_train_tool(start_loc, dest_loc):
16
+ async with sse_client(MCP_SERVER_URL) as (read_stream, write_stream):
17
+ async with ClientSession(read_stream, write_stream) as session:
18
+ await session.initialize()
19
+ result = await session.call_tool("get_train_connections", arguments={
20
+ "start_loc": start_loc,
21
+ "dest_loc": dest_loc
22
+ })
23
+ # result.content[0].text contains the JSON list from your MCP server
24
+ return result.content[0].text
25
 
26
+ # 2. Tool Definition for Gemini
27
+ # This tells Gemini what the tool does and what arguments it needs
28
+ train_tool_declaration = types.Tool(
29
+ function_declarations=[
30
+ types.FunctionDeclaration(
31
+ name="get_train_connections",
32
+ description="Get live train connections between two German cities/stations.",
33
+ parameters=types.Schema(
34
+ type="OBJECT",
35
+ properties={
36
+ "start_loc": types.Schema(type="STRING", description="The departure station"),
37
+ "dest_loc": types.Schema(type="STRING", description="The destination station")
38
+ },
39
+ required=["start_loc", "dest_loc"]
40
+ )
41
+ )
42
+ ]
43
+ )
44
 
45
+ def generate(input_text):
46
+ if not GEMINI_API_KEY:
47
+ return "Error: GEMINI_API_KEY is not set.", ""
 
 
 
 
 
48
 
49
+ client = genai.Client(api_key=GEMINI_API_KEY)
50
+ model_id = "gemini-2.0-flash-exp" # Or gemini-flash-latest
 
 
 
 
 
 
51
 
52
+ # We use a chat session to handle the multi-turn tool loop
53
+ chat = client.chats.create(
54
+ model=model_id,
55
+ config=types.GenerateContentConfig(
56
+ tools=[train_tool_declaration, types.Tool(google_search=types.GoogleSearch())],
57
+ temperature=0.4
58
+ )
59
+ )
60
 
61
  try:
62
+ # Step 1: Send initial request to Gemini
63
+ response = chat.send_message(input_text)
 
 
 
 
64
 
65
+ # Step 2: Check if Gemini wants to use a tool
66
+ # We loop in case Gemini needs multiple tool calls
67
+ while response.candidates[0].content.parts[0].tool_call:
68
+ tool_call = response.candidates[0].content.parts[0].tool_call
 
 
69
 
70
+ if tool_call.name == "get_train_connections":
71
+ # Extract arguments Gemini provided
72
+ args = tool_call.args
73
+
74
+ # Execute the MCP call (running async code in sync Gradio)
75
+ train_data = asyncio.run(call_mcp_train_tool(args["start_loc"], args["dest_loc"]))
76
+
77
+ # Send the tool result back to Gemini
78
+ response = chat.send_message(
79
+ types.Part.from_function_response(
80
+ name="get_train_connections",
81
+ response={"result": train_data}
82
+ )
83
+ )
84
+ else:
85
+ break # Handle other tools or exit
86
 
87
+ return response.text, ""
 
 
 
 
 
 
 
 
 
88
 
89
  except Exception as e:
90
+ return f"Error during generation: {str(e)}", ""
91
+
92
+ # --- GRADIO UI ---
93
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
94
+ gr.Markdown("# 🚄 Gemini 2.0 + DB Train MCP")
95
+ gr.Markdown("Ask about train connections (e.g., 'When is the next train from Munich to Berlin?')")
96
+
97
+ output_textbox = gr.Markdown(label="Response")
98
+ input_textbox = gr.Textbox(lines=2, label="Your Question")
99
+ submit_button = gr.Button("Send", variant="primary")
100
+
101
+ submit_button.click(
102
+ fn=generate,
103
+ inputs=input_textbox,
104
+ outputs=[output_textbox, input_textbox]
105
+ )
106
 
107
+ if __name__ == '__main__':
108
+ demo.launch()