mgokg commited on
Commit
58a714a
·
verified ·
1 Parent(s): 17adbac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -72
app.py CHANGED
@@ -7,106 +7,96 @@ from google.genai import types
7
  from mcp import ClientSession
8
  from mcp.client.sse import sse_client
9
 
10
- # --- CONFIGURATION ---
11
  MCP_SERVER_URL = "https://mgokg-db-api-mcp.hf.space/sse"
12
  GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
13
 
14
- # 1. Function to connect to your MCP Server and run the tool
15
  async def call_mcp_tool(start_loc, dest_loc):
 
16
  async with sse_client(MCP_SERVER_URL) as (read_stream, write_stream):
17
  async with ClientSession(read_stream, write_stream) as session:
18
  await session.initialize()
19
- # Calling the tool defined in your MCP server
20
  result = await session.call_tool("get_train_connections", arguments={
21
  "start_loc": start_loc,
22
  "dest_loc": dest_loc
23
  })
24
- # result.content[0].text contains the JSON train data
25
  return result.content[0].text
26
 
27
- # 2. Define the tool for Gemini
28
- # This describes the tool so Gemini knows when to use it
29
- train_tool = types.Tool(
30
- function_declarations=[
31
- types.FunctionDeclaration(
32
- name="get_train_connections",
33
- description="Finds live train connections between two German stations/cities.",
34
- parameters=types.Schema(
35
- type="OBJECT",
36
- properties={
37
- "start_loc": types.Schema(type="STRING", description="Departure city/station"),
38
- "dest_loc": types.Schema(type="STRING", description="Destination city/station")
39
- },
40
- required=["start_loc", "dest_loc"]
41
- )
42
- )
43
- ]
44
- )
45
-
46
  def generate(input_text):
47
  if not GEMINI_API_KEY:
48
  return "Error: GEMINI_API_KEY is not set.", ""
49
 
50
  client = genai.Client(api_key=GEMINI_API_KEY)
51
- model_id = "gemini-flash-latest" # Recommended for tool use
 
 
 
52
 
53
- # Create a chat session to handle the tool-use loop
54
- chat = client.chats.create(
55
- model=model_id,
56
- config=types.GenerateContentConfig(
57
- tools=[train_tool, types.Tool(google_search=types.GoogleSearch())],
58
- temperature=0.3
59
- )
 
 
 
 
 
 
 
 
 
60
  )
61
 
62
  try:
63
- # Step 1: Send user input to Gemini
 
 
 
 
 
 
 
 
64
  response = chat.send_message(input_text)
65
-
66
- # Step 2: Tool Loop
67
- # If Gemini emits a tool call, we execute it and send results back
68
- while response.candidates[0].content.parts[0].tool_call:
69
- tool_call = response.candidates[0].content.parts[0].tool_call
70
-
71
- if tool_call.name == "get_train_connections":
72
- args = tool_call.args
73
- # Execute the actual MCP request
74
- train_data = asyncio.run(call_mcp_tool(args["start_loc"], args["dest_loc"]))
75
 
76
- # Feed the train data back to Gemini
77
- response = chat.send_message(
78
- types.Part.from_function_response(
79
- name="get_train_connections",
80
- response={"result": train_data}
 
 
 
 
 
 
 
 
 
 
81
  )
82
- )
 
 
 
 
 
83
  else:
84
- break # Exit if it's a tool we don't recognize here
85
 
86
  return response.text, ""
87
 
88
  except Exception as e:
89
- return f"Error: {str(e)}", ""
90
-
91
- # --- GRADIO UI ---
92
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
93
- gr.Markdown("# 🚄 Gemini 2.0 Flash + Live DB Trains")
94
- gr.Markdown("I can search the web and check live train connections via an MCP Server.")
95
-
96
- with gr.Column():
97
- output_textbox = gr.Markdown(label="Response")
98
- input_textbox = gr.Textbox(
99
- lines=3,
100
- label="Ask me anything",
101
- placeholder="e.g. 'Is there a train from Berlin to Munich tonight?'"
102
- )
103
- submit_button = gr.Button("Send", variant="primary")
104
-
105
- submit_button.click(
106
- fn=generate,
107
- inputs=input_textbox,
108
- outputs=[output_textbox, input_textbox]
109
- )
110
 
111
- if __name__ == '__main__':
112
- demo.launch()
 
7
  from mcp import ClientSession
8
  from mcp.client.sse import sse_client
9
 
 
10
  MCP_SERVER_URL = "https://mgokg-db-api-mcp.hf.space/sse"
11
  GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
12
 
 
13
  async def call_mcp_tool(start_loc, dest_loc):
14
+ """Executes the train search via your MCP server."""
15
  async with sse_client(MCP_SERVER_URL) as (read_stream, write_stream):
16
  async with ClientSession(read_stream, write_stream) as session:
17
  await session.initialize()
 
18
  result = await session.call_tool("get_train_connections", arguments={
19
  "start_loc": start_loc,
20
  "dest_loc": dest_loc
21
  })
 
22
  return result.content[0].text
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def generate(input_text):
25
  if not GEMINI_API_KEY:
26
  return "Error: GEMINI_API_KEY is not set.", ""
27
 
28
  client = genai.Client(api_key=GEMINI_API_KEY)
29
+
30
+ # Use 2.0 Flash (Tools work here).
31
+ # Do NOT use 'gemini-2.0-flash-thinking-exp' as it doesn't support tools yet.
32
+ model_id = "gemini-2.0-flash-exp"
33
 
34
+ # Define the tool explicitly
35
+ train_tool = types.Tool(
36
+ function_declarations=[
37
+ types.FunctionDeclaration(
38
+ name="get_train_connections",
39
+ description="Finds live train connections between two German stations/cities.",
40
+ parameters={
41
+ "type": "OBJECT",
42
+ "properties": {
43
+ "start_loc": {"type": "STRING", "description": "Departure city"},
44
+ "dest_loc": {"type": "STRING", "description": "Destination city"}
45
+ },
46
+ "required": ["start_loc", "dest_loc"]
47
+ }
48
+ )
49
+ ]
50
  )
51
 
52
  try:
53
+ # 1. Create the config WITHOUT thinking_config
54
+ config = types.GenerateContentConfig(
55
+ tools=[train_tool, types.Tool(google_search=types.GoogleSearch())],
56
+ temperature=0.3,
57
+ # thinking_config REMOVED - this is the key fix
58
+ )
59
+
60
+ # 2. Start the chat session
61
+ chat = client.chats.create(model=model_id, config=config)
62
  response = chat.send_message(input_text)
63
+
64
+ # 3. Handle the Tool-Use Loop manually
65
+ # This handles both your MCP tool and Google Search
66
+ max_iterations = 5
67
+ for _ in range(max_iterations):
68
+ # Check if there is a tool call in the first part of the message
69
+ if not response.candidates[0].content.parts:
70
+ break
 
 
71
 
72
+ tool_calls = [p.tool_call for p in response.candidates[0].content.parts if p.tool_call]
73
+
74
+ if not tool_calls:
75
+ break
76
+
77
+ tool_responses = []
78
+ for call in tool_calls:
79
+ if call.name == "get_train_connections":
80
+ # Run the MCP call
81
+ result_data = asyncio.run(call_mcp_tool(call.args["start_loc"], call.args["dest_loc"]))
82
+ tool_responses.append(
83
+ types.Part.from_function_response(
84
+ name=call.name,
85
+ response={"result": result_data}
86
+ )
87
  )
88
+ # Google Search is handled automatically by the model if configured,
89
+ # but if it returns a call, we let the loop handle the logic.
90
+
91
+ # Send all tool results back to the model
92
+ if tool_responses:
93
+ response = chat.send_message(tool_responses)
94
  else:
95
+ break
96
 
97
  return response.text, ""
98
 
99
  except Exception as e:
100
+ return f"### Error encountered\n{str(e)}", ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
+ # (Keep your existing Gradio Blocks code below)