webapi / app.py
mgokg's picture
Update app.py
ab0e4c1 verified
raw
history blame
2.13 kB
import os
import asyncio
import gradio as gr
from google import genai
from google.genai import types
from mcp import ClientSession
from mcp.client.sse import sse_client # Spezifischer Transport für Gradio/HF
async def generate(input_text):
# WICHTIG: Gradio MCP Server benötigen oft das Suffix /sse
mcp_url = "https://mgokg-db-timetable-api.hf.space"
try:
client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY"))
# SSE Transport nutzen, um den 'text/html' Fehler zu vermeiden
async with sse_client(url=mcp_url) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as mcp_session:
await mcp_session.initialize()
model_id = "gemini-2.0-flash"
generate_content_config = types.GenerateContentConfig(
temperature=0.4,
tools=[
types.Tool(google_search=types.GoogleSearch()),
mcp_session # Reicht die Tools des DB-Servers an Gemini durch
],
)
response_text = ""
async for chunk in client.aio.models.generate_content_stream(
model=model_id,
contents=input_text,
config=generate_content_config,
):
if chunk.text:
response_text += chunk.text
return response_text, ""
except Exception as e:
return f"Verbindung zum DB-Fahrplan fehlgeschlagen: {str(e)}", ""
def gradio_wrapper(input_text):
return asyncio.run(generate(input_text))
if __name__ == '__main__':
with gr.Blocks() as demo:
gr.Markdown("# Gemini Flash + DB Timetable")
input_tx = gr.Textbox(label="Anfrage", placeholder="Wann fährt der nächste Zug von Berlin nach Hamburg?")
btn = gr.Button("Senden")
output_md = gr.Markdown()
btn.click(fn=gradio_wrapper, inputs=input_tx, outputs=[output_md, input_tx])
demo.launch()