File size: 4,978 Bytes
4ef118d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592cb1d
4ef118d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592cb1d
4ef118d
 
 
 
 
 
 
 
 
592cb1d
4ef118d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592cb1d
4ef118d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592cb1d
4ef118d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""
Research plan API routes.
"""

from __future__ import annotations

import asyncio
import json
from collections.abc import AsyncGenerator

from fastapi import APIRouter, Request
from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse

from ..providers import is_provider_supported
from ._request_secrets import get_llm_api_key
from ..services.research_plan import (
    generate_academic_research_plan,
    generate_research_plan,
    stream_generate_academic_research_plan,
    stream_generate_research_plan,
)

router = APIRouter(tags=["research-plan"])


@router.post("/research-plan")
async def research_plan(request: Request) -> JSONResponse:
    body = await request.json()
    provider = body.get("provider")
    message = body.get("message")
    api_key = get_llm_api_key(request)
    base_url = body.get("baseUrl")
    model = body.get("model")
    research_type = body.get("researchType") or "general"

    if not provider:
        return JSONResponse(status_code=400, content={"error": "Missing required field: provider"})
    if not message:
        return JSONResponse(status_code=400, content={"error": "Missing required field: message"})
    if not api_key:
        return JSONResponse(status_code=400, content={"error": "Missing required header: x-llm-api-key"})
    if not is_provider_supported(provider):
        return JSONResponse(status_code=400, content={"error": f"Unsupported provider: {provider}"})

    if research_type == "academic":
        plan = await generate_academic_research_plan(
            provider=provider,
            user_message=message,
            api_key=api_key,
            base_url=base_url,
            model=model,
        )
    else:
        plan = await generate_research_plan(
            provider=provider,
            user_message=message,
            api_key=api_key,
            base_url=base_url,
            model=model,
        )
    return JSONResponse(content={"plan": plan})


@router.post("/research-plan-stream")
async def research_plan_stream(request: Request) -> EventSourceResponse:
    body = await request.json()
    provider = body.get("provider")
    message = body.get("message")
    api_key = get_llm_api_key(request)
    base_url = body.get("baseUrl")
    model = body.get("model")
    thinking = body.get("thinking")
    temperature = body.get("temperature")
    top_k = body.get("top_k")
    top_p = body.get("top_p")
    frequency_penalty = body.get("frequency_penalty")
    presence_penalty = body.get("presence_penalty")
    research_type = body.get("researchType") or "general"

    if not provider or not message:
        return EventSourceResponse(
            _error_stream("Missing required fields: provider, message"),
            media_type="text/event-stream",
        )
    if not api_key:
        return EventSourceResponse(
            _error_stream("Missing required header: x-llm-api-key"),
            media_type="text/event-stream",
        )
    if not is_provider_supported(provider):
        return EventSourceResponse(
            _error_stream(f"Unsupported provider: {provider}"),
            media_type="text/event-stream",
        )

    async def event_generator() -> AsyncGenerator[dict[str, str], None]:
        try:
            # Choose the appropriate streaming function based on research type
            if research_type == "academic":
                stream_func = stream_generate_academic_research_plan(
                    provider=provider,
                    user_message=message,
                    api_key=api_key,
                    base_url=base_url,
                    model=model,
                    temperature=temperature,
                    top_p=top_p,
                    top_k=top_k,
                    frequency_penalty=frequency_penalty,
                    presence_penalty=presence_penalty,
                    thinking=thinking,
                )
            else:
                stream_func = stream_generate_research_plan(
                    provider=provider,
                    user_message=message,
                    api_key=api_key,
                    base_url=base_url,
                    model=model,
                    temperature=temperature,
                    top_p=top_p,
                    top_k=top_k,
                    frequency_penalty=frequency_penalty,
                    presence_penalty=presence_penalty,
                    thinking=thinking,
                )

            async for event in stream_func:
                if await request.is_disconnected():
                    break
                yield {"data": json.dumps(event, ensure_ascii=False)}
        except asyncio.CancelledError:
            return

    return EventSourceResponse(event_generator(), media_type="text/event-stream")


async def _error_stream(message: str) -> AsyncGenerator[dict[str, str], None]:
    yield {"data": json.dumps({"type": "error", "error": message}, ensure_ascii=False)}