Spaces:
Sleeping
Sleeping
File size: 4,978 Bytes
4ef118d 592cb1d 4ef118d 592cb1d 4ef118d 592cb1d 4ef118d 592cb1d 4ef118d 592cb1d 4ef118d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | """
Research plan API routes.
"""
from __future__ import annotations
import asyncio
import json
from collections.abc import AsyncGenerator
from fastapi import APIRouter, Request
from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse
from ..providers import is_provider_supported
from ._request_secrets import get_llm_api_key
from ..services.research_plan import (
generate_academic_research_plan,
generate_research_plan,
stream_generate_academic_research_plan,
stream_generate_research_plan,
)
router = APIRouter(tags=["research-plan"])
@router.post("/research-plan")
async def research_plan(request: Request) -> JSONResponse:
body = await request.json()
provider = body.get("provider")
message = body.get("message")
api_key = get_llm_api_key(request)
base_url = body.get("baseUrl")
model = body.get("model")
research_type = body.get("researchType") or "general"
if not provider:
return JSONResponse(status_code=400, content={"error": "Missing required field: provider"})
if not message:
return JSONResponse(status_code=400, content={"error": "Missing required field: message"})
if not api_key:
return JSONResponse(status_code=400, content={"error": "Missing required header: x-llm-api-key"})
if not is_provider_supported(provider):
return JSONResponse(status_code=400, content={"error": f"Unsupported provider: {provider}"})
if research_type == "academic":
plan = await generate_academic_research_plan(
provider=provider,
user_message=message,
api_key=api_key,
base_url=base_url,
model=model,
)
else:
plan = await generate_research_plan(
provider=provider,
user_message=message,
api_key=api_key,
base_url=base_url,
model=model,
)
return JSONResponse(content={"plan": plan})
@router.post("/research-plan-stream")
async def research_plan_stream(request: Request) -> EventSourceResponse:
body = await request.json()
provider = body.get("provider")
message = body.get("message")
api_key = get_llm_api_key(request)
base_url = body.get("baseUrl")
model = body.get("model")
thinking = body.get("thinking")
temperature = body.get("temperature")
top_k = body.get("top_k")
top_p = body.get("top_p")
frequency_penalty = body.get("frequency_penalty")
presence_penalty = body.get("presence_penalty")
research_type = body.get("researchType") or "general"
if not provider or not message:
return EventSourceResponse(
_error_stream("Missing required fields: provider, message"),
media_type="text/event-stream",
)
if not api_key:
return EventSourceResponse(
_error_stream("Missing required header: x-llm-api-key"),
media_type="text/event-stream",
)
if not is_provider_supported(provider):
return EventSourceResponse(
_error_stream(f"Unsupported provider: {provider}"),
media_type="text/event-stream",
)
async def event_generator() -> AsyncGenerator[dict[str, str], None]:
try:
# Choose the appropriate streaming function based on research type
if research_type == "academic":
stream_func = stream_generate_academic_research_plan(
provider=provider,
user_message=message,
api_key=api_key,
base_url=base_url,
model=model,
temperature=temperature,
top_p=top_p,
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
thinking=thinking,
)
else:
stream_func = stream_generate_research_plan(
provider=provider,
user_message=message,
api_key=api_key,
base_url=base_url,
model=model,
temperature=temperature,
top_p=top_p,
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
thinking=thinking,
)
async for event in stream_func:
if await request.is_disconnected():
break
yield {"data": json.dumps(event, ensure_ascii=False)}
except asyncio.CancelledError:
return
return EventSourceResponse(event_generator(), media_type="text/event-stream")
async def _error_stream(message: str) -> AsyncGenerator[dict[str, str], None]:
yield {"data": json.dumps({"type": "error", "error": message}, ensure_ascii=False)}
|