File size: 4,211 Bytes
f7dc7b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501deb0
f7dc7b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d832c6f
f7dc7b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os

DEFAULT_MODEL = os.getenv("BOT", default="Claude-3-Sonnet")
LISTEN_PORT = int(os.getenv("PORT", default=7860))
BASE_URL = os.getenv("BASE", default="https://api.poe.com/bot/")
POE_API_KEY = os.getenv("POE_API_KEY","")
AUTHORIZATION_API_KEY = os.getenv("AUTHORIZATION_API_KEY","")
if not POE_API_KEY or not AUTHORIZATION_API_KEY:
    raise ValueError("POE_API_KEY and AUTHORIZATION_API_KEY must be set in the environment variables")

from fastapi import FastAPI, Request, Header, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware

import uvicorn
from typing import AsyncGenerator
import json

from fastapi_poe.types import ProtocolMessage
from fastapi_poe.client import get_bot_response


app = FastAPI()
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])

async def generate_responses(api_key: str, formatted_messages: list, bot_name: str) -> AsyncGenerator[str, None]:
    """An async generator to stream responses from the POE API."""

    # Create a base response template
    response_template = {
        "id": "chatcmpl-123",
        "object": "chat.completion.chunk",
        "created": 1694268190,
        "model": bot_name,
        "choices": [{
            "index": 0,
            "delta": {
                "content": "",  # Placeholder, to be filled for each partial response
                "logprobs": None,
                "finish_reason": None
            }
        }]
    }

    async for partial in get_bot_response(messages=formatted_messages, bot_name=bot_name, api_key=api_key, 
                                          base_url=BASE_URL,
                                          skip_system_prompt=False,
                                          logit_bias={'24383':-100}):

        # Fill the required field for this partial response
        response_template["choices"][0]["delta"]["content"] = partial.text

        # Create the SSE formatted string, and then yield
        yield f"data: {json.dumps(response_template)}\n\n"

    # Send termination sequence
    response_template["choices"][0]["delta"] = {}  # Empty 'delta' field
    response_template["choices"][0]["finish_reason"] = "stop"  # Set 'finish_reason' to 'stop'

    yield f"data: {json.dumps(response_template)}\n\ndata: [DONE]\n\n"


@app.post("/hf/v1/chat/completions")
async def chat_completions(request: Request, authorization: str = Header(None)):
    if not authorization:
        raise HTTPException(status_code=401, detail="Authorization header is missing")

    api_key = authorization.split(" ")[1]  # Assuming the header follows the standard format: "Bearer $API_KEY"
    if  api_key != AUTHORIZATION_API_KEY:
        return HTTPException(status_code=401, detail="Invalid API Key")
    body = await request.json()
    
    # Extract bot_name (model) and messages from the request body
    bot_name = body.get("model", DEFAULT_MODEL)  # Defaulting to a specific bot if not provided
    messages = body.get("messages", [])
    
    formatted_messages = [ProtocolMessage(role=msg["role"].lower().replace("assistant", "bot"),
                                      content=msg["content"],
                                      temperature=msg.get("temperature", 0.95)) 
                      for msg in messages]

    async def response_stream() -> AsyncGenerator[str, None]:
        async for response_content in generate_responses(POE_API_KEY, formatted_messages, bot_name):
            # Assuming each response_content is a complete "message" response from the bot.
            # Adjust according to actual response pattern if needed.
            yield response_content
        
    # Stream responses back to the client
    # Wrap the streamed content to fit the desired response format
    return StreamingResponse(response_stream(), media_type="application/json")

if __name__ == '__main__':
    try:
        import uvloop
    except ImportError:
        uvloop = None
    if uvloop:
        uvloop.install()
    uvicorn.run(app, host="0.0.0.0", port=LISTEN_PORT)