Valtry commited on
Commit
f8fbbce
·
verified ·
1 Parent(s): da301ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -73
app.py CHANGED
@@ -1,89 +1,203 @@
1
  from fastapi import FastAPI
 
 
2
  from pydantic import BaseModel
3
- import torch
4
- from transformers import AutoModelForCausalLM, AutoTokenizer
5
- import uvicorn
6
-
7
- # -----------------------
8
- # LOAD MODEL
9
- # -----------------------
10
- MODEL_ID = "microsoft/phi-2"
11
-
12
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
13
-
14
- model = AutoModelForCausalLM.from_pretrained(
15
- MODEL_ID,
16
- device_map="cpu",
17
- torch_dtype=torch.float32,
18
- low_cpu_mem_usage=True
19
- )
20
-
21
- torch.set_num_threads(2)
22
-
23
- # -----------------------
24
- # FASTAPI
25
- # -----------------------
26
- app = FastAPI()
27
-
28
  class ChatRequest(BaseModel):
29
  message: str
30
-
31
-
32
- @app.get("/")
33
- def home():
34
- return {"status": "API running 🚀"}
35
-
36
-
37
- # -----------------------
38
- # CHAT (NO STREAMING)
39
- # -----------------------
40
- @app.post("/chat")
41
- def chat(req: ChatRequest):
42
-
43
- prompt = f"""You are a concise assistant.
44
- Return plain text only.
45
- No markdown.
46
- No bullet points.
47
- No numbering.
48
- No symbols like # * -.
49
- Only simple readable sentence.
50
-
51
- User: {req.message}
52
- Assistant:"""
53
-
54
- inputs = tokenizer(prompt, return_tensors="pt")
55
-
56
- outputs = model.generate(
57
- **inputs,
58
- max_new_tokens=80,
59
- temperature=0.5,
60
- do_sample=True,
61
- eos_token_id=tokenizer.eos_token_id,
62
- pad_token_id=tokenizer.eos_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  )
64
 
65
- text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- # 🔥 CLEAN OUTPUT
68
- if "Assistant:" in text:
69
- text = text.split("Assistant:")[-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- if "User:" in text:
72
- text = text.split("User:")[0]
73
 
74
- text = text.strip()
 
 
75
 
76
- # remove unwanted formatting
77
- text = text.replace("\n", " ")
78
- text = text.replace(" ", " ")
79
 
80
  return {
81
- "response": text
 
82
  }
83
 
 
 
 
 
 
 
84
 
85
- # -----------------------
86
- # START SERVER
87
- # -----------------------
88
  if __name__ == "__main__":
89
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  from fastapi import FastAPI
2
+ from fastapi.responses import StreamingResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
  from pydantic import BaseModel
5
+ from llama_cpp import Llama
6
+ from huggingface_hub import hf_hub_download
7
+ from supabase import create_client
8
+ import os, json, uvicorn, threading
9
+ from contextlib import asynccontextmanager
10
+
11
+ # =========================
12
+ # CONFIG
13
+ # =========================
14
+ HF_TOKEN = os.getenv("HF_TOKEN")
15
+ SUPABASE_URL = os.getenv("SUPABASE_URL")
16
+ SUPABASE_KEY = os.getenv("SUPABASE_KEY")
17
+
18
+ supabase = create_client(SUPABASE_URL, SUPABASE_KEY)
19
+
20
+ model = None
21
+ stop_flag = False
22
+
23
+ # =========================
24
+ # REQUEST
25
+ # =========================
 
 
 
 
26
  class ChatRequest(BaseModel):
27
  message: str
28
+ temperature: float = 0.7
29
+ stream: bool = False
30
+
31
+ # =========================
32
+ # CLEAN OUTPUT
33
+ # =========================
34
+ def clean_output(text):
35
+ stop_words = [
36
+ "<|eot_id|>",
37
+ "<|end_of_text|>",
38
+ "<|eof|>",
39
+ "Human:",
40
+ "Assistant:",
41
+ "User:"
42
+ ]
43
+ for w in stop_words:
44
+ if w in text:
45
+ text = text.split(w)[0]
46
+ return text.strip()
47
+
48
+ # =========================
49
+ # PROMPT (NO HISTORY)
50
+ # =========================
51
+ def build_prompt(user_msg):
52
+ return f"""<|begin_of_text|>
53
+ <|start_header_id|>system<|end_header_id|>
54
+ You are a helpful AI assistant.
55
+ <|eot_id|>
56
+ <|start_header_id|>user<|end_header_id|>
57
+ {user_msg}
58
+ <|eot_id|>
59
+ <|start_header_id|>assistant<|end_header_id|>
60
+ """
61
+
62
+ # =========================
63
+ # MODEL LOAD
64
+ # =========================
65
+ def load_model():
66
+ return Llama(
67
+ model_path=hf_hub_download(
68
+ repo_id="Valtry/llama3.2-3b-q4-gguf",
69
+ filename="llama3.2-3b-q4.gguf",
70
+ token=HF_TOKEN,
71
+ cache_dir="/data"
72
+ ),
73
+ n_ctx=2048,
74
+ n_threads=4,
75
+ n_batch=512,
76
+ use_mmap=True,
77
+ use_mlock=True,
78
+ f16_kv=True,
79
+ verbose=False
80
  )
81
 
82
+ @asynccontextmanager
83
+ async def lifespan(app: FastAPI):
84
+ global model
85
+ model = load_model()
86
+ yield
87
+
88
+ # =========================
89
+ # APP
90
+ # =========================
91
+ app = FastAPI(lifespan=lifespan)
92
+
93
+ app.add_middleware(
94
+ CORSMiddleware,
95
+ allow_origins=["*"],
96
+ allow_methods=["*"],
97
+ allow_headers=["*"],
98
+ )
99
 
100
+ # =========================
101
+ # SAVE
102
+ # =========================
103
+ def save_message(role, content):
104
+ supabase.table("messages").insert({
105
+ "role": role,
106
+ "content": content
107
+ # timestamp auto handled by DB
108
+ }).execute()
109
+
110
+ # =========================
111
+ # STOP
112
+ # =========================
113
+ @app.post("/v1/stop")
114
+ def stop():
115
+ global stop_flag
116
+ stop_flag = True
117
+ return {"status": "stopped"}
118
+
119
+ # =========================
120
+ # CHAT
121
+ # =========================
122
+ @app.post("/v1/chat")
123
+ async def chat(req: ChatRequest):
124
+
125
+ prompt = build_prompt(req.message)
126
+
127
+ temp, rp, tp = req.temperature, 1.15, 0.9
128
+ max_tokens = 2048
129
+
130
+ if req.stream:
131
+
132
+ def generate():
133
+ global stop_flag
134
+ output = ""
135
+
136
+ stream = model(
137
+ prompt,
138
+ max_tokens=max_tokens,
139
+ temperature=temp,
140
+ top_p=tp,
141
+ repeat_penalty=rp,
142
+ stop=["<|eot_id|>", "<|end_of_text|>", "<|eof|>"],
143
+ stream=True
144
+ )
145
+
146
+ for chunk in stream:
147
+
148
+ if stop_flag:
149
+ stop_flag = False
150
+ break
151
+
152
+ token = chunk["choices"][0]["text"]
153
+ output += token
154
+
155
+ yield f"data: {json.dumps({'choices':[{'delta':{'content':token}}]})}\n\n"
156
+
157
+ output_clean = clean_output(output)
158
+
159
+ yield "event: done\ndata: {}\n\n"
160
+ yield "data: [DONE]\n\n"
161
+
162
+ def save_async():
163
+ save_message("user", req.message)
164
+ save_message("assistant", output_clean)
165
+
166
+ threading.Thread(target=save_async).start()
167
+
168
+ return StreamingResponse(generate(), media_type="text/event-stream")
169
+
170
+ output = model(
171
+ prompt,
172
+ max_tokens=max_tokens,
173
+ temperature=temp,
174
+ top_p=tp,
175
+ repeat_penalty=rp,
176
+ stop=["<|eot_id|>", "<|end_of_text|>", "<|eof|>"]
177
+ )
178
 
179
+ text = clean_output(output["choices"][0]["text"])
 
180
 
181
+ def save_async():
182
+ save_message("user", req.message)
183
+ save_message("assistant", text)
184
 
185
+ threading.Thread(target=save_async).start()
 
 
186
 
187
  return {
188
+ "choices":[{"message":{"role":"assistant","content":text}}],
189
+ "done":True
190
  }
191
 
192
+ # =========================
193
+ # ROOT
194
+ # =========================
195
+ @app.get("/")
196
+ def root():
197
+ return {"status": "Minimal LLaMA API running 🚀"}
198
 
199
+ # =========================
200
+ # RUN
201
+ # =========================
202
  if __name__ == "__main__":
203
+ uvicorn.run("app:app", host="0.0.0.0", port=7860)