neuralbroker commited on
Commit
6c79e9e
·
verified ·
1 Parent(s): bcc7db2

Update server.py (v2.1 production)

Browse files
Files changed (1) hide show
  1. server.py +398 -88
server.py CHANGED
@@ -13,36 +13,49 @@ import asyncio
13
  import json
14
  import logging
15
  import os
 
 
16
  import time
17
- from concurrent.futures import ThreadPoolExecutor
18
- from contextlib import asynccontextmanager
 
 
19
  from dataclasses import dataclass
 
20
  from pathlib import Path
21
- from threading import Lock
22
- from typing import Iterator
23
 
24
  import llama_cpp
25
  import uvicorn
26
  from fastapi import FastAPI, HTTPException, Request
27
  from fastapi.middleware.cors import CORSMiddleware
28
  from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
 
29
  from pydantic import BaseModel, Field
 
30
 
31
  APP_NAME = "BlitzKode"
32
  APP_VERSION = "2.0"
33
  CREATOR = "Sajad"
34
  ROOT_DIR = Path(__file__).resolve().parent
35
  DEFAULT_MODEL_PATH = ROOT_DIR / "blitzkode.gguf"
36
- DEFAULT_FRONTEND_PATH = ROOT_DIR / "frontend" / "index.html"
37
  DEFAULT_CONTEXT = 2048
38
  DEFAULT_MAX_PROMPT_LENGTH = 4000
39
  DEFAULT_MAX_TOKENS = 512
 
 
 
 
40
  STOP_TOKENS = ["<|im_end|>", "<|im_start|>user"]
41
 
42
  SYSTEM_PROMPT = (
43
  "<|im_start|>system\n"
44
  "You are BlitzKode, an AI coding assistant created by Sajad. "
45
  "You are an expert in Python, JavaScript, Java, C++, and other programming languages. "
 
 
 
46
  "Write clean, efficient, and well-documented code. Keep responses concise and practical.<|im_end|>"
47
  )
48
 
@@ -66,6 +79,18 @@ def _int_from_env(name: str, default: int) -> int:
66
  return default
67
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  def _validate_prompt(prompt: str, max_length: int) -> tuple[str, JSONResponse | None]:
70
  prompt = prompt.strip()
71
  if not prompt:
@@ -81,8 +106,8 @@ def _validate_prompt(prompt: str, max_length: int) -> tuple[str, JSONResponse |
81
  @dataclass(slots=True)
82
  class Settings:
83
  root_dir: Path = ROOT_DIR
84
- model_path: Path = Path(os.getenv("BLITZKODE_MODEL_PATH", DEFAULT_MODEL_PATH))
85
- frontend_path: Path = Path(os.getenv("BLITZKODE_FRONTEND_PATH", DEFAULT_FRONTEND_PATH))
86
  host: str = os.getenv("BLITZKODE_HOST", "0.0.0.0")
87
  port: int = _int_from_env("BLITZKODE_PORT", 7860)
88
  n_gpu_layers: int = _int_from_env("BLITZKODE_GPU_LAYERS", 0)
@@ -91,19 +116,21 @@ class Settings:
91
  n_batch: int = _int_from_env("BLITZKODE_BATCH", 128)
92
  max_prompt_length: int = _int_from_env("BLITZKODE_MAX_PROMPT_LENGTH", DEFAULT_MAX_PROMPT_LENGTH)
93
  preload_model: bool = _bool_from_env("BLITZKODE_PRELOAD_MODEL", default=False)
94
- workers: int = _int_from_env("BLITZKODE_WORKERS", 2)
95
- cors_origins: str = os.getenv("BLITZKODE_CORS_ORIGINS", "*")
96
  api_key: str = os.getenv("BLITZKODE_API_KEY", "")
 
 
 
97
 
98
 
99
  class MessageItem(BaseModel):
100
- role: str
101
- content: str
102
 
103
 
104
  class GenerateRequest(BaseModel):
105
  prompt: str
106
- messages: list[MessageItem] = Field(default_factory=list)
107
  temperature: float = Field(default=0.5, ge=0.0, le=2.0)
108
  max_tokens: int = Field(default=256, ge=1, le=DEFAULT_MAX_TOKENS)
109
  top_p: float = Field(default=0.95, gt=0.0, le=1.0)
@@ -111,13 +138,127 @@ class GenerateRequest(BaseModel):
111
  repeat_penalty: float = Field(default=1.05, ge=0.8, le=2.0)
112
 
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  class ModelService:
115
  def __init__(self, settings: Settings):
116
  self.settings = settings
117
- self._llm = None
118
- self._lock = Lock()
119
  self._load_time_seconds: float | None = None
120
  self._last_error: str | None = None
 
121
 
122
  @property
123
  def model_loaded(self) -> bool:
@@ -135,11 +276,15 @@ class ModelService:
135
  def load_time_seconds(self) -> float | None:
136
  return self._load_time_seconds
137
 
 
 
 
 
138
  def load_model(self):
139
  if self._llm is not None:
140
  return self._llm
141
 
142
- with self._lock:
143
  if self._llm is not None:
144
  return self._llm
145
 
@@ -179,48 +324,78 @@ class ModelService:
179
  parts.append("<|im_start|>assistant\n")
180
  return "\n".join(parts)
181
 
182
- def _gen_params(self, req: GenerateRequest) -> dict:
183
- return dict(
184
- max_tokens=req.max_tokens,
185
- temperature=req.temperature,
186
- top_p=req.top_p,
187
- top_k=req.top_k,
188
- repeat_penalty=req.repeat_penalty,
189
- frequency_penalty=0.0,
190
- presence_penalty=0.0,
191
- stop=STOP_TOKENS,
 
 
 
 
 
 
 
 
192
  )
 
 
 
193
 
194
- def generate_once(self, req: GenerateRequest) -> dict[str, object]:
195
- llm = self.load_model()
196
- start = time.perf_counter()
197
-
198
- result = llm(self.build_prompt(req), **self._gen_params(req))
199
- response = result["choices"][0]["text"].strip()
200
- elapsed = time.perf_counter() - start
201
- logger.info("Generated %d chars in %.2fs", len(response), elapsed)
202
-
203
- return {"response": response, "creator": CREATOR, "model": APP_NAME, "version": APP_VERSION}
 
204
 
205
- def stream_tokens(self, req: GenerateRequest) -> Iterator[str]:
206
  llm = self.load_model()
207
- start = time.perf_counter()
208
- token_count = 0
 
 
 
 
 
 
 
 
209
 
 
 
210
  try:
211
- for token in llm(self.build_prompt(req), stream=True, **self._gen_params(req)):
 
 
 
 
 
212
  if not token.get("choices"):
213
  continue
214
  text = token["choices"][0].get("text", "")
215
  if text:
216
  token_count += 1
217
- yield f"data: {json.dumps({'token': text})}\n\n"
218
  elapsed = time.perf_counter() - start
219
  logger.info("Streamed %d tokens in %.2fs", token_count, elapsed)
220
- yield "data: [DONE]\n\n"
221
  except Exception as exc:
222
  logger.error("Stream error: %s", exc)
223
- yield f"data: {json.dumps({'error': str(exc)})}\n\n"
 
 
 
224
 
225
 
226
  def _check_api_key(request: Request, settings: Settings) -> JSONResponse | None:
@@ -228,59 +403,124 @@ def _check_api_key(request: Request, settings: Settings) -> JSONResponse | None:
228
  return None
229
  auth = request.headers.get("Authorization", "")
230
  token = auth[7:] if auth.startswith("Bearer ") else auth
231
- if token != settings.api_key:
 
 
 
 
232
  return JSONResponse({"error": "Unauthorized"}, status_code=401)
233
  return None
234
 
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  def create_app(settings: Settings | None = None) -> FastAPI:
237
  settings = settings or Settings()
238
  model_service = ModelService(settings)
239
- executor = ThreadPoolExecutor(max_workers=settings.workers)
 
240
 
241
  logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s", datefmt="%H:%M:%S")
242
 
243
  @asynccontextmanager
244
  async def lifespan(_: FastAPI):
245
  if settings.preload_model:
246
- try:
247
  await asyncio.to_thread(model_service.load_model)
248
- except Exception:
249
- pass
250
- try:
251
- yield
252
- finally:
253
- executor.shutdown(wait=False, cancel_futures=True)
254
 
255
  app = FastAPI(title=f"{APP_NAME} API", version=APP_VERSION, lifespan=lifespan)
256
  app.state.settings = settings
257
  app.state.model_service = model_service
258
- app.state.executor = executor
 
 
 
 
259
 
260
  cors_origins = [o.strip() for o in settings.cors_origins.split(",") if o.strip()]
261
- app.add_middleware(CORSMiddleware, allow_origins=cors_origins, allow_methods=["*"], allow_headers=["*"])
 
 
 
 
 
 
 
 
 
262
 
263
  @app.get("/")
264
  async def root():
265
  if not settings.frontend_path.exists():
266
- raise HTTPException(status_code=404, detail="Frontend file is missing.")
267
  return FileResponse(str(settings.frontend_path))
268
 
269
  @app.get("/health")
270
  async def health():
271
- status = "healthy"
272
- if not settings.frontend_path.exists() or not model_service.model_exists:
273
- status = "degraded"
274
- return JSONResponse({
275
- "status": status,
276
- "model_loaded": model_service.model_loaded,
277
- "model_path": str(settings.model_path),
278
- "model_exists": model_service.model_exists,
279
- "frontend_exists": settings.frontend_path.exists(),
280
- "version": APP_VERSION,
281
- "gpu_layers": settings.n_gpu_layers,
282
- "last_error": model_service.last_error,
283
- })
 
284
 
285
  @app.post("/generate")
286
  async def generate(req: GenerateRequest, request: Request):
@@ -292,12 +532,42 @@ def create_app(settings: Settings | None = None) -> FastAPI:
292
  if err:
293
  return err
294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  try:
 
296
  sanitized = req.model_copy(update={"prompt": prompt})
297
- payload = await asyncio.get_running_loop().run_in_executor(executor, model_service.generate_once, sanitized)
 
 
 
298
  return JSONResponse(payload)
299
  except FileNotFoundError as exc:
300
  return JSONResponse({"error": str(exc)}, status_code=503)
 
 
301
  except Exception as exc:
302
  return JSONResponse({"error": str(exc)}, status_code=500)
303
 
@@ -315,31 +585,71 @@ def create_app(settings: Settings | None = None) -> FastAPI:
315
  return JSONResponse({"error": f"Model not found at {settings.model_path}"}, status_code=503)
316
 
317
  sanitized = req.model_copy(update={"prompt": prompt})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
  return StreamingResponse(
319
- model_service.stream_tokens(sanitized),
320
  media_type="text/event-stream",
321
  headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
322
  )
323
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  @app.get("/info")
325
  async def info():
326
- return JSONResponse({
327
- "name": APP_NAME,
328
- "creator": CREATOR,
329
- "version": APP_VERSION,
330
- "status": "ready" if model_service.model_exists else "model-missing",
331
- "mode": f"{'GPU' if settings.n_gpu_layers > 0 else 'CPU'} (llama.cpp)",
332
- "gpu_layers": settings.n_gpu_layers,
333
- "context_window": settings.n_ctx,
334
- "model_loaded": model_service.model_loaded,
335
- "load_time_seconds": model_service.load_time_seconds,
336
- "endpoints": {
337
- "generate": "POST /generate",
338
- "stream": "POST /generate/stream",
339
- "health": "GET /health",
340
- "info": "GET /info",
341
- },
342
- })
 
 
 
 
 
 
343
 
344
  return app
345
 
@@ -348,14 +658,14 @@ app = create_app()
348
 
349
 
350
  def main() -> None:
351
- s = app.state.settings
352
  print(f"\n{'=' * 50}")
353
  print(f"{APP_NAME.upper()} v{APP_VERSION}")
354
  print(f"Creator: {CREATOR}")
355
  print(f"{'=' * 50}")
356
  print(f"Model: {s.model_path}")
357
  print(f"GPU: {s.n_gpu_layers} layers")
358
- print(f"Ctx: {s.n_ctx} | Threads: {s.n_threads} | Workers: {s.workers}")
359
  print(f"URL: http://localhost:{s.port}\n")
360
 
361
  uvicorn.run(app, host=s.host, port=s.port, log_level="warning")
 
13
  import json
14
  import logging
15
  import os
16
+ import queue
17
+ import threading
18
  import time
19
+ import urllib.error
20
+ import urllib.parse
21
+ import urllib.request
22
+ from contextlib import asynccontextmanager, suppress
23
  from dataclasses import dataclass
24
+ from dataclasses import field as dataclass_field
25
  from pathlib import Path
26
+ from typing import Any, Literal, cast
 
27
 
28
  import llama_cpp
29
  import uvicorn
30
  from fastapi import FastAPI, HTTPException, Request
31
  from fastapi.middleware.cors import CORSMiddleware
32
  from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
33
+ from fastapi.staticfiles import StaticFiles
34
  from pydantic import BaseModel, Field
35
+ from starlette.middleware.base import BaseHTTPMiddleware
36
 
37
  APP_NAME = "BlitzKode"
38
  APP_VERSION = "2.0"
39
  CREATOR = "Sajad"
40
  ROOT_DIR = Path(__file__).resolve().parent
41
  DEFAULT_MODEL_PATH = ROOT_DIR / "blitzkode.gguf"
42
+ DEFAULT_FRONTEND_DIST_PATH = ROOT_DIR / "frontend" / "dist" / "index.html"
43
  DEFAULT_CONTEXT = 2048
44
  DEFAULT_MAX_PROMPT_LENGTH = 4000
45
  DEFAULT_MAX_TOKENS = 512
46
+ DEFAULT_RATE_LIMIT_MAX = 30
47
+ DEFAULT_MAX_SEARCH_RESULTS = 5
48
+ DEFAULT_SEARCH_TIMEOUT_SECONDS = 8
49
+ DEFAULT_MAX_MESSAGES = 20
50
  STOP_TOKENS = ["<|im_end|>", "<|im_start|>user"]
51
 
52
  SYSTEM_PROMPT = (
53
  "<|im_start|>system\n"
54
  "You are BlitzKode, an AI coding assistant created by Sajad. "
55
  "You are an expert in Python, JavaScript, Java, C++, and other programming languages. "
56
+ "For coding work, first understand the user's goal and constraints, then provide a short plan before code when useful. "
57
+ "Do not invent APIs, file contents, citations, or execution results. "
58
+ "If evidence is missing, say what is unknown and give a safe next step. "
59
  "Write clean, efficient, and well-documented code. Keep responses concise and practical.<|im_end|>"
60
  )
61
 
 
79
  return default
80
 
81
 
82
+ def _path_from_env(name: str, default: Path) -> Path:
83
+ value = os.getenv(name)
84
+ return Path(value) if value else default
85
+
86
+
87
+ def _frontend_path_from_env() -> Path:
88
+ value = os.getenv("BLITZKODE_FRONTEND_PATH")
89
+ if value:
90
+ return Path(value)
91
+ return DEFAULT_FRONTEND_DIST_PATH
92
+
93
+
94
  def _validate_prompt(prompt: str, max_length: int) -> tuple[str, JSONResponse | None]:
95
  prompt = prompt.strip()
96
  if not prompt:
 
106
  @dataclass(slots=True)
107
  class Settings:
108
  root_dir: Path = ROOT_DIR
109
+ model_path: Path = dataclass_field(default_factory=lambda: _path_from_env("BLITZKODE_MODEL_PATH", DEFAULT_MODEL_PATH))
110
+ frontend_path: Path = dataclass_field(default_factory=_frontend_path_from_env)
111
  host: str = os.getenv("BLITZKODE_HOST", "0.0.0.0")
112
  port: int = _int_from_env("BLITZKODE_PORT", 7860)
113
  n_gpu_layers: int = _int_from_env("BLITZKODE_GPU_LAYERS", 0)
 
116
  n_batch: int = _int_from_env("BLITZKODE_BATCH", 128)
117
  max_prompt_length: int = _int_from_env("BLITZKODE_MAX_PROMPT_LENGTH", DEFAULT_MAX_PROMPT_LENGTH)
118
  preload_model: bool = _bool_from_env("BLITZKODE_PRELOAD_MODEL", default=False)
119
+ cors_origins: str = os.getenv("BLITZKODE_CORS_ORIGINS", "http://localhost:7860")
 
120
  api_key: str = os.getenv("BLITZKODE_API_KEY", "")
121
+ web_search_enabled: bool = _bool_from_env("BLITZKODE_WEB_SEARCH", default=True)
122
+ search_timeout_seconds: int = _int_from_env("BLITZKODE_SEARCH_TIMEOUT", DEFAULT_SEARCH_TIMEOUT_SECONDS)
123
+ max_search_results: int = _int_from_env("BLITZKODE_MAX_SEARCH_RESULTS", DEFAULT_MAX_SEARCH_RESULTS)
124
 
125
 
126
  class MessageItem(BaseModel):
127
+ role: Literal["user", "assistant"]
128
+ content: str = Field(min_length=1, max_length=DEFAULT_MAX_PROMPT_LENGTH)
129
 
130
 
131
  class GenerateRequest(BaseModel):
132
  prompt: str
133
+ messages: list[MessageItem] = Field(default_factory=list, max_length=DEFAULT_MAX_MESSAGES)
134
  temperature: float = Field(default=0.5, ge=0.0, le=2.0)
135
  max_tokens: int = Field(default=256, ge=1, le=DEFAULT_MAX_TOKENS)
136
  top_p: float = Field(default=0.95, gt=0.0, le=1.0)
 
138
  repeat_penalty: float = Field(default=1.05, ge=0.8, le=2.0)
139
 
140
 
141
+ class SearchRequest(BaseModel):
142
+ query: str = Field(min_length=1, max_length=500)
143
+ max_results: int = Field(default=DEFAULT_MAX_SEARCH_RESULTS, ge=1, le=10)
144
+ deep: bool = False
145
+
146
+
147
+ class ResearchGenerateRequest(GenerateRequest):
148
+ search_query: str | None = Field(default=None, max_length=500)
149
+ search_results: int = Field(default=DEFAULT_MAX_SEARCH_RESULTS, ge=1, le=10)
150
+ deep_search: bool = False
151
+
152
+
153
+ @dataclass(slots=True)
154
+ class SearchResult:
155
+ title: str
156
+ url: str
157
+ snippet: str
158
+ source: str = "DuckDuckGo"
159
+
160
+ def as_dict(self) -> dict[str, str]:
161
+ return {
162
+ "title": self.title,
163
+ "url": self.url,
164
+ "snippet": self.snippet,
165
+ "source": self.source,
166
+ }
167
+
168
+
169
+ class WebSearchService:
170
+ def __init__(self, settings: Settings):
171
+ self.settings = settings
172
+
173
+ @property
174
+ def enabled(self) -> bool:
175
+ return self.settings.web_search_enabled
176
+
177
+ def _query_variants(self, query: str, deep: bool) -> list[str]:
178
+ query = " ".join(query.split())
179
+ if not deep:
180
+ return [query]
181
+ return [
182
+ query,
183
+ f"{query} official documentation",
184
+ f"{query} best practices",
185
+ ]
186
+
187
+ def _append_result(
188
+ self, results: list[SearchResult], seen_urls: set[str], title: str, url: str, snippet: str, max_results: int
189
+ ) -> None:
190
+ title = " ".join((title or "Untitled").split())[:200]
191
+ url = (url or "").strip()
192
+ snippet = " ".join((snippet or "").split())[:500]
193
+ if not url or url in seen_urls or len(results) >= max_results:
194
+ return
195
+ seen_urls.add(url)
196
+ results.append(SearchResult(title=title, url=url, snippet=snippet))
197
+
198
+ def _collect_related_topics(self, topics: list[dict], results: list[SearchResult], seen_urls: set[str], max_results: int) -> None:
199
+ for topic in topics:
200
+ if len(results) >= max_results:
201
+ return
202
+ if "Topics" in topic:
203
+ self._collect_related_topics(topic.get("Topics", []), results, seen_urls, max_results)
204
+ continue
205
+ text = topic.get("Text", "")
206
+ url = topic.get("FirstURL", "")
207
+ if text and url:
208
+ title = text.split(" - ", 1)[0]
209
+ self._append_result(results, seen_urls, title, url, text, max_results)
210
+
211
+ def search(self, query: str, max_results: int = DEFAULT_MAX_SEARCH_RESULTS, deep: bool = False) -> list[dict[str, str]]:
212
+ if not self.enabled:
213
+ raise RuntimeError("Web search is disabled. Set BLITZKODE_WEB_SEARCH=true to enable it.")
214
+
215
+ query = " ".join(query.split())
216
+ if not query:
217
+ raise ValueError("Search query is required")
218
+
219
+ limit = min(max_results, max(1, self.settings.max_search_results), 10)
220
+ results: list[SearchResult] = []
221
+ seen_urls: set[str] = set()
222
+
223
+ for variant in self._query_variants(query, deep):
224
+ if len(results) >= limit:
225
+ break
226
+ params = urllib.parse.urlencode(
227
+ {
228
+ "q": variant,
229
+ "format": "json",
230
+ "no_html": "1",
231
+ "skip_disambig": "1",
232
+ }
233
+ )
234
+ request = urllib.request.Request(
235
+ f"https://api.duckduckgo.com/?{params}",
236
+ headers={"User-Agent": f"{APP_NAME}/{APP_VERSION}"},
237
+ )
238
+ with urllib.request.urlopen(request, timeout=self.settings.search_timeout_seconds) as response:
239
+ payload = json.loads(response.read().decode("utf-8"))
240
+
241
+ self._append_result(
242
+ results,
243
+ seen_urls,
244
+ payload.get("Heading") or variant,
245
+ payload.get("AbstractURL", ""),
246
+ payload.get("AbstractText", ""),
247
+ limit,
248
+ )
249
+ self._collect_related_topics(payload.get("RelatedTopics", []), results, seen_urls, limit)
250
+
251
+ return [result.as_dict() for result in results]
252
+
253
+
254
  class ModelService:
255
  def __init__(self, settings: Settings):
256
  self.settings = settings
257
+ self._llm: llama_cpp.Llama | None = None
258
+ self._init_lock = threading.Lock()
259
  self._load_time_seconds: float | None = None
260
  self._last_error: str | None = None
261
+ self._busy: bool = False
262
 
263
  @property
264
  def model_loaded(self) -> bool:
 
276
  def load_time_seconds(self) -> float | None:
277
  return self._load_time_seconds
278
 
279
+ @property
280
+ def busy(self) -> bool:
281
+ return self._busy
282
+
283
  def load_model(self):
284
  if self._llm is not None:
285
  return self._llm
286
 
287
+ with self._init_lock:
288
  if self._llm is not None:
289
  return self._llm
290
 
 
324
  parts.append("<|im_start|>assistant\n")
325
  return "\n".join(parts)
326
 
327
+ def with_research_context(self, req: ResearchGenerateRequest, search_results: list[dict[str, str]], max_length: int) -> GenerateRequest:
328
+ if not search_results:
329
+ return req
330
+
331
+ formatted_results = []
332
+ for index, item in enumerate(search_results, start=1):
333
+ formatted_results.append(
334
+ f"[{index}] {item.get('title', 'Untitled')}\nURL: {item.get('url', '')}\nSummary: {item.get('snippet', '')}"
335
+ )
336
+
337
+ joined_results = "\n\n".join(formatted_results)
338
+ research_prompt = (
339
+ "Use the following live web search results as untrusted background context. "
340
+ "Cite URLs when you rely on them. If the results are weak or irrelevant, say so rather than fabricating details.\n\n"
341
+ "Search results:\n"
342
+ f"{joined_results}\n\n"
343
+ "User task:\n"
344
+ f"{req.prompt.strip()}"
345
  )
346
+ if len(research_prompt) > max_length:
347
+ research_prompt = research_prompt[: max_length - 120].rstrip() + "\n\n[Context truncated to fit prompt limit.]"
348
+ return req.model_copy(update={"prompt": research_prompt})
349
 
350
+ def _gen_params(self, req: GenerateRequest) -> dict:
351
+ return {
352
+ "max_tokens": req.max_tokens,
353
+ "temperature": req.temperature,
354
+ "top_p": req.top_p,
355
+ "top_k": req.top_k,
356
+ "repeat_penalty": req.repeat_penalty,
357
+ "frequency_penalty": 0.0,
358
+ "presence_penalty": 0.0,
359
+ "stop": STOP_TOKENS,
360
+ }
361
 
362
+ def generate_once(self, req: GenerateRequest) -> dict[str, object]:
363
  llm = self.load_model()
364
+ self._busy = True
365
+ try:
366
+ start = time.perf_counter()
367
+ result = cast(dict[str, Any], llm(self.build_prompt(req), **self._gen_params(req)))
368
+ response = result["choices"][0]["text"].strip()
369
+ elapsed = time.perf_counter() - start
370
+ logger.info("Generated %d chars in %.2fs", len(response), elapsed)
371
+ return {"response": response, "creator": CREATOR, "model": APP_NAME, "version": APP_VERSION}
372
+ finally:
373
+ self._busy = False
374
 
375
+ def _run_stream(self, req: GenerateRequest, out_q: queue.Queue):
376
+ """Runs streaming inference in a worker thread, puts tokens into out_q."""
377
  try:
378
+ llm = self.load_model()
379
+ self._busy = True
380
+ start = time.perf_counter()
381
+ token_count = 0
382
+ stream = cast(Any, llm(self.build_prompt(req), stream=True, **self._gen_params(req)))
383
+ for token in stream:
384
  if not token.get("choices"):
385
  continue
386
  text = token["choices"][0].get("text", "")
387
  if text:
388
  token_count += 1
389
+ out_q.put(f"data: {json.dumps({'token': text})}\n\n")
390
  elapsed = time.perf_counter() - start
391
  logger.info("Streamed %d tokens in %.2fs", token_count, elapsed)
392
+ out_q.put("data: [DONE]\n\n")
393
  except Exception as exc:
394
  logger.error("Stream error: %s", exc)
395
+ out_q.put(f"data: {json.dumps({'error': str(exc)})}\n\n")
396
+ finally:
397
+ self._busy = False
398
+ out_q.put(None)
399
 
400
 
401
  def _check_api_key(request: Request, settings: Settings) -> JSONResponse | None:
 
403
  return None
404
  auth = request.headers.get("Authorization", "")
405
  token = auth[7:] if auth.startswith("Bearer ") else auth
406
+
407
+ # Timing-safe comparison (prevent timing attacks)
408
+ import hmac
409
+
410
+ if not hmac.compare_digest(token, settings.api_key):
411
  return JSONResponse({"error": "Unauthorized"}, status_code=401)
412
  return None
413
 
414
 
415
+ class RateLimitMiddleware(BaseHTTPMiddleware):
416
+ def __init__(self, app, max_requests: int = DEFAULT_RATE_LIMIT_MAX, window_seconds: int = 60):
417
+ super().__init__(app)
418
+ self._max = max_requests
419
+ self._window = window_seconds
420
+ self._clients: dict[str, list[float]] = {}
421
+ self._lock = threading.Lock()
422
+ self._cleanup_done = 0
423
+
424
+ async def dispatch(self, request: Request, call_next):
425
+ client_ip = request.client.host if request.client else "unknown"
426
+ now = time.monotonic()
427
+
428
+ # Cleanup old entries periodically (every 1000 requests)
429
+ self._cleanup_done += 1
430
+ if self._cleanup_done > 1000:
431
+ self._cleanup_done = 0
432
+ with self._lock:
433
+ cutoff = now - self._window
434
+ self._clients = {ip: [t for t in ts if t >= cutoff] for ip, ts in self._clients.items() if ts}
435
+
436
+ with self._lock:
437
+ timestamps = self._clients.get(client_ip, [])
438
+ timestamps = [t for t in timestamps if now - t < self._window]
439
+ if len(timestamps) >= self._max:
440
+ return JSONResponse(
441
+ {"error": "Rate limit exceeded. Try again later."},
442
+ status_code=429,
443
+ headers={"Retry-After": str(self._window)},
444
+ )
445
+ timestamps.append(now)
446
+ self._clients[client_ip] = timestamps
447
+ return await call_next(request)
448
+
449
+
450
+ class RequestSizeLimitMiddleware(BaseHTTPMiddleware):
451
+ def __init__(self, app, max_bytes: int = 50_000):
452
+ super().__init__(app)
453
+ self._max = max_bytes
454
+
455
+ async def dispatch(self, request: Request, call_next):
456
+ content_length = request.headers.get("content-length")
457
+ if content_length:
458
+ try:
459
+ if int(content_length) > self._max:
460
+ return JSONResponse({"error": "Request body too large"}, status_code=413)
461
+ except ValueError:
462
+ return JSONResponse({"error": "Invalid Content-Length header"}, status_code=400)
463
+ return await call_next(request)
464
+
465
+
466
  def create_app(settings: Settings | None = None) -> FastAPI:
467
  settings = settings or Settings()
468
  model_service = ModelService(settings)
469
+ search_service = WebSearchService(settings)
470
+ model_lock = asyncio.Lock()
471
 
472
  logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s", datefmt="%H:%M:%S")
473
 
474
  @asynccontextmanager
475
  async def lifespan(_: FastAPI):
476
  if settings.preload_model:
477
+ with suppress(Exception):
478
  await asyncio.to_thread(model_service.load_model)
479
+ yield
 
 
 
 
 
480
 
481
  app = FastAPI(title=f"{APP_NAME} API", version=APP_VERSION, lifespan=lifespan)
482
  app.state.settings = settings
483
  app.state.model_service = model_service
484
+ app.state.search_service = search_service
485
+
486
+ frontend_assets_path = settings.frontend_path.parent / "assets"
487
+ if frontend_assets_path.exists():
488
+ app.mount("/assets", StaticFiles(directory=str(frontend_assets_path)), name="frontend-assets")
489
 
490
  cors_origins = [o.strip() for o in settings.cors_origins.split(",") if o.strip()]
491
+ app.add_middleware(
492
+ CORSMiddleware,
493
+ allow_origins=cors_origins,
494
+ allow_methods=["POST", "GET", "OPTIONS"],
495
+ allow_headers=["Content-Type", "Authorization"],
496
+ )
497
+
498
+ if _bool_from_env("BLITZKODE_RATE_LIMIT", default=True):
499
+ app.add_middleware(RateLimitMiddleware, max_requests=_int_from_env("BLITZKODE_RATE_LIMIT_MAX", DEFAULT_RATE_LIMIT_MAX))
500
+ app.add_middleware(RequestSizeLimitMiddleware, max_bytes=_int_from_env("BLITZKODE_MAX_REQUEST_BYTES", 50_000))
501
 
502
  @app.get("/")
503
  async def root():
504
  if not settings.frontend_path.exists():
505
+ raise HTTPException(status_code=404, detail="Frontend build is missing. Run `npm install` and `npm run build` in frontend/.")
506
  return FileResponse(str(settings.frontend_path))
507
 
508
  @app.get("/health")
509
  async def health():
510
+ status = "healthy" if model_service.model_exists else "degraded"
511
+ return JSONResponse(
512
+ {
513
+ "status": status,
514
+ "model_loaded": model_service.model_loaded,
515
+ "model_path": str(settings.model_path),
516
+ "model_exists": model_service.model_exists,
517
+ "frontend_exists": settings.frontend_path.exists(),
518
+ "version": APP_VERSION,
519
+ "gpu_layers": settings.n_gpu_layers,
520
+ "last_error": model_service.last_error,
521
+ "busy": model_service.busy,
522
+ }
523
+ )
524
 
525
  @app.post("/generate")
526
  async def generate(req: GenerateRequest, request: Request):
 
532
  if err:
533
  return err
534
 
535
+ async with model_lock:
536
+ try:
537
+ sanitized = req.model_copy(update={"prompt": prompt})
538
+ payload = await asyncio.to_thread(model_service.generate_once, sanitized)
539
+ return JSONResponse(payload)
540
+ except FileNotFoundError as exc:
541
+ return JSONResponse({"error": str(exc)}, status_code=503)
542
+ except Exception as exc:
543
+ return JSONResponse({"error": str(exc)}, status_code=500)
544
+
545
+ @app.post("/generate/research")
546
+ async def generate_research(req: ResearchGenerateRequest, request: Request):
547
+ auth_err = _check_api_key(request, settings)
548
+ if auth_err:
549
+ return auth_err
550
+
551
+ prompt, err = _validate_prompt(req.prompt, settings.max_prompt_length)
552
+ if err:
553
+ return err
554
+
555
+ if not search_service.enabled:
556
+ return JSONResponse({"error": "Web search is disabled"}, status_code=503)
557
+
558
+ search_query = (req.search_query or prompt).strip()
559
  try:
560
+ results = await asyncio.to_thread(search_service.search, search_query, req.search_results, req.deep_search)
561
  sanitized = req.model_copy(update={"prompt": prompt})
562
+ enriched = model_service.with_research_context(sanitized, results, settings.max_prompt_length)
563
+ async with model_lock:
564
+ payload = await asyncio.to_thread(model_service.generate_once, enriched)
565
+ payload["search_results"] = results
566
  return JSONResponse(payload)
567
  except FileNotFoundError as exc:
568
  return JSONResponse({"error": str(exc)}, status_code=503)
569
+ except (RuntimeError, urllib.error.URLError, TimeoutError) as exc:
570
+ return JSONResponse({"error": f"Search failed: {exc}"}, status_code=502)
571
  except Exception as exc:
572
  return JSONResponse({"error": str(exc)}, status_code=500)
573
 
 
585
  return JSONResponse({"error": f"Model not found at {settings.model_path}"}, status_code=503)
586
 
587
  sanitized = req.model_copy(update={"prompt": prompt})
588
+
589
+ async def _locked_stream():
590
+ async with model_lock:
591
+ token_q: queue.Queue = queue.Queue()
592
+ thread = threading.Thread(
593
+ target=model_service._run_stream,
594
+ args=(sanitized, token_q),
595
+ daemon=True,
596
+ )
597
+ thread.start()
598
+ # Use thread-safe queue.get() instead of deprecated get_running_loop()
599
+ while True:
600
+ chunk = await asyncio.to_thread(token_q.get)
601
+ if chunk is None:
602
+ break
603
+ yield chunk
604
+
605
  return StreamingResponse(
606
+ _locked_stream(),
607
  media_type="text/event-stream",
608
  headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
609
  )
610
 
611
+ @app.post("/search/web")
612
+ async def search_web(req: SearchRequest, request: Request):
613
+ auth_err = _check_api_key(request, settings)
614
+ if auth_err:
615
+ return auth_err
616
+
617
+ if not search_service.enabled:
618
+ return JSONResponse({"error": "Web search is disabled"}, status_code=503)
619
+
620
+ try:
621
+ results = await asyncio.to_thread(search_service.search, req.query, req.max_results, req.deep)
622
+ return JSONResponse({"query": req.query.strip(), "deep": req.deep, "results": results})
623
+ except (RuntimeError, urllib.error.URLError, TimeoutError) as exc:
624
+ return JSONResponse({"error": f"Search failed: {exc}"}, status_code=502)
625
+ except Exception as exc:
626
+ return JSONResponse({"error": str(exc)}, status_code=500)
627
+
628
  @app.get("/info")
629
  async def info():
630
+ return JSONResponse(
631
+ {
632
+ "name": APP_NAME,
633
+ "creator": CREATOR,
634
+ "version": APP_VERSION,
635
+ "status": "ready" if model_service.model_exists else "model-missing",
636
+ "mode": f"{'GPU' if settings.n_gpu_layers > 0 else 'CPU'} (llama.cpp)",
637
+ "gpu_layers": settings.n_gpu_layers,
638
+ "context_window": settings.n_ctx,
639
+ "model_loaded": model_service.model_loaded,
640
+ "load_time_seconds": model_service.load_time_seconds,
641
+ "busy": model_service.busy,
642
+ "web_search_enabled": search_service.enabled,
643
+ "endpoints": {
644
+ "generate": "POST /generate",
645
+ "research_generate": "POST /generate/research",
646
+ "stream": "POST /generate/stream",
647
+ "search": "POST /search/web",
648
+ "health": "GET /health",
649
+ "info": "GET /info",
650
+ },
651
+ }
652
+ )
653
 
654
  return app
655
 
 
658
 
659
 
660
  def main() -> None:
661
+ s = Settings()
662
  print(f"\n{'=' * 50}")
663
  print(f"{APP_NAME.upper()} v{APP_VERSION}")
664
  print(f"Creator: {CREATOR}")
665
  print(f"{'=' * 50}")
666
  print(f"Model: {s.model_path}")
667
  print(f"GPU: {s.n_gpu_layers} layers")
668
+ print(f"Ctx: {s.n_ctx} | Threads: {s.n_threads}")
669
  print(f"URL: http://localhost:{s.port}\n")
670
 
671
  uvicorn.run(app, host=s.host, port=s.port, log_level="warning")