File size: 4,148 Bytes
f1dbca3
ced62fa
000fb0d
ced62fa
 
 
 
000fb0d
ced62fa
 
 
 
 
e699ed8
ced62fa
e699ed8
ced62fa
 
000fb0d
 
 
e699ed8
000fb0d
 
ced62fa
 
000fb0d
ced62fa
f1dbca3
000fb0d
 
 
 
 
ced62fa
 
 
 
 
f1dbca3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ced62fa
f1dbca3
ced62fa
 
 
f1dbca3
 
 
4a56bcf
2865c3e
4a56bcf
 
 
 
f1dbca3
 
 
 
 
 
 
 
ced62fa
 
 
000fb0d
 
 
ced62fa
 
 
 
 
000fb0d
f1dbca3
 
 
 
 
 
 
 
 
 
 
 
000fb0d
f1dbca3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional
from transformers import AutoTokenizer, pipeline

MODEL_ID = "Equall/Saul-7B-Instruct-v1"

print("Loading model... this can take a while on first start.")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
pipe = pipeline(
    "text-generation",
    model=MODEL_ID,
    tokenizer=tokenizer,
    device=-1,                      # CPU only
    max_new_tokens=512,
    pad_token_id=tokenizer.eos_token_id,
)

app = FastAPI()



class ChatMessage(BaseModel):
    role: str  # "system" | "user" | "assistant"
    content: str


class ChatRequest(BaseModel):
    model: Optional[str] = None  # ignored, OpenAI-style compat
    messages: List[ChatMessage]
    temperature: Optional[float] = 0.0
    max_tokens: Optional[int] = 512


@app.get("/")
def root():
    return {"status": "ok", "model": MODEL_ID}


def build_prompt(raw_messages: List[dict]) -> str:
    """
    Normalize messages so they fit the template:
    - Collect system messages and prepend their text to the first user message.
    - Drop leading assistant messages.
    - Merge consecutive messages with the same role.
    - Ensure we end up with user/assistant/user/assistant/... only.
    """

    system_parts = []
    ua_messages = []

    # Separate system vs user/assistant
    for m in raw_messages:
        role = m.get("role")
        content = m.get("content", "")
        if role == "system":
            if content:
                system_parts.append(content)
        elif role in ("user", "assistant"):
            ua_messages.append({"role": role, "content": content})
        # ignore anything else

    # Drop leading assistants (template wants to start with user)
    while ua_messages and ua_messages[0]["role"] != "user":
        ua_messages.pop(0)

    # Merge consecutive messages with same role
    normalized: List[dict] = []
    for m in ua_messages:
        if not normalized:
            normalized.append(m)
        else:
            if normalized[-1]["role"] == m["role"]:
                normalized[-1]["content"] += "\n\n" + m["content"]
            else:
                normalized.append(m)

    if not normalized:
        raise ValueError("No user messages found after normalization.")

    # Prepend system text into the first user message, if any
    if system_parts:
        system_text = "\n\n".join(system_parts)
        if normalized[0]["role"] == "user":
            normalized[0]["content"] = system_text + "\n\n" + normalized[0]["content"]
        else:
            # If for some reason first is assistant, prepend a synthetic user
            normalized.insert(0, {"role": "user", "content": system_text})

    # At this point we should only have user/assistant alternating.
    # Let tokenizer.apply_chat_template enforce the exact format.
    prompt = tokenizer.apply_chat_template(
        normalized,
        tokenize=False,
        add_generation_prompt=True,
    )
    return prompt


@app.post("/debug-echo")
async def debug_echo(request: ChatRequest):
    body = await request.body()
    print("DEBUG ECHO BODY:", body)
    return {"ok": True}

@app.post("/v1/chat/completions")
def chat(request: ChatRequest):
    try:
        messages = [m.dict() for m in request.messages]
        prompt = build_prompt(messages)
    except Exception as e:
        # Don't crash the app – return a 400 with explanation
        raise HTTPException(status_code=400, detail=f"Invalid message history: {e}")

    outputs = pipe(
        prompt,
        max_new_tokens=request.max_tokens or 512,
        do_sample=(request.temperature or 0.0) > 0,
        temperature=request.temperature or 0.0,
        top_p=1.0,
    )

    full = outputs[0]["generated_text"]
    reply = full[len(prompt):].strip()

    return {
        "id": "chatcmpl-1",
        "object": "chat.completion",
        "choices": [
            {
                "index": 0,
                "message": {
                    "role": "assistant",
                    "content": reply,
                },
                "finish_reason": "stop",
            }
        ],
    }