| | import os |
| | import gradio as gr |
| | from google import genai |
| | from google.genai import types |
| | from gradio_client import Client |
| |
|
| | |
| | |
| | db_client = Client("mgokg/db-timetable-api") |
| |
|
| | def get_train_connection(dep: str, dest: str): |
| | """ |
| | Fetches the train timetable between two cities using the external API. |
| | """ |
| | try: |
| | |
| | result = db_client.predict( |
| | dep=dep, |
| | dest=dest, |
| | api_name="/db_timetable_api_ui_wrapper" |
| | ) |
| | return result |
| | except Exception as e: |
| | return f"Error fetching timetable: {str(e)}" |
| |
|
| | |
| | |
| | train_tool = types.FunctionDeclaration( |
| | name="get_train_connection", |
| | description="Find train connections and timetables between a start location (dep) and a destination (dest).", |
| | parameters=types.Schema( |
| | type=types.Type.OBJECT, |
| | properties={ |
| | "dep": types.Schema(type=types.Type.STRING, description="Departure city or station"), |
| | "dest": types.Schema(type=types.Type.STRING, description="Destination city or station"), |
| | }, |
| | required=["dep", "dest"] |
| | ) |
| | ) |
| |
|
| | |
| | tools_map = { |
| | "get_train_connection": get_train_connection |
| | } |
| |
|
| | def generate(input_text, history): |
| | |
| | try: |
| | client = genai.Client( |
| | api_key=os.environ.get("GEMINI_API_KEY"), |
| | ) |
| | except Exception as e: |
| | yield f"Error initializing client: {e}. Make sure GEMINI_API_KEY is set.", history |
| | return |
| |
|
| | model = "gemini-2.0-flash-exp" |
| |
|
| | |
| | |
| | contents = [ |
| | types.Content( |
| | role="user", |
| | parts=[types.Part.from_text(text=input_text)], |
| | ), |
| | ] |
| |
|
| | |
| | tools = [ |
| | types.Tool(google_search=types.GoogleSearch()), |
| | types.Tool(function_declarations=[train_tool]), |
| | ] |
| |
|
| | generate_content_config = types.GenerateContentConfig( |
| | temperature=0.4, |
| | tools=tools, |
| | |
| | |
| | |
| | ) |
| |
|
| | response_text = "" |
| | |
| | |
| | try: |
| | response = client.models.generate_content( |
| | model=model, |
| | contents=contents, |
| | config=generate_content_config, |
| | ) |
| | except Exception as e: |
| | yield f"Error during generation: {e}", history |
| | return |
| |
|
| | |
| | |
| | if response.candidates and response.candidates[0].content.parts: |
| | first_part = response.candidates[0].content.parts[0] |
| | |
| | |
| | if first_part.function_call: |
| | fn_name = first_part.function_call.name |
| | fn_args = first_part.function_call.args |
| | |
| | |
| | if fn_name in tools_map: |
| | status_msg = f"๐ Checking trains from {fn_args.get('dep')} to {fn_args.get('dest')}..." |
| | yield status_msg, history |
| | |
| | api_result = tools_map[fn_name](**fn_args) |
| | |
| | |
| | |
| | contents.append(response.candidates[0].content) |
| | contents.append( |
| | types.Content( |
| | role="tool", |
| | parts=[ |
| | types.Part.from_function_response( |
| | name=fn_name, |
| | response={"result": api_result} |
| | ) |
| | ] |
| | ) |
| | ) |
| | |
| | |
| | stream = client.models.generate_content_stream( |
| | model=model, |
| | contents=contents, |
| | config=generate_content_config |
| | ) |
| | |
| | final_text = "" |
| | for chunk in stream: |
| | if chunk.text: |
| | final_text += chunk.text |
| | yield final_text, history |
| | return |
| |
|
| | |
| | if response.text: |
| | yield response.text, history |
| |
|
| | if __name__ == '__main__': |
| | with gr.Blocks() as demo: |
| | gr.Markdown("# Gemini 2.0 Flash + DB Timetable Tool") |
| | |
| | chatbot = gr.Chatbot(label="Conversation", height=400) |
| | msg = gr.Textbox(lines=1, label="Ask about trains (e.g., 'Train from Berlin to Munich')", placeholder="Enter message here...") |
| | clear = gr.Button("Clear") |
| |
|
| | def user(user_message, history): |
| | return "", history + [[user_message, None]] |
| |
|
| | def bot(history): |
| | user_message = history[-1][0] |
| | |
| | for partial_response, _ in generate(user_message, history): |
| | history[-1][1] = partial_response |
| | yield history |
| |
|
| | msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( |
| | bot, chatbot, chatbot |
| | ) |
| | clear.click(lambda: None, None, chatbot, queue=False) |
| |
|
| | demo.launch(show_error=True) |