KVInfer / main.py
NOT-OMEGA's picture
Update main.py
f523258 verified
"""
KVInfer β€” FastAPI Backend v4.0 (16GB Optimised)
RAM Budget (HF 16GB Space):
4 engines Γ— 580MB model = 2.32GB
4 engines Γ— 20 sessions Γ— 96MB = 7.68GB
Python + OS overhead = ~1.00GB
──────────────────────────────────────────
TOTAL β‰ˆ 11.0GB βœ“ (5GB headroom)
CPU Budget (2 vCPU):
OMP_NUM_THREADS = 1 per engine
4 engines Γ— 1 thread = 4 threads on 2 vCPUs (light oversubscription, fine)
── SPEED MODE (single user fastest) ──
Set env: N_ENGINES=1 OMP_NUM_THREADS=2
β†’ Both cores for 1 engine β†’ ~1.5–2Γ— faster TPS per request
β†’ Trade-off: only 1 user at a time
── BALANCED MODE (default) ──
N_ENGINES=4, OMP_NUM_THREADS=1
β†’ 4 truly parallel users, CPU evenly shared
Concurrency:
4 users β†’ truly parallel
5th+ user β†’ queue mein wait (drop nahi hoga)
"""
import asyncio
import json
import os
import time
import uuid
from contextlib import asynccontextmanager
from pathlib import Path
import psutil
import tiktoken
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, StreamingResponse
from pydantic import BaseModel, Field
from huggingface_hub import hf_hub_download
# ─────────────────────────────────────────────────────────────────────────
# Config
# ─────────────────────────────────────────────────────────────────────────
BASE_DIR = Path(__file__).parent
INFERENCE_EXE = BASE_DIR / "inference"
MODEL_BIN = BASE_DIR / "model.bin"
TOKENIZER_BIN = BASE_DIR / "tokenizer.bin"
HF_REPO_ID = "NOT-OMEGA/KVInfer-152M"
SYSTEM_TOKEN = "System:"
USER_TOKEN = "User:"
ASST_TOKEN = "Assistant:"
SEP = "\n"
BLOCK_SIZE = 1024
MAX_GEN_CEILING = 500
SAFETY_MARGIN = 24
MAX_SESSION_TOKENS = BLOCK_SIZE - MAX_GEN_CEILING - SAFETY_MARGIN
# ── 16GB mein 4 engines comfortably fit ──
# Speed mode: N_ENGINES=1 OMP_NUM_THREADS=2 (env se set karo)
N_ENGINES = int(os.environ.get("N_ENGINES", "4")) # was 3
OMP_PER_ENGINE = int(os.environ.get("OMP_NUM_THREADS", "1")) # tune here
# Session expiry: idle sessions RAM free karti hain
SESSION_TTL_SECONDS = int(os.environ.get("SESSION_TTL", "1800")) # 30 min
enc = tiktoken.get_encoding("gpt2")
STOP_TOKEN_IDS = [50256]
STOP_STRINGS = ["User:", "System:"]
# ─────────────────────────────────────────────────────────────────────────
# Single Inference Engine (ek C++ process)
# ─────────────────────────────────────────────────────────────────────────
class InferenceEngine:
def __init__(self, eid: int):
self.eid = eid
self._proc = None
self._ready = False
async def start(self):
if not INFERENCE_EXE.exists():
raise RuntimeError(f"inference binary not found: {INFERENCE_EXE}")
if not MODEL_BIN.exists():
raise RuntimeError(f"model.bin not found: {MODEL_BIN}")
env = os.environ.copy()
env["OMP_NUM_THREADS"] = str(OMP_PER_ENGINE) # tunable via env
self._proc = await asyncio.create_subprocess_exec(
str(INFERENCE_EXE),
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.DEVNULL,
cwd=str(BASE_DIR),
env=env,
)
while True:
line = (await self._proc.stdout.readline()).decode().strip()
if line == "READY":
self._ready = True
print(f"[engine-{self.eid}] READY pid={self._proc.pid}")
break
if line.startswith("ERROR"):
raise RuntimeError(f"Engine-{self.eid} startup error: {line}")
async def stop(self):
if self._proc:
try:
self._proc.stdin.write(b"QUIT\n")
await self._proc.stdin.drain()
await asyncio.wait_for(self._proc.wait(), timeout=3.0)
except Exception:
self._proc.kill()
async def reset_session(self, session_id: str):
self._proc.stdin.write(f"RESET|{session_id}\n".encode())
await self._proc.stdin.drain()
while True:
raw = await self._proc.stdout.readline()
if not raw or raw.decode().strip() == "RESET_OK":
break
async def generate(self, session_id, new_token_ids, max_new, temperature, top_k):
if not self._ready or self._proc is None:
yield {"type": "error", "message": f"Engine-{self.eid} not ready"}
return
tokens_csv = ",".join(map(str, new_token_ids))
stop_csv = ",".join(map(str, STOP_TOKEN_IDS))
cmd = (
f"REQUEST|{session_id}|{tokens_csv}|"
f"{max_new}|{temperature}|{top_k}|{stop_csv}\n"
)
self._proc.stdin.write(cmd.encode())
await self._proc.stdin.drain()
try:
while True:
raw = await self._proc.stdout.readline()
if not raw:
break
line = raw.decode("utf-8", errors="replace").strip()
if not line:
continue
if line.startswith("TOKEN"):
parts = line.split()
tid = int(parts[1])
ms = float(parts[2])
yield {"type": "token", "id": tid,
"text": enc.decode([tid]), "elapsed_ms": ms}
elif line.startswith("DONE"):
parts = line.split()
total_t = int(parts[1])
total_ms = float(parts[2])
tps = round(total_t / (total_ms / 1000.0), 2) if total_ms > 0 else 0
yield {"type": "done", "total_tokens": total_t,
"total_ms": total_ms, "tps": tps}
break
elif line.startswith("ERROR"):
yield {"type": "error", "message": line}
break
except asyncio.CancelledError:
# Client disconnect β€” pipe drain karo, engine hang na kare
while True:
raw = await self._proc.stdout.readline()
if not raw or raw.decode().strip().startswith(("DONE", "ERROR")):
break
raise
@property
def pid(self):
return self._proc.pid if self._proc else None
# ─────────────────────────────────────────────────────────────────────────
# Engine Pool (session affinity + per-engine lock)
# ─────────────────────────────────────────────────────────────────────────
class EnginePool:
def __init__(self, n: int):
self.n = n
self.engines = [InferenceEngine(i) for i in range(n)]
self._locks: list[asyncio.Lock] = []
self._session_map: dict[str, int] = {}
self._engine_load: list[int] = []
self._map_lock = asyncio.Lock()
async def start(self):
self._locks = [asyncio.Lock() for _ in range(self.n)]
self._engine_load = [0] * self.n
await asyncio.gather(*(e.start() for e in self.engines))
print(f"[pool] {self.n} engines ready (OMP_NUM_THREADS={OMP_PER_ENGINE} each)")
async def stop(self):
await asyncio.gather(*(e.stop() for e in self.engines), return_exceptions=True)
async def _assign_engine(self, session_id: str) -> int:
async with self._map_lock:
if session_id not in self._session_map:
idx = min(range(self.n), key=lambda i: self._engine_load[i])
self._session_map[session_id] = idx
self._engine_load[idx] += 1
print(f"[pool] {session_id[:8]}… β†’ engine-{idx} load={self._engine_load}")
return self._session_map[session_id]
async def _untrack_session(self, session_id: str):
async with self._map_lock:
if session_id in self._session_map:
idx = self._session_map.pop(session_id)
self._engine_load[idx] = max(0, self._engine_load[idx] - 1)
async def generate(self, session_id, new_token_ids, max_new, temperature, top_k):
idx = await self._assign_engine(session_id)
async with self._locks[idx]:
async for chunk in self.engines[idx].generate(
session_id, new_token_ids, max_new, temperature, top_k
):
yield chunk
async def reset_session(self, session_id: str):
async with self._map_lock:
idx = self._session_map.get(session_id)
if idx is not None:
async with self._locks[idx]:
await self.engines[idx].reset_session(session_id)
await self._untrack_session(session_id)
def get_all_pids(self) -> list:
return [e.pid for e in self.engines if e.pid]
def status(self) -> list:
return [
{
"engine_id": i,
"pid": self.engines[i].pid,
"sessions": self._engine_load[i],
"busy": self._locks[i].locked(),
"ready": self.engines[i]._ready,
}
for i in range(self.n)
]
pool = EnginePool(N_ENGINES)
# ─────────────────────────────────────────────────────────────────────────
# Session State (TTL-based expiry for memory management)
# ─────────────────────────────────────────────────────────────────────────
class SessionData:
def __init__(self, system_prompt: str):
self.system_prompt = system_prompt
self.history: list = []
self.tokens_in_engine = 0
self.last_active = time.time() # NEW: TTL tracking
def touch(self):
self.last_active = time.time()
def append_user(self, content: str):
self.history.append({"role": "user", "content": content})
def append_assistant(self, content: str):
self.history.append({"role": "assistant", "content": content})
def new_turn_tokens(self, user_msg: str) -> list:
if self.tokens_in_engine == 0:
full = (
f"{SYSTEM_TOKEN} {self.system_prompt}{SEP}"
f"{USER_TOKEN} {user_msg}{SEP}{ASST_TOKEN} "
)
return enc.encode_ordinary(full)
else:
return enc.encode_ordinary(f"{USER_TOKEN} {user_msg}{SEP}{ASST_TOKEN} ")
sessions: dict[str, SessionData] = {}
metrics = {
"total_requests": 0,
"total_tokens": 0,
"total_ms": 0.0,
"errors": 0,
"start_time": time.time(),
"sessions_evicted": 0,
}
# ─────────────────────────────────────────────────────────────────────────
# RAM Helper
# ─────────────────────────────────────────────────────────────────────────
def get_total_ram_mb() -> float:
try:
total = psutil.Process(os.getpid()).memory_info().rss
for pid in pool.get_all_pids():
try:
total += psutil.Process(pid).memory_info().rss
except psutil.NoSuchProcess:
pass
return round(total / 1e6, 1)
except Exception:
return 0.0
# ─────────────────────────────────────────────────────────────────────────
# Session GC β€” idle sessions ko engine se free karo (TTL-based)
# ─────────────────────────────────────────────────────────────────────────
async def session_gc_loop():
"""Background task: har 5 min mein idle sessions clean karo."""
while True:
await asyncio.sleep(300)
now = time.time()
expired = [
sid for sid, s in list(sessions.items())
if now - s.last_active > SESSION_TTL_SECONDS
]
for sid in expired:
sessions.pop(sid, None)
await pool.reset_session(sid)
metrics["sessions_evicted"] += 1
if expired:
print(f"[GC] Evicted {len(expired)} idle sessions (TTL={SESSION_TTL_SECONDS}s)")
# ─────────────────────────────────────────────────────────────────────────
# Background Startup β€” port TURANT open, download/engines baad mein
# (Ye fix karta hai HF "Starting..." stuck bug)
# ─────────────────────────────────────────────────────────────────────────
async def _startup_background():
"""Model download + engine start background mein β€” port block nahi hoga."""
try:
print("[HF HUB] Checking model files…")
if not MODEL_BIN.exists():
print("[HF HUB] Downloading model.bin…")
hf_hub_download(
repo_id=HF_REPO_ID, filename="model.bin",
local_dir=str(BASE_DIR)
)
if not TOKENIZER_BIN.exists():
print("[HF HUB] Downloading tokenizer.bin…")
hf_hub_download(
repo_id=HF_REPO_ID, filename="tokenizer.bin",
local_dir=str(BASE_DIR)
)
except Exception as e:
print(f"[WARNING] HF download failed: {e}")
try:
await pool.start()
except Exception as e:
print(f"[ERROR] Engine pool start failed: {e}")
@asynccontextmanager
async def lifespan(app: FastAPI):
# Port TURANT available β€” HF health check pass karega
asyncio.create_task(_startup_background())
asyncio.create_task(session_gc_loop())
yield
await pool.stop()
app = FastAPI(title="KVInfer", version="4.0.0", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ─────────────────────────────────────────────────────────────────────────
# Request Models
# ─────────────────────────────────────────────────────────────────────────
class ChatRequest(BaseModel):
message: str
session_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
system_prompt: str = "You are a helpful assistant."
max_new_tokens: int = Field(default=200, ge=1, le=500)
temperature: float = Field(default=0.7, ge=0.01, le=2.0)
top_k: int = Field(default=40, ge=1, le=200)
class ResetRequest(BaseModel):
session_id: str
# ─────────────────────────────────────────────────────────────────────────
# Routes
# ─────────────────────────────────────────────────────────────────────────
@app.get("/")
async def serve_ui():
return FileResponse(BASE_DIR / "index.html")
@app.get("/health")
async def health():
mem = psutil.virtual_memory()
uptime = time.time() - metrics["start_time"]
ready_count = sum(1 for e in pool.engines if e._ready)
return {
"status": "ok" if ready_count > 0 else "starting",
"engines_ready": ready_count,
"engines_total": N_ENGINES,
"omp_threads_per_engine": OMP_PER_ENGINE,
"active_sessions": len(sessions),
"sessions_evicted": metrics["sessions_evicted"],
"process_ram_mb": get_total_ram_mb(),
"system_ram_used_pct": mem.percent,
"system_ram_total_gb": round(mem.total / 1e9, 1),
"uptime_seconds": round(uptime, 1),
}
@app.get("/pool/status")
async def pool_status():
return {
"n_engines": N_ENGINES,
"omp_per_engine": OMP_PER_ENGINE,
"engines": pool.status(),
"total_sessions": len(sessions),
"session_ttl_s": SESSION_TTL_SECONDS,
}
@app.post("/chat")
async def chat(req: ChatRequest):
if not any(e._ready for e in pool.engines):
raise HTTPException(503, "Engines loading… please wait ~30s and retry.")
sess = sessions.get(req.session_id)
if sess is None:
sess = SessionData(req.system_prompt)
sessions[req.session_id] = sess
sess.touch() # TTL reset on every request
new_tokens = sess.new_turn_tokens(req.message)
# Context window overflow check
if sess.tokens_in_engine + len(new_tokens) + req.max_new_tokens > MAX_SESSION_TOKENS:
await pool.reset_session(req.session_id)
sess.tokens_in_engine = 0
new_tokens = sess.new_turn_tokens(req.message)
sess.append_user(req.message)
metrics["total_requests"] += 1
async def event_stream():
response_parts: list[str] = []
t0 = time.time()
try:
async for chunk in pool.generate(
req.session_id, new_tokens,
req.max_new_tokens, req.temperature, req.top_k,
):
if chunk["type"] == "token":
response_parts.append(chunk["text"])
joined = "".join(response_parts)
# Role bleed stop karo
hit_stop = any(f"\n{s}" in joined for s in STOP_STRINGS)
if hit_stop:
for s in STOP_STRINGS:
idx = joined.find(f"\n{s}")
if idx != -1:
response_parts = [joined[:idx]]
break
yield f"data: {json.dumps(chunk)}\n\n"
elif chunk["type"] == "done":
reply = "".join(response_parts).strip()
sess.append_assistant(reply)
sess.tokens_in_engine += len(new_tokens) + chunk["total_tokens"]
metrics["total_tokens"] += chunk["total_tokens"]
metrics["total_ms"] += (time.time() - t0) * 1000
yield f"data: {json.dumps({**chunk, 'session_id': req.session_id, 'full_response': reply})}\n\n"
elif chunk["type"] == "error":
metrics["errors"] += 1
yield f"data: {json.dumps(chunk)}\n\n"
except Exception as e:
metrics["errors"] += 1
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
finally:
yield "data: [DONE]\n\n"
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
@app.post("/chat/reset")
async def reset_chat(req: ResetRequest):
sessions.pop(req.session_id, None)
await pool.reset_session(req.session_id)
return {"status": "ok", "session_id": req.session_id}
@app.get("/chat/history")
async def get_history(session_id: str):
sess = sessions.get(session_id)
if not sess:
return {"session_id": session_id, "turns": 0, "history": []}
return {
"session_id": session_id,
"turns": sum(1 for m in sess.history if m["role"] == "user"),
"tokens_in_engine": sess.tokens_in_engine,
"last_active_ago_s": round(time.time() - sess.last_active, 1),
"history": sess.history,
}
@app.get("/metrics")
async def get_metrics():
n, tok, ms = (
metrics["total_requests"],
metrics["total_tokens"],
metrics["total_ms"],
)
mem = psutil.virtual_memory()
return {
"total_requests": n,
"total_tokens": tok,
"total_errors": metrics["errors"],
"avg_tps": round(tok / (ms / 1000), 2) if ms > 0 else 0,
"active_sessions": len(sessions),
"sessions_evicted_total": metrics["sessions_evicted"],
"n_engines": N_ENGINES,
"omp_per_engine": OMP_PER_ENGINE,
"engines_busy": sum(1 for lk in pool._locks if lk.locked()),
"process_ram_mb": get_total_ram_mb(),
"system_ram_used_pct": mem.percent,
"system_ram_total_gb": round(mem.total / 1e9, 1),
"uptime_s": round(time.time() - metrics["start_time"], 1),
}
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False)