File size: 1,872 Bytes
a9af6e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from fastapi import APIRouter, HTTPException, Depends
from fastapi.responses import StreamingResponse
import json
from typing import AsyncGenerator

from models.chat import ChatRequest, ChatResponse, StreamResponse
from services.chat_service import ChatService

chat_router = APIRouter(prefix="/chat", tags=["chat"])

def get_chat_service():
    return ChatService()

@chat_router.post("/", response_model=ChatResponse)
async def chat_completion(
    request: ChatRequest,
    chat_service: ChatService = Depends(get_chat_service)
):
    """
    Process chat completion request
    """
    try:
        response = await chat_service.process_chat(request)
        return response
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@chat_router.post("/stream")
async def chat_stream(
    request: ChatRequest,
    chat_service: ChatService = Depends(get_chat_service)
):
    """
    Stream chat completion response
    """
    try:
        async def generate_stream() -> AsyncGenerator[str, None]:
            async for chunk in chat_service.stream_chat(request):
                yield f"data: {json.dumps(chunk.model_dump())}\n\n"
            
            # Send final message
            final_chunk = StreamResponse(finished=True)
            yield f"data: {json.dumps(final_chunk.model_dump())}\n\n"

        return StreamingResponse(
            generate_stream(),
            media_type="text/plain",
            headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
        )
    except Exception as e:
        error_chunk = StreamResponse(error=str(e), finished=True)
        async def error_stream():
            yield f"data: {json.dumps(error_chunk.model_dump())}\n\n"
        
        return StreamingResponse(
            error_stream(),
            media_type="text/plain",
            status_code=500
        )