mgokg commited on
Commit
e0b699c
·
verified ·
1 Parent(s): 1b0d7b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -74
app.py CHANGED
@@ -1,84 +1,86 @@
1
- import base64
2
- import gradio as gr
3
- import os
4
  import json
5
- from google import genai
6
- from google.genai import types
7
- from gradio_client import Client
8
 
9
- def clean_json_string(json_str):
10
- """
11
- Removes any comments or prefixes before the actual JSON content.
12
- """
13
- # Find the first occurrence of '{'
14
- json_start = json_str.find('{')
15
- if json_start == -1:
16
- # If no '{' is found, try with '[' for arrays
17
- json_start = json_str.find('[')
18
- if json_start == -1:
19
- return json_str # Return original if no JSON markers found
20
-
21
- # Extract everything from the first JSON marker
22
- cleaned_str = json_str[json_start:]
23
- return cleaned_str
24
- # Verify it's valid JSON
25
- try:
26
- json.loads(cleaned_str)
27
- return cleaned_str
28
- except json.JSONDecodeError:
29
- return json_str # Return original if cleaning results in invalid JSON
30
 
31
- def generate(input_text):
 
 
32
  try:
33
- client = genai.Client(
34
- api_key=os.environ.get("GEMINI_API_KEY"),
35
- )
36
- except Exception as e:
37
- return f"Error initializing client: {e}. Make sure GEMINI_API_KEY is set."
 
 
 
 
 
 
38
 
39
- model = "gemini-flash-latest"
40
- contents = [
41
- types.Content(
42
- role="user",
43
- parts=[
44
- types.Part.from_text(text=input_text),
45
- ],
46
- ),
47
- ]
48
- tools = [
49
- types.Tool(google_search=types.GoogleSearch()),
50
- ]
51
- generate_content_config = types.GenerateContentConfig(
52
- temperature=0.4,
53
- thinking_config = types.ThinkingConfig(
54
- thinking_budget=0,
55
- ),
56
- tools=tools,
57
- response_mime_type="text/plain",
58
- )
59
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- response_text = ""
62
  try:
63
- for chunk in client.models.generate_content_stream(
64
- model=model,
65
- contents=contents,
66
- config=generate_content_config,
67
- ):
68
- response_text += chunk.text
69
- except Exception as e:
70
- return f"Error during generation: {e}"
71
- data = clean_json_string(response_text)
72
- data = data[:-1]
73
- return response_text, ""
74
-
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- if __name__ == '__main__':
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- with gr.Blocks() as demo:
79
- title=gr.Markdown("# Gemini 2.0 Flash + Websearch")
80
- output_textbox = gr.Markdown()
81
- input_textbox = gr.Textbox(lines=3, label="", placeholder="Enter message here...")
82
- submit_button = gr.Button("send")
83
- submit_button.click(fn=generate,inputs=input_textbox,outputs=[output_textbox, input_textbox])
84
- demo.launch(show_error=True)
 
1
+ import requests
 
 
2
  import json
3
+ from datetime import datetime
4
+ from mcp.server.fastmcp import FastMCP
 
5
 
6
+ # Initialize FastMCP Server
7
+ mcp = FastMCP("DB-Train-Helper")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ BASE_URL = "https://v6.db.transport.rest"
10
+
11
+ def get_station_id(query):
12
  try:
13
+ response = requests.get(f"{BASE_URL}/locations", params={
14
+ "poi": "false",
15
+ "addresses": "false",
16
+ "query": query
17
+ })
18
+ data = response.json()
19
+ if data and len(data) > 0:
20
+ return data[0]["id"]
21
+ return None
22
+ except:
23
+ return None
24
 
25
+ def format_duration(departure_str, arrival_str):
26
+ fmt = "%Y-%m-%dT%H:%M:%S%z"
27
+ dep = datetime.strptime(departure_str, fmt)
28
+ arr = datetime.strptime(arrival_str, fmt)
29
+ diff = arr - dep
30
+ hours, remainder = divmod(diff.seconds, 3600)
31
+ minutes = remainder // 60
32
+ return f"{hours}h {minutes}min" if hours > 0 else f"{minutes}min"
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ @mcp.tool()
35
+ def get_train_connections(start_loc: str, dest_loc: str) -> str:
36
+ """
37
+ Fetches train connections between two German cities and returns
38
+ a beautifully formatted HTML snippet with departure times and platforms.
39
+ """
40
+ start_id = get_station_id(start_loc)
41
+ dest_id = get_station_id(dest_loc)
42
+
43
+ if not start_id or not dest_id:
44
+ return f"Error: Could not find station IDs for {start_loc} or {dest_loc}."
45
 
 
46
  try:
47
+ response = requests.get(f"{BASE_URL}/journeys", params={
48
+ "from": start_id,
49
+ "to": dest_id,
50
+ "results": 3
51
+ })
52
+ journey_data = response.json()
53
+
54
+ connections_list = []
55
+ for j in journey_data.get("journeys", []):
56
+ legs = j.get("legs", [])
57
+ if not legs: continue
58
+
59
+ first, last = legs[0], legs[-1]
60
+
61
+ conn_obj = {
62
+ "departure": datetime.strptime(first["departure"], "%Y-%m-%dT%H:%M:%S%z").strftime("%H:%M"),
63
+ "arrival": datetime.strptime(last["arrival"], "%Y-%m-%dT%H:%M:%S%z").strftime("%H:%M"),
64
+ "startLocation": first["origin"]["name"],
65
+ "destination": last["destination"]["name"],
66
+ "duration": format_duration(first["departure"], last["arrival"]),
67
+ "platform": f"Gl. {first.get('departurePlatform', '-')}"
68
+ }
69
+ connections_list.append(conn_obj)
70
 
71
+ # We return the HTML just as your original code did
72
+ connections_json = json.dumps(connections_list)
73
+
74
+ return f"""
75
+ <div style="font-family: sans-serif; background: #1a1a2e; padding: 20px; border-radius: 10px;">
76
+ <h2 style="color: white;">Connections: {start_loc} to {dest_loc}</h2>
77
+ {connections_json}
78
+ <p style="color: #a0a0a0;">(Code generated for HTML rendering)</p>
79
+ </div>
80
+ """
81
+
82
+ except Exception as e:
83
+ return f"Error fetching journeys: {str(e)}"
84
 
85
+ if __name__ == "__main__":
86
+ mcp.run()