NOT-OMEGA commited on
Commit
28073f6
Β·
verified Β·
1 Parent(s): 1fc8fbb

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +66 -209
main.py CHANGED
@@ -3,25 +3,17 @@ KVInfer β€” FastAPI Backend v2.1
3
  ========================================
4
  Fixes applied:
5
  #1 Persistent C++ process β€” model loads ONCE at startup via lifespan.
6
- All requests share one process via asyncio.Lock (serialized, no spawn overhead).
7
- #2 O(n) token cache β€” each session stores which tokens have already been
8
- sent to the C++ engine. New turns only encode + send NEW tokens.
9
- #3 Session KV-cache reuse β€” C++ engine persists KV cache per session;
10
- Python only sends the incremental new tokens each turn.
11
- #4 Stop-token bleed fix β€” only EOS (50256) used as stop token since plain
12
- text format ("User:") doesn't have a dedicated special token ID.
13
- #7 Chat template format fixed to match actual SFT training format:
14
- "System: ...\nUser: ...\nAssistant: " β€” NOT GPT-2 special angle tokens
15
- which tiktoken would fragment into multiple pieces and the model never
16
- saw during training.
17
  """
18
-
19
  import asyncio
20
  import json
21
  import os
22
  import time
23
  import uuid
24
- from collections import defaultdict
25
  from contextlib import asynccontextmanager
26
  from pathlib import Path
27
  from typing import AsyncGenerator
@@ -30,80 +22,47 @@ import psutil
30
  import tiktoken
31
  from fastapi import FastAPI, HTTPException
32
  from fastapi.middleware.cors import CORSMiddleware
33
- from fastapi.responses import StreamingResponse
34
  from pydantic import BaseModel, Field
35
 
36
  # ─────────────────────────────────────────────────────────────────────────
37
  # Config
38
  # ─────────────────────────────────────────────────────────────────────────
39
-
40
  BASE_DIR = Path(__file__).parent
41
  INFERENCE_EXE = BASE_DIR / "inference"
42
  MODEL_BIN = BASE_DIR / "model.bin"
43
 
44
- # FIX #7 β€” Chat template MUST match your SFT training data format exactly.
45
- #
46
- # GPT-2 tiktoken has NO special tokens for <|system|>, <|user|>, <|assistant|>.
47
- # tiktoken breaks them into multiple fragments:
48
- # "<|user|>" β†’ [27, 91, 7220, 91, 29] (5 separate tokens!)
49
- # Your SFT model NEVER saw these fragments during training β†’ garbage output.
50
- #
51
- # Your SFT training used plain text format:
52
- # "System: You are a helpful assistant.\n"
53
- # "User: Hello\n"
54
- # "Assistant: Hi\n"
55
- #
56
- # We MUST use the same format here.
57
  SYSTEM_TOKEN = "System:"
58
  USER_TOKEN = "User:"
59
  ASST_TOKEN = "Assistant:"
60
  SEP = "\n"
61
 
62
- # Context limit: 1024 (block_size) - 200 (max generation) - 24 (safety margin)
63
- # This is the maximum tokens we allow in the KV cache before a soft reset.
64
- # Formula: block_size - max_new_tokens_ceiling - safety_margin
65
- BLOCK_SIZE = 1024 # must match n_layer config in model.bin
66
- MAX_GEN_CEILING = 500 # max allowed by API (see ChatRequest)
67
- SAFETY_MARGIN = 24 # newlines, role tokens, off-by-one buffer
68
  MAX_SESSION_TOKENS = BLOCK_SIZE - MAX_GEN_CEILING - SAFETY_MARGIN # = 500
69
 
70
  # ─────────────────────────────────────────────────────────────────────────
71
  # Tokenizer
72
  # ─────────────────────────────────────────────────────────────────────────
73
-
74
- enc = tiktoken.get_encoding("gpt2")
75
-
76
- # Only EOS stop token needed. Plain text "User:" has no dedicated token ID
77
- # to stop on β€” the model was trained to emit 50256 at end of each reply.
78
  STOP_TOKEN_IDS = [50256]
79
-
80
- # String-level stop patterns β€” model may generate these as plain text since
81
- # training used plain "User:" / "System:" (not special tokens).
82
- # We catch them in the Python streaming loop before sending to the client.
83
- STOP_STRINGS = ["User:", "System:", "Assistant:"]
84
 
85
  # ─────────────────────────────────────────────────────────────────────────
86
- # Persistent Engine (FIX #1)
87
  # ─────────────────────────────────────────────────────────────────────────
88
-
89
  class InferenceEngine:
90
- """
91
- Wraps one long-lived inference.exe process.
92
- All requests are serialised through self._lock so the single
93
- stdin/stdout pipe stays consistent.
94
- """
95
-
96
  def __init__(self):
97
- self._proc: asyncio.subprocess.Process | None = None
98
- self._lock = asyncio.Lock()
99
  self._ready = False
100
 
101
  async def start(self):
102
  if not INFERENCE_EXE.exists():
103
- raise RuntimeError(f"inference.exe not found at {INFERENCE_EXE}")
104
  if not MODEL_BIN.exists():
105
  raise RuntimeError(f"model.bin not found at {MODEL_BIN}")
106
-
107
  self._proc = await asyncio.create_subprocess_exec(
108
  str(INFERENCE_EXE),
109
  stdin=asyncio.subprocess.PIPE,
@@ -111,7 +70,6 @@ class InferenceEngine:
111
  stderr=asyncio.subprocess.DEVNULL,
112
  cwd=str(BASE_DIR),
113
  )
114
- # Wait for READY signal (model loaded)
115
  while True:
116
  line = (await self._proc.stdout.readline()).decode().strip()
117
  if line == "READY":
@@ -132,101 +90,62 @@ class InferenceEngine:
132
 
133
  async def reset_session(self, session_id: str):
134
  async with self._lock:
135
- cmd = f"RESET|{session_id}\n".encode()
136
- self._proc.stdin.write(cmd)
137
  await self._proc.stdin.drain()
138
- # read RESET_OK
139
  await self._proc.stdout.readline()
140
 
141
- async def generate(
142
- self,
143
- session_id: str,
144
- new_token_ids: list[int],
145
- max_new: int,
146
- temperature: float,
147
- top_k: int,
148
- ) -> AsyncGenerator[dict, None]:
149
- """
150
- Yields dicts: {"type":"token","id":int,"text":str,"elapsed_ms":float}
151
- {"type":"done","total_tokens":int,"total_ms":float,"tps":float}
152
- {"type":"error","message":str}
153
- """
154
  if not self._ready or self._proc is None:
155
  yield {"type": "error", "message": "Engine not ready"}
156
  return
157
-
158
  tokens_csv = ",".join(map(str, new_token_ids))
159
  stop_csv = ",".join(map(str, STOP_TOKEN_IDS))
160
  cmd = f"REQUEST|{session_id}|{tokens_csv}|{max_new}|{temperature}|{top_k}|{stop_csv}\n"
161
-
162
  async with self._lock:
163
  self._proc.stdin.write(cmd.encode())
164
  await self._proc.stdin.drain()
165
-
166
- gen_count = 0
167
  while True:
168
- raw = await self._proc.stdout.readline()
169
  line = raw.decode("utf-8", errors="replace").strip()
170
  if not line:
171
  continue
172
-
173
  if line.startswith("TOKEN"):
174
  parts = line.split()
175
  tid = int(parts[1])
176
  ms = float(parts[2])
177
- gen_count += 1
178
  yield {"type": "token", "id": tid,
179
  "text": enc.decode([tid]), "elapsed_ms": ms}
180
-
181
  elif line.startswith("DONE"):
182
  parts = line.split()
183
  total_t = int(parts[1])
184
  total_ms = float(parts[2])
185
- tps = round(total_t / (total_ms / 1000.0), 2) if total_ms > 0 else 0
186
  yield {"type": "done", "total_tokens": total_t,
187
  "total_ms": total_ms, "tps": tps}
188
  break
189
-
190
  elif line.startswith("ERROR"):
191
  yield {"type": "error", "message": line}
192
  break
193
 
194
-
195
  engine = InferenceEngine()
196
 
197
  # ─────────────────────────────────────────────────────────────────────────
198
- # Session State (FIX #2 + #3)
199
  # ─────────────────────────────────────────────────────────────────────────
200
-
201
  class SessionData:
202
- """
203
- Tracks what the C++ engine already knows for this session so we
204
- only ever send NEW incremental tokens β€” O(1) per turn instead of O(n).
205
- """
206
  def __init__(self, system_prompt: str):
207
- self.system_prompt = system_prompt
208
- self.history: list[dict] = [] # {"role":..., "content":...}
209
- self.tokens_in_engine: int = 0 # how many tokens C++ has processed
210
- self.total_chars: int = 0
211
 
212
- def append_user(self, content: str):
213
  self.history.append({"role": "user", "content": content})
214
 
215
- def append_assistant(self, content: str):
216
  self.history.append({"role": "assistant", "content": content})
217
 
218
- def new_turn_tokens(self, user_msg: str) -> list[int]:
219
- """
220
- Returns ONLY the token IDs the C++ engine has not seen yet.
221
- Format matches EXACTLY what SFT training used:
222
- System: <prompt>
223
- User: <msg>
224
- Assistant:
225
- encode_ordinary() ensures tiktoken never interprets anything as
226
- a special token (like <|endoftext|>) mid-prompt by accident.
227
- """
228
  if self.tokens_in_engine == 0:
229
- # First turn - send full context: system + first user message
230
  full = (
231
  f"{SYSTEM_TOKEN} {self.system_prompt}{SEP}"
232
  f"{USER_TOKEN} {user_msg}{SEP}"
@@ -234,27 +153,11 @@ Assistant:
234
  )
235
  return enc.encode_ordinary(full)
236
  else:
237
- # Subsequent turns - engine already has prior context in KV cache.
238
- # Only send new user message + assistant cue.
239
- incremental = (
240
- f"{USER_TOKEN} {user_msg}{SEP}"
241
- f"{ASST_TOKEN} "
242
- )
243
  return enc.encode_ordinary(incremental)
244
 
245
- # NOTE: We intentionally do NOT re-encode the assistant reply to count tokens.
246
- # chunk["total_tokens"] from C++ is the exact generated token count β€” using
247
- # enc.encode_ordinary(reply) would re-tokenize decoded text and can differ
248
- # due to BPE whitespace/boundary effects. C++ count is always ground truth.
249
-
250
-
251
- sessions: dict[str, SessionData] = {}
252
-
253
- # ─────────────────────────────────────────────────────────────────────────
254
- # Server Metrics
255
- # ─────────────────────────────────────────────────────────────────────────
256
-
257
- metrics = {
258
  "total_requests": 0,
259
  "total_tokens": 0,
260
  "total_ms": 0.0,
@@ -265,37 +168,25 @@ metrics = {
265
  # ─────────────────────────────────────────────────────────────────────────
266
  # App + Lifespan
267
  # ─────────────────────────────────────────────────────────────────────────
268
-
269
  @asynccontextmanager
270
  async def lifespan(app: FastAPI):
271
- # Startup β€” launch C++ engine once
272
  try:
273
  await engine.start()
274
  except Exception as e:
275
  print(f"[WARNING] Could not start engine: {e}")
276
  print("[WARNING] Server will start but /chat will return 503 until engine is ready.")
277
  yield
278
- # Shutdown
279
  await engine.stop()
280
 
281
-
282
- app = FastAPI(
283
- title="KVInfer",
284
- version="2.0.0",
285
- lifespan=lifespan,
286
- )
287
-
288
  app.add_middleware(
289
  CORSMiddleware,
290
- allow_origins=["*"],
291
- allow_methods=["*"],
292
- allow_headers=["*"],
293
  )
294
 
295
  # ─────────────────────────────────────────────────────────────────────────
296
  # Pydantic Models
297
  # ─────────────────────────────────────────────────────────────────────────
298
-
299
  class ChatRequest(BaseModel):
300
  message: str
301
  session_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
@@ -317,6 +208,12 @@ class GenerateRequest(BaseModel):
317
  # Routes
318
  # ─────────────────────────────────────────────────────────────────────────
319
 
 
 
 
 
 
 
320
  @app.get("/health")
321
  async def health():
322
  mem = psutil.virtual_memory()
@@ -336,34 +233,23 @@ async def health():
336
 
337
  @app.post("/chat")
338
  async def chat(req: ChatRequest):
339
- """SSE streaming chat β€” real-time token-by-token output."""
340
  if not engine._ready:
341
- raise HTTPException(503, "Engine not ready. Check inference.exe and model.bin.")
342
-
343
- # Get or create session
344
  sess = sessions.get(req.session_id)
345
  if sess is None:
346
  sess = SessionData(req.system_prompt)
347
  sessions[req.session_id] = sess
348
-
349
- # FIX #2 β€” only encode NEW tokens (incremental)
350
  new_tokens = sess.new_turn_tokens(req.message)
351
-
352
- # Guard: don't overflow context
353
  if sess.tokens_in_engine + len(new_tokens) + req.max_new_tokens > MAX_SESSION_TOKENS:
354
- # Soft reset: clear C++ session KV cache, rebuild from full history
355
  await engine.reset_session(req.session_id)
356
  sess.tokens_in_engine = 0
357
- # Re-encode as full prompt
358
  new_tokens = sess.new_turn_tokens(req.message)
359
-
360
  sess.append_user(req.message)
361
  metrics["total_requests"] += 1
362
 
363
  async def event_stream():
364
- response_parts: list[str] = []
365
  t0 = time.time()
366
-
367
  try:
368
  async for chunk in engine.generate(
369
  req.session_id, new_tokens,
@@ -371,40 +257,25 @@ async def chat(req: ChatRequest):
371
  ):
372
  if chunk["type"] == "token":
373
  response_parts.append(chunk["text"])
374
-
375
- # String-level stop detection (Fix #8).
376
- # The model was trained on plain "User:" text β€” it may
377
- # regenerate the next speaker role instead of stopping on EOS.
378
- # We catch this here before streaming the token to the client.
379
  joined = "".join(response_parts)
380
- hit_stop = any(s in joined for s in STOP_STRINGS[:-1]) # User: / System:
381
- if hit_stop:
382
- # Trim the leaked role marker from the reply
383
  for s in STOP_STRINGS[:-1]:
384
  idx = joined.find(s)
385
  if idx != -1:
386
  response_parts = [joined[:idx]]
387
  break
388
-
389
  yield f"data: {json.dumps(chunk)}\n\n"
390
-
391
  elif chunk["type"] == "done":
392
  reply = "".join(response_parts).strip()
393
  sess.append_assistant(reply)
394
-
395
- # FIX #2 β€” update how many tokens the engine now holds
396
  sess.tokens_in_engine += len(new_tokens) + chunk["total_tokens"]
397
-
398
  elapsed = (time.time() - t0) * 1000
399
  metrics["total_tokens"] += chunk["total_tokens"]
400
  metrics["total_ms"] += elapsed
401
-
402
  yield f"data: {json.dumps({**chunk, 'session_id': req.session_id, 'full_response': reply})}\n\n"
403
-
404
  elif chunk["type"] == "error":
405
  metrics["errors"] += 1
406
  yield f"data: {json.dumps(chunk)}\n\n"
407
-
408
  except Exception as e:
409
  metrics["errors"] += 1
410
  yield f"data: {json.dumps({'type':'error','message':str(e)})}\n\n"
@@ -438,30 +309,22 @@ async def get_history(session_id: str):
438
 
439
  @app.post("/generate")
440
  async def generate(req: GenerateRequest):
441
- """Non-streaming single generation (backward-compat)."""
442
  if not engine._ready:
443
  raise HTTPException(503, "Engine not ready.")
444
-
445
  token_ids = enc.encode_ordinary(req.prompt)
446
  tmp_sess = f"_gen_{uuid.uuid4().hex}"
447
- generated: list[str] = []
448
- total_ms = 0.0
449
-
450
- async for chunk in engine.generate(
451
- tmp_sess, token_ids, req.max_tokens, req.temperature, req.top_k
452
- ):
453
  if chunk["type"] == "token":
454
  generated.append(chunk["text"])
455
  elif chunk["type"] == "done":
456
  total_ms = chunk["total_ms"]
457
  elif chunk["type"] == "error":
458
  raise HTTPException(500, chunk["message"])
459
-
460
- # Clean up temp session from C++ engine
461
  await engine.reset_session(tmp_sess)
462
  text = "".join(generated)
463
  tps = len(generated) / (total_ms / 1000.0) if total_ms > 0 else 0
464
-
465
  return {
466
  "prompt": req.prompt, "generated_text": text,
467
  "tokens_in": len(token_ids), "tokens_out": len(generated),
@@ -477,24 +340,22 @@ async def get_metrics():
477
  mem = psutil.virtual_memory()
478
  proc = psutil.Process(os.getpid())
479
  return {
480
- "total_requests": n,
481
- "total_tokens": tok,
482
- "avg_tps": round(tok/(ms/1000),2) if ms>0 else 0,
483
- "avg_latency_ms": round(ms/n,2) if n>0 else 0,
484
- "errors": metrics["errors"],
485
- "active_sessions": len(sessions),
486
- "process_ram_mb": round(proc.memory_info().rss/1e6,1),
487
- "system_ram_used_pct": mem.percent,
488
- "uptime_s": round(time.time()-metrics["start_time"],1),
489
  }
490
 
491
 
492
  @app.get("/benchmark/run")
493
  async def benchmark_run():
494
- """Quick 5-prompt internal benchmark (used by frontend modal)."""
495
  if not engine._ready:
496
  raise HTTPException(503, "Engine not ready.")
497
-
498
  prompts = [
499
  "What is artificial intelligence?",
500
  "How does a CPU work?",
@@ -503,36 +364,32 @@ async def benchmark_run():
503
  "How does photosynthesis work?",
504
  ]
505
  results = []
506
-
507
  for p in prompts:
508
  sid = f"_bench_{uuid.uuid4().hex}"
509
  toks = enc.encode_ordinary(f"{USER_TOKEN} {p}\n{ASST_TOKEN} ")
510
- gen = 0; total_ms = 0.0; ttft_ms = 0.0; first = True
511
  t0 = time.time()
512
-
513
  async for c in engine.generate(sid, toks, 80, 0.7, 40):
514
  if c["type"] == "token":
515
  gen += 1
516
- if first: ttft_ms = (time.time()-t0)*1000; first=False
517
  elif c["type"] == "done":
518
  total_ms = c["total_ms"]
519
-
520
  await engine.reset_session(sid)
521
- tps = gen/(total_ms/1000) if total_ms>0 else 0
522
  results.append({
523
- "prompt_preview": p[:40],
524
- "tokens_in": len(toks),
525
- "tokens_out": gen,
526
- "ttft_ms": round(ttft_ms,1),
527
- "total_ms": round(total_ms,1),
528
- "tokens_per_sec": round(tps,2),
529
  })
530
-
531
- avg_tps = sum(r["tokens_per_sec"] for r in results)/len(results)
532
- avg_ttft = sum(r["ttft_ms"] for r in results)/len(results)
533
  return {
534
- "summary": {"avg_tps": round(avg_tps,2),
535
- "avg_ttft_ms": round(avg_ttft,1),
536
  "runs": len(results)},
537
  "details": results,
538
  }
@@ -540,4 +397,4 @@ async def benchmark_run():
540
 
541
  if __name__ == "__main__":
542
  import uvicorn
543
- uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=False)
 
3
  ========================================
4
  Fixes applied:
5
  #1 Persistent C++ process β€” model loads ONCE at startup via lifespan.
6
+ #2 O(n) token cache β€” incremental tokens only per turn.
7
+ #3 Session KV-cache reuse.
8
+ #4 Stop-token bleed fix.
9
+ #7 Chat template format fixed to match SFT training format.
10
+ #HF Serves index.html at "/" for HF Spaces Docker deployment.
 
 
 
 
 
 
11
  """
 
12
  import asyncio
13
  import json
14
  import os
15
  import time
16
  import uuid
 
17
  from contextlib import asynccontextmanager
18
  from pathlib import Path
19
  from typing import AsyncGenerator
 
22
  import tiktoken
23
  from fastapi import FastAPI, HTTPException
24
  from fastapi.middleware.cors import CORSMiddleware
25
+ from fastapi.responses import FileResponse, StreamingResponse
26
  from pydantic import BaseModel, Field
27
 
28
  # ─────────────────────────────────────────────────────────────────────────
29
  # Config
30
  # ─────────────────────────────────────────────────────────────────────────
 
31
  BASE_DIR = Path(__file__).parent
32
  INFERENCE_EXE = BASE_DIR / "inference"
33
  MODEL_BIN = BASE_DIR / "model.bin"
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  SYSTEM_TOKEN = "System:"
36
  USER_TOKEN = "User:"
37
  ASST_TOKEN = "Assistant:"
38
  SEP = "\n"
39
 
40
+ BLOCK_SIZE = 1024
41
+ MAX_GEN_CEILING = 500
42
+ SAFETY_MARGIN = 24
 
 
 
43
  MAX_SESSION_TOKENS = BLOCK_SIZE - MAX_GEN_CEILING - SAFETY_MARGIN # = 500
44
 
45
  # ─────────────────────────────────────────────────────────────────────────
46
  # Tokenizer
47
  # ─────────────────────────────────────────────────────────────────────────
48
+ enc = tiktoken.get_encoding("gpt2")
 
 
 
 
49
  STOP_TOKEN_IDS = [50256]
50
+ STOP_STRINGS = ["User:", "System:", "Assistant:"]
 
 
 
 
51
 
52
  # ─────────────────────────────────────────────────────────────────────────
53
+ # Persistent Engine
54
  # ─────────────────────────────────────────────────────────────────────────
 
55
  class InferenceEngine:
 
 
 
 
 
 
56
  def __init__(self):
57
+ self._proc = None
58
+ self._lock = asyncio.Lock()
59
  self._ready = False
60
 
61
  async def start(self):
62
  if not INFERENCE_EXE.exists():
63
+ raise RuntimeError(f"inference not found at {INFERENCE_EXE}")
64
  if not MODEL_BIN.exists():
65
  raise RuntimeError(f"model.bin not found at {MODEL_BIN}")
 
66
  self._proc = await asyncio.create_subprocess_exec(
67
  str(INFERENCE_EXE),
68
  stdin=asyncio.subprocess.PIPE,
 
70
  stderr=asyncio.subprocess.DEVNULL,
71
  cwd=str(BASE_DIR),
72
  )
 
73
  while True:
74
  line = (await self._proc.stdout.readline()).decode().strip()
75
  if line == "READY":
 
90
 
91
  async def reset_session(self, session_id: str):
92
  async with self._lock:
93
+ self._proc.stdin.write(f"RESET|{session_id}\n".encode())
 
94
  await self._proc.stdin.drain()
 
95
  await self._proc.stdout.readline()
96
 
97
+ async def generate(self, session_id, new_token_ids, max_new, temperature, top_k):
 
 
 
 
 
 
 
 
 
 
 
 
98
  if not self._ready or self._proc is None:
99
  yield {"type": "error", "message": "Engine not ready"}
100
  return
 
101
  tokens_csv = ",".join(map(str, new_token_ids))
102
  stop_csv = ",".join(map(str, STOP_TOKEN_IDS))
103
  cmd = f"REQUEST|{session_id}|{tokens_csv}|{max_new}|{temperature}|{top_k}|{stop_csv}\n"
 
104
  async with self._lock:
105
  self._proc.stdin.write(cmd.encode())
106
  await self._proc.stdin.drain()
 
 
107
  while True:
108
+ raw = await self._proc.stdout.readline()
109
  line = raw.decode("utf-8", errors="replace").strip()
110
  if not line:
111
  continue
 
112
  if line.startswith("TOKEN"):
113
  parts = line.split()
114
  tid = int(parts[1])
115
  ms = float(parts[2])
 
116
  yield {"type": "token", "id": tid,
117
  "text": enc.decode([tid]), "elapsed_ms": ms}
 
118
  elif line.startswith("DONE"):
119
  parts = line.split()
120
  total_t = int(parts[1])
121
  total_ms = float(parts[2])
122
+ tps = round(total_t / (total_ms / 1000.0), 2) if total_ms > 0 else 0
123
  yield {"type": "done", "total_tokens": total_t,
124
  "total_ms": total_ms, "tps": tps}
125
  break
 
126
  elif line.startswith("ERROR"):
127
  yield {"type": "error", "message": line}
128
  break
129
 
 
130
  engine = InferenceEngine()
131
 
132
  # ─────────────────────────────────────────────────────────────────────────
133
+ # Session State
134
  # ─────────────────────────────────────────────────────────────────────────
 
135
  class SessionData:
 
 
 
 
136
  def __init__(self, system_prompt: str):
137
+ self.system_prompt = system_prompt
138
+ self.history = []
139
+ self.tokens_in_engine = 0
 
140
 
141
+ def append_user(self, content):
142
  self.history.append({"role": "user", "content": content})
143
 
144
+ def append_assistant(self, content):
145
  self.history.append({"role": "assistant", "content": content})
146
 
147
+ def new_turn_tokens(self, user_msg):
 
 
 
 
 
 
 
 
 
148
  if self.tokens_in_engine == 0:
 
149
  full = (
150
  f"{SYSTEM_TOKEN} {self.system_prompt}{SEP}"
151
  f"{USER_TOKEN} {user_msg}{SEP}"
 
153
  )
154
  return enc.encode_ordinary(full)
155
  else:
156
+ incremental = f"{USER_TOKEN} {user_msg}{SEP}{ASST_TOKEN} "
 
 
 
 
 
157
  return enc.encode_ordinary(incremental)
158
 
159
+ sessions = {}
160
+ metrics = {
 
 
 
 
 
 
 
 
 
 
 
161
  "total_requests": 0,
162
  "total_tokens": 0,
163
  "total_ms": 0.0,
 
168
  # ─────────────────────────────────────────────────────────────────────────
169
  # App + Lifespan
170
  # ─────────────────────────────────────────────────────────────────────────
 
171
  @asynccontextmanager
172
  async def lifespan(app: FastAPI):
 
173
  try:
174
  await engine.start()
175
  except Exception as e:
176
  print(f"[WARNING] Could not start engine: {e}")
177
  print("[WARNING] Server will start but /chat will return 503 until engine is ready.")
178
  yield
 
179
  await engine.stop()
180
 
181
+ app = FastAPI(title="KVInfer", version="2.1.0", lifespan=lifespan)
 
 
 
 
 
 
182
  app.add_middleware(
183
  CORSMiddleware,
184
+ allow_origins=["*"], allow_methods=["*"], allow_headers=["*"],
 
 
185
  )
186
 
187
  # ─────────────────────────────────────────────────────────────────────────
188
  # Pydantic Models
189
  # ─────────────────────────────────────────────────────────────────────────
 
190
  class ChatRequest(BaseModel):
191
  message: str
192
  session_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
 
208
  # Routes
209
  # ─────────────────────────────────────────────────────────────────────────
210
 
211
+ @app.get("/")
212
+ async def serve_ui():
213
+ """Serve the Chat UI β€” required for HF Spaces Docker deployment."""
214
+ return FileResponse(BASE_DIR / "index.html")
215
+
216
+
217
  @app.get("/health")
218
  async def health():
219
  mem = psutil.virtual_memory()
 
233
 
234
  @app.post("/chat")
235
  async def chat(req: ChatRequest):
 
236
  if not engine._ready:
237
+ raise HTTPException(503, "Engine not ready. Check inference and model.bin.")
 
 
238
  sess = sessions.get(req.session_id)
239
  if sess is None:
240
  sess = SessionData(req.system_prompt)
241
  sessions[req.session_id] = sess
 
 
242
  new_tokens = sess.new_turn_tokens(req.message)
 
 
243
  if sess.tokens_in_engine + len(new_tokens) + req.max_new_tokens > MAX_SESSION_TOKENS:
 
244
  await engine.reset_session(req.session_id)
245
  sess.tokens_in_engine = 0
 
246
  new_tokens = sess.new_turn_tokens(req.message)
 
247
  sess.append_user(req.message)
248
  metrics["total_requests"] += 1
249
 
250
  async def event_stream():
251
+ response_parts = []
252
  t0 = time.time()
 
253
  try:
254
  async for chunk in engine.generate(
255
  req.session_id, new_tokens,
 
257
  ):
258
  if chunk["type"] == "token":
259
  response_parts.append(chunk["text"])
 
 
 
 
 
260
  joined = "".join(response_parts)
261
+ if any(s in joined for s in STOP_STRINGS[:-1]):
 
 
262
  for s in STOP_STRINGS[:-1]:
263
  idx = joined.find(s)
264
  if idx != -1:
265
  response_parts = [joined[:idx]]
266
  break
 
267
  yield f"data: {json.dumps(chunk)}\n\n"
 
268
  elif chunk["type"] == "done":
269
  reply = "".join(response_parts).strip()
270
  sess.append_assistant(reply)
 
 
271
  sess.tokens_in_engine += len(new_tokens) + chunk["total_tokens"]
 
272
  elapsed = (time.time() - t0) * 1000
273
  metrics["total_tokens"] += chunk["total_tokens"]
274
  metrics["total_ms"] += elapsed
 
275
  yield f"data: {json.dumps({**chunk, 'session_id': req.session_id, 'full_response': reply})}\n\n"
 
276
  elif chunk["type"] == "error":
277
  metrics["errors"] += 1
278
  yield f"data: {json.dumps(chunk)}\n\n"
 
279
  except Exception as e:
280
  metrics["errors"] += 1
281
  yield f"data: {json.dumps({'type':'error','message':str(e)})}\n\n"
 
309
 
310
  @app.post("/generate")
311
  async def generate(req: GenerateRequest):
 
312
  if not engine._ready:
313
  raise HTTPException(503, "Engine not ready.")
 
314
  token_ids = enc.encode_ordinary(req.prompt)
315
  tmp_sess = f"_gen_{uuid.uuid4().hex}"
316
+ generated = []
317
+ total_ms = 0.0
318
+ async for chunk in engine.generate(tmp_sess, token_ids, req.max_tokens, req.temperature, req.top_k):
 
 
 
319
  if chunk["type"] == "token":
320
  generated.append(chunk["text"])
321
  elif chunk["type"] == "done":
322
  total_ms = chunk["total_ms"]
323
  elif chunk["type"] == "error":
324
  raise HTTPException(500, chunk["message"])
 
 
325
  await engine.reset_session(tmp_sess)
326
  text = "".join(generated)
327
  tps = len(generated) / (total_ms / 1000.0) if total_ms > 0 else 0
 
328
  return {
329
  "prompt": req.prompt, "generated_text": text,
330
  "tokens_in": len(token_ids), "tokens_out": len(generated),
 
340
  mem = psutil.virtual_memory()
341
  proc = psutil.Process(os.getpid())
342
  return {
343
+ "total_requests": n,
344
+ "total_tokens": tok,
345
+ "avg_tps": round(tok/(ms/1000), 2) if ms > 0 else 0,
346
+ "avg_latency_ms": round(ms/n, 2) if n > 0 else 0,
347
+ "errors": metrics["errors"],
348
+ "active_sessions": len(sessions),
349
+ "process_ram_mb": round(proc.memory_info().rss/1e6, 1),
350
+ "system_ram_used_pct": mem.percent,
351
+ "uptime_s": round(time.time()-metrics["start_time"], 1),
352
  }
353
 
354
 
355
  @app.get("/benchmark/run")
356
  async def benchmark_run():
 
357
  if not engine._ready:
358
  raise HTTPException(503, "Engine not ready.")
 
359
  prompts = [
360
  "What is artificial intelligence?",
361
  "How does a CPU work?",
 
364
  "How does photosynthesis work?",
365
  ]
366
  results = []
 
367
  for p in prompts:
368
  sid = f"_bench_{uuid.uuid4().hex}"
369
  toks = enc.encode_ordinary(f"{USER_TOKEN} {p}\n{ASST_TOKEN} ")
370
+ gen = 0; total_ms = 0.0; ttft_ms = 0.0; first = True
371
  t0 = time.time()
 
372
  async for c in engine.generate(sid, toks, 80, 0.7, 40):
373
  if c["type"] == "token":
374
  gen += 1
375
+ if first: ttft_ms = (time.time()-t0)*1000; first = False
376
  elif c["type"] == "done":
377
  total_ms = c["total_ms"]
 
378
  await engine.reset_session(sid)
379
+ tps = gen/(total_ms/1000) if total_ms > 0 else 0
380
  results.append({
381
+ "prompt_preview": p[:40],
382
+ "tokens_in": len(toks),
383
+ "tokens_out": gen,
384
+ "ttft_ms": round(ttft_ms, 1),
385
+ "total_ms": round(total_ms, 1),
386
+ "tokens_per_sec": round(tps, 2),
387
  })
388
+ avg_tps = sum(r["tokens_per_sec"] for r in results) / len(results)
389
+ avg_ttft = sum(r["ttft_ms"] for r in results) / len(results)
 
390
  return {
391
+ "summary": {"avg_tps": round(avg_tps, 2),
392
+ "avg_ttft_ms": round(avg_ttft, 1),
393
  "runs": len(results)},
394
  "details": results,
395
  }
 
397
 
398
  if __name__ == "__main__":
399
  import uvicorn
400
+ uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False)