mgokg commited on
Commit
83992cb
·
verified ·
1 Parent(s): dc11235

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -30
app.py CHANGED
@@ -6,25 +6,19 @@ from google import genai
6
  from google.genai import types
7
  from gradio_client import Client
8
 
9
-
10
  route="""
11
  how to handle special case "zugverbindung".
12
  Wichtig: Dies Regeln gelten nur wenn eine zugverbindung angefragt wird, else answer prompt
13
  Regeln:
14
  Wenn eine Zugverbindung von {Startort} nach {Zielort} angefragt wird, return json object with Startort and Zielort.
15
  always follow json scheme below.
16
-
17
  Wichtig: Gib absolut keinen Text vor oder nach dem JSON aus (keine Erklärungen, kein "Hier ist das Ergebnis").
18
-
19
  {
20
  "start_loc": "fill in Startort here",
21
  "dest_loc": "fill in Zielort here"
22
  }
23
-
24
  """
25
 
26
-
27
-
28
  def clean_json_string(json_str):
29
  """
30
  Removes any comments or prefixes before the actual JSON content.
@@ -39,22 +33,95 @@ def clean_json_string(json_str):
39
 
40
  # Extract everything from the first JSON marker
41
  cleaned_str = json_str[json_start:]
 
 
 
 
 
42
  return cleaned_str
43
- # Verify it's valid JSON
 
 
 
 
 
 
 
 
 
 
 
 
44
  try:
45
- json.loads(cleaned_str)
46
- return cleaned_str
47
- except json.JSONDecodeError:
48
- return json_str # Return original if cleaning results in invalid JSON
 
 
 
 
49
 
50
  def generate(input_text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  try:
52
  client = genai.Client(
53
  api_key=os.environ.get("GEMINI_API_KEY"),
54
  )
55
  except Exception as e:
56
- return f"Error initializing client: {e}. Make sure GEMINI_API_KEY is set."
57
-
58
  model = "gemini-flash-latest"
59
  contents = [
60
  types.Content(
@@ -64,41 +131,39 @@ def generate(input_text):
64
  ],
65
  ),
66
  ]
 
67
  tools = [
68
  types.Tool(google_search=types.GoogleSearch()),
69
  ]
 
70
  generate_content_config = types.GenerateContentConfig(
71
  temperature=0.4,
72
- thinking_config = types.ThinkingConfig(
73
  thinking_budget=0,
74
  ),
75
  tools=tools,
76
  response_mime_type="text/plain",
77
  )
78
-
79
-
80
  response_text = ""
81
  try:
82
- for chunk in client.models.generate_content_stream(
83
- model=model,
84
- contents=contents,
85
- config=generate_content_config,
86
- ):
87
- response_text += chunk.text
88
  except Exception as e:
89
- return f"Error during generation: {e}"
90
- data = clean_json_string(response_text)
91
- data = data[:-1]
92
  return response_text, ""
93
-
94
 
95
  if __name__ == '__main__':
96
-
97
  with gr.Blocks() as demo:
98
- title=gr.Markdown("# Gemini 2.0 Flash + Websearch")
 
99
  output_textbox = gr.Markdown()
100
  input_textbox = gr.Textbox(lines=3, label="", placeholder="Enter message here...")
101
  submit_button = gr.Button("send")
102
- submit_button.click(fn=generate,inputs=input_textbox,outputs=[output_textbox, input_textbox])
103
  demo.launch(show_error=True)
104
-
 
6
  from google.genai import types
7
  from gradio_client import Client
8
 
 
9
  route="""
10
  how to handle special case "zugverbindung".
11
  Wichtig: Dies Regeln gelten nur wenn eine zugverbindung angefragt wird, else answer prompt
12
  Regeln:
13
  Wenn eine Zugverbindung von {Startort} nach {Zielort} angefragt wird, return json object with Startort and Zielort.
14
  always follow json scheme below.
 
15
  Wichtig: Gib absolut keinen Text vor oder nach dem JSON aus (keine Erklärungen, kein "Hier ist das Ergebnis").
 
16
  {
17
  "start_loc": "fill in Startort here",
18
  "dest_loc": "fill in Zielort here"
19
  }
 
20
  """
21
 
 
 
22
  def clean_json_string(json_str):
23
  """
24
  Removes any comments or prefixes before the actual JSON content.
 
33
 
34
  # Extract everything from the first JSON marker
35
  cleaned_str = json_str[json_start:]
36
+
37
+ # Remove trailing markdown code fence if present
38
+ if cleaned_str.endswith('```'):
39
+ cleaned_str = cleaned_str[:-3].rstrip()
40
+
41
  return cleaned_str
42
+
43
+ def is_train_connection_query(text):
44
+ """
45
+ Checks if the query is about train connections.
46
+ """
47
+ keywords = ['zugverbindung', 'zug', 'bahn', 'von', 'nach', 'fahrt', 'reise', 'verbindung']
48
+ text_lower = text.lower()
49
+ return any(keyword in text_lower for keyword in keywords) and ('von' in text_lower or 'nach' in text_lower)
50
+
51
+ def get_train_connections(start_loc, dest_loc):
52
+ """
53
+ Calls the MCP server to get train connections.
54
+ """
55
  try:
56
+ client = Client("mgokg/db-timetable-api")
57
+ result = client.predict(
58
+ query=f"Verbindung von {start_loc} nach {dest_loc}",
59
+ api_name="/db_timetable_api_ui_wrapper"
60
+ )
61
+ return result
62
+ except Exception as e:
63
+ return f"Error calling MCP server: {e}"
64
 
65
  def generate(input_text):
66
+ # Check if this is a train connection query
67
+ if is_train_connection_query(input_text):
68
+ try:
69
+ # Use Gemini to extract start and destination
70
+ client = genai.Client(
71
+ api_key=os.environ.get("GEMINI_API_KEY"),
72
+ )
73
+
74
+ model = "gemini-flash-latest"
75
+ contents = [
76
+ types.Content(
77
+ role="user",
78
+ parts=[
79
+ types.Part.from_text(text=f"{route}\n\nUser query: {input_text}"),
80
+ ],
81
+ ),
82
+ ]
83
+
84
+ generate_content_config = types.GenerateContentConfig(
85
+ temperature=0.1,
86
+ response_mime_type="application/json",
87
+ )
88
+
89
+ response = client.models.generate_content(
90
+ model=model,
91
+ contents=contents,
92
+ config=generate_content_config,
93
+ )
94
+
95
+ # Parse the JSON response
96
+ json_str = clean_json_string(response.text)
97
+ location_data = json.loads(json_str)
98
+
99
+ start_loc = location_data.get("start_loc", "")
100
+ dest_loc = location_data.get("dest_loc", "")
101
+
102
+ if start_loc and dest_loc:
103
+ # Call MCP server for train connections
104
+ train_data = get_train_connections(start_loc, dest_loc)
105
+
106
+ # Format the response nicely
107
+ formatted_response = f"## Zugverbindung von {start_loc} nach {dest_loc}\n\n{train_data}"
108
+ return formatted_response, ""
109
+ else:
110
+ return "Konnte Start- oder Zielort nicht identifizieren. Bitte gib beide Orte an.", ""
111
+
112
+ except json.JSONDecodeError as e:
113
+ return f"Error parsing location data: {e}", ""
114
+ except Exception as e:
115
+ return f"Error processing train connection request: {e}", ""
116
+
117
+ # If not a train query, use regular Gemini with web search
118
  try:
119
  client = genai.Client(
120
  api_key=os.environ.get("GEMINI_API_KEY"),
121
  )
122
  except Exception as e:
123
+ return f"Error initializing client: {e}. Make sure GEMINI_API_KEY is set.", ""
124
+
125
  model = "gemini-flash-latest"
126
  contents = [
127
  types.Content(
 
131
  ],
132
  ),
133
  ]
134
+
135
  tools = [
136
  types.Tool(google_search=types.GoogleSearch()),
137
  ]
138
+
139
  generate_content_config = types.GenerateContentConfig(
140
  temperature=0.4,
141
+ thinking_config=types.ThinkingConfig(
142
  thinking_budget=0,
143
  ),
144
  tools=tools,
145
  response_mime_type="text/plain",
146
  )
147
+
 
148
  response_text = ""
149
  try:
150
+ for chunk in client.models.generate_content_stream(
151
+ model=model,
152
+ contents=contents,
153
+ config=generate_content_config,
154
+ ):
155
+ response_text += chunk.text
156
  except Exception as e:
157
+ return f"Error during generation: {e}", ""
158
+
 
159
  return response_text, ""
 
160
 
161
  if __name__ == '__main__':
 
162
  with gr.Blocks() as demo:
163
+ title = gr.Markdown("# Gemini 2.0 Flash + Websearch + Deutsche Bahn Verbindungen")
164
+ gr.Markdown("Frage nach Zugverbindungen (z.B. 'Zugverbindung von Berlin nach München') oder stelle allgemeine Fragen.")
165
  output_textbox = gr.Markdown()
166
  input_textbox = gr.Textbox(lines=3, label="", placeholder="Enter message here...")
167
  submit_button = gr.Button("send")
168
+ submit_button.click(fn=generate, inputs=input_textbox, outputs=[output_textbox, input_textbox])
169
  demo.launch(show_error=True)