File size: 2,120 Bytes
e0447d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from fastapi import FastAPI, HTTPException, Header
from pydantic import BaseModel
from typing import List

from starlette.middleware.cors import CORSMiddleware
from sse_starlette.sse import EventSourceResponse

from groq import AsyncGroq, Groq
from groq.resources import Models
from groq.types import ModelList
from groq.types.chat.completion_create_params import Message

from json import dumps
import async_timeout
import asyncio

GENERATION_TIMEOUT_SEC = 60

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"]
)


class ChatInput(BaseModel):
    model: str
    messages: List[Message]
    stream: bool
    temperature: float = 0
    max_tokens: int = 100
    user: str = "user"


async def get_groq_response(client: AsyncGroq, req: ChatInput):
    async with async_timeout.timeout(GENERATION_TIMEOUT_SEC):
        try:
            stream = await client.chat.completions.create(
                messages=req.messages,
                model=req.model,
                temperature=req.temperature,
                max_tokens=req.max_tokens,
                stream=req.stream,
            )

            async for chunk in stream:
                yield {"data": dumps(chunk.dict())}

        except asyncio.TimeoutError:
            raise HTTPException(status_code=504, detail="Stream timed out")


@app.get("/models")
async def models(authorization: str = Header()) -> ModelList:
    client = Groq(
        api_key=authorization.split(" ")[-1],
    )
    models = Models(client=client).list()

    return models


@app.post("/chat/completions")
async def completion(req: ChatInput, authorization: str = Header()):
    client = AsyncGroq(
        api_key=authorization.split(" ")[-1],
    )

    if req.stream:
        return EventSourceResponse(get_groq_response(client, req))

    response = await client.chat.completions.create(
        messages=req.messages,
        model=req.model,
        temperature=req.temperature,
        max_tokens=req.max_tokens,
        stream=req.stream,
    )

    return response