mgokg commited on
Commit
64369b9
·
verified ·
1 Parent(s): a272972

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -26
app.py CHANGED
@@ -5,17 +5,73 @@ import json
5
  from google import genai
6
  from google.genai import types
7
  from gradio_client import Client
 
 
8
 
9
 
10
- def generate(input_text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  try:
12
  client = genai.Client(
13
  api_key=os.environ.get("GEMINI_API_KEY"),
14
  )
15
  except Exception as e:
16
- return f"Error initializing client: {e}. Make sure GEMINI_API_KEY is set."
17
 
18
  model = "gemini-flash-latest"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  contents = [
20
  types.Content(
21
  role="user",
@@ -24,42 +80,112 @@ def generate(input_text):
24
  ],
25
  ),
26
  ]
27
- tools = [
28
- types.Tool(google_search=types.GoogleSearch()),
29
- ]
 
 
30
  generate_content_config = types.GenerateContentConfig(
31
  temperature=0.4,
32
- thinking_config = types.ThinkingConfig(
33
  thinking_budget=0,
34
  ),
35
  tools=tools,
36
  response_mime_type="text/plain",
37
  )
38
 
39
-
40
  response_text = ""
41
  try:
42
- for chunk in client.models.generate_content_stream(
43
- model=model,
44
- contents=contents,
45
- config=generate_content_config,
46
- ):
47
- response_text += chunk.text
 
48
  except Exception as e:
49
  return f"Error during generation: {e}"
50
- data = response_text
51
- #data = clean_json_string(response_text)
52
- data = data[:-1]
53
- return response_text, ""
54
-
55
 
56
- if __name__ == '__main__':
57
 
58
- with gr.Blocks() as demo:
59
- title=gr.Markdown("# Gemini 2.0 Flash + Websearch")
60
- output_textbox = gr.Markdown()
61
- input_textbox = gr.Textbox(lines=3, label="", placeholder="Enter message here...")
62
- submit_button = gr.Button("send")
63
- submit_button.click(fn=generate,inputs=input_textbox,outputs=[output_textbox, input_textbox])
64
- demo.launch(show_error=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from google import genai
6
  from google.genai import types
7
  from gradio_client import Client
8
+ import requests
9
+ from typing import Optional
10
 
11
 
12
+ def get_train_connections(departure: str, destination: str) -> str:
13
+ """
14
+ Get train connections using the db_timetable_api MCP tool.
15
+
16
+ Args:
17
+ departure: Departure station (e.g., "Schweinfurt HBF")
18
+ destination: Destination station (e.g., "Oerlenbach")
19
+
20
+ Returns:
21
+ Formatted train connection information
22
+ """
23
+ try:
24
+ # Use the db_timetable_api tool directly
25
+ result = db_timetable_api_ui_wrapper(dep=departure, dest=destination)
26
+ return result
27
+ except Exception as e:
28
+ return f"Error getting train connections: {str(e)}"
29
+
30
+
31
+ def generate_with_tools(input_text: str, use_websearch: bool = True, use_train_api: bool = False) -> str:
32
+ """
33
+ Generate response using Gemini with optional web search and train API integration.
34
+
35
+ Args:
36
+ input_text: User input text
37
+ use_websearch: Whether to use web search functionality
38
+ use_train_api: Whether to use train API for connections
39
+
40
+ Returns:
41
+ Generated response text
42
+ """
43
  try:
44
  client = genai.Client(
45
  api_key=os.environ.get("GEMINI_API_KEY"),
46
  )
47
  except Exception as e:
48
+ return f"Error initializing client: {e}. Make sure GEMINI_API_KEY is set."
49
 
50
  model = "gemini-flash-latest"
51
+
52
+ # Check if user is asking for train connections
53
+ train_keywords = ["zug", "bahn", "train", "connection", "fahrplan", "timetable", "abfahrt", "ankunft"]
54
+ is_train_query = any(keyword in input_text.lower() for keyword in train_keywords)
55
+
56
+ # If it's a train query, try to extract stations and use train API
57
+ if is_train_query and use_train_api:
58
+ # Simple extraction - in production you might want more sophisticated NLP
59
+ words = input_text.split()
60
+ potential_stations = []
61
+
62
+ # Common German station patterns
63
+ for i, word in enumerate(words):
64
+ if word.upper() == "HBF" or "bahnhof" in word.lower():
65
+ # Look for station name before HBF
66
+ if i > 0:
67
+ station = words[i-1] + " " + word
68
+ potential_stations.append(station)
69
+
70
+ if len(potential_stations) >= 2:
71
+ train_result = get_train_connections(potential_stations[0], potential_stations[1])
72
+ return f"**Train Connections Found:**\n\n{train_result}\n\n---\n\n*Powered by Deutsche Bahn API*"
73
+
74
+ # Original websearch functionality
75
  contents = [
76
  types.Content(
77
  role="user",
 
80
  ],
81
  ),
82
  ]
83
+
84
+ tools = []
85
+ if use_websearch:
86
+ tools.append(types.Tool(google_search=types.GoogleSearch()))
87
+
88
  generate_content_config = types.GenerateContentConfig(
89
  temperature=0.4,
90
+ thinking_config=types.ThinkingConfig(
91
  thinking_budget=0,
92
  ),
93
  tools=tools,
94
  response_mime_type="text/plain",
95
  )
96
 
 
97
  response_text = ""
98
  try:
99
+ for chunk in client.models.generate_content_stream(
100
+ model=model,
101
+ contents=contents,
102
+ config=generate_content_config,
103
+ ):
104
+ if hasattr(chunk, 'text'):
105
+ response_text += chunk.text
106
  except Exception as e:
107
  return f"Error during generation: {e}"
108
+
109
+ return response_text
 
 
 
110
 
 
111
 
112
+ def process_input(input_text: str, mode: str = "Auto-detect") -> str:
113
+ """
114
+ Process user input based on selected mode.
115
+
116
+ Args:
117
+ input_text: User input text
118
+ mode: Processing mode ("Auto-detect", "Web Search", "Train Connections", "Both")
119
+
120
+ Returns:
121
+ Processed response
122
+ """
123
+ if not input_text.strip():
124
+ return "Please enter a message."
125
+
126
+ if mode == "Web Search":
127
+ return generate_with_tools(input_text, use_websearch=True, use_train_api=False)
128
+ elif mode == "Train Connections":
129
+ return generate_with_tools(input_text, use_websearch=False, use_train_api=True)
130
+ elif mode == "Both":
131
+ return generate_with_tools(input_text, use_websearch=True, use_train_api=True)
132
+ else: # Auto-detect
133
+ return generate_with_tools(input_text, use_websearch=True, use_train_api=True)
134
+
135
 
136
+ if __name__ == '__main__':
137
+ with gr.Blocks(title="Gemini 2.0 Flash + Websearch + Train Connections") as demo:
138
+ gr.Markdown("# 🤖 Gemini 2.0 Flash + Websearch + Train Connections")
139
+ gr.Markdown("Ask me anything! I can search the web and find train connections for you.")
140
+
141
+ with gr.Row():
142
+ with gr.Column(scale=3):
143
+ input_textbox = gr.Textbox(
144
+ lines=3,
145
+ label="Your Message",
146
+ placeholder="Enter your message here...\nExamples:\n- 'What's the weather in Berlin?'\n- 'Train from Schweinfurt HBF to Oerlenbach'\n- 'Latest news about AI'"
147
+ )
148
+ with gr.Column(scale=1):
149
+ mode_dropdown = gr.Dropdown(
150
+ choices=["Auto-detect", "Web Search", "Train Connections", "Both"],
151
+ value="Auto-detect",
152
+ label="Mode",
153
+ info="Choose how to process your query"
154
+ )
155
+
156
+ submit_button = gr.Button("Send", variant="primary")
157
+
158
+ output_textbox = gr.Markdown(
159
+ label="Response",
160
+ show_copy_button=True
161
+ )
162
+
163
+ # Examples
164
+ gr.Markdown("### 🎯 Quick Examples")
165
+ with gr.Row():
166
+ example1 = gr.Button("Weather in Munich")
167
+ example2 = gr.Button("Train Schweinfurt HBF to Oerlenbach")
168
+ example3 = gr.Button("Latest AI news")
169
+
170
+ def set_example(example_text):
171
+ return example_text
172
+
173
+ example1.click(fn=lambda: set_example("What's the weather in Munich today?"), outputs=input_textbox)
174
+ example2.click(fn=lambda: set_example("Train connections from Schweinfurt HBF to Oerlenbach"), outputs=input_textbox)
175
+ example3.click(fn=lambda: set_example("Latest news about artificial intelligence developments"), outputs=input_textbox)
176
+
177
+ # Main functionality
178
+ submit_button.click(
179
+ fn=process_input,
180
+ inputs=[input_textbox, mode_dropdown],
181
+ outputs=output_textbox
182
+ )
183
+
184
+ # Allow Enter key submission
185
+ input_textbox.submit(
186
+ fn=process_input,
187
+ inputs=[input_textbox, mode_dropdown],
188
+ outputs=output_textbox
189
+ )
190
+
191
+ demo.launch(show_error=True)