File size: 5,201 Bytes
16c8676
b51aa25
f8fbbce
c9b1f83
f8fbbce
 
 
926595a
f8fbbce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9b1f83
 
3bee657
f8fbbce
 
 
 
 
 
 
49b3087
 
 
f8fbbce
 
 
 
 
 
 
 
 
 
49b3087
f8fbbce
 
49b3087
 
8147e6f
879454d
69e3535
8147e6f
df77e05
8147e6f
69e3535
df77e05
 
 
 
879454d
49b3087
 
f8fbbce
49b3087
 
f8fbbce
 
 
 
 
 
 
 
49b3087
 
f8fbbce
 
 
49b3087
 
 
f8fbbce
49b3087
f8fbbce
 
d696b9c
 
f8fbbce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20bcc59
f8fbbce
 
 
3bee657
f8fbbce
 
3bee657
 
f8fbbce
 
 
926595a
f8fbbce
926595a
 
 
 
 
 
 
 
c7f5553
b51aa25
926595a
b51aa25
926595a
f8fbbce
b51aa25
 
49b3087
b51aa25
49b3087
b51aa25
49b3087
b51aa25
49b3087
b51aa25
 
 
 
 
 
 
 
 
20bcc59
b51aa25
 
 
20bcc59
b51aa25
49b3087
d1ba6fb
49b3087
b51aa25
 
df77e05
b51aa25
 
df77e05
b51aa25
20bcc59
3bee657
 
 
 
 
d1aba81
 
49b3087
3bee657
 
d1aba81
3bee657
d1aba81
 
3bee657
d1aba81
3bee657
49b3087
 
df77e05
49b3087
3bee657
49b3087
d1aba81
 
3bee657
d1aba81
f8fbbce
 
 
 
 
df77e05
d696b9c
f8fbbce
 
 
c9b1f83
f8fbbce
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from llama_cpp import Llama
from huggingface_hub import hf_hub_download
from supabase import create_client
import os, uvicorn, threading
from contextlib import asynccontextmanager

# =========================
# CONFIG
# =========================
HF_TOKEN = os.getenv("HF_TOKEN")
SUPABASE_URL = os.getenv("SUPABASE_URL")
SUPABASE_KEY = os.getenv("SUPABASE_KEY")

supabase = create_client(SUPABASE_URL, SUPABASE_KEY)

model = None

# =========================
# REQUEST
# =========================
class ChatRequest(BaseModel):
    message: str
    request_id: str
    temperature: float = 0.7

# =========================
# CLEAN OUTPUT
# =========================
def clean_output(text):
    stop_words = [
        "<|eot_id|>",
        "<|end_of_text|>",
        "<|eof|>",
        "Human:",
        "Assistant:",
        "User:"
    ]
    for w in stop_words:
        if w in text:
            text = text.split(w)[0]
    return text.strip()

# =========================
# PROMPT
# =========================
def build_prompt(user_msg):
    return f"""<|begin_of_text|>
<|start_header_id|>system<|end_header_id|>
Your name is Llama and you are a cheerful friendly AI buddy made for voice conversation.
Rules:
- Always refer to yourself as Llama
- Speak naturally like a real voice conversation with a friend
- Use casual spoken language like hey sure yep got it
- Answer in 1 to 2 sentences only
- Keep answer under 30 words
- Do not use symbols
- Do not use abbreviations
- Use digits instead of words
- No new lines
- Output plain text only
<|eot_id|>
<|start_header_id|>user<|end_header_id|>
{user_msg}
<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
"""

# =========================
# MODEL LOAD
# =========================
def load_model():
    return Llama(
        model_path=hf_hub_download(
            repo_id="Valtry/llama3.2-3b-q4-gguf",
            filename="llama3.2-3b-q4.gguf",
            token=HF_TOKEN,
            cache_dir="/data"
        ),
        n_ctx=2048,
        n_threads=4,
        n_batch=512,
        use_mmap=True,
        use_mlock=True,
        f16_kv=True,
        verbose=False
    )

@asynccontextmanager
async def lifespan(app: FastAPI):
    global model
    model = load_model()
    yield

# =========================
# APP
# =========================
app = FastAPI(lifespan=lifespan)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# =========================
# SAVE
# =========================
def save_message(role, content, request_id):
    supabase.table("messages").insert({
        "role": role,
        "content": content,
        "request_id": request_id
    }).execute()

# =========================
# SUPABASE UPDATE HELPER
# =========================
def update_message(msg_id, content, status=None):
    data = {"content": content}
    if status:
        data["status"] = status
    try:
        supabase.table("messages").update(data).eq("id", msg_id).execute()
    except Exception as e:
        print(f"Supabase update failed: {e}")


# =========================
# CHAT
# =========================

@app.post("/v1/chat")
async def chat(req: ChatRequest):

    def generate():

        prompt = build_prompt(req.message)

        full_text = ""

        stream = model(
            prompt,
            max_tokens=2048,
            temperature=req.temperature,
            top_p=0.9,
            repeat_penalty=1.15,
            stop=["<|eot_id|>", "<|end_of_text|>", "<|eof|>"],
            stream=True
        )

        # 🔥 STREAM DIRECTLY TO ESP
        for chunk in stream:
            token = chunk["choices"][0]["text"]

            full_text += token

            yield token.replace("\n", " ").replace("\r", "")   # ⚡ direct streaming

        # 🔥 SAVE AFTER COMPLETION
        final = clean_output(full_text)

        save_message("user", req.message, req.request_id)
        save_message("assistant", final, req.request_id)

    return StreamingResponse(generate(), media_type="text/plain")

# =========================
# GET RESPONSE
# =========================
@app.get("/v1/get_response/{request_id}")
def get_response(request_id: str):
    try:
        res = supabase.table("messages") \
            .select("content, status") \
            .eq("role", "assistant") \
            .eq("request_id", request_id) \
            .order("created_at", desc=True) \
            .limit(1) \
            .execute()

        data = res.data

        if data:
            return {
                "response": data[0]["content"],
                "status": data[0]["status"]
            }
        else:
            return {"response": None, "status": "waiting"}

    except Exception as e:
        return {"error": str(e)}

# =========================
# ROOT
# =========================
@app.get("/")
def root():
    return {"status": "LLaMA API running"}

# =========================
# RUN
# =========================
if __name__ == "__main__":
    uvicorn.run("app:app", host="0.0.0.0", port=7860)