File size: 3,111 Bytes
b00a63a
 
e12212e
b00a63a
ef678c6
e12212e
b00a63a
937bcae
b00a63a
 
0ca1cb9
e9df17f
0ca1cb9
ef678c6
 
0ca1cb9
ef678c6
 
 
6fc208d
 
937bcae
0ca1cb9
ef678c6
0ca1cb9
e12212e
b00a63a
bdc78db
b00a63a
 
bdc78db
 
b00a63a
 
 
 
 
bdc78db
 
 
 
 
 
0ca1cb9
937bcae
bdc78db
 
 
 
937bcae
 
0ca1cb9
b00a63a
 
 
 
 
 
 
 
 
 
 
bdc78db
 
b00a63a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
937bcae
 
b00a63a
 
 
 
bdc78db
 
937bcae
b00a63a
 
bdc78db
 
b00a63a
 
 
 
 
 
bdc78db
b00a63a
 
 
937bcae
 
d89fbc7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from sse_starlette.sse import EventSourceResponse
import torch
from threading import Thread
import json
import uvicorn

app = FastAPI()

model_name = "Qwen/Qwen2-0.5B-Instruct"

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)

print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    dtype=torch.float32,
    low_cpu_mem_usage=True
).to("cpu")

print("Model loaded successfully!")


class ChatRequest(BaseModel):
    model: str = "auric-ai"
    messages: list
    stream: bool = False
    max_tokens: int = 512
    temperature: float = 0.1


@app.post("/v1/chat/completions")
async def chat(req: ChatRequest):

    # Use Qwen2's chat template for proper system/user/assistant formatting
    prompt = tokenizer.apply_chat_template(
        req.messages,
        tokenize=False,
        add_generation_prompt=True
    )

    inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
    prompt_length = inputs["input_ids"].shape[-1]

    temperature = max(req.temperature, 0.01)
    max_tokens = min(req.max_tokens, 2048)

    # ---------------- STREAM MODE ----------------

    if req.stream:

        streamer = TextIteratorStreamer(
            tokenizer,
            skip_prompt=True,
            skip_special_tokens=True
        )

        generation_kwargs = dict(
            **inputs,
            streamer=streamer,
            max_new_tokens=max_tokens,
            temperature=temperature,
            do_sample=True
        )

        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()

        async def event_generator():

            for token in streamer:

                data = {
                    "choices": [
                        {
                            "delta": {
                                "content": token
                            }
                        }
                    ]
                }

                yield {
                    "event": "message",
                    "data": json.dumps(data)
                }

            yield {
                "event": "message",
                "data": "[DONE]"
            }

        return EventSourceResponse(event_generator())

    # ---------------- NORMAL MODE ----------------

    else:

        output = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=temperature,
            do_sample=True
        )

        # Only decode the newly generated tokens, not the prompt
        response = tokenizer.decode(output[0][prompt_length:], skip_special_tokens=True)

        return {
            "choices": [
                {
                    "message": {
                        "role": "assistant",
                        "content": response.strip()
                    }
                }
            ]
        }

if __name__ == "__main__":
    uvicorn.run("app:app", host="0.0.0.0", port=7860, workers=3)