| """ |
| KVInfer β FastAPI Backend v4.1 |
| 2 vCPU Β· 16 GB RAM HuggingFace Space ke liye optimize kiya hua |
| |
| RAM estimate: |
| 2 engines Γ 4 GB (Llama 1B float32) = 8.0 GB |
| 2 engines Γ 8 sess Γ ~48 MB KV = 0.8 GB |
| Python + FastAPI + tokenizer = ~0.7 GB |
| ββββββββββββββββββββββββββββββββββββββββββββ |
| TOTAL β 9.5 GB β (16 GB mein safe) |
| """ |
| import asyncio, json, os, time, uuid |
| from contextlib import asynccontextmanager |
| from pathlib import Path |
|
|
| import psutil |
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import FileResponse, StreamingResponse |
| from pydantic import BaseModel, Field |
| from huggingface_hub import hf_hub_download |
| from transformers import AutoTokenizer |
|
|
| |
| BASE_DIR = Path(__file__).parent |
| INFERENCE_EXE = BASE_DIR / "inference" |
| MODEL_BIN = BASE_DIR / "model_llama.bin" |
| HF_REPO_ID = os.environ.get("HF_REPO_ID", "YOUR_HF_USERNAME/YOUR_REPO") |
|
|
| BLOCK_SIZE = 2048 |
| MAX_GEN_CEILING = 500 |
| SAFETY_MARGIN = 50 |
| MAX_SESS_TOKENS = BLOCK_SIZE - MAX_GEN_CEILING - SAFETY_MARGIN |
|
|
| |
| N_ENGINES = int(os.environ.get("N_ENGINES", "2")) |
|
|
| |
| EOS_IDS = [128001, 128009] |
| EOT_STR = "<|eot_id|>" |
| SYS_H = "<|start_header_id|>system<|end_header_id|>\n\n" |
| USR_H = "<|start_header_id|>user<|end_header_id|>\n\n" |
| AST_H = "<|start_header_id|>assistant<|end_header_id|>\n\n" |
| STOP_STR = ["<|eot_id|>", "<|start_header_id|>user", "<|start_header_id|>system"] |
|
|
| tokenizer = None |
|
|
| def load_tokenizer(): |
| global tokenizer |
| local = BASE_DIR / "tokenizer_files" |
| src = str(local) if local.exists() else "unsloth/Llama-3.2-1B-Instruct" |
| tokenizer = AutoTokenizer.from_pretrained(src) |
| print(f"[tok] vocab={tokenizer.vocab_size}") |
|
|
| def enc(text: str) -> list[int]: |
| return tokenizer.encode(text, add_special_tokens=False) |
|
|
| def dec(ids: list[int]) -> str: |
| return tokenizer.decode(ids, skip_special_tokens=False) |
|
|
| |
| class Engine: |
| def __init__(self, eid): |
| self.eid = eid; self._proc = None; self._ready = False |
|
|
| async def start(self): |
| if not INFERENCE_EXE.exists(): raise RuntimeError("Binary not found") |
| if not MODEL_BIN.exists(): raise RuntimeError("model_llama.bin not found") |
| env = os.environ.copy() |
| env["OMP_NUM_THREADS"] = "1" |
| self._proc = await asyncio.create_subprocess_exec( |
| str(INFERENCE_EXE), |
| stdin=asyncio.subprocess.PIPE, |
| stdout=asyncio.subprocess.PIPE, |
| stderr=asyncio.subprocess.DEVNULL, |
| cwd=str(BASE_DIR), env=env, |
| ) |
| while True: |
| line = (await self._proc.stdout.readline()).decode().strip() |
| if line.startswith("[engine]"): print(f"[E{self.eid}] {line}") |
| elif line == "READY": |
| self._ready = True |
| print(f"[E{self.eid}] READY pid={self._proc.pid}") |
| break |
| elif line.startswith("ERROR"): raise RuntimeError(line) |
|
|
| async def stop(self): |
| if not self._proc: return |
| try: |
| self._proc.stdin.write(b"QUIT\n"); await self._proc.stdin.drain() |
| await asyncio.wait_for(self._proc.wait(), 3.0) |
| except: self._proc.kill() |
|
|
| async def reset(self, sid): |
| self._proc.stdin.write(f"RESET|{sid}\n".encode()) |
| await self._proc.stdin.drain() |
| while True: |
| raw = await self._proc.stdout.readline() |
| if not raw or raw.decode().strip() == "RESET_OK": break |
|
|
| async def generate(self, sid, tokens, max_new, temp, top_k): |
| if not self._ready: yield {"type":"error","message":"not ready"}; return |
| cmd = f"REQUEST|{sid}|{','.join(map(str,tokens))}|{max_new}|{temp}|{top_k}|{','.join(map(str,EOS_IDS))}\n" |
| self._proc.stdin.write(cmd.encode()); await self._proc.stdin.drain() |
| try: |
| while True: |
| raw = await self._proc.stdout.readline() |
| if not raw: break |
| line = raw.decode("utf-8","replace").strip() |
| if not line: continue |
| if line.startswith("TOKEN"): |
| p = line.split(); yield {"type":"token","id":int(p[1]),"text":dec([int(p[1])]),"elapsed_ms":float(p[2])} |
| elif line.startswith("DONE"): |
| p = line.split(); t=int(p[1]); ms=float(p[2]) |
| yield {"type":"done","total_tokens":t,"total_ms":ms, |
| "tps": round(t/(ms/1000),2) if ms>0 else 0}; break |
| elif line.startswith("ERROR"): |
| yield {"type":"error","message":line}; break |
| except asyncio.CancelledError: |
| while True: |
| raw = await self._proc.stdout.readline() |
| if not raw or raw.decode().strip().startswith(("DONE","ERROR")): break |
| raise |
|
|
| @property |
| def pid(self): return self._proc.pid if self._proc else None |
|
|
| |
| class Pool: |
| def __init__(self, n): |
| self.n=n; self.engines=[Engine(i) for i in range(n)] |
| self._locks=[]; self._smap={}; self._load=[]; self._ml=None |
|
|
| async def start(self): |
| self._ml=asyncio.Lock(); self._locks=[asyncio.Lock() for _ in range(self.n)] |
| self._load=[0]*self.n |
| await asyncio.gather(*(e.start() for e in self.engines)) |
| print(f"[pool] {self.n} engines up") |
|
|
| async def stop(self): |
| await asyncio.gather(*(e.stop() for e in self.engines),return_exceptions=True) |
|
|
| async def _assign(self, sid): |
| async with self._ml: |
| if sid not in self._smap: |
| idx=min(range(self.n),key=lambda i:self._load[i]) |
| self._smap[sid]=idx; self._load[idx]+=1 |
| return self._smap[sid] |
|
|
| async def _drop(self, sid): |
| async with self._ml: |
| if sid in self._smap: |
| idx=self._smap.pop(sid); self._load[idx]=max(0,self._load[idx]-1) |
|
|
| async def generate(self, sid, tokens, max_new, temp, top_k): |
| idx=await self._assign(sid) |
| async with self._locks[idx]: |
| async for c in self.engines[idx].generate(sid,tokens,max_new,temp,top_k): yield c |
|
|
| async def reset(self, sid): |
| async with self._ml: idx=self._smap.get(sid) |
| if idx is not None: |
| async with self._locks[idx]: await self.engines[idx].reset(sid) |
| await self._drop(sid) |
|
|
| def pids(self): return [e.pid for e in self.engines if e.pid] |
|
|
| def status(self): |
| return [{"id":i,"pid":self.engines[i].pid,"sessions":self._load[i], |
| "busy":self._locks[i].locked(),"ready":self.engines[i]._ready} |
| for i in range(self.n)] |
|
|
| pool = Pool(N_ENGINES) |
|
|
| |
| class Sess: |
| def __init__(self, sys_p): |
| self.sys_p=sys_p; self.history=[]; self.n_cached=0 |
|
|
| def push_user(self, m): self.history.append({"role":"user","content":m}) |
| def push_asst(self, m): self.history.append({"role":"assistant","content":m}) |
|
|
| def new_tokens(self, msg): |
| if self.n_cached == 0: |
| text = f"<|begin_of_text|>{SYS_H}{self.sys_p}{EOT_STR}{USR_H}{msg}{EOT_STR}{AST_H}" |
| else: |
| text = f"{USR_H}{msg}{EOT_STR}{AST_H}" |
| return enc(text) |
|
|
| sessions: dict[str, Sess] = {} |
| metrics = {"req":0,"tok":0,"ms":0.0,"err":0,"t0":time.time()} |
|
|
| def total_ram(): |
| try: |
| mb=psutil.Process(os.getpid()).memory_info().rss |
| for p in pool.pids(): |
| try: mb+=psutil.Process(p).memory_info().rss |
| except: pass |
| return round(mb/1e6,1) |
| except: return 0.0 |
|
|
| |
| @asynccontextmanager |
| async def lifespan(app): |
| print("[start] Loading tokenizerβ¦") |
| load_tokenizer() |
| if not MODEL_BIN.exists(): |
| try: |
| print("[start] Downloading model_llama.bin from HFβ¦") |
| hf_hub_download(repo_id=HF_REPO_ID,filename="model_llama.bin",local_dir=str(BASE_DIR)) |
| except Exception as e: print(f"[warn] download failed: {e}") |
| try: await pool.start() |
| except Exception as e: print(f"[error] pool start: {e}") |
| yield |
| await pool.stop() |
|
|
| app = FastAPI(title="KVInfer",version="4.1",lifespan=lifespan) |
| app.add_middleware(CORSMiddleware,allow_origins=["*"],allow_methods=["*"],allow_headers=["*"]) |
|
|
| |
| class ChatReq(BaseModel): |
| message: str |
| session_id: str = Field(default_factory=lambda: str(uuid.uuid4())) |
| system_prompt: str = "You are a helpful, concise assistant." |
| max_new_tokens: int = Field(default=256, ge=1, le=500) |
| temperature: float = Field(default=0.7, ge=0.01, le=2.0) |
| top_k: int = Field(default=40, ge=1, le=200) |
|
|
| class ResetReq(BaseModel): |
| session_id: str |
|
|
| |
| @app.get("/") |
| async def ui(): return FileResponse(BASE_DIR/"index.html") |
|
|
| @app.get("/health") |
| async def health(): |
| mem=psutil.virtual_memory() |
| return {"status":"ok" if any(e._ready for e in pool.engines) else "starting", |
| "engines_ready":sum(1 for e in pool.engines if e._ready), |
| "engines_total":N_ENGINES,"active_sessions":len(sessions), |
| "process_ram_mb":total_ram(),"system_ram_used_pct":mem.percent, |
| "uptime_seconds":round(time.time()-metrics["t0"],1)} |
|
|
| @app.get("/metrics") |
| async def get_metrics(): |
| n,tok,ms=metrics["req"],metrics["tok"],metrics["ms"] |
| mem=psutil.virtual_memory() |
| return {"total_requests":n,"total_tokens":tok,"total_errors":metrics["err"], |
| "avg_tps":round(tok/(ms/1000),2) if ms>0 else 0, |
| "active_sessions":len(sessions),"n_engines":N_ENGINES, |
| "engines_ready":sum(1 for e in pool.engines if e._ready), |
| "engines_busy":sum(1 for lk in pool._locks if lk.locked()), |
| "process_ram_mb":total_ram(),"system_ram_used_pct":mem.percent, |
| "uptime_s":round(time.time()-metrics["t0"],1)} |
|
|
| @app.post("/chat") |
| async def chat(req: ChatReq): |
| if not any(e._ready for e in pool.engines): |
| raise HTTPException(503,"No engines ready yet β please wait a moment.") |
| sess=sessions.setdefault(req.session_id, Sess(req.system_prompt)) |
| toks=sess.new_tokens(req.message) |
| if sess.n_cached+len(toks)+req.max_new_tokens > MAX_SESS_TOKENS: |
| await pool.reset(req.session_id); sess.n_cached=0; toks=sess.new_tokens(req.message) |
| sess.push_user(req.message); metrics["req"]+=1 |
|
|
| async def stream(): |
| parts=[]; t0=time.time(); stopped=False |
| try: |
| async for c in pool.generate(req.session_id,toks,req.max_new_tokens,req.temperature,req.top_k): |
| if c["type"]=="token" and not stopped: |
| parts.append(c["text"]); joined="".join(parts) |
| for s in STOP_STR: |
| if s in joined: parts=[joined[:joined.find(s)]]; stopped=True; break |
| if not stopped: yield f"data:{json.dumps(c)}\n\n" |
| elif c["type"]=="done": |
| reply="".join(parts).strip() |
| for s in STOP_STR: reply=reply.split(s)[0] |
| reply=reply.strip() |
| sess.push_asst(reply) |
| sess.n_cached+=len(toks)+c["total_tokens"] |
| metrics["tok"]+=c["total_tokens"]; metrics["ms"]+=(time.time()-t0)*1000 |
| yield f"data:{json.dumps({**c,'session_id':req.session_id,'full_response':reply})}\n\n" |
| elif c["type"]=="error": |
| metrics["err"]+=1; yield f"data:{json.dumps(c)}\n\n" |
| except Exception as e: |
| metrics["err"]+=1; yield f"data:{json.dumps({'type':'error','message':str(e)})}\n\n" |
| finally: yield "data:[DONE]\n\n" |
|
|
| return StreamingResponse(stream(),media_type="text/event-stream", |
| headers={"Cache-Control":"no-cache","X-Accel-Buffering":"no"}) |
|
|
| @app.post("/chat/reset") |
| async def reset(req: ResetReq): |
| sessions.pop(req.session_id, None) |
| await pool.reset(req.session_id) |
| return {"status":"ok","session_id":req.session_id} |
|
|
| @app.get("/chat/history") |
| async def history(session_id: str): |
| s=sessions.get(session_id) |
| if not s: return {"session_id":session_id,"turns":0,"history":[]} |
| return {"session_id":session_id,"tokens_in_engine":s.n_cached, |
| "turns":sum(1 for m in s.history if m["role"]=="user"),"history":s.history} |
|
|
| @app.get("/pool/status") |
| async def pool_status(): return {"n_engines":N_ENGINES,"engines":pool.status(),"sessions":len(sessions)} |
|
|
| if __name__=="__main__": |
| import uvicorn; uvicorn.run("main:app",host="0.0.0.0",port=7860,reload=False) |
|
|