mgokg commited on
Commit
b735e9b
·
verified ·
1 Parent(s): d7c1e56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -155
app.py CHANGED
@@ -1,166 +1,66 @@
1
- import base64
2
- import gradio as gr
3
- import os
4
- import json
5
- from google import genai
6
- from google.genai import types
7
- from gradio_client import Client
8
-
9
-
10
- route="""
11
- how to handle special case "zugverbindung".
12
- Wichtig: Dies Regeln gelten nur wenn eine zugverbindung angefragt wird, else answer prompt
13
- Regeln:
14
- Wenn eine Zugverbindung von {Startort} nach {Zielort} angefragt wird, return json object with Startort and Zielort.
15
- always follow json scheme below.
16
-
17
- Wichtig: Gib absolut keinen Text vor oder nach dem JSON aus (keine Erklärungen, kein "Hier ist das Ergebnis").
18
-
19
- {
20
- "start_loc": "fill in Startort here",
21
- "dest_loc": "fill in Zielort here"
22
- }
23
-
24
- """
25
-
26
- def clean_json_string(json_str):
27
- """
28
- Removes any comments or prefixes before the actual JSON content.
29
- """
30
- # Find the first occurrence of '{'
31
- json_start = json_str.find('{')
32
- if json_start == -1:
33
- # If no '{' is found, try with '[' for arrays
34
- json_start = json_str.find('[')
35
- if json_start == -1:
36
- return json_str # Return original if no JSON markers found
37
-
38
- # Extract everything from the first JSON marker
39
- cleaned_str = json_str[json_start:]
40
- return cleaned_str
41
- # Verify it's valid JSON
42
- try:
43
- json.loads(cleaned_str)
44
- return cleaned_str
45
- except json.JSONDecodeError:
46
- return json_str # Return original if cleaning results in invalid JSON
47
-
48
- def generate(input_text):
49
- try:
50
- client = genai.Client(
51
- api_key=os.environ.get("GEMINI_API_KEY"),
52
- )
53
- except Exception as e:
54
- return f"Error initializing client: {e}. Make sure GEMINI_API_KEY is set."
55
-
56
- model = "gemini-flash-latest"
57
- contents = [
58
- types.Content(
59
- role="user",
60
- parts=[
61
- types.Part.from_text(text=f"{input_text}"),
62
- ],
63
- ),
64
- ]
65
- tools = [
66
- types.Tool(google_search=types.GoogleSearch()),
67
- ]
68
- generate_content_config = types.GenerateContentConfig(
69
- temperature=0.4,
70
- thinking_config = types.ThinkingConfig(
71
- thinking_budget=0,
72
- ),
73
- tools=tools,
74
- response_mime_type="text/plain",
75
- )
76
-
77
-
78
- response_text = ""
79
- try:
80
- for chunk in client.models.generate_content_stream(
81
- model=model,
82
- contents=contents,
83
- config=generate_content_config,
84
- ):
85
- response_text += chunk.text
86
- except Exception as e:
87
- return f"Error during generation: {e}"
88
- data = clean_json_string(response_text)
89
- data = data[:-1]
90
- return response_text, ""
91
-
92
-
93
- if __name__ == '__main__':
94
-
95
- with gr.Blocks() as demo:
96
- title=gr.Markdown("# Gemini 2.0 Flash + Websearch")
97
- output_textbox = gr.Markdown()
98
- input_textbox = gr.Textbox(lines=3, label="", placeholder="Enter message here...")
99
- submit_button = gr.Button("send")
100
- submit_button.click(fn=generate,inputs=input_textbox,outputs=[output_textbox, input_textbox])
101
- demo.launch(show_error=True)
102
-
103
-
104
-
105
-
106
-
107
- """
108
  import os
109
  import asyncio
110
- import gradio as gr
111
  from google import genai
112
  from google.genai import types
113
-
114
  from mcp import ClientSession
115
- from mcp.client.sse import sse_client # Spezifischer Transport für Gradio/HF
116
-
117
- async def generate(input_text):
118
- # WICHTIG: Gradio MCP Server benötigen oft das Suffix /sse
119
- mcp_url = "https://mgokg-db-timetable-api.hf.space/gradio_api/mcp/"
120
-
121
- try:
122
- client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY"))
123
-
124
- # SSE Transport nutzen, um den 'text/html' Fehler zu vermeiden
125
- async with sse_client(url=mcp_url) as (read_stream, write_stream):
126
- async with ClientSession(read_stream, write_stream) as mcp_session:
127
- await mcp_session.initialize()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
- model_id = "gemini-2.0-flash"
 
 
 
 
 
130
 
131
- generate_content_config = types.GenerateContentConfig(
132
- temperature=0.4,
133
- tools=[
134
- types.Tool(google_search=types.GoogleSearch()),
135
- mcp_session # Reicht die Tools des DB-Servers an Gemini durch
136
- ],
137
  )
138
-
139
- response_text = ""
140
- async for chunk in client.aio.models.generate_content_stream(
141
- model=model_id,
142
- contents=input_text,
143
- config=generate_content_config,
144
- ):
145
- if chunk.text:
146
- response_text += chunk.text
147
 
148
- return response_text, ""
 
149
 
150
- except Exception as e:
151
- return f"Verbindung zum DB-Fahrplan fehlgeschlagen: {str(e)}", ""
152
-
153
- def gradio_wrapper(input_text):
154
- return asyncio.run(generate(input_text))
 
155
 
156
- if __name__ == '__main__':
157
- with gr.Blocks() as demo:
158
- gr.Markdown("# Gemini Flash + DB Timetable")
159
- input_tx = gr.Textbox(label="Anfrage", placeholder="Wann fährt der nächste Zug von Berlin nach Hamburg?")
160
- btn = gr.Button("Senden")
161
- output_md = gr.Markdown()
162
-
163
- btn.click(fn=gradio_wrapper, inputs=input_tx, outputs=[output_md, input_tx])
164
-
165
- demo.launch()
166
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import asyncio
 
3
  from google import genai
4
  from google.genai import types
 
5
  from mcp import ClientSession
6
+ from mcp.client.sse import sse_client
7
+
8
+ # Konfiguration des MCP-Servers (Beispiel DB-Timetable auf Hugging Face)
9
+ MCP_URL = "https://mgokg-db-timetable-api.hf.space/gradio_api/mcp/"
10
+
11
+ async def fetch_train_connections(prompt: str):
12
+ client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY"))
13
+ model_id = "gemini-2.0-flash"
14
+
15
+ # 1. Verbindung zum MCP-Server über SSE herstellen [3]
16
+ async with sse_client(url=MCP_URL) as (read_stream, write_stream):
17
+ async with ClientSession(read_stream, write_stream) as session:
18
+ # Initialisierung der MCP-Sitzung [4], [5]
19
+ await session.initialize()
20
+
21
+ # 2. MCP-Tools abrufen und für Gemini konvertieren [4], [2]
22
+ mcp_tools = await session.list_tools()
23
+ tools = types.Tool(function_declarations=[
24
+ {
25
+ "name": tool.name,
26
+ "description": tool.description,
27
+ "parameters": tool.inputSchema,
28
+ }
29
+ for tool in mcp_tools.tools
30
+ ])
31
+
32
+ # 3. Anfrage an das Modell senden [6]
33
+ contents = [types.Content(role="user", parts=[types.Part(text=prompt)])]
34
+ response = await client.aio.models.generate_content(
35
+ model=model_id,
36
+ contents=contents,
37
+ config=types.GenerateContentConfig(
38
+ tools=[tools],
39
+ temperature=0.4
40
+ )
41
+ )
42
 
43
+ # 4. Tool-Calling Loop: Falls das Modell eine Zugverbindung sucht [7], [8]
44
+ if response.candidates.content.parts.function_call:
45
+ fc = response.candidates.content.parts.function_call
46
+
47
+ # Tool auf dem MCP-Server ausführen
48
+ tool_result = await session.call_tool(fc.name, fc.args)
49
 
50
+ # Ergebnis an das Modell zurückgeben für die finale Antwort [8], [9]
51
+ tool_response_part = types.Part.from_function_response(
52
+ name=fc.name,
53
+ response={"result": tool_result.content.text}
 
 
54
  )
 
 
 
 
 
 
 
 
 
55
 
56
+ contents.append(response.candidates.content)
57
+ contents.append(types.Content(role="user", parts=[tool_response_part]))
58
 
59
+ final_response = await client.aio.models.generate_content(
60
+ model=model_id,
61
+ contents=contents,
62
+ config=types.GenerateContentConfig(tools=[tools])
63
+ )
64
+ return final_response.text
65
 
66
+ return response.text