mgokg commited on
Commit
7dd6dde
·
verified ·
1 Parent(s): 04067cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -94
app.py CHANGED
@@ -1,108 +1,84 @@
 
 
1
  import os
2
  import json
3
- import asyncio
4
- import gradio as gr
5
  from google import genai
6
  from google.genai import types
7
- from mcp import ClientSession
8
- from mcp.client.sse import sse_client
9
-
10
- # --- CONFIGURATION ---
11
- MCP_SERVER_URL = "https://mgokg/DB_API_MCP.hf.space/sse"
12
- GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
13
 
14
- # 1. Helper to talk to your MCP Server
15
- async def call_mcp_train_tool(start_loc, dest_loc):
16
- async with sse_client(MCP_SERVER_URL) as (read_stream, write_stream):
17
- async with ClientSession(read_stream, write_stream) as session:
18
- await session.initialize()
19
- result = await session.call_tool("get_train_connections", arguments={
20
- "start_loc": start_loc,
21
- "dest_loc": dest_loc
22
- })
23
- # result.content[0].text contains the JSON list from your MCP server
24
- return result.content[0].text
25
-
26
- # 2. Tool Definition for Gemini
27
- # This tells Gemini what the tool does and what arguments it needs
28
- train_tool_declaration = types.Tool(
29
- function_declarations=[
30
- types.FunctionDeclaration(
31
- name="get_train_connections",
32
- description="Get live train connections between two German cities/stations.",
33
- parameters=types.Schema(
34
- type="OBJECT",
35
- properties={
36
- "start_loc": types.Schema(type="STRING", description="The departure station"),
37
- "dest_loc": types.Schema(type="STRING", description="The destination station")
38
- },
39
- required=["start_loc", "dest_loc"]
40
- )
41
- )
42
- ]
43
- )
44
 
45
  def generate(input_text):
46
- if not GEMINI_API_KEY:
47
- return "Error: GEMINI_API_KEY is not set.", ""
48
-
49
- client = genai.Client(api_key=GEMINI_API_KEY)
50
- model_id = "gemini-flash-latest" # Or gemini-flash-latest
51
-
52
- # We use a chat session to handle the multi-turn tool loop
53
- chat = client.chats.create(
54
- model=model_id,
55
- config=types.GenerateContentConfig(
56
- tools=[train_tool_declaration, types.Tool(google_search=types.GoogleSearch())],
57
- temperature=0.4
58
  )
59
- )
 
60
 
61
- try:
62
- # Step 1: Send initial request to Gemini
63
- response = chat.send_message(input_text)
64
-
65
- # Step 2: Check if Gemini wants to use a tool
66
- # We loop in case Gemini needs multiple tool calls
67
- while response.candidates[0].content.parts[0].tool_call:
68
- tool_call = response.candidates[0].content.parts[0].tool_call
69
-
70
- if tool_call.name == "get_train_connections":
71
- # Extract arguments Gemini provided
72
- args = tool_call.args
73
-
74
- # Execute the MCP call (running async code in sync Gradio)
75
- train_data = asyncio.run(call_mcp_train_tool(args["start_loc"], args["dest_loc"]))
76
-
77
- # Send the tool result back to Gemini
78
- response = chat.send_message(
79
- types.Part.from_function_response(
80
- name="get_train_connections",
81
- response={"result": train_data}
82
- )
83
- )
84
- else:
85
- break # Handle other tools or exit
86
 
87
- return response.text, ""
88
 
 
 
 
 
 
 
 
 
89
  except Exception as e:
90
- return f"Error during generation: {str(e)}", ""
91
-
92
- # --- GRADIO UI ---
93
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
94
- gr.Markdown("# 🚄 Gemini 2.0 + DB Train MCP")
95
- gr.Markdown("Ask about train connections (e.g., 'When is the next train from Munich to Berlin?')")
96
-
97
- output_textbox = gr.Markdown(label="Response")
98
- input_textbox = gr.Textbox(lines=2, label="Your Question")
99
- submit_button = gr.Button("Send", variant="primary")
100
-
101
- submit_button.click(
102
- fn=generate,
103
- inputs=input_textbox,
104
- outputs=[output_textbox, input_textbox]
105
- )
106
 
107
  if __name__ == '__main__':
108
- demo.launch()
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import gradio as gr
3
  import os
4
  import json
 
 
5
  from google import genai
6
  from google.genai import types
7
+ from gradio_client import Client
 
 
 
 
 
8
 
9
+ def clean_json_string(json_str):
10
+ """
11
+ Removes any comments or prefixes before the actual JSON content.
12
+ """
13
+ # Find the first occurrence of '{'
14
+ json_start = json_str.find('{')
15
+ if json_start == -1:
16
+ # If no '{' is found, try with '[' for arrays
17
+ json_start = json_str.find('[')
18
+ if json_start == -1:
19
+ return json_str # Return original if no JSON markers found
20
+
21
+ # Extract everything from the first JSON marker
22
+ cleaned_str = json_str[json_start:]
23
+ return cleaned_str
24
+ # Verify it's valid JSON
25
+ try:
26
+ json.loads(cleaned_str)
27
+ return cleaned_str
28
+ except json.JSONDecodeError:
29
+ return json_str # Return original if cleaning results in invalid JSON
 
 
 
 
 
 
 
 
 
30
 
31
  def generate(input_text):
32
+ try:
33
+ client = genai.Client(
34
+ api_key=os.environ.get("GEMINI_API_KEY"),
 
 
 
 
 
 
 
 
 
35
  )
36
+ except Exception as e:
37
+ return f"Error initializing client: {e}. Make sure GEMINI_API_KEY is set."
38
 
39
+ model = "gemini-flash-latest"
40
+ contents = [
41
+ types.Content(
42
+ role="user",
43
+ parts=[
44
+ types.Part.from_text(text=input_text),
45
+ ],
46
+ ),
47
+ ]
48
+ tools = [
49
+ types.Tool(google_search=types.GoogleSearch()),
50
+ ]
51
+ generate_content_config = types.GenerateContentConfig(
52
+ temperature=0.4,
53
+ thinking_config = types.ThinkingConfig(
54
+ thinking_budget=0,
55
+ ),
56
+ tools=tools,
57
+ response_mime_type="text/plain",
58
+ )
 
 
 
 
 
59
 
 
60
 
61
+ response_text = ""
62
+ try:
63
+ for chunk in client.models.generate_content_stream(
64
+ model=model,
65
+ contents=contents,
66
+ config=generate_content_config,
67
+ ):
68
+ response_text += chunk.text
69
  except Exception as e:
70
+ return f"Error during generation: {e}"
71
+ data = clean_json_string(response_text)
72
+ data = data[:-1]
73
+ return response_text, ""
74
+
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  if __name__ == '__main__':
77
+
78
+ with gr.Blocks() as demo:
79
+ title=gr.Markdown("# Gemini 2.0 Flash + Websearch")
80
+ output_textbox = gr.Markdown()
81
+ input_textbox = gr.Textbox(lines=3, label="", placeholder="Enter message here...")
82
+ submit_button = gr.Button("send")
83
+ submit_button.click(fn=generate,inputs=input_textbox,outputs=[output_textbox, input_textbox])
84
+ demo.launch(show_error=True)