mgokg commited on
Commit
8682f4d
·
verified ·
1 Parent(s): 880e5ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -127
app.py CHANGED
@@ -1,10 +1,16 @@
 
 
 
 
 
1
  import base64
2
  import gradio as gr
3
  import os
4
  import json
5
- import requests
6
  from google import genai
7
  from google.genai import types
 
 
8
 
9
  route="""
10
  how to handle special case "zugverbindung".
@@ -12,13 +18,18 @@ Wichtig: Dies Regeln gelten nur wenn eine zugverbindung angefragt wird, else ans
12
  Regeln:
13
  Wenn eine Zugverbindung von {Startort} nach {Zielort} angefragt wird, return json object with Startort and Zielort.
14
  always follow json scheme below.
 
15
  Wichtig: Gib absolut keinen Text vor oder nach dem JSON aus (keine Erklärungen, kein "Hier ist das Ergebnis").
 
16
  {
17
  "start_loc": "fill in Startort here",
18
  "dest_loc": "fill in Zielort here"
19
  }
 
20
  """
21
 
 
 
22
  def clean_json_string(json_str):
23
  """
24
  Removes any comments or prefixes before the actual JSON content.
@@ -33,125 +44,22 @@ def clean_json_string(json_str):
33
 
34
  # Extract everything from the first JSON marker
35
  cleaned_str = json_str[json_start:]
36
-
37
- # Remove trailing markdown code fence if present
38
- if cleaned_str.endswith('```'):
39
- cleaned_str = cleaned_str[:-3].rstrip()
40
-
41
  return cleaned_str
42
-
43
- def is_train_connection_query(text):
44
- """
45
- Checks if the query is about train connections.
46
- """
47
- keywords = ['zugverbindung', 'zug', 'bahn', 'von', 'nach', 'fahrt', 'reise', 'verbindung']
48
- text_lower = text.lower()
49
- return any(keyword in text_lower for keyword in keywords) and ('von' in text_lower or 'nach' in text_lower)
50
-
51
- def get_train_connections(start_loc, dest_loc):
52
- """
53
- Calls the MCP server via HTTP to get train connections.
54
- """
55
  try:
56
- mcp_url = "https://mgokg-db-timetable-api.hf.space/gradio_api/mcp/"
57
-
58
- # MCP protocol request structure
59
- payload = {
60
- "jsonrpc": "2.0",
61
- "id": 1,
62
- "method": "tools/call",
63
- "params": {
64
- "name": "db_timetable_api_ui_wrapper",
65
- "arguments": {
66
- "query": f"Verbindung von {start_loc} nach {dest_loc}"
67
- }
68
- }
69
- }
70
-
71
- headers = {
72
- "Content-Type": "application/json"
73
- }
74
-
75
- response = requests.post(mcp_url, json=payload, headers=headers, timeout=30)
76
- response.raise_for_status()
77
-
78
- result = response.json()
79
-
80
- # Extract the result from MCP response
81
- if "result" in result:
82
- content = result["result"].get("content", [])
83
- if content and len(content) > 0:
84
- return content[0].get("text", "Keine Verbindungen gefunden")
85
-
86
- return f"Unerwartete Antwortstruktur: {result}"
87
-
88
- except requests.exceptions.Timeout:
89
- return "Die Anfrage hat zu lange gedauert. Bitte versuche es erneut."
90
- except requests.exceptions.RequestException as e:
91
- return f"Fehler beim Abrufen der Zugverbindungen: {e}"
92
- except Exception as e:
93
- return f"Unerwarteter Fehler: {e}"
94
 
95
  def generate(input_text):
96
- # Check if this is a train connection query
97
- if is_train_connection_query(input_text):
98
- try:
99
- # Use Gemini to extract start and destination
100
- client = genai.Client(
101
- api_key=os.environ.get("GEMINI_API_KEY"),
102
- )
103
-
104
- model = "gemini-flash-latest"
105
- contents = [
106
- types.Content(
107
- role="user",
108
- parts=[
109
- types.Part.from_text(text=f"{route}\n\nUser query: {input_text}"),
110
- ],
111
- ),
112
- ]
113
-
114
- generate_content_config = types.GenerateContentConfig(
115
- temperature=0.1,
116
- response_mime_type="application/json",
117
- )
118
-
119
- response = client.models.generate_content(
120
- model=model,
121
- contents=contents,
122
- config=generate_content_config,
123
- )
124
-
125
- # Parse the JSON response
126
- json_str = clean_json_string(response.text)
127
- location_data = json.loads(json_str)
128
-
129
- start_loc = location_data.get("start_loc", "")
130
- dest_loc = location_data.get("dest_loc", "")
131
-
132
- if start_loc and dest_loc:
133
- # Call MCP server for train connections
134
- train_data = get_train_connections(start_loc, dest_loc)
135
-
136
- # Format the response nicely
137
- formatted_response = f"## Zugverbindung von {start_loc} nach {dest_loc}\n\n{train_data}"
138
- return formatted_response, ""
139
- else:
140
- return "Konnte Start- oder Zielort nicht identifizieren. Bitte gib beide Orte an.", ""
141
-
142
- except json.JSONDecodeError as e:
143
- return f"Error parsing location data: {e}\nResponse was: {response.text}", ""
144
- except Exception as e:
145
- return f"Error processing train connection request: {e}", ""
146
-
147
- # If not a train query, use regular Gemini with web search
148
  try:
149
  client = genai.Client(
150
  api_key=os.environ.get("GEMINI_API_KEY"),
151
  )
152
  except Exception as e:
153
- return f"Error initializing client: {e}. Make sure GEMINI_API_KEY is set.", ""
154
-
155
  model = "gemini-flash-latest"
156
  contents = [
157
  types.Content(
@@ -161,39 +69,41 @@ def generate(input_text):
161
  ],
162
  ),
163
  ]
164
-
165
  tools = [
166
  types.Tool(google_search=types.GoogleSearch()),
167
  ]
168
-
169
  generate_content_config = types.GenerateContentConfig(
170
  temperature=0.4,
171
- thinking_config=types.ThinkingConfig(
172
  thinking_budget=0,
173
  ),
174
  tools=tools,
175
  response_mime_type="text/plain",
176
  )
177
-
 
178
  response_text = ""
179
  try:
180
- for chunk in client.models.generate_content_stream(
181
- model=model,
182
- contents=contents,
183
- config=generate_content_config,
184
- ):
185
- response_text += chunk.text
186
  except Exception as e:
187
- return f"Error during generation: {e}", ""
188
-
 
189
  return response_text, ""
 
190
 
191
  if __name__ == '__main__':
 
192
  with gr.Blocks() as demo:
193
- title = gr.Markdown("# Gemini 2.0 Flash + Websearch + Deutsche Bahn Verbindungen")
194
- gr.Markdown("Frage nach Zugverbindungen (z.B. 'Zugverbindung von Berlin nach München') oder stelle allgemeine Fragen.")
195
  output_textbox = gr.Markdown()
196
  input_textbox = gr.Textbox(lines=3, label="", placeholder="Enter message here...")
197
  submit_button = gr.Button("send")
198
- submit_button.click(fn=generate, inputs=input_textbox, outputs=[output_textbox, input_textbox])
199
- demo.launch(show_error=True)
 
 
1
+
2
+
3
+
4
+
5
+
6
  import base64
7
  import gradio as gr
8
  import os
9
  import json
 
10
  from google import genai
11
  from google.genai import types
12
+ from gradio_client import Client
13
+
14
 
15
  route="""
16
  how to handle special case "zugverbindung".
 
18
  Regeln:
19
  Wenn eine Zugverbindung von {Startort} nach {Zielort} angefragt wird, return json object with Startort and Zielort.
20
  always follow json scheme below.
21
+
22
  Wichtig: Gib absolut keinen Text vor oder nach dem JSON aus (keine Erklärungen, kein "Hier ist das Ergebnis").
23
+
24
  {
25
  "start_loc": "fill in Startort here",
26
  "dest_loc": "fill in Zielort here"
27
  }
28
+
29
  """
30
 
31
+
32
+
33
  def clean_json_string(json_str):
34
  """
35
  Removes any comments or prefixes before the actual JSON content.
 
44
 
45
  # Extract everything from the first JSON marker
46
  cleaned_str = json_str[json_start:]
 
 
 
 
 
47
  return cleaned_str
48
+ # Verify it's valid JSON
 
 
 
 
 
 
 
 
 
 
 
 
49
  try:
50
+ json.loads(cleaned_str)
51
+ return cleaned_str
52
+ except json.JSONDecodeError:
53
+ return json_str # Return original if cleaning results in invalid JSON
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  def generate(input_text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  try:
57
  client = genai.Client(
58
  api_key=os.environ.get("GEMINI_API_KEY"),
59
  )
60
  except Exception as e:
61
+ return f"Error initializing client: {e}. Make sure GEMINI_API_KEY is set."
62
+
63
  model = "gemini-flash-latest"
64
  contents = [
65
  types.Content(
 
69
  ],
70
  ),
71
  ]
 
72
  tools = [
73
  types.Tool(google_search=types.GoogleSearch()),
74
  ]
 
75
  generate_content_config = types.GenerateContentConfig(
76
  temperature=0.4,
77
+ thinking_config = types.ThinkingConfig(
78
  thinking_budget=0,
79
  ),
80
  tools=tools,
81
  response_mime_type="text/plain",
82
  )
83
+
84
+
85
  response_text = ""
86
  try:
87
+ for chunk in client.models.generate_content_stream(
88
+ model=model,
89
+ contents=contents,
90
+ config=generate_content_config,
91
+ ):
92
+ response_text += chunk.text
93
  except Exception as e:
94
+ return f"Error during generation: {e}"
95
+ data = clean_json_string(response_text)
96
+ data = data[:-1]
97
  return response_text, ""
98
+
99
 
100
  if __name__ == '__main__':
101
+
102
  with gr.Blocks() as demo:
103
+ title=gr.Markdown("# Gemini 2.0 Flash + Websearch")
 
104
  output_textbox = gr.Markdown()
105
  input_textbox = gr.Textbox(lines=3, label="", placeholder="Enter message here...")
106
  submit_button = gr.Button("send")
107
+ submit_button.click(fn=generate,inputs=input_textbox,outputs=[output_textbox, input_textbox])
108
+ demo.launch(show_error=True)
109
+ """"""