Qurio / backend-python /src /routes /research_plan.py
veeiiinnnnn's picture
new
592cb1d
"""
Research plan API routes.
"""
from __future__ import annotations
import asyncio
import json
from collections.abc import AsyncGenerator
from fastapi import APIRouter, Request
from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse
from ..providers import is_provider_supported
from ._request_secrets import get_llm_api_key
from ..services.research_plan import (
generate_academic_research_plan,
generate_research_plan,
stream_generate_academic_research_plan,
stream_generate_research_plan,
)
router = APIRouter(tags=["research-plan"])
@router.post("/research-plan")
async def research_plan(request: Request) -> JSONResponse:
body = await request.json()
provider = body.get("provider")
message = body.get("message")
api_key = get_llm_api_key(request)
base_url = body.get("baseUrl")
model = body.get("model")
research_type = body.get("researchType") or "general"
if not provider:
return JSONResponse(status_code=400, content={"error": "Missing required field: provider"})
if not message:
return JSONResponse(status_code=400, content={"error": "Missing required field: message"})
if not api_key:
return JSONResponse(status_code=400, content={"error": "Missing required header: x-llm-api-key"})
if not is_provider_supported(provider):
return JSONResponse(status_code=400, content={"error": f"Unsupported provider: {provider}"})
if research_type == "academic":
plan = await generate_academic_research_plan(
provider=provider,
user_message=message,
api_key=api_key,
base_url=base_url,
model=model,
)
else:
plan = await generate_research_plan(
provider=provider,
user_message=message,
api_key=api_key,
base_url=base_url,
model=model,
)
return JSONResponse(content={"plan": plan})
@router.post("/research-plan-stream")
async def research_plan_stream(request: Request) -> EventSourceResponse:
body = await request.json()
provider = body.get("provider")
message = body.get("message")
api_key = get_llm_api_key(request)
base_url = body.get("baseUrl")
model = body.get("model")
thinking = body.get("thinking")
temperature = body.get("temperature")
top_k = body.get("top_k")
top_p = body.get("top_p")
frequency_penalty = body.get("frequency_penalty")
presence_penalty = body.get("presence_penalty")
research_type = body.get("researchType") or "general"
if not provider or not message:
return EventSourceResponse(
_error_stream("Missing required fields: provider, message"),
media_type="text/event-stream",
)
if not api_key:
return EventSourceResponse(
_error_stream("Missing required header: x-llm-api-key"),
media_type="text/event-stream",
)
if not is_provider_supported(provider):
return EventSourceResponse(
_error_stream(f"Unsupported provider: {provider}"),
media_type="text/event-stream",
)
async def event_generator() -> AsyncGenerator[dict[str, str], None]:
try:
# Choose the appropriate streaming function based on research type
if research_type == "academic":
stream_func = stream_generate_academic_research_plan(
provider=provider,
user_message=message,
api_key=api_key,
base_url=base_url,
model=model,
temperature=temperature,
top_p=top_p,
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
thinking=thinking,
)
else:
stream_func = stream_generate_research_plan(
provider=provider,
user_message=message,
api_key=api_key,
base_url=base_url,
model=model,
temperature=temperature,
top_p=top_p,
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
thinking=thinking,
)
async for event in stream_func:
if await request.is_disconnected():
break
yield {"data": json.dumps(event, ensure_ascii=False)}
except asyncio.CancelledError:
return
return EventSourceResponse(event_generator(), media_type="text/event-stream")
async def _error_stream(message: str) -> AsyncGenerator[dict[str, str], None]:
yield {"data": json.dumps({"type": "error", "error": message}, ensure_ascii=False)}