Valtry commited on
Commit
3bee657
Β·
verified Β·
1 Parent(s): d1aba81

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -93
app.py CHANGED
@@ -1,11 +1,10 @@
1
  from fastapi import FastAPI
2
- from fastapi.responses import StreamingResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from pydantic import BaseModel
5
  from llama_cpp import Llama
6
  from huggingface_hub import hf_hub_download
7
  from supabase import create_client
8
- import os, json, uvicorn, threading
9
  from contextlib import asynccontextmanager
10
 
11
  # =========================
@@ -18,15 +17,14 @@ SUPABASE_KEY = os.getenv("SUPABASE_KEY")
18
  supabase = create_client(SUPABASE_URL, SUPABASE_KEY)
19
 
20
  model = None
21
- stop_flag = False
22
 
23
  # =========================
24
  # REQUEST
25
  # =========================
26
  class ChatRequest(BaseModel):
27
  message: str
 
28
  temperature: float = 0.7
29
- stream: bool = False
30
 
31
  # =========================
32
  # CLEAN OUTPUT
@@ -46,7 +44,7 @@ def clean_output(text):
46
  return text.strip()
47
 
48
  # =========================
49
- # PROMPT (NO HISTORY)
50
  # =========================
51
  def build_prompt(user_msg):
52
  return f"""<|begin_of_text|>
@@ -100,22 +98,13 @@ app.add_middleware(
100
  # =========================
101
  # SAVE
102
  # =========================
103
- def save_message(role, content):
104
  supabase.table("messages").insert({
105
  "role": role,
106
- "content": content
107
- # timestamp auto handled by DB
108
  }).execute()
109
 
110
- # =========================
111
- # STOP
112
- # =========================
113
- @app.post("/v1/stop")
114
- def stop():
115
- global stop_flag
116
- stop_flag = True
117
- return {"status": "stopped"}
118
-
119
  # =========================
120
  # CHAT
121
  # =========================
@@ -124,105 +113,53 @@ async def chat(req: ChatRequest):
124
 
125
  prompt = build_prompt(req.message)
126
 
127
- temp, rp, tp = req.temperature, 1.15, 0.9
128
- max_tokens = 2048
129
-
130
- if req.stream:
131
-
132
- def generate():
133
- global stop_flag
134
- output = ""
135
-
136
- stream = model(
137
- prompt,
138
- max_tokens=max_tokens,
139
- temperature=temp,
140
- top_p=tp,
141
- repeat_penalty=rp,
142
- stop=["<|eot_id|>", "<|end_of_text|>", "<|eof|>"],
143
- stream=True
144
- )
145
-
146
- for chunk in stream:
147
-
148
- if stop_flag:
149
- stop_flag = False
150
- break
151
-
152
- token = chunk["choices"][0]["text"]
153
- output += token
154
-
155
- yield f"data: {json.dumps({'choices':[{'delta':{'content':token}}]})}\n\n"
156
-
157
- output_clean = clean_output(output)
158
-
159
- yield "event: done\ndata: {}\n\n"
160
- yield "data: [DONE]\n\n"
161
-
162
- def save_async():
163
- save_message("user", req.message)
164
- save_message("assistant", output_clean)
165
-
166
- threading.Thread(target=save_async).start()
167
-
168
- return StreamingResponse(generate(), media_type="text/event-stream")
169
-
170
  output = model(
171
  prompt,
172
- max_tokens=max_tokens,
173
- temperature=temp,
174
- top_p=tp,
175
- repeat_penalty=rp,
176
  stop=["<|eot_id|>", "<|end_of_text|>", "<|eof|>"]
177
  )
178
 
179
  text = clean_output(output["choices"][0]["text"])
180
 
181
- def save_async():
182
- save_message("user", req.message)
183
- save_message("assistant", text)
184
 
185
- threading.Thread(target=save_async).start()
186
 
187
- return {
188
- "choices":[{"message":{"role":"assistant","content":text}}],
189
- "done":True
190
- }
191
-
192
- @app.get("/v1/latest")
193
- def get_latest():
194
  try:
195
  res = supabase.table("messages") \
196
- .select("role, content") \
 
 
197
  .order("created_at", desc=True) \
198
- .limit(2) \
199
  .execute()
200
 
201
- data = res.data or []
202
-
203
- user_msg = ""
204
- assistant_msg = ""
205
-
206
- for item in reversed(data):
207
- if item["role"] == "user":
208
- user_msg = item["content"]
209
- elif item["role"] == "assistant":
210
- assistant_msg = item["content"]
211
 
212
- return {
213
- "user": user_msg,
214
- "assistant": assistant_msg
215
- }
216
 
217
  except Exception as e:
218
- return {"error": str(e)}
219
 
220
  # =========================
221
  # ROOT
222
  # =========================
223
  @app.get("/")
224
  def root():
225
- return {"status": "Minimal LLaMA API running πŸš€"}
226
 
227
  # =========================
228
  # RUN
 
1
  from fastapi import FastAPI
 
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  from llama_cpp import Llama
5
  from huggingface_hub import hf_hub_download
6
  from supabase import create_client
7
+ import os, uvicorn
8
  from contextlib import asynccontextmanager
9
 
10
  # =========================
 
17
  supabase = create_client(SUPABASE_URL, SUPABASE_KEY)
18
 
19
  model = None
 
20
 
21
  # =========================
22
  # REQUEST
23
  # =========================
24
  class ChatRequest(BaseModel):
25
  message: str
26
+ request_id: str
27
  temperature: float = 0.7
 
28
 
29
  # =========================
30
  # CLEAN OUTPUT
 
44
  return text.strip()
45
 
46
  # =========================
47
+ # PROMPT
48
  # =========================
49
  def build_prompt(user_msg):
50
  return f"""<|begin_of_text|>
 
98
  # =========================
99
  # SAVE
100
  # =========================
101
+ def save_message(role, content, request_id):
102
  supabase.table("messages").insert({
103
  "role": role,
104
+ "content": content,
105
+ "request_id": request_id
106
  }).execute()
107
 
 
 
 
 
 
 
 
 
 
108
  # =========================
109
  # CHAT
110
  # =========================
 
113
 
114
  prompt = build_prompt(req.message)
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  output = model(
117
  prompt,
118
+ max_tokens=2048,
119
+ temperature=req.temperature,
120
+ top_p=0.9,
121
+ repeat_penalty=1.15,
122
  stop=["<|eot_id|>", "<|end_of_text|>", "<|eof|>"]
123
  )
124
 
125
  text = clean_output(output["choices"][0]["text"])
126
 
127
+ # βœ… SAVE BOTH
128
+ save_message("user", req.message, req.request_id)
129
+ save_message("assistant", text, req.request_id)
130
 
131
+ return {"status": "saved"}
132
 
133
+ # =========================
134
+ # GET RESPONSE
135
+ # =========================
136
+ @app.get("/v1/get_response/{request_id}")
137
+ def get_response(request_id: str):
 
 
138
  try:
139
  res = supabase.table("messages") \
140
+ .select("content") \
141
+ .eq("role", "assistant") \
142
+ .eq("request_id", request_id) \
143
  .order("created_at", desc=True) \
144
+ .limit(1) \
145
  .execute()
146
 
147
+ data = res.data
 
 
 
 
 
 
 
 
 
148
 
149
+ if data:
150
+ return {"response": data[0]["content"]}
151
+ else:
152
+ return {"response": None}
153
 
154
  except Exception as e:
155
+ return {"error": str(e)}
156
 
157
  # =========================
158
  # ROOT
159
  # =========================
160
  @app.get("/")
161
  def root():
162
+ return {"status": "LLaMA API running πŸš€"}
163
 
164
  # =========================
165
  # RUN