blitzkode / server.py
neuralbroker's picture
Update server.py (v2.1 production)
6c79e9e verified
raw
history blame
26 kB
#!/usr/bin/env python3
"""
BlitzKode backend server.
Serves the bundled frontend and proxies prompts to a local GGUF model
through llama.cpp. Model is loaded lazily so the module stays importable
in tests and environments where the model artifact is not present yet.
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import queue
import threading
import time
import urllib.error
import urllib.parse
import urllib.request
from contextlib import asynccontextmanager, suppress
from dataclasses import dataclass
from dataclasses import field as dataclass_field
from pathlib import Path
from typing import Any, Literal, cast
import llama_cpp
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field
from starlette.middleware.base import BaseHTTPMiddleware
APP_NAME = "BlitzKode"
APP_VERSION = "2.0"
CREATOR = "Sajad"
ROOT_DIR = Path(__file__).resolve().parent
DEFAULT_MODEL_PATH = ROOT_DIR / "blitzkode.gguf"
DEFAULT_FRONTEND_DIST_PATH = ROOT_DIR / "frontend" / "dist" / "index.html"
DEFAULT_CONTEXT = 2048
DEFAULT_MAX_PROMPT_LENGTH = 4000
DEFAULT_MAX_TOKENS = 512
DEFAULT_RATE_LIMIT_MAX = 30
DEFAULT_MAX_SEARCH_RESULTS = 5
DEFAULT_SEARCH_TIMEOUT_SECONDS = 8
DEFAULT_MAX_MESSAGES = 20
STOP_TOKENS = ["<|im_end|>", "<|im_start|>user"]
SYSTEM_PROMPT = (
"<|im_start|>system\n"
"You are BlitzKode, an AI coding assistant created by Sajad. "
"You are an expert in Python, JavaScript, Java, C++, and other programming languages. "
"For coding work, first understand the user's goal and constraints, then provide a short plan before code when useful. "
"Do not invent APIs, file contents, citations, or execution results. "
"If evidence is missing, say what is unknown and give a safe next step. "
"Write clean, efficient, and well-documented code. Keep responses concise and practical.<|im_end|>"
)
logger = logging.getLogger("blitzkode")
def _bool_from_env(name: str, default: bool = False) -> bool:
value = os.getenv(name)
if value is None:
return default
return value.strip().lower() in {"1", "true", "yes", "on"}
def _int_from_env(name: str, default: int) -> int:
value = os.getenv(name)
if not value:
return default
try:
return int(value)
except ValueError:
return default
def _path_from_env(name: str, default: Path) -> Path:
value = os.getenv(name)
return Path(value) if value else default
def _frontend_path_from_env() -> Path:
value = os.getenv("BLITZKODE_FRONTEND_PATH")
if value:
return Path(value)
return DEFAULT_FRONTEND_DIST_PATH
def _validate_prompt(prompt: str, max_length: int) -> tuple[str, JSONResponse | None]:
prompt = prompt.strip()
if not prompt:
return prompt, JSONResponse({"error": "Prompt is required"}, status_code=400)
if len(prompt) > max_length:
return prompt, JSONResponse(
{"error": f"Prompt too long. Max {max_length} chars."},
status_code=400,
)
return prompt, None
@dataclass(slots=True)
class Settings:
root_dir: Path = ROOT_DIR
model_path: Path = dataclass_field(default_factory=lambda: _path_from_env("BLITZKODE_MODEL_PATH", DEFAULT_MODEL_PATH))
frontend_path: Path = dataclass_field(default_factory=_frontend_path_from_env)
host: str = os.getenv("BLITZKODE_HOST", "0.0.0.0")
port: int = _int_from_env("BLITZKODE_PORT", 7860)
n_gpu_layers: int = _int_from_env("BLITZKODE_GPU_LAYERS", 0)
n_ctx: int = _int_from_env("BLITZKODE_N_CTX", DEFAULT_CONTEXT)
n_threads: int = _int_from_env("BLITZKODE_THREADS", max(1, min(8, os.cpu_count() or 1)))
n_batch: int = _int_from_env("BLITZKODE_BATCH", 128)
max_prompt_length: int = _int_from_env("BLITZKODE_MAX_PROMPT_LENGTH", DEFAULT_MAX_PROMPT_LENGTH)
preload_model: bool = _bool_from_env("BLITZKODE_PRELOAD_MODEL", default=False)
cors_origins: str = os.getenv("BLITZKODE_CORS_ORIGINS", "http://localhost:7860")
api_key: str = os.getenv("BLITZKODE_API_KEY", "")
web_search_enabled: bool = _bool_from_env("BLITZKODE_WEB_SEARCH", default=True)
search_timeout_seconds: int = _int_from_env("BLITZKODE_SEARCH_TIMEOUT", DEFAULT_SEARCH_TIMEOUT_SECONDS)
max_search_results: int = _int_from_env("BLITZKODE_MAX_SEARCH_RESULTS", DEFAULT_MAX_SEARCH_RESULTS)
class MessageItem(BaseModel):
role: Literal["user", "assistant"]
content: str = Field(min_length=1, max_length=DEFAULT_MAX_PROMPT_LENGTH)
class GenerateRequest(BaseModel):
prompt: str
messages: list[MessageItem] = Field(default_factory=list, max_length=DEFAULT_MAX_MESSAGES)
temperature: float = Field(default=0.5, ge=0.0, le=2.0)
max_tokens: int = Field(default=256, ge=1, le=DEFAULT_MAX_TOKENS)
top_p: float = Field(default=0.95, gt=0.0, le=1.0)
top_k: int = Field(default=20, ge=1, le=200)
repeat_penalty: float = Field(default=1.05, ge=0.8, le=2.0)
class SearchRequest(BaseModel):
query: str = Field(min_length=1, max_length=500)
max_results: int = Field(default=DEFAULT_MAX_SEARCH_RESULTS, ge=1, le=10)
deep: bool = False
class ResearchGenerateRequest(GenerateRequest):
search_query: str | None = Field(default=None, max_length=500)
search_results: int = Field(default=DEFAULT_MAX_SEARCH_RESULTS, ge=1, le=10)
deep_search: bool = False
@dataclass(slots=True)
class SearchResult:
title: str
url: str
snippet: str
source: str = "DuckDuckGo"
def as_dict(self) -> dict[str, str]:
return {
"title": self.title,
"url": self.url,
"snippet": self.snippet,
"source": self.source,
}
class WebSearchService:
def __init__(self, settings: Settings):
self.settings = settings
@property
def enabled(self) -> bool:
return self.settings.web_search_enabled
def _query_variants(self, query: str, deep: bool) -> list[str]:
query = " ".join(query.split())
if not deep:
return [query]
return [
query,
f"{query} official documentation",
f"{query} best practices",
]
def _append_result(
self, results: list[SearchResult], seen_urls: set[str], title: str, url: str, snippet: str, max_results: int
) -> None:
title = " ".join((title or "Untitled").split())[:200]
url = (url or "").strip()
snippet = " ".join((snippet or "").split())[:500]
if not url or url in seen_urls or len(results) >= max_results:
return
seen_urls.add(url)
results.append(SearchResult(title=title, url=url, snippet=snippet))
def _collect_related_topics(self, topics: list[dict], results: list[SearchResult], seen_urls: set[str], max_results: int) -> None:
for topic in topics:
if len(results) >= max_results:
return
if "Topics" in topic:
self._collect_related_topics(topic.get("Topics", []), results, seen_urls, max_results)
continue
text = topic.get("Text", "")
url = topic.get("FirstURL", "")
if text and url:
title = text.split(" - ", 1)[0]
self._append_result(results, seen_urls, title, url, text, max_results)
def search(self, query: str, max_results: int = DEFAULT_MAX_SEARCH_RESULTS, deep: bool = False) -> list[dict[str, str]]:
if not self.enabled:
raise RuntimeError("Web search is disabled. Set BLITZKODE_WEB_SEARCH=true to enable it.")
query = " ".join(query.split())
if not query:
raise ValueError("Search query is required")
limit = min(max_results, max(1, self.settings.max_search_results), 10)
results: list[SearchResult] = []
seen_urls: set[str] = set()
for variant in self._query_variants(query, deep):
if len(results) >= limit:
break
params = urllib.parse.urlencode(
{
"q": variant,
"format": "json",
"no_html": "1",
"skip_disambig": "1",
}
)
request = urllib.request.Request(
f"https://api.duckduckgo.com/?{params}",
headers={"User-Agent": f"{APP_NAME}/{APP_VERSION}"},
)
with urllib.request.urlopen(request, timeout=self.settings.search_timeout_seconds) as response:
payload = json.loads(response.read().decode("utf-8"))
self._append_result(
results,
seen_urls,
payload.get("Heading") or variant,
payload.get("AbstractURL", ""),
payload.get("AbstractText", ""),
limit,
)
self._collect_related_topics(payload.get("RelatedTopics", []), results, seen_urls, limit)
return [result.as_dict() for result in results]
class ModelService:
def __init__(self, settings: Settings):
self.settings = settings
self._llm: llama_cpp.Llama | None = None
self._init_lock = threading.Lock()
self._load_time_seconds: float | None = None
self._last_error: str | None = None
self._busy: bool = False
@property
def model_loaded(self) -> bool:
return self._llm is not None
@property
def model_exists(self) -> bool:
return self.settings.model_path.exists()
@property
def last_error(self) -> str | None:
return self._last_error
@property
def load_time_seconds(self) -> float | None:
return self._load_time_seconds
@property
def busy(self) -> bool:
return self._busy
def load_model(self):
if self._llm is not None:
return self._llm
with self._init_lock:
if self._llm is not None:
return self._llm
if not self.model_exists:
self._last_error = f"Model not found at {self.settings.model_path}"
raise FileNotFoundError(self._last_error)
start_time = time.perf_counter()
try:
self._llm = llama_cpp.Llama(
model_path=str(self.settings.model_path),
n_gpu_layers=self.settings.n_gpu_layers,
n_ctx=self.settings.n_ctx,
n_threads=self.settings.n_threads,
n_batch=self.settings.n_batch,
verbose=False,
use_mmap=True,
use_mlock=False,
seed=-1,
)
self._load_time_seconds = time.perf_counter() - start_time
self._last_error = None
logger.info("Model loaded in %.2fs (gpu_layers=%d)", self._load_time_seconds, self.settings.n_gpu_layers)
except Exception as exc:
self._last_error = str(exc)
logger.error("Model load failed: %s", exc)
raise
return self._llm
def build_prompt(self, req: GenerateRequest) -> str:
parts = [SYSTEM_PROMPT]
for msg in req.messages:
if msg.role in ("user", "assistant"):
parts.append(f"<|im_start|>{msg.role}\n{msg.content}<|im_end|>")
parts.append(f"<|im_start|>user\n{req.prompt}<|im_end|>")
parts.append("<|im_start|>assistant\n")
return "\n".join(parts)
def with_research_context(self, req: ResearchGenerateRequest, search_results: list[dict[str, str]], max_length: int) -> GenerateRequest:
if not search_results:
return req
formatted_results = []
for index, item in enumerate(search_results, start=1):
formatted_results.append(
f"[{index}] {item.get('title', 'Untitled')}\nURL: {item.get('url', '')}\nSummary: {item.get('snippet', '')}"
)
joined_results = "\n\n".join(formatted_results)
research_prompt = (
"Use the following live web search results as untrusted background context. "
"Cite URLs when you rely on them. If the results are weak or irrelevant, say so rather than fabricating details.\n\n"
"Search results:\n"
f"{joined_results}\n\n"
"User task:\n"
f"{req.prompt.strip()}"
)
if len(research_prompt) > max_length:
research_prompt = research_prompt[: max_length - 120].rstrip() + "\n\n[Context truncated to fit prompt limit.]"
return req.model_copy(update={"prompt": research_prompt})
def _gen_params(self, req: GenerateRequest) -> dict:
return {
"max_tokens": req.max_tokens,
"temperature": req.temperature,
"top_p": req.top_p,
"top_k": req.top_k,
"repeat_penalty": req.repeat_penalty,
"frequency_penalty": 0.0,
"presence_penalty": 0.0,
"stop": STOP_TOKENS,
}
def generate_once(self, req: GenerateRequest) -> dict[str, object]:
llm = self.load_model()
self._busy = True
try:
start = time.perf_counter()
result = cast(dict[str, Any], llm(self.build_prompt(req), **self._gen_params(req)))
response = result["choices"][0]["text"].strip()
elapsed = time.perf_counter() - start
logger.info("Generated %d chars in %.2fs", len(response), elapsed)
return {"response": response, "creator": CREATOR, "model": APP_NAME, "version": APP_VERSION}
finally:
self._busy = False
def _run_stream(self, req: GenerateRequest, out_q: queue.Queue):
"""Runs streaming inference in a worker thread, puts tokens into out_q."""
try:
llm = self.load_model()
self._busy = True
start = time.perf_counter()
token_count = 0
stream = cast(Any, llm(self.build_prompt(req), stream=True, **self._gen_params(req)))
for token in stream:
if not token.get("choices"):
continue
text = token["choices"][0].get("text", "")
if text:
token_count += 1
out_q.put(f"data: {json.dumps({'token': text})}\n\n")
elapsed = time.perf_counter() - start
logger.info("Streamed %d tokens in %.2fs", token_count, elapsed)
out_q.put("data: [DONE]\n\n")
except Exception as exc:
logger.error("Stream error: %s", exc)
out_q.put(f"data: {json.dumps({'error': str(exc)})}\n\n")
finally:
self._busy = False
out_q.put(None)
def _check_api_key(request: Request, settings: Settings) -> JSONResponse | None:
if not settings.api_key:
return None
auth = request.headers.get("Authorization", "")
token = auth[7:] if auth.startswith("Bearer ") else auth
# Timing-safe comparison (prevent timing attacks)
import hmac
if not hmac.compare_digest(token, settings.api_key):
return JSONResponse({"error": "Unauthorized"}, status_code=401)
return None
class RateLimitMiddleware(BaseHTTPMiddleware):
def __init__(self, app, max_requests: int = DEFAULT_RATE_LIMIT_MAX, window_seconds: int = 60):
super().__init__(app)
self._max = max_requests
self._window = window_seconds
self._clients: dict[str, list[float]] = {}
self._lock = threading.Lock()
self._cleanup_done = 0
async def dispatch(self, request: Request, call_next):
client_ip = request.client.host if request.client else "unknown"
now = time.monotonic()
# Cleanup old entries periodically (every 1000 requests)
self._cleanup_done += 1
if self._cleanup_done > 1000:
self._cleanup_done = 0
with self._lock:
cutoff = now - self._window
self._clients = {ip: [t for t in ts if t >= cutoff] for ip, ts in self._clients.items() if ts}
with self._lock:
timestamps = self._clients.get(client_ip, [])
timestamps = [t for t in timestamps if now - t < self._window]
if len(timestamps) >= self._max:
return JSONResponse(
{"error": "Rate limit exceeded. Try again later."},
status_code=429,
headers={"Retry-After": str(self._window)},
)
timestamps.append(now)
self._clients[client_ip] = timestamps
return await call_next(request)
class RequestSizeLimitMiddleware(BaseHTTPMiddleware):
def __init__(self, app, max_bytes: int = 50_000):
super().__init__(app)
self._max = max_bytes
async def dispatch(self, request: Request, call_next):
content_length = request.headers.get("content-length")
if content_length:
try:
if int(content_length) > self._max:
return JSONResponse({"error": "Request body too large"}, status_code=413)
except ValueError:
return JSONResponse({"error": "Invalid Content-Length header"}, status_code=400)
return await call_next(request)
def create_app(settings: Settings | None = None) -> FastAPI:
settings = settings or Settings()
model_service = ModelService(settings)
search_service = WebSearchService(settings)
model_lock = asyncio.Lock()
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s", datefmt="%H:%M:%S")
@asynccontextmanager
async def lifespan(_: FastAPI):
if settings.preload_model:
with suppress(Exception):
await asyncio.to_thread(model_service.load_model)
yield
app = FastAPI(title=f"{APP_NAME} API", version=APP_VERSION, lifespan=lifespan)
app.state.settings = settings
app.state.model_service = model_service
app.state.search_service = search_service
frontend_assets_path = settings.frontend_path.parent / "assets"
if frontend_assets_path.exists():
app.mount("/assets", StaticFiles(directory=str(frontend_assets_path)), name="frontend-assets")
cors_origins = [o.strip() for o in settings.cors_origins.split(",") if o.strip()]
app.add_middleware(
CORSMiddleware,
allow_origins=cors_origins,
allow_methods=["POST", "GET", "OPTIONS"],
allow_headers=["Content-Type", "Authorization"],
)
if _bool_from_env("BLITZKODE_RATE_LIMIT", default=True):
app.add_middleware(RateLimitMiddleware, max_requests=_int_from_env("BLITZKODE_RATE_LIMIT_MAX", DEFAULT_RATE_LIMIT_MAX))
app.add_middleware(RequestSizeLimitMiddleware, max_bytes=_int_from_env("BLITZKODE_MAX_REQUEST_BYTES", 50_000))
@app.get("/")
async def root():
if not settings.frontend_path.exists():
raise HTTPException(status_code=404, detail="Frontend build is missing. Run `npm install` and `npm run build` in frontend/.")
return FileResponse(str(settings.frontend_path))
@app.get("/health")
async def health():
status = "healthy" if model_service.model_exists else "degraded"
return JSONResponse(
{
"status": status,
"model_loaded": model_service.model_loaded,
"model_path": str(settings.model_path),
"model_exists": model_service.model_exists,
"frontend_exists": settings.frontend_path.exists(),
"version": APP_VERSION,
"gpu_layers": settings.n_gpu_layers,
"last_error": model_service.last_error,
"busy": model_service.busy,
}
)
@app.post("/generate")
async def generate(req: GenerateRequest, request: Request):
auth_err = _check_api_key(request, settings)
if auth_err:
return auth_err
prompt, err = _validate_prompt(req.prompt, settings.max_prompt_length)
if err:
return err
async with model_lock:
try:
sanitized = req.model_copy(update={"prompt": prompt})
payload = await asyncio.to_thread(model_service.generate_once, sanitized)
return JSONResponse(payload)
except FileNotFoundError as exc:
return JSONResponse({"error": str(exc)}, status_code=503)
except Exception as exc:
return JSONResponse({"error": str(exc)}, status_code=500)
@app.post("/generate/research")
async def generate_research(req: ResearchGenerateRequest, request: Request):
auth_err = _check_api_key(request, settings)
if auth_err:
return auth_err
prompt, err = _validate_prompt(req.prompt, settings.max_prompt_length)
if err:
return err
if not search_service.enabled:
return JSONResponse({"error": "Web search is disabled"}, status_code=503)
search_query = (req.search_query or prompt).strip()
try:
results = await asyncio.to_thread(search_service.search, search_query, req.search_results, req.deep_search)
sanitized = req.model_copy(update={"prompt": prompt})
enriched = model_service.with_research_context(sanitized, results, settings.max_prompt_length)
async with model_lock:
payload = await asyncio.to_thread(model_service.generate_once, enriched)
payload["search_results"] = results
return JSONResponse(payload)
except FileNotFoundError as exc:
return JSONResponse({"error": str(exc)}, status_code=503)
except (RuntimeError, urllib.error.URLError, TimeoutError) as exc:
return JSONResponse({"error": f"Search failed: {exc}"}, status_code=502)
except Exception as exc:
return JSONResponse({"error": str(exc)}, status_code=500)
@app.post("/generate/stream")
async def generate_stream(req: GenerateRequest, request: Request):
auth_err = _check_api_key(request, settings)
if auth_err:
return auth_err
prompt, err = _validate_prompt(req.prompt, settings.max_prompt_length)
if err:
return err
if not model_service.model_exists:
return JSONResponse({"error": f"Model not found at {settings.model_path}"}, status_code=503)
sanitized = req.model_copy(update={"prompt": prompt})
async def _locked_stream():
async with model_lock:
token_q: queue.Queue = queue.Queue()
thread = threading.Thread(
target=model_service._run_stream,
args=(sanitized, token_q),
daemon=True,
)
thread.start()
# Use thread-safe queue.get() instead of deprecated get_running_loop()
while True:
chunk = await asyncio.to_thread(token_q.get)
if chunk is None:
break
yield chunk
return StreamingResponse(
_locked_stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)
@app.post("/search/web")
async def search_web(req: SearchRequest, request: Request):
auth_err = _check_api_key(request, settings)
if auth_err:
return auth_err
if not search_service.enabled:
return JSONResponse({"error": "Web search is disabled"}, status_code=503)
try:
results = await asyncio.to_thread(search_service.search, req.query, req.max_results, req.deep)
return JSONResponse({"query": req.query.strip(), "deep": req.deep, "results": results})
except (RuntimeError, urllib.error.URLError, TimeoutError) as exc:
return JSONResponse({"error": f"Search failed: {exc}"}, status_code=502)
except Exception as exc:
return JSONResponse({"error": str(exc)}, status_code=500)
@app.get("/info")
async def info():
return JSONResponse(
{
"name": APP_NAME,
"creator": CREATOR,
"version": APP_VERSION,
"status": "ready" if model_service.model_exists else "model-missing",
"mode": f"{'GPU' if settings.n_gpu_layers > 0 else 'CPU'} (llama.cpp)",
"gpu_layers": settings.n_gpu_layers,
"context_window": settings.n_ctx,
"model_loaded": model_service.model_loaded,
"load_time_seconds": model_service.load_time_seconds,
"busy": model_service.busy,
"web_search_enabled": search_service.enabled,
"endpoints": {
"generate": "POST /generate",
"research_generate": "POST /generate/research",
"stream": "POST /generate/stream",
"search": "POST /search/web",
"health": "GET /health",
"info": "GET /info",
},
}
)
return app
app = create_app()
def main() -> None:
s = Settings()
print(f"\n{'=' * 50}")
print(f"{APP_NAME.upper()} v{APP_VERSION}")
print(f"Creator: {CREATOR}")
print(f"{'=' * 50}")
print(f"Model: {s.model_path}")
print(f"GPU: {s.n_gpu_layers} layers")
print(f"Ctx: {s.n_ctx} | Threads: {s.n_threads}")
print(f"URL: http://localhost:{s.port}\n")
uvicorn.run(app, host=s.host, port=s.port, log_level="warning")
if __name__ == "__main__":
main()