mgokg commited on
Commit
92ec150
Β·
verified Β·
1 Parent(s): c12ff7b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -1
app.py CHANGED
@@ -1,3 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import base64
2
  import gradio as gr
3
  import os
@@ -62,4 +191,4 @@ if __name__ == '__main__':
62
  submit_button = gr.Button("send")
63
  submit_button.click(fn=generate,inputs=input_textbox,outputs=[output_textbox, input_textbox])
64
  demo.launch(show_error=True)
65
-
 
1
+ import os
2
+ import json
3
+ import asyncio
4
+ import gradio as gr
5
+ from google import genai
6
+ from google.genai import types
7
+ from mcp import ClientSession
8
+ from mcp.client.sse import sse_client
9
+
10
+ # --- CONFIGURATION ---
11
+ MCP_SERVER_URL = "https://mgokg-db-api-mcp.hf.space/sse"
12
+ # Ensure your API key is set in your environment variables
13
+ GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
14
+
15
+ async def call_mcp_tool(start_loc, dest_loc):
16
+ """Connects to the MCP server to fetch train data."""
17
+ async with sse_client(MCP_SERVER_URL) as (read_stream, write_stream):
18
+ async with ClientSession(read_stream, write_stream) as session:
19
+ await session.initialize()
20
+ result = await session.call_tool("get_train_connections", arguments={
21
+ "start_loc": start_loc,
22
+ "dest_loc": dest_loc
23
+ })
24
+ return result.content[0].text
25
+
26
+ async def generate(input_text):
27
+ """Async generator to handle tool loops and UI updates."""
28
+ if not GEMINI_API_KEY:
29
+ yield "### Error\nGEMINI_API_KEY is not set.", ""
30
+ return
31
+
32
+ client = genai.Client(api_key=GEMINI_API_KEY)
33
+ # Using Gemini 3 Flash as requested
34
+ model_id = "gemini-flash-latest"
35
+
36
+ # Define the train tool schema
37
+ train_tool = types.Tool(
38
+ function_declarations=[
39
+ types.FunctionDeclaration(
40
+ name="get_train_connections",
41
+ description="Finds live train connections between two German stations.",
42
+ parameters={
43
+ "type": "OBJECT",
44
+ "properties": {
45
+ "start_loc": {"type": "STRING", "description": "Departure city"},
46
+ "dest_loc": {"type": "STRING", "description": "Destination city"}
47
+ },
48
+ "required": ["start_loc", "dest_loc"]
49
+ }
50
+ )
51
+ ]
52
+ )
53
+
54
+ try:
55
+ yield "πŸ” Thinking...", ""
56
+
57
+ config = types.GenerateContentConfig(
58
+ tools=[train_tool, types.Tool(google_search=types.GoogleSearch())],
59
+ temperature=0.3
60
+ )
61
+
62
+ chat = client.chats.create(model=model_id, config=config)
63
+ response = await chat.send_message_async(input_text)
64
+
65
+ # --- Manual Tool Loop ---
66
+ max_turns = 5
67
+ for _ in range(max_turns):
68
+ if not response.candidates[0].content.parts:
69
+ break
70
+
71
+ tool_calls = [p.tool_call for p in response.candidates[0].content.parts if p.tool_call]
72
+ if not tool_calls:
73
+ break
74
+
75
+ tool_responses = []
76
+ for call in tool_calls:
77
+ if call.name == "get_train_connections":
78
+ yield f"πŸš„ Fetching train data for **{call.args['start_loc']}** to **{call.args['dest_loc']}**...", ""
79
+
80
+ # Execute the MCP call
81
+ train_data = await call_mcp_tool(call.args["start_loc"], call.args["dest_loc"])
82
+
83
+ tool_responses.append(
84
+ types.Part.from_function_response(
85
+ name=call.name,
86
+ response={"result": train_data}
87
+ )
88
+ )
89
+
90
+ if tool_responses:
91
+ yield "πŸ“ Processing train data...", ""
92
+ response = await chat.send_message_async(tool_responses)
93
+ else:
94
+ break
95
+
96
+ # Final output
97
+ yield response.text, ""
98
+
99
+ except Exception as e:
100
+ yield f"### Logic Error\n{str(e)}", ""
101
+
102
+ # --- GRADIO UI ---
103
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
104
+ gr.Markdown("# πŸš„ Gemini 3 Flash + Live DB Trains")
105
+ gr.Markdown("I search the web and check live train connections via an MCP Server.")
106
+
107
+ with gr.Column():
108
+ output_textbox = gr.Markdown(label="Response")
109
+ input_textbox = gr.Textbox(
110
+ lines=3,
111
+ label="Ask me anything",
112
+ placeholder="e.g. 'Is there a train from Berlin to Munich tonight?'"
113
+ )
114
+ submit_button = gr.Button("Send", variant="primary")
115
+
116
+ # Gradio handles async generators automatically
117
+ submit_button.click(
118
+ fn=generate,
119
+ inputs=input_textbox,
120
+ outputs=[output_textbox, input_textbox]
121
+ )
122
+
123
+ if __name__ == '__main__':
124
+ demo.launch()
125
+
126
+
127
+
128
+
129
+ """
130
  import base64
131
  import gradio as gr
132
  import os
 
191
  submit_button = gr.Button("send")
192
  submit_button.click(fn=generate,inputs=input_textbox,outputs=[output_textbox, input_textbox])
193
  demo.launch(show_error=True)
194
+ """