File size: 3,405 Bytes
d4b63e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec4c32b
d4b63e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#!/usr/bin/env python3
"""OpenAI-compatible Chat Completions API server for the Fourth GPT model."""

import time
import uuid
import json
import os
from fastapi import FastAPI, HTTPException, Header
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional
from model import FourthModel

app = FastAPI(title="Fourth GPT API", version="1.0.0")
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

fourth = FourthModel()


class Message(BaseModel):
    role: str
    content: str


class ChatCompletionRequest(BaseModel):
    model: str = "fourth-gpt"
    messages: list[Message]
    max_tokens: int = 128
    temperature: float = 0.5
    stream: bool = False


class Choice(BaseModel):
    index: int = 0
    message: Message
    finish_reason: str = "stop"


class Usage(BaseModel):
    prompt_tokens: int
    completion_tokens: int
    total_tokens: int


class ChatCompletionResponse(BaseModel):
    id: str
    object: str = "chat.completion"
    created: int
    model: str
    choices: list[Choice]
    usage: Usage


class ModelInfo(BaseModel):
    id: str
    object: str = "model"
    created: int = 0
    owned_by: str = "fourth"


class ModelList(BaseModel):
    object: str = "list"
    data: list[ModelInfo]


@app.on_event("startup")
def startup():
    fourth.load()


@app.get("/")
def root():
    return {"message": "Fourth GPT API is running", "model": "fourth-gpt", "params": 344256}


@app.get("/v1/models")
@app.get("/models")
def list_models():
    return ModelList(data=[
        ModelInfo(id="fourth-gpt", created=int(time.time())),
    ])


@app.get("/v1/models/{model_id}")
@app.get("/models/{model_id}")
def get_model(model_id: str):
    if model_id != "fourth-gpt":
        raise HTTPException(status_code=404, detail="Model not found")
    return ModelInfo(id="fourth-gpt", created=int(time.time()))


@app.post("/v1/chat/completions")
@app.post("/chat/completions")
def chat_completions(
    req: ChatCompletionRequest,
    authorization: Optional[str] = Header(None),
):
    if not req.messages:
        raise HTTPException(status_code=400, detail="messages must not be empty")

    user_msg = None
    for msg in reversed(req.messages):
        if msg.role == "user":
            user_msg = msg.content
            break

    if user_msg is None:
        raise HTTPException(status_code=400, detail="No user message found")

    response_text = fourth.generate(
        prompt=user_msg,
        max_tokens=req.max_tokens,
        temperature=req.temperature,
    )

    prompt_tokens = len(user_msg)
    completion_tokens = len(response_text)

    return ChatCompletionResponse(
        id=f"chatcmpl-{uuid.uuid4().hex[:12]}",
        created=int(time.time()),
        model=req.model,
        choices=[
            Choice(
                message=Message(role="assistant", content=response_text),
            )
        ],
        usage=Usage(
            prompt_tokens=prompt_tokens,
            completion_tokens=completion_tokens,
            total_tokens=prompt_tokens + completion_tokens,
        ),
    )


@app.get("/health")
def health():
    return {"status": "ok", "model_loaded": fourth.model is not None}


if __name__ == "__main__":
    import uvicorn
    port = int(os.environ.get("PORT", 7860))
    uvicorn.run(app, host="0.0.0.0", port=port)