mgokg commited on
Commit
839116a
·
verified ·
1 Parent(s): 58a714a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -25
app.py CHANGED
@@ -7,11 +7,12 @@ from google.genai import types
7
  from mcp import ClientSession
8
  from mcp.client.sse import sse_client
9
 
 
10
  MCP_SERVER_URL = "https://mgokg-db-api-mcp.hf.space/sse"
11
  GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
12
 
13
  async def call_mcp_tool(start_loc, dest_loc):
14
- """Executes the train search via your MCP server."""
15
  async with sse_client(MCP_SERVER_URL) as (read_stream, write_stream):
16
  async with ClientSession(read_stream, write_stream) as session:
17
  await session.initialize()
@@ -27,16 +28,16 @@ def generate(input_text):
27
 
28
  client = genai.Client(api_key=GEMINI_API_KEY)
29
 
30
- # Use 2.0 Flash (Tools work here).
31
- # Do NOT use 'gemini-2.0-flash-thinking-exp' as it doesn't support tools yet.
32
  model_id = "gemini-2.0-flash-exp"
33
 
34
- # Define the tool explicitly
35
  train_tool = types.Tool(
36
  function_declarations=[
37
  types.FunctionDeclaration(
38
  name="get_train_connections",
39
- description="Finds live train connections between two German stations/cities.",
40
  parameters={
41
  "type": "OBJECT",
42
  "properties": {
@@ -50,45 +51,43 @@ def generate(input_text):
50
  )
51
 
52
  try:
53
- # 1. Create the config WITHOUT thinking_config
54
  config = types.GenerateContentConfig(
55
  tools=[train_tool, types.Tool(google_search=types.GoogleSearch())],
56
- temperature=0.3,
57
- # thinking_config REMOVED - this is the key fix
58
  )
59
 
60
- # 2. Start the chat session
61
  chat = client.chats.create(model=model_id, config=config)
62
  response = chat.send_message(input_text)
63
 
64
- # 3. Handle the Tool-Use Loop manually
65
- # This handles both your MCP tool and Google Search
66
- max_iterations = 5
67
- for _ in range(max_iterations):
68
- # Check if there is a tool call in the first part of the message
69
  if not response.candidates[0].content.parts:
70
  break
71
-
72
- tool_calls = [p.tool_call for p in response.candidates[0].content.parts if p.tool_call]
73
 
 
 
74
  if not tool_calls:
75
  break
76
 
77
  tool_responses = []
78
  for call in tool_calls:
79
  if call.name == "get_train_connections":
80
- # Run the MCP call
81
- result_data = asyncio.run(call_mcp_tool(call.args["start_loc"], call.args["dest_loc"]))
 
 
82
  tool_responses.append(
83
  types.Part.from_function_response(
84
  name=call.name,
85
- response={"result": result_data}
86
  )
87
  )
88
- # Google Search is handled automatically by the model if configured,
89
- # but if it returns a call, we let the loop handle the logic.
90
-
91
- # Send all tool results back to the model
92
  if tool_responses:
93
  response = chat.send_message(tool_responses)
94
  else:
@@ -97,6 +96,6 @@ def generate(input_text):
97
  return response.text, ""
98
 
99
  except Exception as e:
100
- return f"### Error encountered\n{str(e)}", ""
101
 
102
- # (Keep your existing Gradio Blocks code below)
 
7
  from mcp import ClientSession
8
  from mcp.client.sse import sse_client
9
 
10
+ # --- CONFIGURATION ---
11
  MCP_SERVER_URL = "https://mgokg-db-api-mcp.hf.space/sse"
12
  GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
13
 
14
  async def call_mcp_tool(start_loc, dest_loc):
15
+ """Connects to the MCP server to fetch train data."""
16
  async with sse_client(MCP_SERVER_URL) as (read_stream, write_stream):
17
  async with ClientSession(read_stream, write_stream) as session:
18
  await session.initialize()
 
28
 
29
  client = genai.Client(api_key=GEMINI_API_KEY)
30
 
31
+ # CRITICAL: Use gemini-2.0-flash-exp.
32
+ # Do NOT use gemini-2.0-flash-thinking-exp as it doesn't support tools.
33
  model_id = "gemini-2.0-flash-exp"
34
 
35
+ # Define the train tool schema
36
  train_tool = types.Tool(
37
  function_declarations=[
38
  types.FunctionDeclaration(
39
  name="get_train_connections",
40
+ description="Finds live train connections between two German stations.",
41
  parameters={
42
  "type": "OBJECT",
43
  "properties": {
 
51
  )
52
 
53
  try:
54
+ # 1. Setup config WITHOUT thinking_config
55
  config = types.GenerateContentConfig(
56
  tools=[train_tool, types.Tool(google_search=types.GoogleSearch())],
57
+ temperature=0.3
 
58
  )
59
 
 
60
  chat = client.chats.create(model=model_id, config=config)
61
  response = chat.send_message(input_text)
62
 
63
+ # 2. Manual Tool Loop
64
+ # We must manually process tool calls because we are calling an external MCP server
65
+ max_turns = 5
66
+ for _ in range(max_turns):
67
+ # Exit if there are no parts or no tool calls
68
  if not response.candidates[0].content.parts:
69
  break
 
 
70
 
71
+ # Find all tool calls in the message parts
72
+ tool_calls = [p.tool_call for p in response.candidates[0].content.parts if p.tool_call]
73
  if not tool_calls:
74
  break
75
 
76
  tool_responses = []
77
  for call in tool_calls:
78
  if call.name == "get_train_connections":
79
+ # Execute the MCP call
80
+ # We use asyncio.run because Gradio's click handler is synchronous
81
+ train_data = asyncio.run(call_mcp_tool(call.args["start_loc"], call.args["dest_loc"]))
82
+
83
  tool_responses.append(
84
  types.Part.from_function_response(
85
  name=call.name,
86
+ response={"result": train_data}
87
  )
88
  )
89
+
90
+ # If we have results, send them back to the model
 
 
91
  if tool_responses:
92
  response = chat.send_message(tool_responses)
93
  else:
 
96
  return response.text, ""
97
 
98
  except Exception as e:
99
+ return f"### Logic Error\n{str(e)}", ""
100
 
101
+ # (The rest of your Gradio code remains the same)