""" KVInfer — FastAPI Backend v4.0 (16GB Optimised) RAM Budget (HF 16GB Space): 4 engines × 580MB model = 2.32GB 4 engines × 20 sessions × 96MB = 7.68GB Python + OS overhead = ~1.00GB ────────────────────────────────────────── TOTAL ≈ 11.0GB ✓ (5GB headroom) CPU Budget (2 vCPU): OMP_NUM_THREADS = 1 per engine 4 engines × 1 thread = 4 threads on 2 vCPUs (light oversubscription, fine) ── SPEED MODE (single user fastest) ── Set env: N_ENGINES=1 OMP_NUM_THREADS=2 → Both cores for 1 engine → ~1.5–2× faster TPS per request → Trade-off: only 1 user at a time ── BALANCED MODE (default) ── N_ENGINES=4, OMP_NUM_THREADS=1 → 4 truly parallel users, CPU evenly shared Concurrency: 4 users → truly parallel 5th+ user → queue mein wait (drop nahi hoga) """ import asyncio import json import os import time import uuid from contextlib import asynccontextmanager from pathlib import Path import psutil import tiktoken from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, StreamingResponse from pydantic import BaseModel, Field from huggingface_hub import hf_hub_download # ───────────────────────────────────────────────────────────────────────── # Config # ───────────────────────────────────────────────────────────────────────── BASE_DIR = Path(__file__).parent INFERENCE_EXE = BASE_DIR / "inference" MODEL_BIN = BASE_DIR / "model.bin" TOKENIZER_BIN = BASE_DIR / "tokenizer.bin" HF_REPO_ID = "NOT-OMEGA/KVInfer-152M" SYSTEM_TOKEN = "System:" USER_TOKEN = "User:" ASST_TOKEN = "Assistant:" SEP = "\n" BLOCK_SIZE = 1024 MAX_GEN_CEILING = 500 SAFETY_MARGIN = 24 MAX_SESSION_TOKENS = BLOCK_SIZE - MAX_GEN_CEILING - SAFETY_MARGIN # ── 16GB mein 4 engines comfortably fit ── # Speed mode: N_ENGINES=1 OMP_NUM_THREADS=2 (env se set karo) N_ENGINES = int(os.environ.get("N_ENGINES", "4")) # was 3 OMP_PER_ENGINE = int(os.environ.get("OMP_NUM_THREADS", "1")) # tune here # Session expiry: idle sessions RAM free karti hain SESSION_TTL_SECONDS = int(os.environ.get("SESSION_TTL", "1800")) # 30 min enc = tiktoken.get_encoding("gpt2") STOP_TOKEN_IDS = [50256] STOP_STRINGS = ["User:", "System:"] # ───────────────────────────────────────────────────────────────────────── # Single Inference Engine (ek C++ process) # ───────────────────────────────────────────────────────────────────────── class InferenceEngine: def __init__(self, eid: int): self.eid = eid self._proc = None self._ready = False async def start(self): if not INFERENCE_EXE.exists(): raise RuntimeError(f"inference binary not found: {INFERENCE_EXE}") if not MODEL_BIN.exists(): raise RuntimeError(f"model.bin not found: {MODEL_BIN}") env = os.environ.copy() env["OMP_NUM_THREADS"] = str(OMP_PER_ENGINE) # tunable via env self._proc = await asyncio.create_subprocess_exec( str(INFERENCE_EXE), stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.DEVNULL, cwd=str(BASE_DIR), env=env, ) while True: line = (await self._proc.stdout.readline()).decode().strip() if line == "READY": self._ready = True print(f"[engine-{self.eid}] READY pid={self._proc.pid}") break if line.startswith("ERROR"): raise RuntimeError(f"Engine-{self.eid} startup error: {line}") async def stop(self): if self._proc: try: self._proc.stdin.write(b"QUIT\n") await self._proc.stdin.drain() await asyncio.wait_for(self._proc.wait(), timeout=3.0) except Exception: self._proc.kill() async def reset_session(self, session_id: str): self._proc.stdin.write(f"RESET|{session_id}\n".encode()) await self._proc.stdin.drain() while True: raw = await self._proc.stdout.readline() if not raw or raw.decode().strip() == "RESET_OK": break async def generate(self, session_id, new_token_ids, max_new, temperature, top_k): if not self._ready or self._proc is None: yield {"type": "error", "message": f"Engine-{self.eid} not ready"} return tokens_csv = ",".join(map(str, new_token_ids)) stop_csv = ",".join(map(str, STOP_TOKEN_IDS)) cmd = ( f"REQUEST|{session_id}|{tokens_csv}|" f"{max_new}|{temperature}|{top_k}|{stop_csv}\n" ) self._proc.stdin.write(cmd.encode()) await self._proc.stdin.drain() try: while True: raw = await self._proc.stdout.readline() if not raw: break line = raw.decode("utf-8", errors="replace").strip() if not line: continue if line.startswith("TOKEN"): parts = line.split() tid = int(parts[1]) ms = float(parts[2]) yield {"type": "token", "id": tid, "text": enc.decode([tid]), "elapsed_ms": ms} elif line.startswith("DONE"): parts = line.split() total_t = int(parts[1]) total_ms = float(parts[2]) tps = round(total_t / (total_ms / 1000.0), 2) if total_ms > 0 else 0 yield {"type": "done", "total_tokens": total_t, "total_ms": total_ms, "tps": tps} break elif line.startswith("ERROR"): yield {"type": "error", "message": line} break except asyncio.CancelledError: # Client disconnect — pipe drain karo, engine hang na kare while True: raw = await self._proc.stdout.readline() if not raw or raw.decode().strip().startswith(("DONE", "ERROR")): break raise @property def pid(self): return self._proc.pid if self._proc else None # ───────────────────────────────────────────────────────────────────────── # Engine Pool (session affinity + per-engine lock) # ───────────────────────────────────────────────────────────────────────── class EnginePool: def __init__(self, n: int): self.n = n self.engines = [InferenceEngine(i) for i in range(n)] self._locks: list[asyncio.Lock] = [] self._session_map: dict[str, int] = {} self._engine_load: list[int] = [] self._map_lock = asyncio.Lock() async def start(self): self._locks = [asyncio.Lock() for _ in range(self.n)] self._engine_load = [0] * self.n await asyncio.gather(*(e.start() for e in self.engines)) print(f"[pool] {self.n} engines ready (OMP_NUM_THREADS={OMP_PER_ENGINE} each)") async def stop(self): await asyncio.gather(*(e.stop() for e in self.engines), return_exceptions=True) async def _assign_engine(self, session_id: str) -> int: async with self._map_lock: if session_id not in self._session_map: idx = min(range(self.n), key=lambda i: self._engine_load[i]) self._session_map[session_id] = idx self._engine_load[idx] += 1 print(f"[pool] {session_id[:8]}… → engine-{idx} load={self._engine_load}") return self._session_map[session_id] async def _untrack_session(self, session_id: str): async with self._map_lock: if session_id in self._session_map: idx = self._session_map.pop(session_id) self._engine_load[idx] = max(0, self._engine_load[idx] - 1) async def generate(self, session_id, new_token_ids, max_new, temperature, top_k): idx = await self._assign_engine(session_id) async with self._locks[idx]: async for chunk in self.engines[idx].generate( session_id, new_token_ids, max_new, temperature, top_k ): yield chunk async def reset_session(self, session_id: str): async with self._map_lock: idx = self._session_map.get(session_id) if idx is not None: async with self._locks[idx]: await self.engines[idx].reset_session(session_id) await self._untrack_session(session_id) def get_all_pids(self) -> list: return [e.pid for e in self.engines if e.pid] def status(self) -> list: return [ { "engine_id": i, "pid": self.engines[i].pid, "sessions": self._engine_load[i], "busy": self._locks[i].locked(), "ready": self.engines[i]._ready, } for i in range(self.n) ] pool = EnginePool(N_ENGINES) # ───────────────────────────────────────────────────────────────────────── # Session State (TTL-based expiry for memory management) # ───────────────────────────────────────────────────────────────────────── class SessionData: def __init__(self, system_prompt: str): self.system_prompt = system_prompt self.history: list = [] self.tokens_in_engine = 0 self.last_active = time.time() # NEW: TTL tracking def touch(self): self.last_active = time.time() def append_user(self, content: str): self.history.append({"role": "user", "content": content}) def append_assistant(self, content: str): self.history.append({"role": "assistant", "content": content}) def new_turn_tokens(self, user_msg: str) -> list: if self.tokens_in_engine == 0: full = ( f"{SYSTEM_TOKEN} {self.system_prompt}{SEP}" f"{USER_TOKEN} {user_msg}{SEP}{ASST_TOKEN} " ) return enc.encode_ordinary(full) else: return enc.encode_ordinary(f"{USER_TOKEN} {user_msg}{SEP}{ASST_TOKEN} ") sessions: dict[str, SessionData] = {} metrics = { "total_requests": 0, "total_tokens": 0, "total_ms": 0.0, "errors": 0, "start_time": time.time(), "sessions_evicted": 0, } # ───────────────────────────────────────────────────────────────────────── # RAM Helper # ───────────────────────────────────────────────────────────────────────── def get_total_ram_mb() -> float: try: total = psutil.Process(os.getpid()).memory_info().rss for pid in pool.get_all_pids(): try: total += psutil.Process(pid).memory_info().rss except psutil.NoSuchProcess: pass return round(total / 1e6, 1) except Exception: return 0.0 # ───────────────────────────────────────────────────────────────────────── # Session GC — idle sessions ko engine se free karo (TTL-based) # ───────────────────────────────────────────────────────────────────────── async def session_gc_loop(): """Background task: har 5 min mein idle sessions clean karo.""" while True: await asyncio.sleep(300) now = time.time() expired = [ sid for sid, s in list(sessions.items()) if now - s.last_active > SESSION_TTL_SECONDS ] for sid in expired: sessions.pop(sid, None) await pool.reset_session(sid) metrics["sessions_evicted"] += 1 if expired: print(f"[GC] Evicted {len(expired)} idle sessions (TTL={SESSION_TTL_SECONDS}s)") # ───────────────────────────────────────────────────────────────────────── # Background Startup — port TURANT open, download/engines baad mein # (Ye fix karta hai HF "Starting..." stuck bug) # ───────────────────────────────────────────────────────────────────────── async def _startup_background(): """Model download + engine start background mein — port block nahi hoga.""" try: print("[HF HUB] Checking model files…") if not MODEL_BIN.exists(): print("[HF HUB] Downloading model.bin…") hf_hub_download( repo_id=HF_REPO_ID, filename="model.bin", local_dir=str(BASE_DIR) ) if not TOKENIZER_BIN.exists(): print("[HF HUB] Downloading tokenizer.bin…") hf_hub_download( repo_id=HF_REPO_ID, filename="tokenizer.bin", local_dir=str(BASE_DIR) ) except Exception as e: print(f"[WARNING] HF download failed: {e}") try: await pool.start() except Exception as e: print(f"[ERROR] Engine pool start failed: {e}") @asynccontextmanager async def lifespan(app: FastAPI): # Port TURANT available — HF health check pass karega asyncio.create_task(_startup_background()) asyncio.create_task(session_gc_loop()) yield await pool.stop() app = FastAPI(title="KVInfer", version="4.0.0", lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ───────────────────────────────────────────────────────────────────────── # Request Models # ───────────────────────────────────────────────────────────────────────── class ChatRequest(BaseModel): message: str session_id: str = Field(default_factory=lambda: str(uuid.uuid4())) system_prompt: str = "You are a helpful assistant." max_new_tokens: int = Field(default=200, ge=1, le=500) temperature: float = Field(default=0.7, ge=0.01, le=2.0) top_k: int = Field(default=40, ge=1, le=200) class ResetRequest(BaseModel): session_id: str # ───────────────────────────────────────────────────────────────────────── # Routes # ───────────────────────────────────────────────────────────────────────── @app.get("/") async def serve_ui(): return FileResponse(BASE_DIR / "index.html") @app.get("/health") async def health(): mem = psutil.virtual_memory() uptime = time.time() - metrics["start_time"] ready_count = sum(1 for e in pool.engines if e._ready) return { "status": "ok" if ready_count > 0 else "starting", "engines_ready": ready_count, "engines_total": N_ENGINES, "omp_threads_per_engine": OMP_PER_ENGINE, "active_sessions": len(sessions), "sessions_evicted": metrics["sessions_evicted"], "process_ram_mb": get_total_ram_mb(), "system_ram_used_pct": mem.percent, "system_ram_total_gb": round(mem.total / 1e9, 1), "uptime_seconds": round(uptime, 1), } @app.get("/pool/status") async def pool_status(): return { "n_engines": N_ENGINES, "omp_per_engine": OMP_PER_ENGINE, "engines": pool.status(), "total_sessions": len(sessions), "session_ttl_s": SESSION_TTL_SECONDS, } @app.post("/chat") async def chat(req: ChatRequest): if not any(e._ready for e in pool.engines): raise HTTPException(503, "Engines loading… please wait ~30s and retry.") sess = sessions.get(req.session_id) if sess is None: sess = SessionData(req.system_prompt) sessions[req.session_id] = sess sess.touch() # TTL reset on every request new_tokens = sess.new_turn_tokens(req.message) # Context window overflow check if sess.tokens_in_engine + len(new_tokens) + req.max_new_tokens > MAX_SESSION_TOKENS: await pool.reset_session(req.session_id) sess.tokens_in_engine = 0 new_tokens = sess.new_turn_tokens(req.message) sess.append_user(req.message) metrics["total_requests"] += 1 async def event_stream(): response_parts: list[str] = [] t0 = time.time() try: async for chunk in pool.generate( req.session_id, new_tokens, req.max_new_tokens, req.temperature, req.top_k, ): if chunk["type"] == "token": response_parts.append(chunk["text"]) joined = "".join(response_parts) # Role bleed stop karo hit_stop = any(f"\n{s}" in joined for s in STOP_STRINGS) if hit_stop: for s in STOP_STRINGS: idx = joined.find(f"\n{s}") if idx != -1: response_parts = [joined[:idx]] break yield f"data: {json.dumps(chunk)}\n\n" elif chunk["type"] == "done": reply = "".join(response_parts).strip() sess.append_assistant(reply) sess.tokens_in_engine += len(new_tokens) + chunk["total_tokens"] metrics["total_tokens"] += chunk["total_tokens"] metrics["total_ms"] += (time.time() - t0) * 1000 yield f"data: {json.dumps({**chunk, 'session_id': req.session_id, 'full_response': reply})}\n\n" elif chunk["type"] == "error": metrics["errors"] += 1 yield f"data: {json.dumps(chunk)}\n\n" except Exception as e: metrics["errors"] += 1 yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" finally: yield "data: [DONE]\n\n" return StreamingResponse( event_stream(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, ) @app.post("/chat/reset") async def reset_chat(req: ResetRequest): sessions.pop(req.session_id, None) await pool.reset_session(req.session_id) return {"status": "ok", "session_id": req.session_id} @app.get("/chat/history") async def get_history(session_id: str): sess = sessions.get(session_id) if not sess: return {"session_id": session_id, "turns": 0, "history": []} return { "session_id": session_id, "turns": sum(1 for m in sess.history if m["role"] == "user"), "tokens_in_engine": sess.tokens_in_engine, "last_active_ago_s": round(time.time() - sess.last_active, 1), "history": sess.history, } @app.get("/metrics") async def get_metrics(): n, tok, ms = ( metrics["total_requests"], metrics["total_tokens"], metrics["total_ms"], ) mem = psutil.virtual_memory() return { "total_requests": n, "total_tokens": tok, "total_errors": metrics["errors"], "avg_tps": round(tok / (ms / 1000), 2) if ms > 0 else 0, "active_sessions": len(sessions), "sessions_evicted_total": metrics["sessions_evicted"], "n_engines": N_ENGINES, "omp_per_engine": OMP_PER_ENGINE, "engines_busy": sum(1 for lk in pool._locks if lk.locked()), "process_ram_mb": get_total_ram_mb(), "system_ram_used_pct": mem.percent, "system_ram_total_gb": round(mem.total / 1e9, 1), "uptime_s": round(time.time() - metrics["start_time"], 1), } if __name__ == "__main__": import uvicorn uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False)