mgokg commited on
Commit
186a462
·
verified ·
1 Parent(s): 7dd6dde

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -70
app.py CHANGED
@@ -1,84 +1,112 @@
1
- import base64
2
- import gradio as gr
3
  import os
4
  import json
 
 
5
  from google import genai
6
  from google.genai import types
7
- from gradio_client import Client
 
8
 
9
- def clean_json_string(json_str):
10
- """
11
- Removes any comments or prefixes before the actual JSON content.
12
- """
13
- # Find the first occurrence of '{'
14
- json_start = json_str.find('{')
15
- if json_start == -1:
16
- # If no '{' is found, try with '[' for arrays
17
- json_start = json_str.find('[')
18
- if json_start == -1:
19
- return json_str # Return original if no JSON markers found
20
-
21
- # Extract everything from the first JSON marker
22
- cleaned_str = json_str[json_start:]
23
- return cleaned_str
24
- # Verify it's valid JSON
25
- try:
26
- json.loads(cleaned_str)
27
- return cleaned_str
28
- except json.JSONDecodeError:
29
- return json_str # Return original if cleaning results in invalid JSON
30
 
31
- def generate(input_text):
32
- try:
33
- client = genai.Client(
34
- api_key=os.environ.get("GEMINI_API_KEY"),
35
- )
36
- except Exception as e:
37
- return f"Error initializing client: {e}. Make sure GEMINI_API_KEY is set."
 
 
 
 
 
38
 
39
- model = "gemini-flash-latest"
40
- contents = [
41
- types.Content(
42
- role="user",
43
- parts=[
44
- types.Part.from_text(text=input_text),
45
- ],
46
- ),
47
- ]
48
- tools = [
49
- types.Tool(google_search=types.GoogleSearch()),
 
 
 
 
 
50
  ]
51
- generate_content_config = types.GenerateContentConfig(
52
- temperature=0.4,
53
- thinking_config = types.ThinkingConfig(
54
- thinking_budget=0,
55
- ),
56
- tools=tools,
57
- response_mime_type="text/plain",
58
- )
59
 
 
 
 
 
 
 
 
 
60
 
61
- response_text = ""
62
  try:
63
- for chunk in client.models.generate_content_stream(
64
- model=model,
65
- contents=contents,
66
- config=generate_content_config,
67
- ):
68
- response_text += chunk.text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  except Exception as e:
70
- return f"Error during generation: {e}"
71
- data = clean_json_string(response_text)
72
- data = data[:-1]
73
- return response_text, ""
74
-
75
 
76
- if __name__ == '__main__':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- with gr.Blocks() as demo:
79
- title=gr.Markdown("# Gemini 2.0 Flash + Websearch")
80
- output_textbox = gr.Markdown()
81
- input_textbox = gr.Textbox(lines=3, label="", placeholder="Enter message here...")
82
- submit_button = gr.Button("send")
83
- submit_button.click(fn=generate,inputs=input_textbox,outputs=[output_textbox, input_textbox])
84
- demo.launch(show_error=True)
 
 
 
1
  import os
2
  import json
3
+ import asyncio
4
+ import gradio as gr
5
  from google import genai
6
  from google.genai import types
7
+ from mcp import ClientSession
8
+ from mcp.client.sse import sse_client
9
 
10
+ # --- CONFIGURATION ---
11
+ MCP_SERVER_URL = "https://mgokg-db-api-mcp.hf.space/sse"
12
+ GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # 1. Function to connect to your MCP Server and run the tool
15
+ async def call_mcp_tool(start_loc, dest_loc):
16
+ async with sse_client(MCP_SERVER_URL) as (read_stream, write_stream):
17
+ async with ClientSession(read_stream, write_stream) as session:
18
+ await session.initialize()
19
+ # Calling the tool defined in your MCP server
20
+ result = await session.call_tool("get_train_connections", arguments={
21
+ "start_loc": start_loc,
22
+ "dest_loc": dest_loc
23
+ })
24
+ # result.content[0].text contains the JSON train data
25
+ return result.content[0].text
26
 
27
+ # 2. Define the tool for Gemini
28
+ # This describes the tool so Gemini knows when to use it
29
+ train_tool = types.Tool(
30
+ function_declarations=[
31
+ types.FunctionDeclaration(
32
+ name="get_train_connections",
33
+ description="Finds live train connections between two German stations/cities.",
34
+ parameters=types.Schema(
35
+ type="OBJECT",
36
+ properties={
37
+ "start_loc": types.Schema(type="STRING", description="Departure city/station"),
38
+ "dest_loc": types.Schema(type="STRING", description="Destination city/station")
39
+ },
40
+ required=["start_loc", "dest_loc"]
41
+ )
42
+ )
43
  ]
44
+ )
45
+
46
+ def generate(input_text):
47
+ if not GEMINI_API_KEY:
48
+ return "Error: GEMINI_API_KEY is not set.", ""
49
+
50
+ client = genai.Client(api_key=GEMINI_API_KEY)
51
+ model_id = "gemini-2.0-flash-exp" # Recommended for tool use
52
 
53
+ # Create a chat session to handle the tool-use loop
54
+ chat = client.chats.create(
55
+ model=model_id,
56
+ config=types.GenerateContentConfig(
57
+ tools=[train_tool, types.Tool(google_search=types.GoogleSearch())],
58
+ temperature=0.3
59
+ )
60
+ )
61
 
 
62
  try:
63
+ # Step 1: Send user input to Gemini
64
+ response = chat.send_message(input_text)
65
+
66
+ # Step 2: Tool Loop
67
+ # If Gemini emits a tool call, we execute it and send results back
68
+ while response.candidates[0].content.parts[0].tool_call:
69
+ tool_call = response.candidates[0].content.parts[0].tool_call
70
+
71
+ if tool_call.name == "get_train_connections":
72
+ args = tool_call.args
73
+ # Execute the actual MCP request
74
+ train_data = asyncio.run(call_mcp_tool(args["start_loc"], args["dest_loc"]))
75
+
76
+ # Feed the train data back to Gemini
77
+ response = chat.send_message(
78
+ types.Part.from_function_response(
79
+ name="get_train_connections",
80
+ response={"result": train_data}
81
+ )
82
+ )
83
+ else:
84
+ break # Exit if it's a tool we don't recognize here
85
+
86
+ return response.text, ""
87
+
88
  except Exception as e:
89
+ return f"Error: {str(e)}", ""
 
 
 
 
90
 
91
+ # --- GRADIO UI ---
92
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
93
+ gr.Markdown("# 🚄 Gemini 2.0 Flash + Live DB Trains")
94
+ gr.Markdown("I can search the web and check live train connections via an MCP Server.")
95
+
96
+ with gr.Column():
97
+ output_textbox = gr.Markdown(label="Response")
98
+ input_textbox = gr.Textbox(
99
+ lines=3,
100
+ label="Ask me anything",
101
+ placeholder="e.g. 'Is there a train from Berlin to Munich tonight?'"
102
+ )
103
+ submit_button = gr.Button("Send", variant="primary")
104
+
105
+ submit_button.click(
106
+ fn=generate,
107
+ inputs=input_textbox,
108
+ outputs=[output_textbox, input_textbox]
109
+ )
110
 
111
+ if __name__ == '__main__':
112
+ demo.launch()