mgokg commited on
Commit
3286ab8
·
verified ·
1 Parent(s): a5072c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -141
app.py CHANGED
@@ -1,177 +1,104 @@
 
1
  import gradio as gr
2
  import os
 
3
  from google import genai
4
  from google.genai import types
5
  from gradio_client import Client
6
 
7
- # --- 1. Robust Client Initialization ---
8
- # We initialize this globally but handle errors so the app doesn't crash on startup.
9
- db_client = None
10
 
11
- def init_db_client():
12
- global db_client
13
- try:
14
- # Use the DIRECT URL to avoid DNS resolution issues with the Hub API
15
- print("Connecting to DB Timetable API...")
16
- db_client = Client("https://mgokg-db-timetable-api.hf.space/")
17
- print("Successfully connected.")
18
- except Exception as e:
19
- print(f"Warning: Could not connect to DB Timetable API: {e}")
20
 
21
- # Attempt connection on startup
22
- init_db_client()
23
 
24
- # --- 2. Tool Definition ---
 
 
 
25
 
26
- def get_train_connection(dep: str, dest: str):
27
- """
28
- Fetches the train timetable between two cities using the external API.
29
- """
30
- global db_client
31
- # If client failed to load initially, try one more time
32
- if db_client is None:
33
- init_db_client()
34
- if db_client is None:
35
- return "Error: The train database is currently unreachable. Please check your network connection."
36
 
37
- try:
38
- # Calling the specific endpoint mentioned in the MCP docs
39
- result = db_client.predict(
40
- dep=dep,
41
- dest=dest,
42
- api_name="/db_timetable_api_ui_wrapper"
43
- )
44
- return result
45
- except Exception as e:
46
- return f"Error fetching timetable: {str(e)}"
47
 
48
- # Define the tool schema for Gemini
49
- train_tool = types.FunctionDeclaration(
50
- name="get_train_connection",
51
- description="Find train connections and timetables between a start location (dep) and a destination (dest).",
52
- parameters=types.Schema(
53
- type=types.Type.OBJECT,
54
- properties={
55
- "dep": types.Schema(type=types.Type.STRING, description="Departure city or station"),
56
- "dest": types.Schema(type=types.Type.STRING, description="Destination city or station"),
57
- },
58
- required=["dep", "dest"]
59
- )
60
- )
61
 
62
- # Map string name to the actual function
63
- tools_map = {
64
- "get_train_connection": get_train_connection
65
- }
66
-
67
- # --- 3. Generation Logic ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  def generate(input_text):
70
- if not input_text:
71
- yield "", ""
72
- return
73
-
74
  try:
75
- # Initialize Client with v1alpha
76
  client = genai.Client(
77
  api_key=os.environ.get("GEMINI_API_KEY"),
78
- http_options={'api_version': 'v1alpha'} # <--- ENABLE v1alpha HERE
79
  )
80
  except Exception as e:
81
- yield f"Error initializing client: {e}. Make sure GEMINI_API_KEY is set.", input_text
82
- return
83
-
84
- model = "gemini-2.0-flash-latest" # Ensure you use a model version that supports tools
85
 
86
- # Configure tools (Google Search + Our Custom DB Tool)
 
 
 
 
 
 
 
 
87
  tools = [
88
- types.Tool(
89
- google_search=types.GoogleSearch(),
90
- function_declarations=[train_tool]
91
- )
92
  ]
93
-
94
  generate_content_config = types.GenerateContentConfig(
95
  temperature=0.4,
 
 
 
96
  tools=tools,
 
97
  )
98
 
99
- contents = [
100
- types.Content(
101
- role="user",
102
- parts=[types.Part.from_text(text=input_text)],
103
- ),
104
- ]
105
 
106
  response_text = ""
107
-
108
  try:
109
- # First API Call: Ask Gemini what to do
110
- response = client.models.generate_content(
111
- model=model,
112
- contents=contents,
113
- config=generate_content_config,
114
- )
115
-
116
- # Check if Gemini wants to call a function
117
- if response.candidates and response.candidates[0].content.parts:
118
- part = response.candidates[0].content.parts[0]
119
-
120
- # If it's a function call
121
- if part.function_call:
122
- fn_name = part.function_call.name
123
- fn_args = part.function_call.args
124
-
125
- if fn_name in tools_map:
126
- # 1. Execute the Python function (calls the external Gradio app)
127
- api_result = tools_map[fn_name](**fn_args)
128
-
129
- # 2. Feed the result back to Gemini
130
- contents.append(response.candidates[0].content) # Add the model's call to history
131
- contents.append(
132
- types.Content(
133
- role="tool",
134
- parts=[
135
- types.Part.from_function_response(
136
- name=fn_name,
137
- response={"result": api_result}
138
- )
139
- ]
140
- )
141
- )
142
-
143
- # 3. Get the final natural language answer
144
- stream = client.models.generate_content_stream(
145
- model=model,
146
- contents=contents,
147
- config=generate_content_config
148
- )
149
-
150
- for chunk in stream:
151
- response_text += chunk.text
152
- yield response_text, ""
153
- return
154
-
155
- # If no function call, just return the text (e.g. normal chat or Google Search)
156
- if response.text:
157
- yield response.text, ""
158
-
159
  except Exception as e:
160
- yield f"Error during generation: {e}", input_text
161
-
162
- # --- 4. UI Setup ---
 
 
163
 
164
  if __name__ == '__main__':
 
165
  with gr.Blocks() as demo:
166
- title = gr.Markdown("# Gemini 2.0 Flash + DB Timetable Tool")
167
  output_textbox = gr.Markdown()
168
- input_textbox = gr.Textbox(lines=3, label="", placeholder="Ask for a train connection (e.g., 'Train from Berlin to Frankfurt')...")
169
- submit_button = gr.Button("Send")
170
-
171
- submit_button.click(
172
- fn=generate,
173
- inputs=input_textbox,
174
- outputs=[output_textbox, input_textbox]
175
- )
176
-
177
- demo.launch(show_error=True)
 
1
+ import base64
2
  import gradio as gr
3
  import os
4
+ import json
5
  from google import genai
6
  from google.genai import types
7
  from gradio_client import Client
8
 
 
 
 
9
 
10
+ route="""
11
+ how to handle special case "zugverbindung".
12
+ Wichtig: Dies Regeln gelten nur wenn eine zugverbindung angefragt wird, else answer prompt
13
+ Regeln:
14
+ Wenn eine Zugverbindung von {Startort} nach {Zielort} angefragt wird, return json object with Startort and Zielort.
15
+ always follow json scheme below.
 
 
 
16
 
17
+ Wichtig: Gib absolut keinen Text vor oder nach dem JSON aus (keine Erklärungen, kein "Hier ist das Ergebnis").
 
18
 
19
+ {
20
+ "start_loc": "fill in Startort here",
21
+ "dest_loc": "fill in Zielort here"
22
+ }
23
 
24
+ """
 
 
 
 
 
 
 
 
 
25
 
 
 
 
 
 
 
 
 
 
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ def clean_json_string(json_str):
29
+ """
30
+ Removes any comments or prefixes before the actual JSON content.
31
+ """
32
+ # Find the first occurrence of '{'
33
+ json_start = json_str.find('{')
34
+ if json_start == -1:
35
+ # If no '{' is found, try with '[' for arrays
36
+ json_start = json_str.find('[')
37
+ if json_start == -1:
38
+ return json_str # Return original if no JSON markers found
39
+
40
+ # Extract everything from the first JSON marker
41
+ cleaned_str = json_str[json_start:]
42
+ return cleaned_str
43
+ # Verify it's valid JSON
44
+ try:
45
+ json.loads(cleaned_str)
46
+ return cleaned_str
47
+ except json.JSONDecodeError:
48
+ return json_str # Return original if cleaning results in invalid JSON
49
 
50
  def generate(input_text):
 
 
 
 
51
  try:
 
52
  client = genai.Client(
53
  api_key=os.environ.get("GEMINI_API_KEY"),
 
54
  )
55
  except Exception as e:
56
+ return f"Error initializing client: {e}. Make sure GEMINI_API_KEY is set."
 
 
 
57
 
58
+ model = "gemini-flash-latest"
59
+ contents = [
60
+ types.Content(
61
+ role="user",
62
+ parts=[
63
+ types.Part.from_text(text=f"{input_text}"),
64
+ ],
65
+ ),
66
+ ]
67
  tools = [
68
+ types.Tool(google_search=types.GoogleSearch()),
 
 
 
69
  ]
 
70
  generate_content_config = types.GenerateContentConfig(
71
  temperature=0.4,
72
+ thinking_config = types.ThinkingConfig(
73
+ thinking_budget=0,
74
+ ),
75
  tools=tools,
76
+ response_mime_type="text/plain",
77
  )
78
 
 
 
 
 
 
 
79
 
80
  response_text = ""
 
81
  try:
82
+ for chunk in client.models.generate_content_stream(
83
+ model=model,
84
+ contents=contents,
85
+ config=generate_content_config,
86
+ ):
87
+ response_text += chunk.text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  except Exception as e:
89
+ return f"Error during generation: {e}"
90
+ data = clean_json_string(response_text)
91
+ data = data[:-1]
92
+ return response_text, ""
93
+
94
 
95
  if __name__ == '__main__':
96
+
97
  with gr.Blocks() as demo:
98
+ title=gr.Markdown("# Gemini 2.0 Flash + Websearch")
99
  output_textbox = gr.Markdown()
100
+ input_textbox = gr.Textbox(lines=3, label="", placeholder="Enter message here...")
101
+ submit_button = gr.Button("send")
102
+ submit_button.click(fn=generate,inputs=input_textbox,outputs=[output_textbox, input_textbox])
103
+ demo.launch(show_error=True)
104
+ """"""