mgokg commited on
Commit
83c49dc
·
verified ·
1 Parent(s): 2517ef7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -67
app.py CHANGED
@@ -1,39 +1,12 @@
1
- import os
2
- import json
3
- import asyncio
4
- import gradio as gr
5
- from google import genai
6
- from google.genai import types
7
- from mcp import ClientSession
8
- from mcp.client.sse import sse_client
9
-
10
- # --- CONFIGURATION ---
11
- MCP_SERVER_URL = "https://mgokg-db-api-mcp.hf.space/sse"
12
- GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
13
-
14
- async def call_mcp_tool(start_loc, dest_loc):
15
- """Connects to the MCP server to fetch train data."""
16
- async with sse_client(MCP_SERVER_URL) as (read_stream, write_stream):
17
- async with ClientSession(read_stream, write_stream) as session:
18
- await session.initialize()
19
- result = await session.call_tool("get_train_connections", arguments={
20
- "start_loc": start_loc,
21
- "dest_loc": dest_loc
22
- })
23
- return result.content[0].text
24
-
25
- def generate(input_text):
26
  if not GEMINI_API_KEY:
27
- return "Error: GEMINI_API_KEY is not set.", ""
 
28
 
29
  client = genai.Client(api_key=GEMINI_API_KEY)
30
-
31
- # CRITICAL: Use gemini-2.0-flash-exp.
32
- # Do NOT use gemini-2.0-flash-thinking-exp as it doesn't support tools.
33
- model_id = "gemini-3-flash-preview" # Korrekte Modell-ID verwenden
34
-
35
 
36
- # Define the train tool schema
37
  train_tool = types.Tool(
38
  function_declarations=[
39
  types.FunctionDeclaration(
@@ -52,24 +25,20 @@ def generate(input_text):
52
  )
53
 
54
  try:
55
- # 1. Setup config WITHOUT thinking_config
56
  config = types.GenerateContentConfig(
57
  tools=[train_tool, types.Tool(google_search=types.GoogleSearch())],
58
  temperature=0.3
59
  )
60
 
61
  chat = client.chats.create(model=model_id, config=config)
62
- response = chat.send_message(input_text)
 
63
 
64
- # 2. Manual Tool Loop
65
- # We must manually process tool calls because we are calling an external MCP server
66
  max_turns = 5
67
  for _ in range(max_turns):
68
- # Exit if there are no parts or no tool calls
69
  if not response.candidates[0].content.parts:
70
  break
71
 
72
- # Find all tool calls in the message parts
73
  tool_calls = [p.tool_call for p in response.candidates[0].content.parts if p.tool_call]
74
  if not tool_calls:
75
  break
@@ -77,9 +46,8 @@ def generate(input_text):
77
  tool_responses = []
78
  for call in tool_calls:
79
  if call.name == "get_train_connections":
80
- # Execute the MCP call
81
- # We use asyncio.run because Gradio's click handler is synchronous
82
- train_data = asyncio.run(call_mcp_tool(call.args["start_loc"], call.args["dest_loc"]))
83
 
84
  tool_responses.append(
85
  types.Part.from_function_response(
@@ -88,36 +56,12 @@ def generate(input_text):
88
  )
89
  )
90
 
91
- # If we have results, send them back to the model
92
  if tool_responses:
93
- response = chat.send_message(tool_responses)
94
  else:
95
  break
96
 
97
  return response.text, ""
98
 
99
  except Exception as e:
100
- return f"### Logic Error\n{str(e)}", ""
101
-
102
- # --- GRADIO UI ---
103
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
104
- gr.Markdown("# 🚄 Gemini 2.0 Flash + Live DB Trains")
105
- gr.Markdown("I can search the web and check live train connections via an MCP Server.")
106
-
107
- with gr.Column():
108
- output_textbox = gr.Markdown(label="Response")
109
- input_textbox = gr.Textbox(
110
- lines=3,
111
- label="Ask me anything",
112
- placeholder="e.g. 'Is there a train from Berlin to Munich tonight?'"
113
- )
114
- submit_button = gr.Button("Send", variant="primary")
115
-
116
- submit_button.click(
117
- fn=generate,
118
- inputs=input_textbox,
119
- outputs=[output_textbox, input_textbox]
120
- )
121
-
122
- if __name__ == '__main__':
123
- demo.launch()
 
1
+ # Change the function to be async
2
+ async def generate(input_text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  if not GEMINI_API_KEY:
4
+ yield "Error: GEMINI_API_KEY is not set."
5
+ return
6
 
7
  client = genai.Client(api_key=GEMINI_API_KEY)
8
+ model_id = "gemini-3-flash-preview"
 
 
 
 
9
 
 
10
  train_tool = types.Tool(
11
  function_declarations=[
12
  types.FunctionDeclaration(
 
25
  )
26
 
27
  try:
 
28
  config = types.GenerateContentConfig(
29
  tools=[train_tool, types.Tool(google_search=types.GoogleSearch())],
30
  temperature=0.3
31
  )
32
 
33
  chat = client.chats.create(model=model_id, config=config)
34
+ # Use await for the async call
35
+ response = await chat.send_message_async(input_text)
36
 
 
 
37
  max_turns = 5
38
  for _ in range(max_turns):
 
39
  if not response.candidates[0].content.parts:
40
  break
41
 
 
42
  tool_calls = [p.tool_call for p in response.candidates[0].content.parts if p.tool_call]
43
  if not tool_calls:
44
  break
 
46
  tool_responses = []
47
  for call in tool_calls:
48
  if call.name == "get_train_connections":
49
+ # Now we can simply await the tool call directly
50
+ train_data = await call_mcp_tool(call.args["start_loc"], call.args["dest_loc"])
 
51
 
52
  tool_responses.append(
53
  types.Part.from_function_response(
 
56
  )
57
  )
58
 
 
59
  if tool_responses:
60
+ response = await chat.send_message_async(tool_responses)
61
  else:
62
  break
63
 
64
  return response.text, ""
65
 
66
  except Exception as e:
67
+ return f"### Logic Error\n{str(e)}", ""