mgokg commited on
Commit
6d108cd
·
verified ·
1 Parent(s): 8682f4d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -74
app.py CHANGED
@@ -1,109 +1,167 @@
1
-
2
-
3
-
4
-
5
-
6
- import base64
7
- import gradio as gr
8
  import os
9
- import json
10
  from google import genai
11
  from google.genai import types
12
  from gradio_client import Client
13
 
 
 
 
14
 
15
- route="""
16
- how to handle special case "zugverbindung".
17
- Wichtig: Dies Regeln gelten nur wenn eine zugverbindung angefragt wird, else answer prompt
18
- Regeln:
19
- Wenn eine Zugverbindung von {Startort} nach {Zielort} angefragt wird, return json object with Startort and Zielort.
20
- always follow json scheme below.
21
-
22
- Wichtig: Gib absolut keinen Text vor oder nach dem JSON aus (keine Erklärungen, kein "Hier ist das Ergebnis").
23
-
24
- {
25
- "start_loc": "fill in Startort here",
26
- "dest_loc": "fill in Zielort here"
27
- }
28
-
29
- """
30
-
31
-
32
-
33
- def clean_json_string(json_str):
34
  """
35
- Removes any comments or prefixes before the actual JSON content.
36
  """
37
- # Find the first occurrence of '{'
38
- json_start = json_str.find('{')
39
- if json_start == -1:
40
- # If no '{' is found, try with '[' for arrays
41
- json_start = json_str.find('[')
42
- if json_start == -1:
43
- return json_str # Return original if no JSON markers found
44
-
45
- # Extract everything from the first JSON marker
46
- cleaned_str = json_str[json_start:]
47
- return cleaned_str
48
- # Verify it's valid JSON
49
  try:
50
- json.loads(cleaned_str)
51
- return cleaned_str
52
- except json.JSONDecodeError:
53
- return json_str # Return original if cleaning results in invalid JSON
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- def generate(input_text):
 
56
  try:
57
  client = genai.Client(
58
  api_key=os.environ.get("GEMINI_API_KEY"),
59
  )
60
  except Exception as e:
61
- return f"Error initializing client: {e}. Make sure GEMINI_API_KEY is set."
 
 
 
62
 
63
- model = "gemini-flash-latest"
 
64
  contents = [
65
  types.Content(
66
  role="user",
67
- parts=[
68
- types.Part.from_text(text=f"{input_text}"),
69
- ],
70
  ),
71
  ]
 
 
72
  tools = [
73
  types.Tool(google_search=types.GoogleSearch()),
 
74
  ]
 
75
  generate_content_config = types.GenerateContentConfig(
76
  temperature=0.4,
77
- thinking_config = types.ThinkingConfig(
78
- thinking_budget=0,
79
- ),
80
  tools=tools,
81
- response_mime_type="text/plain",
 
 
82
  )
83
 
84
-
85
  response_text = ""
 
 
86
  try:
87
- for chunk in client.models.generate_content_stream(
88
- model=model,
89
- contents=contents,
90
- config=generate_content_config,
91
- ):
92
- response_text += chunk.text
93
  except Exception as e:
94
- return f"Error during generation: {e}"
95
- data = clean_json_string(response_text)
96
- data = data[:-1]
97
- return response_text, ""
98
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  if __name__ == '__main__':
101
-
102
  with gr.Blocks() as demo:
103
- title=gr.Markdown("# Gemini 2.0 Flash + Websearch")
104
- output_textbox = gr.Markdown()
105
- input_textbox = gr.Textbox(lines=3, label="", placeholder="Enter message here...")
106
- submit_button = gr.Button("send")
107
- submit_button.click(fn=generate,inputs=input_textbox,outputs=[output_textbox, input_textbox])
108
- demo.launch(show_error=True)
109
- """"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import gradio as gr
3
  from google import genai
4
  from google.genai import types
5
  from gradio_client import Client
6
 
7
+ # 1. Initialize the client for the external DB Timetable App
8
+ # We use the Hugging Face Space ID provided in your documentation
9
+ db_client = Client("mgokg/db-timetable-api")
10
 
11
+ def get_train_connection(dep: str, dest: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  """
13
+ Fetches the train timetable between two cities using the external API.
14
  """
 
 
 
 
 
 
 
 
 
 
 
 
15
  try:
16
+ # Calling the specific endpoint mentioned in the MCP docs: db_timetable_api_ui_wrapper
17
+ result = db_client.predict(
18
+ dep=dep,
19
+ dest=dest,
20
+ api_name="/db_timetable_api_ui_wrapper"
21
+ )
22
+ return result
23
+ except Exception as e:
24
+ return f"Error fetching timetable: {str(e)}"
25
+
26
+ # 2. Define the tool for Gemini
27
+ # This tells the model how to use the Python function above
28
+ train_tool = types.FunctionDeclaration(
29
+ name="get_train_connection",
30
+ description="Find train connections and timetables between a start location (dep) and a destination (dest).",
31
+ parameters=types.Schema(
32
+ type=types.Type.OBJECT,
33
+ properties={
34
+ "dep": types.Schema(type=types.Type.STRING, description="Departure city or station"),
35
+ "dest": types.Schema(type=types.Type.STRING, description="Destination city or station"),
36
+ },
37
+ required=["dep", "dest"]
38
+ )
39
+ )
40
+
41
+ # Map the string name to the actual python function
42
+ tools_map = {
43
+ "get_train_connection": get_train_connection
44
+ }
45
 
46
+ def generate(input_text, history):
47
+ # Initialize Gemini Client
48
  try:
49
  client = genai.Client(
50
  api_key=os.environ.get("GEMINI_API_KEY"),
51
  )
52
  except Exception as e:
53
+ yield f"Error initializing client: {e}. Make sure GEMINI_API_KEY is set.", history
54
+ return
55
+
56
+ model = "gemini-2.0-flash-exp" # Or "gemini-2.0-flash" depending on availability
57
 
58
+ # Prepare the conversation history for context
59
+ # (Optional: You can add previous history here if you want multi-turn chat)
60
  contents = [
61
  types.Content(
62
  role="user",
63
+ parts=[types.Part.from_text(text=input_text)],
 
 
64
  ),
65
  ]
66
+
67
+ # 3. Configure tools (Google Search + Our Custom DB Tool)
68
  tools = [
69
  types.Tool(google_search=types.GoogleSearch()),
70
+ types.Tool(function_declarations=[train_tool]),
71
  ]
72
+
73
  generate_content_config = types.GenerateContentConfig(
74
  temperature=0.4,
 
 
 
75
  tools=tools,
76
+ # Automatic function calling allows the SDK to handle the loop,
77
+ # but for granular control in Gradio, we often handle it manually below
78
+ # or rely on the model to return a function call part.
79
  )
80
 
 
81
  response_text = ""
82
+
83
+ # First API Call: Ask the model what to do
84
  try:
85
+ response = client.models.generate_content(
86
+ model=model,
87
+ contents=contents,
88
+ config=generate_content_config,
89
+ )
 
90
  except Exception as e:
91
+ yield f"Error during generation: {e}", history
92
+ return
93
+
94
+ # 4. Check if the model wants to call a function
95
+ # We look at the first candidate's first part
96
+ if response.candidates and response.candidates[0].content.parts:
97
+ first_part = response.candidates[0].content.parts[0]
98
+
99
+ # If it's a function call
100
+ if first_part.function_call:
101
+ fn_name = first_part.function_call.name
102
+ fn_args = first_part.function_call.args
103
+
104
+ # Execute the tool
105
+ if fn_name in tools_map:
106
+ status_msg = f"🔄 Checking trains from {fn_args.get('dep')} to {fn_args.get('dest')}..."
107
+ yield status_msg, history
108
+
109
+ api_result = tools_map[fn_name](**fn_args)
110
+
111
+ # Send the result back to Gemini
112
+ # We append the model's function call and our function response to history
113
+ contents.append(response.candidates[0].content)
114
+ contents.append(
115
+ types.Content(
116
+ role="tool",
117
+ parts=[
118
+ types.Part.from_function_response(
119
+ name=fn_name,
120
+ response={"result": api_result}
121
+ )
122
+ ]
123
+ )
124
+ )
125
+
126
+ # Second API Call: Get the final natural language answer
127
+ stream = client.models.generate_content_stream(
128
+ model=model,
129
+ contents=contents,
130
+ config=generate_content_config # Keep tools enabled just in case
131
+ )
132
+
133
+ final_text = ""
134
+ for chunk in stream:
135
+ if chunk.text:
136
+ final_text += chunk.text
137
+ yield final_text, history
138
+ return
139
+
140
+ # If no function call, just return the text (e.g., normal chat or Google Search result)
141
+ if response.text:
142
+ yield response.text, history
143
 
144
  if __name__ == '__main__':
 
145
  with gr.Blocks() as demo:
146
+ gr.Markdown("# Gemini 2.0 Flash + DB Timetable Tool")
147
+
148
+ chatbot = gr.Chatbot(label="Conversation", height=400)
149
+ msg = gr.Textbox(lines=1, label="Ask about trains (e.g., 'Train from Berlin to Munich')", placeholder="Enter message here...")
150
+ clear = gr.Button("Clear")
151
+
152
+ def user(user_message, history):
153
+ return "", history + [[user_message, None]]
154
+
155
+ def bot(history):
156
+ user_message = history[-1][0]
157
+ # Call generate and update the last message in history
158
+ for partial_response, _ in generate(user_message, history):
159
+ history[-1][1] = partial_response
160
+ yield history
161
+
162
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
163
+ bot, chatbot, chatbot
164
+ )
165
+ clear.click(lambda: None, None, chatbot, queue=False)
166
+
167
+ demo.launch(show_error=True)