| """ |
| KVInfer β FastAPI Backend v4.0 (16GB Optimised) |
| |
| RAM Budget (HF 16GB Space): |
| 4 engines Γ 580MB model = 2.32GB |
| 4 engines Γ 20 sessions Γ 96MB = 7.68GB |
| Python + OS overhead = ~1.00GB |
| ββββββββββββββββββββββββββββββββββββββββββ |
| TOTAL β 11.0GB β (5GB headroom) |
| |
| CPU Budget (2 vCPU): |
| OMP_NUM_THREADS = 1 per engine |
| 4 engines Γ 1 thread = 4 threads on 2 vCPUs (light oversubscription, fine) |
| |
| ββ SPEED MODE (single user fastest) ββ |
| Set env: N_ENGINES=1 OMP_NUM_THREADS=2 |
| β Both cores for 1 engine β ~1.5β2Γ faster TPS per request |
| β Trade-off: only 1 user at a time |
| |
| ββ BALANCED MODE (default) ββ |
| N_ENGINES=4, OMP_NUM_THREADS=1 |
| β 4 truly parallel users, CPU evenly shared |
| |
| Concurrency: |
| 4 users β truly parallel |
| 5th+ user β queue mein wait (drop nahi hoga) |
| """ |
|
|
| import asyncio |
| import json |
| import os |
| import time |
| import uuid |
| from contextlib import asynccontextmanager |
| from pathlib import Path |
|
|
| import psutil |
| import tiktoken |
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import FileResponse, StreamingResponse |
| from pydantic import BaseModel, Field |
| from huggingface_hub import hf_hub_download |
|
|
| |
| |
| |
| BASE_DIR = Path(__file__).parent |
| INFERENCE_EXE = BASE_DIR / "inference" |
| MODEL_BIN = BASE_DIR / "model.bin" |
| TOKENIZER_BIN = BASE_DIR / "tokenizer.bin" |
| HF_REPO_ID = "NOT-OMEGA/KVInfer-152M" |
|
|
| SYSTEM_TOKEN = "System:" |
| USER_TOKEN = "User:" |
| ASST_TOKEN = "Assistant:" |
| SEP = "\n" |
|
|
| BLOCK_SIZE = 1024 |
| MAX_GEN_CEILING = 500 |
| SAFETY_MARGIN = 24 |
| MAX_SESSION_TOKENS = BLOCK_SIZE - MAX_GEN_CEILING - SAFETY_MARGIN |
|
|
| |
| |
| N_ENGINES = int(os.environ.get("N_ENGINES", "4")) |
| OMP_PER_ENGINE = int(os.environ.get("OMP_NUM_THREADS", "1")) |
|
|
| |
| SESSION_TTL_SECONDS = int(os.environ.get("SESSION_TTL", "1800")) |
|
|
| enc = tiktoken.get_encoding("gpt2") |
| STOP_TOKEN_IDS = [50256] |
| STOP_STRINGS = ["User:", "System:"] |
|
|
|
|
| |
| |
| |
| class InferenceEngine: |
| def __init__(self, eid: int): |
| self.eid = eid |
| self._proc = None |
| self._ready = False |
|
|
| async def start(self): |
| if not INFERENCE_EXE.exists(): |
| raise RuntimeError(f"inference binary not found: {INFERENCE_EXE}") |
| if not MODEL_BIN.exists(): |
| raise RuntimeError(f"model.bin not found: {MODEL_BIN}") |
|
|
| env = os.environ.copy() |
| env["OMP_NUM_THREADS"] = str(OMP_PER_ENGINE) |
|
|
| self._proc = await asyncio.create_subprocess_exec( |
| str(INFERENCE_EXE), |
| stdin=asyncio.subprocess.PIPE, |
| stdout=asyncio.subprocess.PIPE, |
| stderr=asyncio.subprocess.DEVNULL, |
| cwd=str(BASE_DIR), |
| env=env, |
| ) |
|
|
| while True: |
| line = (await self._proc.stdout.readline()).decode().strip() |
| if line == "READY": |
| self._ready = True |
| print(f"[engine-{self.eid}] READY pid={self._proc.pid}") |
| break |
| if line.startswith("ERROR"): |
| raise RuntimeError(f"Engine-{self.eid} startup error: {line}") |
|
|
| async def stop(self): |
| if self._proc: |
| try: |
| self._proc.stdin.write(b"QUIT\n") |
| await self._proc.stdin.drain() |
| await asyncio.wait_for(self._proc.wait(), timeout=3.0) |
| except Exception: |
| self._proc.kill() |
|
|
| async def reset_session(self, session_id: str): |
| self._proc.stdin.write(f"RESET|{session_id}\n".encode()) |
| await self._proc.stdin.drain() |
| while True: |
| raw = await self._proc.stdout.readline() |
| if not raw or raw.decode().strip() == "RESET_OK": |
| break |
|
|
| async def generate(self, session_id, new_token_ids, max_new, temperature, top_k): |
| if not self._ready or self._proc is None: |
| yield {"type": "error", "message": f"Engine-{self.eid} not ready"} |
| return |
|
|
| tokens_csv = ",".join(map(str, new_token_ids)) |
| stop_csv = ",".join(map(str, STOP_TOKEN_IDS)) |
| cmd = ( |
| f"REQUEST|{session_id}|{tokens_csv}|" |
| f"{max_new}|{temperature}|{top_k}|{stop_csv}\n" |
| ) |
| self._proc.stdin.write(cmd.encode()) |
| await self._proc.stdin.drain() |
|
|
| try: |
| while True: |
| raw = await self._proc.stdout.readline() |
| if not raw: |
| break |
| line = raw.decode("utf-8", errors="replace").strip() |
| if not line: |
| continue |
|
|
| if line.startswith("TOKEN"): |
| parts = line.split() |
| tid = int(parts[1]) |
| ms = float(parts[2]) |
| yield {"type": "token", "id": tid, |
| "text": enc.decode([tid]), "elapsed_ms": ms} |
|
|
| elif line.startswith("DONE"): |
| parts = line.split() |
| total_t = int(parts[1]) |
| total_ms = float(parts[2]) |
| tps = round(total_t / (total_ms / 1000.0), 2) if total_ms > 0 else 0 |
| yield {"type": "done", "total_tokens": total_t, |
| "total_ms": total_ms, "tps": tps} |
| break |
|
|
| elif line.startswith("ERROR"): |
| yield {"type": "error", "message": line} |
| break |
|
|
| except asyncio.CancelledError: |
| |
| while True: |
| raw = await self._proc.stdout.readline() |
| if not raw or raw.decode().strip().startswith(("DONE", "ERROR")): |
| break |
| raise |
|
|
| @property |
| def pid(self): |
| return self._proc.pid if self._proc else None |
|
|
|
|
| |
| |
| |
| class EnginePool: |
| def __init__(self, n: int): |
| self.n = n |
| self.engines = [InferenceEngine(i) for i in range(n)] |
| self._locks: list[asyncio.Lock] = [] |
| self._session_map: dict[str, int] = {} |
| self._engine_load: list[int] = [] |
| self._map_lock = asyncio.Lock() |
|
|
| async def start(self): |
| self._locks = [asyncio.Lock() for _ in range(self.n)] |
| self._engine_load = [0] * self.n |
| await asyncio.gather(*(e.start() for e in self.engines)) |
| print(f"[pool] {self.n} engines ready (OMP_NUM_THREADS={OMP_PER_ENGINE} each)") |
|
|
| async def stop(self): |
| await asyncio.gather(*(e.stop() for e in self.engines), return_exceptions=True) |
|
|
| async def _assign_engine(self, session_id: str) -> int: |
| async with self._map_lock: |
| if session_id not in self._session_map: |
| idx = min(range(self.n), key=lambda i: self._engine_load[i]) |
| self._session_map[session_id] = idx |
| self._engine_load[idx] += 1 |
| print(f"[pool] {session_id[:8]}β¦ β engine-{idx} load={self._engine_load}") |
| return self._session_map[session_id] |
|
|
| async def _untrack_session(self, session_id: str): |
| async with self._map_lock: |
| if session_id in self._session_map: |
| idx = self._session_map.pop(session_id) |
| self._engine_load[idx] = max(0, self._engine_load[idx] - 1) |
|
|
| async def generate(self, session_id, new_token_ids, max_new, temperature, top_k): |
| idx = await self._assign_engine(session_id) |
| async with self._locks[idx]: |
| async for chunk in self.engines[idx].generate( |
| session_id, new_token_ids, max_new, temperature, top_k |
| ): |
| yield chunk |
|
|
| async def reset_session(self, session_id: str): |
| async with self._map_lock: |
| idx = self._session_map.get(session_id) |
| if idx is not None: |
| async with self._locks[idx]: |
| await self.engines[idx].reset_session(session_id) |
| await self._untrack_session(session_id) |
|
|
| def get_all_pids(self) -> list: |
| return [e.pid for e in self.engines if e.pid] |
|
|
| def status(self) -> list: |
| return [ |
| { |
| "engine_id": i, |
| "pid": self.engines[i].pid, |
| "sessions": self._engine_load[i], |
| "busy": self._locks[i].locked(), |
| "ready": self.engines[i]._ready, |
| } |
| for i in range(self.n) |
| ] |
|
|
|
|
| pool = EnginePool(N_ENGINES) |
|
|
|
|
| |
| |
| |
| class SessionData: |
| def __init__(self, system_prompt: str): |
| self.system_prompt = system_prompt |
| self.history: list = [] |
| self.tokens_in_engine = 0 |
| self.last_active = time.time() |
|
|
| def touch(self): |
| self.last_active = time.time() |
|
|
| def append_user(self, content: str): |
| self.history.append({"role": "user", "content": content}) |
|
|
| def append_assistant(self, content: str): |
| self.history.append({"role": "assistant", "content": content}) |
|
|
| def new_turn_tokens(self, user_msg: str) -> list: |
| if self.tokens_in_engine == 0: |
| full = ( |
| f"{SYSTEM_TOKEN} {self.system_prompt}{SEP}" |
| f"{USER_TOKEN} {user_msg}{SEP}{ASST_TOKEN} " |
| ) |
| return enc.encode_ordinary(full) |
| else: |
| return enc.encode_ordinary(f"{USER_TOKEN} {user_msg}{SEP}{ASST_TOKEN} ") |
|
|
|
|
| sessions: dict[str, SessionData] = {} |
|
|
| metrics = { |
| "total_requests": 0, |
| "total_tokens": 0, |
| "total_ms": 0.0, |
| "errors": 0, |
| "start_time": time.time(), |
| "sessions_evicted": 0, |
| } |
|
|
|
|
| |
| |
| |
| def get_total_ram_mb() -> float: |
| try: |
| total = psutil.Process(os.getpid()).memory_info().rss |
| for pid in pool.get_all_pids(): |
| try: |
| total += psutil.Process(pid).memory_info().rss |
| except psutil.NoSuchProcess: |
| pass |
| return round(total / 1e6, 1) |
| except Exception: |
| return 0.0 |
|
|
|
|
| |
| |
| |
| async def session_gc_loop(): |
| """Background task: har 5 min mein idle sessions clean karo.""" |
| while True: |
| await asyncio.sleep(300) |
| now = time.time() |
| expired = [ |
| sid for sid, s in list(sessions.items()) |
| if now - s.last_active > SESSION_TTL_SECONDS |
| ] |
| for sid in expired: |
| sessions.pop(sid, None) |
| await pool.reset_session(sid) |
| metrics["sessions_evicted"] += 1 |
| if expired: |
| print(f"[GC] Evicted {len(expired)} idle sessions (TTL={SESSION_TTL_SECONDS}s)") |
|
|
|
|
| |
| |
| |
| |
| async def _startup_background(): |
| """Model download + engine start background mein β port block nahi hoga.""" |
| try: |
| print("[HF HUB] Checking model filesβ¦") |
| if not MODEL_BIN.exists(): |
| print("[HF HUB] Downloading model.binβ¦") |
| hf_hub_download( |
| repo_id=HF_REPO_ID, filename="model.bin", |
| local_dir=str(BASE_DIR) |
| ) |
| if not TOKENIZER_BIN.exists(): |
| print("[HF HUB] Downloading tokenizer.binβ¦") |
| hf_hub_download( |
| repo_id=HF_REPO_ID, filename="tokenizer.bin", |
| local_dir=str(BASE_DIR) |
| ) |
| except Exception as e: |
| print(f"[WARNING] HF download failed: {e}") |
|
|
| try: |
| await pool.start() |
| except Exception as e: |
| print(f"[ERROR] Engine pool start failed: {e}") |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| |
| asyncio.create_task(_startup_background()) |
| asyncio.create_task(session_gc_loop()) |
| yield |
| await pool.stop() |
|
|
|
|
| app = FastAPI(title="KVInfer", version="4.0.0", lifespan=lifespan) |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| |
| |
| |
| class ChatRequest(BaseModel): |
| message: str |
| session_id: str = Field(default_factory=lambda: str(uuid.uuid4())) |
| system_prompt: str = "You are a helpful assistant." |
| max_new_tokens: int = Field(default=200, ge=1, le=500) |
| temperature: float = Field(default=0.7, ge=0.01, le=2.0) |
| top_k: int = Field(default=40, ge=1, le=200) |
|
|
|
|
| class ResetRequest(BaseModel): |
| session_id: str |
|
|
|
|
| |
| |
| |
| @app.get("/") |
| async def serve_ui(): |
| return FileResponse(BASE_DIR / "index.html") |
|
|
|
|
| @app.get("/health") |
| async def health(): |
| mem = psutil.virtual_memory() |
| uptime = time.time() - metrics["start_time"] |
| ready_count = sum(1 for e in pool.engines if e._ready) |
| return { |
| "status": "ok" if ready_count > 0 else "starting", |
| "engines_ready": ready_count, |
| "engines_total": N_ENGINES, |
| "omp_threads_per_engine": OMP_PER_ENGINE, |
| "active_sessions": len(sessions), |
| "sessions_evicted": metrics["sessions_evicted"], |
| "process_ram_mb": get_total_ram_mb(), |
| "system_ram_used_pct": mem.percent, |
| "system_ram_total_gb": round(mem.total / 1e9, 1), |
| "uptime_seconds": round(uptime, 1), |
| } |
|
|
|
|
| @app.get("/pool/status") |
| async def pool_status(): |
| return { |
| "n_engines": N_ENGINES, |
| "omp_per_engine": OMP_PER_ENGINE, |
| "engines": pool.status(), |
| "total_sessions": len(sessions), |
| "session_ttl_s": SESSION_TTL_SECONDS, |
| } |
|
|
|
|
| @app.post("/chat") |
| async def chat(req: ChatRequest): |
| if not any(e._ready for e in pool.engines): |
| raise HTTPException(503, "Engines loading⦠please wait ~30s and retry.") |
|
|
| sess = sessions.get(req.session_id) |
| if sess is None: |
| sess = SessionData(req.system_prompt) |
| sessions[req.session_id] = sess |
| sess.touch() |
|
|
| new_tokens = sess.new_turn_tokens(req.message) |
|
|
| |
| if sess.tokens_in_engine + len(new_tokens) + req.max_new_tokens > MAX_SESSION_TOKENS: |
| await pool.reset_session(req.session_id) |
| sess.tokens_in_engine = 0 |
| new_tokens = sess.new_turn_tokens(req.message) |
|
|
| sess.append_user(req.message) |
| metrics["total_requests"] += 1 |
|
|
| async def event_stream(): |
| response_parts: list[str] = [] |
| t0 = time.time() |
|
|
| try: |
| async for chunk in pool.generate( |
| req.session_id, new_tokens, |
| req.max_new_tokens, req.temperature, req.top_k, |
| ): |
| if chunk["type"] == "token": |
| response_parts.append(chunk["text"]) |
| joined = "".join(response_parts) |
|
|
| |
| hit_stop = any(f"\n{s}" in joined for s in STOP_STRINGS) |
| if hit_stop: |
| for s in STOP_STRINGS: |
| idx = joined.find(f"\n{s}") |
| if idx != -1: |
| response_parts = [joined[:idx]] |
| break |
| yield f"data: {json.dumps(chunk)}\n\n" |
|
|
| elif chunk["type"] == "done": |
| reply = "".join(response_parts).strip() |
| sess.append_assistant(reply) |
| sess.tokens_in_engine += len(new_tokens) + chunk["total_tokens"] |
| metrics["total_tokens"] += chunk["total_tokens"] |
| metrics["total_ms"] += (time.time() - t0) * 1000 |
| yield f"data: {json.dumps({**chunk, 'session_id': req.session_id, 'full_response': reply})}\n\n" |
|
|
| elif chunk["type"] == "error": |
| metrics["errors"] += 1 |
| yield f"data: {json.dumps(chunk)}\n\n" |
|
|
| except Exception as e: |
| metrics["errors"] += 1 |
| yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" |
|
|
| finally: |
| yield "data: [DONE]\n\n" |
|
|
| return StreamingResponse( |
| event_stream(), |
| media_type="text/event-stream", |
| headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, |
| ) |
|
|
|
|
| @app.post("/chat/reset") |
| async def reset_chat(req: ResetRequest): |
| sessions.pop(req.session_id, None) |
| await pool.reset_session(req.session_id) |
| return {"status": "ok", "session_id": req.session_id} |
|
|
|
|
| @app.get("/chat/history") |
| async def get_history(session_id: str): |
| sess = sessions.get(session_id) |
| if not sess: |
| return {"session_id": session_id, "turns": 0, "history": []} |
| return { |
| "session_id": session_id, |
| "turns": sum(1 for m in sess.history if m["role"] == "user"), |
| "tokens_in_engine": sess.tokens_in_engine, |
| "last_active_ago_s": round(time.time() - sess.last_active, 1), |
| "history": sess.history, |
| } |
|
|
|
|
| @app.get("/metrics") |
| async def get_metrics(): |
| n, tok, ms = ( |
| metrics["total_requests"], |
| metrics["total_tokens"], |
| metrics["total_ms"], |
| ) |
| mem = psutil.virtual_memory() |
| return { |
| "total_requests": n, |
| "total_tokens": tok, |
| "total_errors": metrics["errors"], |
| "avg_tps": round(tok / (ms / 1000), 2) if ms > 0 else 0, |
| "active_sessions": len(sessions), |
| "sessions_evicted_total": metrics["sessions_evicted"], |
| "n_engines": N_ENGINES, |
| "omp_per_engine": OMP_PER_ENGINE, |
| "engines_busy": sum(1 for lk in pool._locks if lk.locked()), |
| "process_ram_mb": get_total_ram_mb(), |
| "system_ram_used_pct": mem.percent, |
| "system_ram_total_gb": round(mem.total / 1e9, 1), |
| "uptime_s": round(time.time() - metrics["start_time"], 1), |
| } |
|
|
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False) |