mgokg commited on
Commit
34c40ee
·
verified ·
1 Parent(s): 6942776

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -184
app.py CHANGED
@@ -1,187 +1,3 @@
1
- import os
2
- import asyncio
3
- import json
4
- import gradio as gr
5
-
6
- from google import genai
7
- from google.genai import types
8
-
9
- from mcp import ClientSession
10
- from mcp.client.sse import sse_client
11
-
12
- # --- CONFIGURATION ---
13
- MCP_SERVER_URL = "https://mgokg-db-timetable-api.hf.space/gradio_api/mcp/"
14
- GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
15
-
16
-
17
- async def call_mcp_tool(start_loc: str, dest_loc: str) -> str:
18
- """Connect to the MCP server to fetch train data."""
19
- async with sse_client(MCP_SERVER_URL) as (read_stream, write_stream):
20
- async with ClientSession(read_stream, write_stream) as session:
21
- await session.initialize()
22
- result = await session.call_tool(
23
- "get_train_connections",
24
- arguments={"start_loc": start_loc, "dest_loc": dest_loc}
25
- )
26
-
27
- # Defensive handling of potential result formats
28
- try:
29
- if hasattr(result, "content") and result.content:
30
- first = result.content[0]
31
- if hasattr(first, "text") and first.text:
32
- return first.text
33
- # Fallback: serialize unknown structures
34
- return json.dumps(
35
- [getattr(c, "__dict__", str(c)) for c in result.content],
36
- ensure_ascii=False
37
- )
38
- except Exception:
39
- pass
40
-
41
- return "No train data returned from MCP."
42
-
43
-
44
- async def generate(input_text: str):
45
- """
46
- Async generator handling the tool loop with the corrected aio client.
47
- Yields (markdown_message, input_box_echo) tuples for Gradio.
48
- """
49
- if not GEMINI_API_KEY:
50
- yield "### Error\n`GEMINI_API_KEY` is not set.", ""
51
- return
52
-
53
- client = genai.Client(api_key=GEMINI_API_KEY)
54
- model_id = "gemini-2.5-flash" # Ensure this is enabled for your API key
55
-
56
- # Declare the function/tool to the model
57
- train_tool = types.Tool(
58
- function_declarations=[
59
- types.FunctionDeclaration(
60
- name="get_train_connections",
61
- description="Finds live train connections between two German stations.",
62
- parameters={
63
- "type": "OBJECT",
64
- "properties": {
65
- "start_loc": {"type": "STRING", "description": "Departure city"},
66
- "dest_loc": {"type": "STRING", "description": "Destination city"},
67
- },
68
- "required": ["start_loc", "dest_loc"],
69
- },
70
- )
71
- ]
72
- )
73
-
74
- try:
75
- yield "🔍 Thinking...", ""
76
-
77
- config = types.GenerateContentConfig(
78
- tools=[train_tool, types.Tool(google_search=types.GoogleSearch())],
79
- temperature=0.3,
80
- )
81
-
82
- # Create async chat session (DO NOT await here)
83
- chat = client.aio.chats.create(model=model_id, config=config)
84
-
85
- # Send the user input to the model (await is correct here)
86
- response = chat.send_message(input_text)
87
-
88
- # --- Manual Tool Loop ---
89
- max_turns = 5
90
- for _ in range(max_turns):
91
- # Guards
92
- if not getattr(response, "candidates", None):
93
- break
94
-
95
- candidate = response.candidates[0]
96
- parts = getattr(candidate.content, "parts", []) or []
97
- if not parts:
98
- break
99
-
100
- # Collect function calls from parts (correct attribute is function_call)
101
- tool_calls = []
102
- for p in parts:
103
- call = getattr(p, "function_call", None)
104
- if call and hasattr(call, "name") and hasattr(call, "args"):
105
- tool_calls.append(call)
106
-
107
- # If no tools were requested, the model is done
108
- if not tool_calls:
109
- break
110
-
111
- tool_responses = []
112
-
113
- for call in tool_calls:
114
- if call.name == "get_train_connections":
115
- start_loc = call.args.get("start_loc", "").strip()
116
- dest_loc = call.args.get("dest_loc", "").strip()
117
- yield f"🚄 Fetching train data: **{start_loc}** → **{dest_loc}**...", ""
118
-
119
- # Call the MCP tool
120
- train_data = await call_mcp_tool(start_loc, dest_loc)
121
-
122
- # Return the tool function result to the model
123
- tool_responses.append(
124
- types.Part.from_function_response(
125
- name=call.name,
126
- response={"result": train_data},
127
- )
128
- )
129
- else:
130
- # Unknown tool – inform the model
131
- tool_responses.append(
132
- types.Part.from_function_response(
133
- name=call.name,
134
- response={
135
- "error": f"Unknown tool '{call.name}'. Only 'get_train_connections' is supported."
136
- },
137
- )
138
- )
139
-
140
- if tool_responses:
141
- yield "📝 Finalizing response...", ""
142
- # Provide tool results back to the model so it can compose the final answer
143
- response = chat.send_message(tool_responses)
144
- else:
145
- break
146
-
147
- # Prefer response.text; fallback to concatenating text parts
148
- final_text = getattr(response, "text", "") or ""
149
- if not final_text and getattr(response, "candidates", None):
150
- parts = getattr(response.candidates[0].content, "parts", [])
151
- texts = []
152
- for p in parts:
153
- if hasattr(p, "text") and p.text:
154
- texts.append(p.text)
155
- final_text = "\n".join(texts)
156
-
157
- if not final_text:
158
- final_text = "I couldn't generate a response. Please try rephrasing your question."
159
-
160
- yield final_text, ""
161
-
162
- except Exception as e:
163
- # Show the actual logic/exception error back in the UI
164
- yield f"### Logic Error\n{str(e)}", ""
165
-
166
-
167
- # --- GRADIO UI ---
168
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
169
- gr.Markdown("# 🚄 Gemini MCP Train Assistant")
170
-
171
- output_textbox = gr.Markdown(label="Response")
172
- input_textbox = gr.Textbox(label="Ask about German train connections", placeholder="Berlin to Munich?")
173
- submit_button = gr.Button("Send", variant="primary")
174
-
175
- submit_button.click(
176
- fn=generate,
177
- inputs=input_textbox,
178
- outputs=[output_textbox, input_textbox]
179
- )
180
-
181
- if __name__ == "__main__":
182
- demo.launch()
183
-
184
- """
185
  import base64
186
  import gradio as gr
187
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import base64
2
  import gradio as gr
3
  import os