webapi / app.py
mgokg's picture
Update app.py
6d108cd verified
raw
history blame
6.08 kB
import os
import gradio as gr
from google import genai
from google.genai import types
from gradio_client import Client
# 1. Initialize the client for the external DB Timetable App
# We use the Hugging Face Space ID provided in your documentation
db_client = Client("mgokg/db-timetable-api")
def get_train_connection(dep: str, dest: str):
"""
Fetches the train timetable between two cities using the external API.
"""
try:
# Calling the specific endpoint mentioned in the MCP docs: db_timetable_api_ui_wrapper
result = db_client.predict(
dep=dep,
dest=dest,
api_name="/db_timetable_api_ui_wrapper"
)
return result
except Exception as e:
return f"Error fetching timetable: {str(e)}"
# 2. Define the tool for Gemini
# This tells the model how to use the Python function above
train_tool = types.FunctionDeclaration(
name="get_train_connection",
description="Find train connections and timetables between a start location (dep) and a destination (dest).",
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
"dep": types.Schema(type=types.Type.STRING, description="Departure city or station"),
"dest": types.Schema(type=types.Type.STRING, description="Destination city or station"),
},
required=["dep", "dest"]
)
)
# Map the string name to the actual python function
tools_map = {
"get_train_connection": get_train_connection
}
def generate(input_text, history):
# Initialize Gemini Client
try:
client = genai.Client(
api_key=os.environ.get("GEMINI_API_KEY"),
)
except Exception as e:
yield f"Error initializing client: {e}. Make sure GEMINI_API_KEY is set.", history
return
model = "gemini-2.0-flash-exp" # Or "gemini-2.0-flash" depending on availability
# Prepare the conversation history for context
# (Optional: You can add previous history here if you want multi-turn chat)
contents = [
types.Content(
role="user",
parts=[types.Part.from_text(text=input_text)],
),
]
# 3. Configure tools (Google Search + Our Custom DB Tool)
tools = [
types.Tool(google_search=types.GoogleSearch()),
types.Tool(function_declarations=[train_tool]),
]
generate_content_config = types.GenerateContentConfig(
temperature=0.4,
tools=tools,
# Automatic function calling allows the SDK to handle the loop,
# but for granular control in Gradio, we often handle it manually below
# or rely on the model to return a function call part.
)
response_text = ""
# First API Call: Ask the model what to do
try:
response = client.models.generate_content(
model=model,
contents=contents,
config=generate_content_config,
)
except Exception as e:
yield f"Error during generation: {e}", history
return
# 4. Check if the model wants to call a function
# We look at the first candidate's first part
if response.candidates and response.candidates[0].content.parts:
first_part = response.candidates[0].content.parts[0]
# If it's a function call
if first_part.function_call:
fn_name = first_part.function_call.name
fn_args = first_part.function_call.args
# Execute the tool
if fn_name in tools_map:
status_msg = f"๐Ÿ”„ Checking trains from {fn_args.get('dep')} to {fn_args.get('dest')}..."
yield status_msg, history
api_result = tools_map[fn_name](**fn_args)
# Send the result back to Gemini
# We append the model's function call and our function response to history
contents.append(response.candidates[0].content)
contents.append(
types.Content(
role="tool",
parts=[
types.Part.from_function_response(
name=fn_name,
response={"result": api_result}
)
]
)
)
# Second API Call: Get the final natural language answer
stream = client.models.generate_content_stream(
model=model,
contents=contents,
config=generate_content_config # Keep tools enabled just in case
)
final_text = ""
for chunk in stream:
if chunk.text:
final_text += chunk.text
yield final_text, history
return
# If no function call, just return the text (e.g., normal chat or Google Search result)
if response.text:
yield response.text, history
if __name__ == '__main__':
with gr.Blocks() as demo:
gr.Markdown("# Gemini 2.0 Flash + DB Timetable Tool")
chatbot = gr.Chatbot(label="Conversation", height=400)
msg = gr.Textbox(lines=1, label="Ask about trains (e.g., 'Train from Berlin to Munich')", placeholder="Enter message here...")
clear = gr.Button("Clear")
def user(user_message, history):
return "", history + [[user_message, None]]
def bot(history):
user_message = history[-1][0]
# Call generate and update the last message in history
for partial_response, _ in generate(user_message, history):
history[-1][1] = partial_response
yield history
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, chatbot, chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
demo.launch(show_error=True)