mgokg commited on
Commit
42bd84e
·
verified ·
1 Parent(s): df06d07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -147
app.py CHANGED
@@ -1,177 +1,65 @@
 
1
  import gradio as gr
2
  import os
 
3
  from google import genai
4
  from google.genai import types
5
  from gradio_client import Client
6
 
7
- # --- 1. Robust Client Initialization for the External Tool ---
8
- db_client = None
9
-
10
- def init_db_client():
11
- global db_client
12
- try:
13
- # Use the full URL to avoid DNS resolution issues
14
- print("Connecting to DB Timetable API...")
15
- db_client = Client("https://mgokg-db-timetable-api.hf.space/")
16
- print("Successfully connected.")
17
- except Exception as e:
18
- print(f"Warning: Could not connect to DB Timetable API: {e}")
19
-
20
- # Attempt connection on startup
21
- init_db_client()
22
-
23
- # --- 2. Tool Definition ---
24
-
25
- def get_train_connection(dep: str, dest: str):
26
- """
27
- Fetches the train timetable between two cities using the external API.
28
- """
29
- global db_client
30
- # Retry connection if it failed initially
31
- if db_client is None:
32
- init_db_client()
33
- if db_client is None:
34
- return "Error: The train database is currently unreachable."
35
-
36
- try:
37
- # Calling the specific endpoint mentioned in the MCP docs
38
- result = db_client.predict(
39
- dep=dep,
40
- dest=dest,
41
- api_name="/db_timetable_api_ui_wrapper"
42
- )
43
- return result
44
- except Exception as e:
45
- return f"Error fetching timetable: {str(e)}"
46
-
47
- # Define the tool schema for Gemini
48
- train_tool = types.FunctionDeclaration(
49
- name="get_train_connection",
50
- description="Find train connections and timetables between a start location (dep) and a destination (dest).",
51
- parameters=types.Schema(
52
- type=types.Type.OBJECT,
53
- properties={
54
- "dep": types.Schema(type=types.Type.STRING, description="Departure city or station"),
55
- "dest": types.Schema(type=types.Type.STRING, description="Destination city or station"),
56
- },
57
- required=["dep", "dest"]
58
- )
59
- )
60
-
61
- # Map string name to the actual function
62
- tools_map = {
63
- "get_train_connection": get_train_connection
64
- }
65
-
66
- # --- 3. Generation Logic (v1alpha) ---
67
 
68
  def generate(input_text):
69
- if not input_text:
70
- yield "", ""
71
- return
72
-
73
  try:
74
- # Initialize Client with v1alpha
75
  client = genai.Client(
76
  api_key=os.environ.get("GEMINI_API_KEY"),
77
- #http_options={'api_version': 'v1alpha'} # <--- ENABLE v1alpha HERE
78
  )
79
  except Exception as e:
80
- yield f"Error initializing client: {e}. Make sure GEMINI_API_KEY is set.", input_text
81
- return
82
 
83
- #model = "gemini-flash-latest"
84
  model = "gemini-flash-latest"
85
-
86
- # Configure tools
 
 
 
 
 
 
87
  tools = [
88
- types.Tool(
89
- google_search=types.GoogleSearch(),
90
- function_declarations=[train_tool]
91
- )
92
  ]
93
-
94
  generate_content_config = types.GenerateContentConfig(
95
  temperature=0.4,
 
 
 
96
  tools=tools,
 
97
  )
98
 
99
- contents = [
100
- types.Content(
101
- role="user",
102
- parts=[types.Part.from_text(text=input_text)],
103
- ),
104
- ]
105
 
106
  response_text = ""
107
-
108
  try:
109
- # First API Call
110
- response = client.models.generate_content(
111
- model=model,
112
- contents=contents,
113
- config=generate_content_config,
114
- )
115
-
116
- # Handle Function Calling
117
- if response.candidates and response.candidates[0].content.parts:
118
- part = response.candidates[0].content.parts[0]
119
-
120
- if part.function_call:
121
- fn_name = part.function_call.name
122
- fn_args = part.function_call.args
123
-
124
- if fn_name in tools_map:
125
- # Execute tool
126
- api_result = tools_map[fn_name](**fn_args)
127
-
128
- # Append history
129
- contents.append(response.candidates[0].content)
130
- contents.append(
131
- types.Content(
132
- role="tool",
133
- parts=[
134
- types.Part.from_function_response(
135
- name=fn_name,
136
- response={"result": api_result}
137
- )
138
- ]
139
- )
140
- )
141
-
142
- # Second API Call (Streamed)
143
- stream = client.models.generate_content_stream(
144
- model=model,
145
- contents=contents,
146
- config=generate_content_config
147
- )
148
-
149
- for chunk in stream:
150
- response_text += chunk.text
151
- yield response_text, ""
152
- return
153
-
154
- # Handle standard response
155
- if response.text:
156
- yield response.text, ""
157
-
158
  except Exception as e:
159
- yield f"Error during generation: {e}", input_text
160
-
161
- # --- 4. UI Setup ---
 
 
 
162
 
163
  if __name__ == '__main__':
 
164
  with gr.Blocks() as demo:
165
- gr.Markdown("# Gemini 2.0 Flash + DB train connections")
166
-
167
  output_textbox = gr.Markdown()
168
- input_textbox = gr.Textbox(lines=3, label="", placeholder="Ask for a train connection (e.g., 'Train from Berlin to Frankfurt')...")
169
- submit_button = gr.Button("Send")
170
-
171
- submit_button.click(
172
- fn=generate,
173
- inputs=input_textbox,
174
- outputs=[output_textbox, input_textbox]
175
- )
176
-
177
- demo.launch(show_error=True)
 
1
+ import base64
2
  import gradio as gr
3
  import os
4
+ import json
5
  from google import genai
6
  from google.genai import types
7
  from gradio_client import Client
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  def generate(input_text):
 
 
 
 
11
  try:
 
12
  client = genai.Client(
13
  api_key=os.environ.get("GEMINI_API_KEY"),
 
14
  )
15
  except Exception as e:
16
+ return f"Error initializing client: {e}. Make sure GEMINI_API_KEY is set."
 
17
 
 
18
  model = "gemini-flash-latest"
19
+ contents = [
20
+ types.Content(
21
+ role="user",
22
+ parts=[
23
+ types.Part.from_text(text=f"{input_text}"),
24
+ ],
25
+ ),
26
+ ]
27
  tools = [
28
+ types.Tool(google_search=types.GoogleSearch()),
 
 
 
29
  ]
 
30
  generate_content_config = types.GenerateContentConfig(
31
  temperature=0.4,
32
+ thinking_config = types.ThinkingConfig(
33
+ thinking_budget=0,
34
+ ),
35
  tools=tools,
36
+ response_mime_type="text/plain",
37
  )
38
 
 
 
 
 
 
 
39
 
40
  response_text = ""
 
41
  try:
42
+ for chunk in client.models.generate_content_stream(
43
+ model=model,
44
+ contents=contents,
45
+ config=generate_content_config,
46
+ ):
47
+ response_text += chunk.text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  except Exception as e:
49
+ return f"Error during generation: {e}"
50
+ data = response_text
51
+ #data = clean_json_string(response_text)
52
+ data = data[:-1]
53
+ return response_text, ""
54
+
55
 
56
  if __name__ == '__main__':
57
+
58
  with gr.Blocks() as demo:
59
+ title=gr.Markdown("# Gemini 2.0 Flash + Websearch")
 
60
  output_textbox = gr.Markdown()
61
+ input_textbox = gr.Textbox(lines=3, label="", placeholder="Enter message here...")
62
+ submit_button = gr.Button("send")
63
+ submit_button.click(fn=generate,inputs=input_textbox,outputs=[output_textbox, input_textbox])
64
+ demo.launch(show_error=True)
65
+