mtyrrell's picture
sources
f852f01
raw
history blame
6.66 kB
import gradio as gr
import asyncio
import logging
import json
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from utils.generator import generate_streaming, generate
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler('app.log')
]
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------
# FastAPI app for ChatUI endpoints
# ---------------------------------------------------------------------
app = FastAPI(title="ChatFed Generator", version="1.0.0")
@app.post("/generate")
async def generate_endpoint(request: Request):
"""
Non-streaming generation endpoint for ChatUI format.
Expected request body:
{
"query": "user question",
"context": [...] // list of retrieval results
}
Returns ChatUI format:
{
"answer": "response with citations [1][2]",
"sources": [{"link": "doc://...", "title": "..."}]
}
"""
try:
body = await request.json()
query = body.get("query", "")
context = body.get("context", [])
result = await generate(query, context, chatui_format=True)
return result
except Exception as e:
logger.exception("Generation endpoint failed")
return {"error": str(e)}
@app.post("/generate/stream")
async def generate_stream_endpoint(request: Request):
"""
Streaming generation endpoint for ChatUI format.
Expected request body:
{
"query": "user question",
"context": [...] // list of retrieval results
}
Returns Server-Sent Events in ChatUI format:
event: data
data: "response chunk"
event: sources
data: {"sources": [...]}
event: end
"""
try:
body = await request.json()
query = body.get("query", "")
context = body.get("context", [])
async def event_stream():
async for event in generate_streaming(query, context, chatui_format=True):
event_type = event["event"]
event_data = event["data"]
if event_type == "data":
yield f"event: data\ndata: {json.dumps(event_data)}\n\n"
elif event_type == "sources":
yield f"event: sources\ndata: {json.dumps(event_data)}\n\n"
elif event_type == "end":
yield f"event: end\ndata: {{}}\n\n"
elif event_type == "error":
yield f"event: error\ndata: {json.dumps(event_data)}\n\n"
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "*",
}
)
except Exception as e:
logger.exception("Streaming endpoint failed")
async def error_stream():
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
return StreamingResponse(
error_stream(),
media_type="text/event-stream"
)
# ---------------------------------------------------------------------
# Wrapper function to handle async streaming for Gradio
# ---------------------------------------------------------------------
def generate_streaming_wrapper(query: str, context: str):
"""Wrapper to convert async generator to sync generator for Gradio"""
logger.info(f"Starting generation request - Query length: {len(query)}, Context length: {len(context)}")
async def _async_generator():
async for chunk in generate_streaming(query, context, chatui_format=False):
yield chunk
# Create a new event loop for this thread
try:
loop = asyncio.get_event_loop()
logger.debug("Using existing event loop")
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
logger.debug("Created new event loop")
# Convert async generator to sync generator
async_gen = _async_generator()
# Accumulate chunks for Gradio streaming
accumulated_text = ""
chunk_count = 0
while True:
try:
chunk = loop.run_until_complete(async_gen.__anext__())
accumulated_text += chunk
chunk_count += 1
yield accumulated_text # Yield the accumulated text, not just the chunk
except StopAsyncIteration:
logger.info(f"Generation completed - Total chunks: {chunk_count}, Final text length: {len(accumulated_text)}")
break
# ---------------------------------------------------------------------
# Gradio Interface with MCP support and streaming
# ---------------------------------------------------------------------
logger.info("Initializing Gradio interface")
ui = gr.Interface(
fn=generate_streaming_wrapper, # Use streaming wrapper function
inputs=[
gr.Textbox(
label="Query",
lines=2,
placeholder="Enter query here",
info="The query to search for in the vector database"
),
gr.Textbox(
label="Context",
lines=8,
placeholder="Paste relevant context here",
info="Provide the context/documents to use for answering. The API expects a list of dictionaries, but the UI should except anything"
),
],
outputs=gr.Textbox(
label="Generated Answer",
lines=6,
show_copy_button=True
),
title="ChatFed Generation Module",
description="Ask questions based on provided context. Intended for use in RAG pipelines as an MCP server with other ChatFed modules (i.e. context supplied by semantic retriever service).",
api_name="generate"
)
# Mount Gradio app to FastAPI
app = gr.mount_gradio_app(app, ui, path="/gradio")
# Launch with MCP server enabled
if __name__ == "__main__":
import uvicorn
logger.info("Starting ChatFed Generation Module server")
logger.info("FastAPI server will be available at http://0.0.0.0:7860")
logger.info("Gradio UI will be available at http://0.0.0.0:7860/gradio")
logger.info("ChatUI endpoints: /generate (non-streaming), /generate/stream (streaming)")
uvicorn.run(app, host="0.0.0.0", port=7860)