mrmadblack commited on
Commit
325785f
·
verified ·
1 Parent(s): 996a96e

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +291 -174
server.py CHANGED
@@ -1,4 +1,16 @@
1
- from fastapi import FastAPI
 
 
 
 
 
 
 
 
 
 
 
 
2
  from fastapi.responses import StreamingResponse, JSONResponse
3
  from pydantic import BaseModel
4
  from huggingface_hub import hf_hub_download
@@ -10,283 +22,388 @@ import json
10
  import time
11
  import hashlib
12
  import threading
 
13
 
14
  app = FastAPI()
15
 
16
- # -------------------------
17
- # MODEL CONFIG
18
- # -------------------------
 
19
 
20
  MODELS = {
21
- "qwen:0.8b": {
22
- "repo": "Qwen/Qwen3.5-0.8B-GGUF",
23
- "file": "qwen3.5-0.8b-q4_k_m.gguf",
24
- "path": "models/qwen_0_8b.gguf",
25
- "port": 8081
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  },
27
- "qwen:2b": {
28
- "repo": "Qwen/Qwen3.5-2B-GGUF",
29
- "file": "qwen3.5-2b-q4_k_m.gguf",
30
- "path": "models/qwen_2b.gguf",
31
- "port": 8082
32
- }
33
  }
34
 
35
- LLAMA_SERVER = "./llama.cpp/build/bin/llama-server"
 
36
 
37
- os.makedirs("models", exist_ok=True)
38
 
39
- # -------------------------
40
  # REQUEST MODELS
41
- # -------------------------
42
 
43
  class ChatRequest(BaseModel):
44
- model: str
45
  messages: list
46
- stream: bool = True
 
47
 
48
 
49
  class GenerateRequest(BaseModel):
50
- model: str
51
- prompt: str
 
 
52
 
53
 
54
- # -------------------------
55
- # PROMPT BUILDER (QWEN)
56
- # -------------------------
57
-
58
- def build_prompt(messages):
59
 
 
60
  prompt = ""
61
-
 
 
62
  for m in messages:
63
- role = m["role"]
64
- content = m["content"]
65
-
66
- if role == "user":
 
 
 
67
  prompt += f"<|im_start|>user\n{content}<|im_end|>\n"
68
-
69
  elif role == "assistant":
70
  prompt += f"<|im_start|>assistant\n{content}<|im_end|>\n"
71
-
72
  prompt += "<|im_start|>assistant\n"
73
-
74
  return prompt
75
 
76
 
77
- # -------------------------
78
- # DOWNLOAD MODELS
79
- # -------------------------
80
 
81
- def download_models():
 
 
 
 
 
 
 
 
82
 
83
- for name, m in MODELS.items():
84
 
85
- if os.path.exists(m["path"]):
86
- continue
 
87
 
88
- print("Downloading", name)
89
 
90
- f = hf_hub_download(
91
- repo_id=m["repo"],
92
- filename=m["file"]
93
- )
 
 
94
 
95
- os.system(f"cp {f} {m['path']}")
 
96
 
97
- download_models()
98
 
99
- # -------------------------
100
  # START LLAMA SERVERS
101
- # -------------------------
102
 
103
- def start_model(name, cfg):
104
 
105
- threads = "2"
106
 
107
- print("Starting", name)
 
108
 
109
- subprocess.Popen([
110
- LLAMA_SERVER,
111
 
112
- "-m", cfg["path"],
113
-
114
- "--host", "0.0.0.0",
115
- "--port", str(cfg["port"]),
116
-
117
- "--threads", threads,
118
- "--parallel", "2",
119
-
120
- "--ctx-size", "4096",
121
- "--batch-size", "1024",
122
- "--ubatch-size", "512",
123
-
124
- "-ngl", "0"
125
- ])
126
-
127
- for i in range(30):
128
  try:
129
- r = requests.get(f"http://localhost:{cfg['port']}/health")
130
-
131
  if r.status_code == 200:
132
- print(name, "ready")
133
- return
134
- except:
 
135
  pass
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  time.sleep(1)
 
 
 
 
138
 
139
- raise RuntimeError(name + " failed to start")
140
 
 
 
 
141
 
142
- def start_all_models():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
- for name, cfg in MODELS.items():
145
- threading.Thread(
146
- target=start_model,
147
- args=(name, cfg),
148
- daemon=True
149
- ).start()
150
 
151
- start_all_models()
 
 
 
 
 
 
 
 
 
152
 
153
- # -------------------------
 
154
  # ROOT
155
- # -------------------------
156
 
157
  @app.get("/")
158
  def root():
159
- return {"status": "running"}
 
160
 
161
- # -------------------------
162
- # OLLAMA TAGS
163
- # -------------------------
164
 
165
  @app.get("/api/tags")
166
  def tags():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
- models = []
169
-
170
- for name, m in MODELS.items():
171
-
172
- size = os.path.getsize(m["path"])
173
 
174
- with open(m["path"], "rb") as f:
175
- digest = hashlib.sha256(f.read()).hexdigest()
 
176
 
177
- models.append({
178
- "name": name,
179
- "model": name,
180
- "modified_at": time.strftime("%Y-%m-%dT%H:%M:%SZ"),
181
- "size": size,
182
- "digest": digest,
183
- "details": {
184
- "format": "gguf",
185
- "family": "qwen",
186
- "parameter_size": name.split(":")[1]
187
- }
188
- })
189
 
190
- return {"models": models}
191
 
192
- # -------------------------
193
- # GENERATE
194
- # -------------------------
195
 
196
  @app.post("/api/generate")
197
  def generate(req: GenerateRequest):
 
 
 
 
198
 
199
- cfg = MODELS[req.model]
 
 
200
 
201
  r = requests.post(
202
  f"http://localhost:{cfg['port']}/completion",
203
- json={
204
- "prompt": req.prompt,
205
- "n_predict": 512
206
- }
207
  )
208
 
209
- data = r.json()
 
 
210
 
211
- return {
212
- "model": req.model,
213
- "response": data.get("content", ""),
214
- "done": True
215
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
- # -------------------------
218
- # CHAT
219
- # -------------------------
 
220
 
221
  @app.post("/api/chat")
222
  def chat(req: ChatRequest):
 
 
223
 
224
- cfg = MODELS[req.model]
225
 
226
  prompt = build_prompt(req.messages)
 
 
 
227
 
228
  r = requests.post(
229
  f"http://localhost:{cfg['port']}/completion",
230
- json={
231
- "prompt": prompt,
232
- "stream": req.stream,
233
- "n_predict": 1024,
234
- "temperature": 0.7
235
- },
236
- stream=req.stream
237
  )
238
 
239
  if not req.stream:
240
-
241
- data = r.json()
242
-
243
  return JSONResponse({
244
- "model": req.model,
245
- "message": {
246
- "role": "assistant",
247
- "content": data.get("content", "")
248
- },
249
- "done": True
250
  })
251
 
252
- def stream():
253
-
254
  for line in r.iter_lines():
255
-
256
  if not line:
257
  continue
258
-
259
- line = line.decode().replace("data:", "").strip()
260
-
261
  try:
262
  data = json.loads(line)
263
- except:
264
  continue
265
-
266
  token = data.get("content", "")
267
-
268
  yield json.dumps({
269
- "model": req.model,
270
- "message": {
271
- "role": "assistant",
272
- "content": token
273
- },
274
- "done": False
275
  }) + "\n"
 
 
 
276
 
277
- yield json.dumps({
278
- "model": req.model,
279
- "done": True
280
- }) + "\n"
281
 
282
- return StreamingResponse(
283
- stream(),
284
- media_type="application/x-ndjson"
285
- )
286
 
287
- # -------------------------
288
- # START API
289
- # -------------------------
290
 
291
  if __name__ == "__main__":
292
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ """
2
+ Ollama-compatible API server
3
+ Models: Qwen3.5-0.8B (fast) + Qwen3.5-2B (smart)
4
+ Optimized for HuggingFace free tier: 2 vCPU, 16GB RAM
5
+
6
+ FIXES vs previous version:
7
+ 1. Removed --flash-attn / --mlock / --no-mmap (not all llama.cpp builds support them — caused silent crash)
8
+ 2. llama-server logs go to llama_<model>.log so errors are visible in HF Space terminal
9
+ 3. /api/chat and /api/generate now WAIT up to 120s for server readiness
10
+ instead of immediately crashing with ConnectionRefused
11
+ """
12
+
13
+ from fastapi import FastAPI, HTTPException
14
  from fastapi.responses import StreamingResponse, JSONResponse
15
  from pydantic import BaseModel
16
  from huggingface_hub import hf_hub_download
 
22
  import time
23
  import hashlib
24
  import threading
25
+ from typing import Optional
26
 
27
  app = FastAPI()
28
 
29
+
30
+ # ---------------------------
31
+ # MODEL CONFIGS
32
+ # ---------------------------
33
 
34
  MODELS = {
35
+ "qwen3.5-0.8b": {
36
+ "path": "models/qwen3.5-0.8b.gguf",
37
+ "repo": "bartowski/Qwen_Qwen3.5-0.8B-GGUF",
38
+ "file": "Qwen_Qwen3.5-0.8B-Q4_K_M.gguf",
39
+ "port": 8080,
40
+ "param_size": "0.8B",
41
+ "family": "qwen3.5",
42
+ "threads": 2,
43
+ "ctx": 2048,
44
+ "batch": 512,
45
+ },
46
+ "qwen3.5-2b": {
47
+ "path": "models/qwen3.5-2b.gguf",
48
+ "repo": "bartowski/Qwen_Qwen3.5-2B-GGUF",
49
+ "file": "Qwen_Qwen3.5-2B-Q4_K_M.gguf",
50
+ "port": 8081,
51
+ "param_size": "2B",
52
+ "family": "qwen3.5",
53
+ "threads": 2,
54
+ "ctx": 2048,
55
+ "batch": 512,
56
  },
 
 
 
 
 
 
57
  }
58
 
59
+ DEFAULT_MODEL = "qwen3.5-0.8b"
60
+ LLAMA_SERVER = "./llama.cpp/build/bin/llama-server"
61
 
 
62
 
63
+ # ---------------------------
64
  # REQUEST MODELS
65
+ # ---------------------------
66
 
67
  class ChatRequest(BaseModel):
68
+ model: str = DEFAULT_MODEL
69
  messages: list
70
+ stream: bool = True
71
+ options: Optional[dict] = None
72
 
73
 
74
  class GenerateRequest(BaseModel):
75
+ model: str = DEFAULT_MODEL
76
+ prompt: str
77
+ stream: bool = False
78
+ options: Optional[dict] = None
79
 
80
 
81
+ # ---------------------------
82
+ # PROMPT BUILDER (Qwen3.5 ChatML)
83
+ # ---------------------------
 
 
84
 
85
+ def build_prompt(messages: list) -> str:
86
  prompt = ""
87
+ has_system = any(m.get("role") == "system" for m in messages)
88
+ if not has_system:
89
+ prompt += "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
90
  for m in messages:
91
+ role = m.get("role", "user")
92
+ content = m.get("content", "").strip()
93
+ if not content:
94
+ continue
95
+ if role == "system":
96
+ prompt += f"<|im_start|>system\n{content}<|im_end|>\n"
97
+ elif role == "user":
98
  prompt += f"<|im_start|>user\n{content}<|im_end|>\n"
 
99
  elif role == "assistant":
100
  prompt += f"<|im_start|>assistant\n{content}<|im_end|>\n"
 
101
  prompt += "<|im_start|>assistant\n"
 
102
  return prompt
103
 
104
 
105
+ # ---------------------------
106
+ # MODEL RESOLVER
107
+ # ---------------------------
108
 
109
+ def resolve_model(name: str) -> str:
110
+ """Fuzzy match model name → key in MODELS. Falls back to default."""
111
+ name = (name or DEFAULT_MODEL).lower().strip()
112
+ if name in MODELS:
113
+ return name
114
+ for key in MODELS:
115
+ if key in name or name in key:
116
+ return key
117
+ return DEFAULT_MODEL
118
 
 
119
 
120
+ # ---------------------------
121
+ # DOWNLOAD MODELS
122
+ # ---------------------------
123
 
124
+ os.makedirs("models", exist_ok=True)
125
 
126
+ def download_model(cfg: dict):
127
+ if not os.path.exists(cfg["path"]):
128
+ print(f"Downloading {cfg['file']} ...")
129
+ downloaded = hf_hub_download(repo_id=cfg["repo"], filename=cfg["file"])
130
+ os.system(f"cp '{downloaded}' '{cfg['path']}'")
131
+ print(f" ✓ saved to {cfg['path']}")
132
 
133
+ for m in MODELS.values():
134
+ download_model(m)
135
 
 
136
 
137
+ # ---------------------------
138
  # START LLAMA SERVERS
139
+ # ---------------------------
140
 
141
+ _server_ready: dict = {k: False for k in MODELS}
142
 
 
143
 
144
+ def start_llama(model_name: str, cfg: dict):
145
+ print(f"Starting llama-server for {model_name} on port {cfg['port']} ...")
146
 
147
+ # FIX 1: Write logs to file — safe flags only, no --flash-attn/--mlock/--no-mmap
148
+ log = open(f"llama_{model_name}.log", "w")
149
 
150
+ process = subprocess.Popen([
151
+ LLAMA_SERVER,
152
+ "-m", cfg["path"],
153
+ "--host", "0.0.0.0",
154
+ "--port", str(cfg["port"]),
155
+ "-c", str(cfg["ctx"]),
156
+ "--threads", str(cfg["threads"]),
157
+ "--batch-size", str(cfg["batch"]),
158
+ "-ngl", "0", # CPU only
159
+ "-np", "1", # 1 parallel slot
160
+ ], stdout=log, stderr=log)
161
+
162
+ url = f"http://localhost:{cfg['port']}/health"
163
+
164
+ for i in range(90): # up to 3 min
165
+ time.sleep(2)
166
  try:
167
+ r = requests.get(url, timeout=2)
 
168
  if r.status_code == 200:
169
+ _server_ready[model_name] = True
170
+ print(f" ✓ {model_name} ready (took ~{(i+1)*2}s)")
171
+ return process
172
+ except Exception:
173
  pass
174
 
175
+ # FIX 2: Echo last log line so HF Space logs show real llama-server output
176
+ try:
177
+ with open(f"llama_{model_name}.log") as lf:
178
+ lines = [l.strip() for l in lf.read().splitlines() if l.strip()]
179
+ print(f" [{model_name}] {lines[-1] if lines else 'starting...'}")
180
+ except Exception:
181
+ print(f" waiting for {model_name}... ({i+1}/90)")
182
+
183
+ print(f" ✗ {model_name} failed — check llama_{model_name}.log")
184
+ return None
185
+
186
+
187
+ for name, cfg in MODELS.items():
188
+ threading.Thread(target=start_llama, args=(name, cfg), daemon=True).start()
189
+
190
+
191
+ # ---------------------------
192
+ # READINESS GUARD ← KEY FIX
193
+ # ---------------------------
194
+
195
+ def wait_for_model(model_key: str, timeout: int = 120):
196
+ """
197
+ FIX 3: Block the incoming request until the llama-server is ready.
198
+ Instead of crashing with ConnectionRefused, the client gets a clean
199
+ response once the model is loaded (or a 503 if it never comes up).
200
+ """
201
+ deadline = time.time() + timeout
202
+ while time.time() < deadline:
203
+ if _server_ready.get(model_key):
204
+ return
205
  time.sleep(1)
206
+ raise HTTPException(
207
+ status_code=503,
208
+ detail=f"Model '{model_key}' is still loading. Please wait and retry."
209
+ )
210
 
 
211
 
212
+ # ---------------------------
213
+ # HELPERS
214
+ # ---------------------------
215
 
216
+ def model_meta(name: str, cfg: dict) -> dict:
217
+ size = os.path.getsize(cfg["path"]) if os.path.exists(cfg["path"]) else 0
218
+ digest = ""
219
+ if os.path.exists(cfg["path"]):
220
+ with open(cfg["path"], "rb") as f:
221
+ digest = hashlib.md5(f.read(65536)).hexdigest()
222
+ return {
223
+ "name": name,
224
+ "model": name,
225
+ "modified_at": time.strftime("%Y-%m-%dT%H:%M:%SZ"),
226
+ "size": size,
227
+ "digest": f"sha256:{digest}",
228
+ "details": {
229
+ "format": "gguf",
230
+ "family": cfg["family"],
231
+ "families": [cfg["family"]],
232
+ "parameter_size": cfg["param_size"],
233
+ "quantization_level": "Q4_K_M",
234
+ },
235
+ }
236
 
 
 
 
 
 
 
237
 
238
+ def llama_params(options: Optional[dict]) -> dict:
239
+ o = options or {}
240
+ return {
241
+ "temperature": o.get("temperature", 0.7),
242
+ "top_p": o.get("top_p", 0.9),
243
+ "top_k": o.get("top_k", 40),
244
+ "repeat_penalty": o.get("repeat_penalty", 1.1),
245
+ "n_predict": o.get("num_predict", 1024),
246
+ "stop": o.get("stop", ["<|im_end|>", "<|endoftext|>"]),
247
+ }
248
 
249
+
250
+ # ---------------------------
251
  # ROOT
252
+ # ---------------------------
253
 
254
  @app.get("/")
255
  def root():
256
+ return {"status": "running", "models_ready": dict(_server_ready)}
257
+
258
 
259
+ # ---------------------------
260
+ # /api/tags
261
+ # ---------------------------
262
 
263
  @app.get("/api/tags")
264
  def tags():
265
+ return {"models": [model_meta(n, c) for n, c in MODELS.items()]}
266
+
267
+
268
+ # ---------------------------
269
+ # /api/show
270
+ # ---------------------------
271
+
272
+ @app.post("/api/show")
273
+ def show(body: dict):
274
+ key = resolve_model(body.get("name", DEFAULT_MODEL))
275
+ cfg = MODELS[key]
276
+ meta = model_meta(key, cfg)
277
+ meta["modelfile"] = f"FROM {key}\n"
278
+ meta["parameters"] = "num_ctx 2048\nnum_predict 1024"
279
+ meta["template"] = (
280
+ "<|im_start|>system\n{{ .System }}<|im_end|>\n"
281
+ "<|im_start|>user\n{{ .Prompt }}<|im_end|>\n"
282
+ "<|im_start|>assistant\n"
283
+ )
284
+ return meta
285
 
 
 
 
 
 
286
 
287
+ # ---------------------------
288
+ # /api/ps
289
+ # ---------------------------
290
 
291
+ @app.get("/api/ps")
292
+ def ps():
293
+ running = []
294
+ for name, cfg in MODELS.items():
295
+ if _server_ready.get(name):
296
+ m = model_meta(name, cfg)
297
+ m["expires_at"] = "0001-01-01T00:00:00Z"
298
+ m["size_vram"] = 0
299
+ running.append(m)
300
+ return {"models": running}
 
 
301
 
 
302
 
303
+ # ---------------------------
304
+ # /api/generate
305
+ # ---------------------------
306
 
307
  @app.post("/api/generate")
308
  def generate(req: GenerateRequest):
309
+ key = resolve_model(req.model)
310
+ cfg = MODELS[key]
311
+
312
+ wait_for_model(key) # ← blocks until ready, not crash
313
 
314
+ params = llama_params(req.options)
315
+ params["prompt"] = req.prompt
316
+ params["stream"] = req.stream
317
 
318
  r = requests.post(
319
  f"http://localhost:{cfg['port']}/completion",
320
+ json=params, stream=req.stream, timeout=120,
 
 
 
321
  )
322
 
323
+ if not req.stream:
324
+ text = r.json().get("content", "").strip()
325
+ return {"model": req.model, "response": text, "done": True, "done_reason": "stop"}
326
 
327
+ def stream_gen():
328
+ for line in r.iter_lines():
329
+ if not line:
330
+ continue
331
+ line = line.decode("utf-8").strip()
332
+ if line.startswith("data:"):
333
+ line = line[5:].strip()
334
+ try:
335
+ data = json.loads(line)
336
+ except Exception:
337
+ continue
338
+ token = data.get("content", "")
339
+ done = data.get("stop", False)
340
+ yield json.dumps({"model": req.model, "response": token, "done": done}) + "\n"
341
+ if done:
342
+ break
343
+ yield json.dumps({"model": req.model, "response": "", "done": True, "done_reason": "stop"}) + "\n"
344
+
345
+ return StreamingResponse(stream_gen(), media_type="application/x-ndjson",
346
+ headers={"Cache-Control": "no-cache"})
347
 
348
+
349
+ # ---------------------------
350
+ # /api/chat
351
+ # ---------------------------
352
 
353
  @app.post("/api/chat")
354
  def chat(req: ChatRequest):
355
+ key = resolve_model(req.model)
356
+ cfg = MODELS[key]
357
 
358
+ wait_for_model(key) # blocks until ready, not crash
359
 
360
  prompt = build_prompt(req.messages)
361
+ params = llama_params(req.options)
362
+ params["prompt"] = prompt
363
+ params["stream"] = req.stream
364
 
365
  r = requests.post(
366
  f"http://localhost:{cfg['port']}/completion",
367
+ json=params, stream=req.stream, timeout=120,
 
 
 
 
 
 
368
  )
369
 
370
  if not req.stream:
371
+ text = r.json().get("content", "").strip()
 
 
372
  return JSONResponse({
373
+ "model": req.model,
374
+ "message": {"role": "assistant", "content": text},
375
+ "done": True, "done_reason": "stop",
 
 
 
376
  })
377
 
378
+ def stream_gen():
 
379
  for line in r.iter_lines():
 
380
  if not line:
381
  continue
382
+ line = line.decode("utf-8").strip()
383
+ if line.startswith("data:"):
384
+ line = line[5:].strip()
385
  try:
386
  data = json.loads(line)
387
+ except Exception:
388
  continue
 
389
  token = data.get("content", "")
390
+ done = data.get("stop", False)
391
  yield json.dumps({
392
+ "model": req.model,
393
+ "message": {"role": "assistant", "content": token},
394
+ "done": done,
 
 
 
395
  }) + "\n"
396
+ if done:
397
+ break
398
+ yield json.dumps({"model": req.model, "done": True, "done_reason": "stop"}) + "\n"
399
 
400
+ return StreamingResponse(stream_gen(), media_type="application/x-ndjson",
401
+ headers={"Cache-Control": "no-cache"})
 
 
402
 
 
 
 
 
403
 
404
+ # ---------------------------
405
+ # START
406
+ # ---------------------------
407
 
408
  if __name__ == "__main__":
409
+ uvicorn.run(app, host="0.0.0.0", port=7860, workers=1)