Update main.py
Browse files
main.py
CHANGED
|
@@ -1,302 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import ast
|
| 3 |
import re
|
| 4 |
-
import
|
|
|
|
| 5 |
import pathlib
|
| 6 |
import asyncio
|
|
|
|
| 7 |
from concurrent.futures import ThreadPoolExecutor
|
| 8 |
from contextlib import asynccontextmanager
|
|
|
|
| 9 |
from typing import Annotated
|
| 10 |
|
| 11 |
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
| 12 |
-
os.environ["TOKENIZERS_PARALLELISM"]
|
| 13 |
-
os.environ["HF_HUB_VERBOSITY"]
|
|
|
|
|
|
|
| 14 |
|
| 15 |
-
import torch
|
| 16 |
import numpy as np
|
| 17 |
-
import
|
| 18 |
-
|
| 19 |
-
# ββ Compatibility patches ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 20 |
-
# jina-bert-v2 (trust_remote_code) was written against transformers 4.x.
|
| 21 |
-
# Transformers 5.x removed / broke three things the model relies on.
|
| 22 |
-
# All patches are no-ops when the symbol already exists.
|
| 23 |
-
#
|
| 24 |
-
# 1. find_pruneable_heads_and_indices β removed from pytorch_utils
|
| 25 |
-
# 2. PretrainedConfig.is_decoder etc β no longer set as instance defaults
|
| 26 |
-
# 3. PreTrainedModel.get_head_mask β removed from modeling_utils in T5
|
| 27 |
-
|
| 28 |
-
# ββ patch 1: pytorch_utils ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 29 |
-
import transformers.pytorch_utils as _pt_utils
|
| 30 |
-
if not hasattr(_pt_utils, "find_pruneable_heads_and_indices"):
|
| 31 |
-
def _find_pruneable_heads_and_indices(
|
| 32 |
-
heads, n_heads: int, head_size: int, already_pruned_heads
|
| 33 |
-
):
|
| 34 |
-
mask = torch.ones(n_heads, head_size)
|
| 35 |
-
heads = set(heads) - already_pruned_heads
|
| 36 |
-
for head in heads:
|
| 37 |
-
head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
|
| 38 |
-
mask[head] = 0
|
| 39 |
-
mask = mask.view(-1).contiguous().eq(1)
|
| 40 |
-
index = torch.arange(len(mask))[mask].long()
|
| 41 |
-
return heads, index
|
| 42 |
-
_pt_utils.find_pruneable_heads_and_indices = _find_pruneable_heads_and_indices
|
| 43 |
-
|
| 44 |
-
# ββ patch 2: PretrainedConfig legacy defaults βββββββββββββββββββββββββββββββββ
|
| 45 |
-
import transformers.configuration_utils as _cfg_utils
|
| 46 |
-
_PC = _cfg_utils.PretrainedConfig
|
| 47 |
-
if not hasattr(_PC, "_jina_compat_patched"):
|
| 48 |
-
_LEGACY_CFG_DEFAULTS = {
|
| 49 |
-
"is_decoder": False,
|
| 50 |
-
"add_cross_attention": False,
|
| 51 |
-
"cross_attention_hidden_size": None,
|
| 52 |
-
"use_cache": True,
|
| 53 |
-
}
|
| 54 |
-
def _pc_getattr(self, key: str):
|
| 55 |
-
if key in _LEGACY_CFG_DEFAULTS:
|
| 56 |
-
return _LEGACY_CFG_DEFAULTS[key]
|
| 57 |
-
raise AttributeError(
|
| 58 |
-
f"'{type(self).__name__}' object has no attribute '{key}'"
|
| 59 |
-
)
|
| 60 |
-
_PC.__getattr__ = _pc_getattr
|
| 61 |
-
_PC._jina_compat_patched = True
|
| 62 |
-
|
| 63 |
-
# ββ patch 3: PreTrainedModel.get_head_mask ββββββββββββββββββββββββββββββββββββ
|
| 64 |
-
import transformers.modeling_utils as _mod_utils
|
| 65 |
-
_PTM = _mod_utils.PreTrainedModel
|
| 66 |
-
if not hasattr(_PTM, "get_head_mask"):
|
| 67 |
-
def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
|
| 68 |
-
if head_mask.dim() == 1:
|
| 69 |
-
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
| 70 |
-
head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
|
| 71 |
-
elif head_mask.dim() == 2:
|
| 72 |
-
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
| 73 |
-
assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
|
| 74 |
-
head_mask = head_mask.to(dtype=self.dtype)
|
| 75 |
-
return head_mask
|
| 76 |
-
|
| 77 |
-
def _get_head_mask(self, head_mask, num_hidden_layers, is_attention_chunked=False):
|
| 78 |
-
if head_mask is not None:
|
| 79 |
-
head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
|
| 80 |
-
if is_attention_chunked:
|
| 81 |
-
head_mask = head_mask.unsqueeze(-1)
|
| 82 |
-
else:
|
| 83 |
-
head_mask = [None] * num_hidden_layers
|
| 84 |
-
return head_mask
|
| 85 |
-
|
| 86 |
-
if not hasattr(_PTM, "_convert_head_mask_to_5d"):
|
| 87 |
-
_PTM._convert_head_mask_to_5d = _convert_head_mask_to_5d
|
| 88 |
-
_PTM.get_head_mask = _get_head_mask
|
| 89 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 90 |
-
|
| 91 |
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
| 92 |
from fastapi.middleware.cors import CORSMiddleware
|
| 93 |
from pydantic import BaseModel, Field
|
| 94 |
from sentence_transformers import SentenceTransformer
|
| 95 |
|
| 96 |
|
| 97 |
-
# βββββββββββββββββββββββββββ Constants
|
| 98 |
-
DIM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
|
|
|
| 100 |
def _resolve_store_dir() -> pathlib.Path:
|
| 101 |
-
""
|
| 102 |
-
Try /data/indexes (HF Spaces persistent volume).
|
| 103 |
-
Fall back to ~/.cache/code-search/indexes if /data is not writable
|
| 104 |
-
(local dev, or volume not yet mounted with correct permissions).
|
| 105 |
-
"""
|
| 106 |
-
primary = pathlib.Path("/data/indexes")
|
| 107 |
try:
|
| 108 |
primary.mkdir(parents=True, exist_ok=True)
|
| 109 |
probe = primary / ".write_probe"
|
| 110 |
-
probe.touch()
|
| 111 |
-
probe.unlink()
|
| 112 |
return primary
|
| 113 |
except OSError:
|
| 114 |
-
fallback = pathlib.Path.home() / ".cache" / "code-search" / "
|
| 115 |
fallback.mkdir(parents=True, exist_ok=True)
|
| 116 |
-
print(f"Warning: /data/
|
| 117 |
return fallback
|
| 118 |
|
| 119 |
STORE_DIR = _resolve_store_dir()
|
| 120 |
|
| 121 |
-
LANGUAGE_MAP = {
|
| 122 |
-
".py": "python",
|
| 123 |
-
".js": "javascript",
|
| 124 |
-
".ts": "typescript",
|
| 125 |
-
".tsx": "typescript",
|
| 126 |
-
".jsx": "javascript",
|
| 127 |
-
".go": "go",
|
| 128 |
-
".rs": "rust",
|
| 129 |
-
".java": "java",
|
| 130 |
-
".cpp": "cpp",
|
| 131 |
-
".c": "c",
|
| 132 |
-
".cs": "csharp",
|
| 133 |
-
".rb": "ruby",
|
| 134 |
-
".php": "php",
|
| 135 |
-
".md": "markdown",
|
| 136 |
-
".txt": "text",
|
| 137 |
-
}
|
| 138 |
-
|
| 139 |
|
| 140 |
-
# βββββββββββββββββββββββββββ
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
@asynccontextmanager
|
| 149 |
async def lifespan(app: FastAPI):
|
| 150 |
-
print("Loading jina-embeddings-v2-base-codeβ¦")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
model = SentenceTransformer(
|
| 152 |
-
"jinaai/jina-embeddings-v2-base-code",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
)
|
| 154 |
model.max_seq_length = 8192
|
| 155 |
-
# Cast to float16 β cuts model VRAM/RAM from ~550 MB to ~275 MB.
|
| 156 |
-
# SentenceTransformer wraps a nn.Module; half() applies recursively.
|
| 157 |
-
model.half()
|
| 158 |
-
model.eval()
|
| 159 |
models["model"] = model
|
| 160 |
-
print("Model ready.")
|
| 161 |
-
|
| 162 |
-
# Restore persisted indexes from /data
|
| 163 |
-
if STORE_DIR.exists():
|
| 164 |
-
for faiss_path in STORE_DIR.glob("*.faiss"):
|
| 165 |
-
doc_id = faiss_path.stem
|
| 166 |
-
meta_path = STORE_DIR / f"{doc_id}.meta.pkl"
|
| 167 |
-
if not meta_path.exists():
|
| 168 |
-
continue
|
| 169 |
-
try:
|
| 170 |
-
index = faiss.read_index(str(faiss_path))
|
| 171 |
-
with open(meta_path, "rb") as f:
|
| 172 |
-
meta = pickle.load(f)
|
| 173 |
-
store[doc_id] = {"chunks": meta["chunks"], "index": index}
|
| 174 |
-
print(f"Restored index: {doc_id} ({index.ntotal} vectors)")
|
| 175 |
-
except Exception as e:
|
| 176 |
-
print(f"Warning: could not restore {doc_id}: {e}")
|
| 177 |
-
|
| 178 |
yield
|
| 179 |
models.clear()
|
| 180 |
|
| 181 |
|
| 182 |
-
# βββββββββββββββββββββββββββ App
|
| 183 |
-
MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_MB", "50")) * 1024 * 1024 # default 50 MB
|
| 184 |
-
|
| 185 |
app = FastAPI(
|
| 186 |
title="Code Search API",
|
| 187 |
-
description=
|
| 188 |
-
|
| 189 |
-
"jinaai/jina-embeddings-v2-base-code + FAISS ANN search."
|
| 190 |
-
),
|
| 191 |
-
version="2.0.0",
|
| 192 |
lifespan=lifespan,
|
| 193 |
)
|
| 194 |
-
|
| 195 |
app.add_middleware(
|
| 196 |
CORSMiddleware,
|
| 197 |
-
allow_origins=["*"],
|
| 198 |
-
allow_methods=["*"],
|
| 199 |
-
allow_headers=["*"],
|
| 200 |
)
|
| 201 |
|
| 202 |
|
| 203 |
-
# βββββββββββββββββββββββββββ
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
def encode(texts: list[str]) -> np.ndarray:
|
| 207 |
"""
|
| 208 |
-
Synchronous encode
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
intermediate tensors immediately instead of waiting for GC.
|
| 213 |
"""
|
| 214 |
-
|
| 215 |
-
all_embeddings = []
|
| 216 |
for i in range(0, len(texts), ENCODE_BATCH_SIZE):
|
| 217 |
batch = texts[i : i + ENCODE_BATCH_SIZE]
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
all_embeddings.append(np.array(embs, dtype=np.float32))
|
| 226 |
-
# Free activations between batches
|
| 227 |
gc.collect()
|
| 228 |
-
|
| 229 |
-
torch.cuda.empty_cache()
|
| 230 |
-
return np.vstack(all_embeddings)
|
| 231 |
|
| 232 |
|
| 233 |
-
async def
|
| 234 |
-
"""Non-blocking wrapper β frees the event loop during model inference."""
|
| 235 |
loop = asyncio.get_event_loop()
|
| 236 |
-
return await loop.run_in_executor(_executor,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|
|
|
|
|
|
|
| 238 |
|
| 239 |
-
|
| 240 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
"""
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
|
|
|
|
|
|
|
|
|
| 247 |
"""
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
|
|
|
|
|
|
|
| 262 |
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
for i in range(len(indices[0]))
|
| 275 |
-
if indices[0][i] >= 0 # FAISS returns -1 for empty slots
|
| 276 |
-
]
|
| 277 |
|
|
|
|
|
|
|
| 278 |
|
| 279 |
-
# βββββββββββββββββββββββββββ Persistence helpers βββββββββββββββββββββββββββββ
|
| 280 |
-
def persist_index(doc_id: str, chunks: list[str], index: faiss.Index) -> None:
|
| 281 |
-
STORE_DIR.mkdir(parents=True, exist_ok=True)
|
| 282 |
-
faiss.write_index(index, str(STORE_DIR / f"{doc_id}.faiss"))
|
| 283 |
-
with open(STORE_DIR / f"{doc_id}.meta.pkl", "wb") as f:
|
| 284 |
-
pickle.dump({"chunks": chunks, "doc_id": doc_id}, f)
|
| 285 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
|
| 291 |
|
| 292 |
-
# βββββββββββββββββββββββββββ Chunking
|
| 293 |
def detect_language(filename: str) -> str:
|
| 294 |
-
|
| 295 |
-
return LANGUAGE_MAP.get(ext, "text")
|
| 296 |
|
| 297 |
|
| 298 |
def chunk_text(text: str, chunk_size: int = 3, overlap: int = 1) -> list[str]:
|
| 299 |
-
"""Sentence-window chunker for prose / markdown."""
|
| 300 |
sentences = re.split(r'(?<=[.!?])\s+', text.strip())
|
| 301 |
sentences = [s.strip() for s in sentences if s.strip()]
|
| 302 |
chunks, i = [], 0
|
|
@@ -307,7 +347,6 @@ def chunk_text(text: str, chunk_size: int = 3, overlap: int = 1) -> list[str]:
|
|
| 307 |
|
| 308 |
|
| 309 |
def chunk_fallback(source: str, max_lines: int = 40, overlap: int = 5) -> list[str]:
|
| 310 |
-
"""Fixed line-window chunking with overlap β last resort."""
|
| 311 |
lines = source.splitlines()
|
| 312 |
chunks = []
|
| 313 |
i = 0
|
|
@@ -318,16 +357,13 @@ def chunk_fallback(source: str, max_lines: int = 40, overlap: int = 5) -> list[s
|
|
| 318 |
|
| 319 |
|
| 320 |
def chunk_python(source: str, filepath: str = "") -> list[str]:
|
| 321 |
-
"""AST-based chunker β extracts functions and classes."""
|
| 322 |
try:
|
| 323 |
tree = ast.parse(source)
|
| 324 |
lines = source.splitlines()
|
| 325 |
chunks = []
|
| 326 |
for node in ast.walk(tree):
|
| 327 |
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
|
| 328 |
-
|
| 329 |
-
end = node.end_lineno
|
| 330 |
-
snippet = "\n".join(lines[start:end])
|
| 331 |
prefix = f"# {filepath}\n" if filepath else ""
|
| 332 |
chunks.append(f"{prefix}{snippet}")
|
| 333 |
return chunks if chunks else chunk_fallback(source)
|
|
@@ -336,14 +372,9 @@ def chunk_python(source: str, filepath: str = "") -> list[str]:
|
|
| 336 |
|
| 337 |
|
| 338 |
def chunk_generic(source: str, filepath: str = "") -> list[str]:
|
| 339 |
-
"""
|
| 340 |
-
Regex chunker for JS, TS, Go, Rust, Java, C++, etc.
|
| 341 |
-
Splits on function / class declaration boundaries.
|
| 342 |
-
"""
|
| 343 |
pattern = re.compile(
|
| 344 |
r'(?:^|\n)(?='
|
| 345 |
-
r'(?:export\s+)?'
|
| 346 |
-
r'(?:async\s+)?'
|
| 347 |
r'(?:function|class|const\s+\w+\s*=\s*(?:async\s+)?(?:\(|function)|'
|
| 348 |
r'(?:public|private|protected|static|\s)*(?:fn|func|def)\s+\w+)'
|
| 349 |
r')',
|
|
@@ -356,7 +387,6 @@ def chunk_generic(source: str, filepath: str = "") -> list[str]:
|
|
| 356 |
|
| 357 |
|
| 358 |
def chunk_code(source: str, filename: str = "") -> list[str]:
|
| 359 |
-
"""Master dispatcher β routes to the best chunker for the file type."""
|
| 360 |
lang = detect_language(filename)
|
| 361 |
if lang == "python":
|
| 362 |
return chunk_python(source, filepath=filename)
|
|
@@ -366,58 +396,36 @@ def chunk_code(source: str, filename: str = "") -> list[str]:
|
|
| 366 |
return chunk_generic(source, filepath=filename)
|
| 367 |
|
| 368 |
|
| 369 |
-
# βββββββββββββββββββββββββββ Schemas
|
| 370 |
class IndexResponse(BaseModel):
|
| 371 |
-
doc_id: str
|
| 372 |
-
chunks_indexed: int
|
| 373 |
-
message: str
|
| 374 |
-
|
| 375 |
|
| 376 |
class SearchRequest(BaseModel):
|
| 377 |
-
doc_id: str = Field(
|
| 378 |
-
query: str = Field(..., description="Natural language or code query")
|
| 379 |
-
top_k: int = Field(5, ge=1, le=20)
|
| 380 |
-
|
| 381 |
|
| 382 |
class SearchResult(BaseModel):
|
| 383 |
-
rank:
|
| 384 |
-
score: float
|
| 385 |
-
text: str
|
| 386 |
-
|
| 387 |
|
| 388 |
class SearchResponse(BaseModel):
|
| 389 |
-
doc_id:
|
| 390 |
-
query: str
|
| 391 |
-
results: list[SearchResult]
|
| 392 |
-
|
| 393 |
|
| 394 |
class EmbedRequest(BaseModel):
|
| 395 |
-
texts: list[str] = Field(
|
| 396 |
-
|
| 397 |
|
| 398 |
class EmbedResponse(BaseModel):
|
| 399 |
-
embeddings: list[list[float]]
|
| 400 |
-
dimensions: int
|
| 401 |
-
|
| 402 |
|
| 403 |
class FileEntry(BaseModel):
|
| 404 |
-
filename: str
|
| 405 |
-
content: str # raw file content as string
|
| 406 |
-
|
| 407 |
|
| 408 |
class BatchIndexRequest(BaseModel):
|
| 409 |
-
doc_id:
|
| 410 |
-
files: list[FileEntry]
|
| 411 |
-
replace: bool = True # if True, replaces existing index for this doc_id
|
| 412 |
-
|
| 413 |
|
| 414 |
class BatchIndexResponse(BaseModel):
|
| 415 |
-
doc_id:
|
| 416 |
-
files_indexed: int
|
| 417 |
-
chunks_indexed: int
|
| 418 |
|
| 419 |
|
| 420 |
-
# βββββββββββββββββββββββββββ Routes
|
| 421 |
@app.get("/", tags=["health"])
|
| 422 |
def root():
|
| 423 |
return {"status": "ok", "docs": "/docs"}
|
|
@@ -425,42 +433,34 @@ def root():
|
|
| 425 |
|
| 426 |
@app.get("/health", tags=["health"])
|
| 427 |
def health():
|
| 428 |
-
return {"status": "ok", "models_loaded": bool(models)
|
|
|
|
| 429 |
|
| 430 |
|
| 431 |
@app.post("/index", response_model=IndexResponse, tags=["search"])
|
| 432 |
async def index_document(
|
| 433 |
-
file:
|
| 434 |
-
doc_id:
|
| 435 |
):
|
| 436 |
-
"""
|
| 437 |
-
Upload a source file and embed it with code-aware chunking.
|
| 438 |
-
Returns the doc_id to use in /search.
|
| 439 |
-
"""
|
| 440 |
if not models:
|
| 441 |
-
raise HTTPException(503, "Model not loaded yet
|
| 442 |
|
| 443 |
-
content
|
| 444 |
if len(content) > MAX_UPLOAD_BYTES:
|
| 445 |
-
raise HTTPException(
|
| 446 |
-
|
| 447 |
-
f"
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
source = content.decode("utf-8", errors="replace")
|
| 452 |
-
filename = file.filename or "unknown"
|
| 453 |
resolved_id = doc_id.strip() or os.path.splitext(filename)[0]
|
| 454 |
|
| 455 |
chunks = chunk_code(source, filename=filename)
|
| 456 |
if not chunks:
|
| 457 |
-
raise HTTPException(400, "Document produced no chunks.
|
| 458 |
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
store[resolved_id] = {"chunks": chunks, "index": index}
|
| 462 |
-
persist_index(resolved_id, chunks, index)
|
| 463 |
-
import gc; gc.collect() # free encoding intermediates before responding
|
| 464 |
|
| 465 |
return IndexResponse(
|
| 466 |
doc_id=resolved_id,
|
|
@@ -471,36 +471,23 @@ async def index_document(
|
|
| 471 |
|
| 472 |
@app.post("/index/batch", response_model=BatchIndexResponse, tags=["search"])
|
| 473 |
async def index_batch(req: BatchIndexRequest):
|
| 474 |
-
"""
|
| 475 |
-
Index an entire codebase in one HTTP call.
|
| 476 |
-
Ideal for IDE integrations β send all files, get one searchable doc_id back.
|
| 477 |
-
"""
|
| 478 |
if not models:
|
| 479 |
raise HTTPException(503, "Model not loaded yet.")
|
| 480 |
|
| 481 |
-
|
| 482 |
-
del store[req.doc_id]
|
| 483 |
-
delete_persisted(req.doc_id)
|
| 484 |
-
|
| 485 |
all_chunks: list[str] = []
|
| 486 |
for entry in req.files:
|
| 487 |
all_chunks.extend(chunk_code(entry.content, filename=entry.filename))
|
| 488 |
|
| 489 |
if not all_chunks:
|
| 490 |
raise HTTPException(400, "No chunks produced from provided files.")
|
| 491 |
-
|
| 492 |
-
MAX_CHUNKS = int(os.getenv("MAX_CHUNKS", "10000")) # ~3 GB RAM at 10k chunks; raise carefully
|
| 493 |
if len(all_chunks) > MAX_CHUNKS:
|
| 494 |
-
raise HTTPException(
|
| 495 |
-
|
| 496 |
-
f"Too many chunks ({len(all_chunks):,}). Max: {MAX_CHUNKS:,}. "
|
| 497 |
-
"Split your project into smaller doc_id groups.",
|
| 498 |
-
)
|
| 499 |
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
persist_index(req.doc_id, all_chunks, index)
|
| 504 |
|
| 505 |
return BatchIndexResponse(
|
| 506 |
doc_id=req.doc_id,
|
|
@@ -511,11 +498,13 @@ async def index_batch(req: BatchIndexRequest):
|
|
| 511 |
|
| 512 |
@app.post("/search", response_model=SearchResponse, tags=["search"])
|
| 513 |
async def search_document(req: SearchRequest):
|
| 514 |
-
|
| 515 |
-
if req.doc_id not in store:
|
| 516 |
raise HTTPException(404, f"doc_id '{req.doc_id}' not found. Call /index first.")
|
| 517 |
|
| 518 |
-
|
|
|
|
|
|
|
|
|
|
| 519 |
return SearchResponse(
|
| 520 |
doc_id=req.doc_id,
|
| 521 |
query=req.query,
|
|
@@ -525,35 +514,33 @@ async def search_document(req: SearchRequest):
|
|
| 525 |
|
| 526 |
@app.post("/embed", response_model=EmbedResponse, tags=["embeddings"])
|
| 527 |
async def embed_texts(req: EmbedRequest):
|
| 528 |
-
"""Embed arbitrary texts. Returns raw float embeddings."""
|
| 529 |
if not models:
|
| 530 |
raise HTTPException(503, "Model not loaded yet.")
|
| 531 |
if len(req.texts) > 64:
|
| 532 |
raise HTTPException(400, "Maximum 64 texts per request.")
|
| 533 |
|
| 534 |
-
embs = await
|
| 535 |
-
return EmbedResponse(
|
| 536 |
-
embeddings=embs.tolist(),
|
| 537 |
-
dimensions=embs.shape[1],
|
| 538 |
-
)
|
| 539 |
|
| 540 |
|
| 541 |
@app.get("/documents", tags=["search"])
|
| 542 |
def list_documents():
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
|
|
|
|
|
|
|
|
|
| 550 |
|
| 551 |
|
| 552 |
@app.delete("/documents/{doc_id}", tags=["search"])
|
| 553 |
def delete_document(doc_id: str):
|
| 554 |
-
|
| 555 |
-
if doc_id not in store:
|
| 556 |
raise HTTPException(404, f"doc_id '{doc_id}' not found.")
|
| 557 |
-
|
| 558 |
-
|
| 559 |
return {"deleted": doc_id}
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Code Search API β v3.0
|
| 3 |
+
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 4 |
+
Key architecture changes from v2:
|
| 5 |
+
|
| 6 |
+
β’ Model : ONNX fp16 via sentence-transformers backend="onnx"
|
| 7 |
+
β ONNX Runtime replaces PyTorch for every forward pass.
|
| 8 |
+
β Pre-built onnx/model_fp16.onnx from the HF repo is used
|
| 9 |
+
directly β no export step, no trust_remote_code issues.
|
| 10 |
+
β All three transformers-compatibility patches removed.
|
| 11 |
+
|
| 12 |
+
β’ Storage : LanceDB (disk-backed, columnar, mmap)
|
| 13 |
+
β Vectors live on disk, not in Python RAM.
|
| 14 |
+
β Chunks stored alongside vectors in the same table β
|
| 15 |
+
no separate pickle files.
|
| 16 |
+
β FAISS removed entirely.
|
| 17 |
+
|
| 18 |
+
β’ Indexing: Streaming pipeline
|
| 19 |
+
β Chunks are produced, encoded in micro-batches, and written
|
| 20 |
+
to LanceDB immediately. The full embeddings array is never
|
| 21 |
+
held in RAM.
|
| 22 |
+
|
| 23 |
+
β’ Retrieval: On-demand table loading + LRU cache
|
| 24 |
+
β Tables are opened from disk per request.
|
| 25 |
+
β An LRU cache (default: 5 tables, TTL: 10 min) keeps
|
| 26 |
+
recently used handles warm without pinning everything.
|
| 27 |
+
|
| 28 |
+
β’ RAM budget (approximate, CPU-only HF Space):
|
| 29 |
+
Model weights (fp16 ONNX) ~275 MB
|
| 30 |
+
Encoding peak (batch=8) ~100 MB transient
|
| 31 |
+
LanceDB per query ~10-50 MB transient
|
| 32 |
+
Python overhead ~150 MB
|
| 33 |
+
βββββββββββββββββββββββββββββββββββββ
|
| 34 |
+
Total steady-state ~425 MB (vs ~16 GB before)
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
import os
|
| 38 |
import ast
|
| 39 |
import re
|
| 40 |
+
import gc
|
| 41 |
+
import time
|
| 42 |
import pathlib
|
| 43 |
import asyncio
|
| 44 |
+
from collections import OrderedDict
|
| 45 |
from concurrent.futures import ThreadPoolExecutor
|
| 46 |
from contextlib import asynccontextmanager
|
| 47 |
+
from threading import Lock
|
| 48 |
from typing import Annotated
|
| 49 |
|
| 50 |
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
| 51 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 52 |
+
os.environ["HF_HUB_VERBOSITY"] = "error"
|
| 53 |
+
# Tell ONNX Runtime to use a modest thread count so it doesn't spike RSS
|
| 54 |
+
os.environ.setdefault("OMP_NUM_THREADS", "2")
|
| 55 |
|
|
|
|
| 56 |
import numpy as np
|
| 57 |
+
import lancedb
|
| 58 |
+
import pyarrow as pa
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
| 60 |
from fastapi.middleware.cors import CORSMiddleware
|
| 61 |
from pydantic import BaseModel, Field
|
| 62 |
from sentence_transformers import SentenceTransformer
|
| 63 |
|
| 64 |
|
| 65 |
+
# βββββββββββββββββββββββββββ Constants ββββββββββββββββββββββββββββββββββββββββ
|
| 66 |
+
DIM = 768
|
| 67 |
+
ENCODE_BATCH_SIZE = int(os.getenv("ENCODE_BATCH_SIZE", "8"))
|
| 68 |
+
MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_MB", "50")) * 1024 * 1024
|
| 69 |
+
MAX_CHUNKS = int(os.getenv("MAX_CHUNKS", "10000"))
|
| 70 |
+
LRU_MAXSIZE = int(os.getenv("LRU_TABLE_CACHE", "5"))
|
| 71 |
+
LRU_TTL = int(os.getenv("LRU_TTL_SECONDS", "600")) # 10 min
|
| 72 |
+
|
| 73 |
+
LANGUAGE_MAP = {
|
| 74 |
+
".py": "python", ".js": "javascript", ".ts": "typescript",
|
| 75 |
+
".tsx": "typescript", ".jsx": "javascript", ".go": "go",
|
| 76 |
+
".rs": "rust", ".java": "java", ".cpp": "cpp",
|
| 77 |
+
".c": "c", ".cs": "csharp", ".rb": "ruby",
|
| 78 |
+
".php": "php", ".md": "markdown", ".txt": "text",
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
# LanceDB schema β one row per chunk
|
| 82 |
+
_SCHEMA = pa.schema([
|
| 83 |
+
pa.field("chunk_id", pa.int32()),
|
| 84 |
+
pa.field("text", pa.large_utf8()),
|
| 85 |
+
pa.field("vector", pa.list_(pa.float32(), DIM)),
|
| 86 |
+
])
|
| 87 |
+
|
| 88 |
|
| 89 |
+
# βββββββββββββββββββββββββββ Storage directory ββββββββββββββββββββββββββββββββ
|
| 90 |
def _resolve_store_dir() -> pathlib.Path:
|
| 91 |
+
primary = pathlib.Path("/data/lancedb")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
try:
|
| 93 |
primary.mkdir(parents=True, exist_ok=True)
|
| 94 |
probe = primary / ".write_probe"
|
| 95 |
+
probe.touch(); probe.unlink()
|
|
|
|
| 96 |
return primary
|
| 97 |
except OSError:
|
| 98 |
+
fallback = pathlib.Path.home() / ".cache" / "code-search" / "lancedb"
|
| 99 |
fallback.mkdir(parents=True, exist_ok=True)
|
| 100 |
+
print(f"Warning: /data/lancedb not writable β using fallback: {fallback}")
|
| 101 |
return fallback
|
| 102 |
|
| 103 |
STORE_DIR = _resolve_store_dir()
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
+
# βββββββββββββββββββββββββββ LRU table-handle cache βββββββββββββββββββββββββββ
|
| 107 |
+
class _LRUTableCache:
|
| 108 |
+
"""
|
| 109 |
+
Keeps up to `maxsize` LanceDB table handles open in memory.
|
| 110 |
+
Entries expire after `ttl` seconds of inactivity.
|
| 111 |
+
Opening a LanceDB table is cheap (no vectors loaded into RAM), so
|
| 112 |
+
this is primarily about limiting open file-descriptor churn.
|
| 113 |
+
"""
|
| 114 |
+
def __init__(self, maxsize: int = 5, ttl: int = 600):
|
| 115 |
+
self._cache: OrderedDict = OrderedDict()
|
| 116 |
+
self._maxsize = maxsize
|
| 117 |
+
self._ttl = ttl
|
| 118 |
+
self._lock = Lock()
|
| 119 |
+
|
| 120 |
+
def get(self, key: str):
|
| 121 |
+
with self._lock:
|
| 122 |
+
entry = self._cache.get(key)
|
| 123 |
+
if entry is None:
|
| 124 |
+
return None
|
| 125 |
+
ts, tbl = entry
|
| 126 |
+
if time.monotonic() - ts > self._ttl:
|
| 127 |
+
del self._cache[key]
|
| 128 |
+
return None
|
| 129 |
+
self._cache.move_to_end(key)
|
| 130 |
+
self._cache[key] = (time.monotonic(), tbl)
|
| 131 |
+
return tbl
|
| 132 |
+
|
| 133 |
+
def set(self, key: str, tbl) -> None:
|
| 134 |
+
with self._lock:
|
| 135 |
+
if key in self._cache:
|
| 136 |
+
self._cache.move_to_end(key)
|
| 137 |
+
self._cache[key] = (time.monotonic(), tbl)
|
| 138 |
+
while len(self._cache) > self._maxsize:
|
| 139 |
+
self._cache.popitem(last=False)
|
| 140 |
+
|
| 141 |
+
def evict(self, key: str) -> None:
|
| 142 |
+
with self._lock:
|
| 143 |
+
self._cache.pop(key, None)
|
| 144 |
+
|
| 145 |
+
def keys(self):
|
| 146 |
+
with self._lock:
|
| 147 |
+
now = time.monotonic()
|
| 148 |
+
return [k for k, (ts, _) in self._cache.items()
|
| 149 |
+
if now - ts <= self._ttl]
|
| 150 |
+
|
| 151 |
+
_table_cache = _LRUTableCache(maxsize=LRU_MAXSIZE, ttl=LRU_TTL)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# βββββββββββββββββββββββββββ Global state βββββββββββββββββββββββββββββββββββββ
|
| 155 |
+
models: dict = {}
|
| 156 |
+
_executor = ThreadPoolExecutor(max_workers=2)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# βββββββββββββββββββββββββββ Lifespan βββββββββββββββββββββββββββββββββββββββββ
|
| 160 |
@asynccontextmanager
|
| 161 |
async def lifespan(app: FastAPI):
|
| 162 |
+
print("Loading jina-embeddings-v2-base-code (ONNX fp16)β¦")
|
| 163 |
+
# backend="onnx" tells sentence-transformers to use ONNX Runtime instead
|
| 164 |
+
# of PyTorch for the forward pass. file_name points to the pre-built
|
| 165 |
+
# fp16 ONNX graph that ships with the model on HuggingFace Hub.
|
| 166 |
+
# This completely bypasses the custom trust_remote_code PyTorch modeling
|
| 167 |
+
# code β no compat patches needed, no PyTorch GPU/RAM usage for inference.
|
| 168 |
model = SentenceTransformer(
|
| 169 |
+
"jinaai/jina-embeddings-v2-base-code",
|
| 170 |
+
backend="onnx",
|
| 171 |
+
model_kwargs={
|
| 172 |
+
"file_name": "onnx/model_fp16.onnx",
|
| 173 |
+
"provider": "CPUExecutionProvider",
|
| 174 |
+
"provider_options": [{
|
| 175 |
+
"intra_op_num_threads": int(os.getenv("OMP_NUM_THREADS", "2")),
|
| 176 |
+
}],
|
| 177 |
+
},
|
| 178 |
+
trust_remote_code=True,
|
| 179 |
)
|
| 180 |
model.max_seq_length = 8192
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
models["model"] = model
|
| 182 |
+
print(f"Model ready [backend={model.backend}]")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
yield
|
| 184 |
models.clear()
|
| 185 |
|
| 186 |
|
| 187 |
+
# βββββββββββββββββββββββββββ App βββββββββββοΏ½οΏ½οΏ½ββββββββββββββββββββββββββββββββββ
|
|
|
|
|
|
|
| 188 |
app = FastAPI(
|
| 189 |
title="Code Search API",
|
| 190 |
+
description="Semantic code search β jina-embeddings-v2-base-code ONNX fp16 + LanceDB",
|
| 191 |
+
version="3.0.0",
|
|
|
|
|
|
|
|
|
|
| 192 |
lifespan=lifespan,
|
| 193 |
)
|
|
|
|
| 194 |
app.add_middleware(
|
| 195 |
CORSMiddleware,
|
| 196 |
+
allow_origins=["*"], allow_methods=["*"], allow_headers=["*"],
|
|
|
|
|
|
|
| 197 |
)
|
| 198 |
|
| 199 |
|
| 200 |
+
# βββββββββββββββββββββββββββ Encoding βββββββββββββββββββββββββββββββββββββββββ
|
| 201 |
+
def _encode_sync(texts: list[str]) -> np.ndarray:
|
|
|
|
|
|
|
| 202 |
"""
|
| 203 |
+
Synchronous encode via ONNX Runtime.
|
| 204 |
+
Processes ENCODE_BATCH_SIZE texts at a time; GC between batches.
|
| 205 |
+
Returns float32 array of shape (len(texts), DIM).
|
| 206 |
+
Note: no torch.no_grad() needed β ONNX Runtime has no autograd.
|
|
|
|
| 207 |
"""
|
| 208 |
+
parts = []
|
|
|
|
| 209 |
for i in range(0, len(texts), ENCODE_BATCH_SIZE):
|
| 210 |
batch = texts[i : i + ENCODE_BATCH_SIZE]
|
| 211 |
+
embs = models["model"].encode(
|
| 212 |
+
batch,
|
| 213 |
+
show_progress_bar=False,
|
| 214 |
+
convert_to_numpy=True,
|
| 215 |
+
normalize_embeddings=False,
|
| 216 |
+
)
|
| 217 |
+
parts.append(np.asarray(embs, dtype=np.float32))
|
|
|
|
|
|
|
| 218 |
gc.collect()
|
| 219 |
+
return np.vstack(parts)
|
|
|
|
|
|
|
| 220 |
|
| 221 |
|
| 222 |
+
async def _encode_async(texts: list[str]) -> np.ndarray:
|
|
|
|
| 223 |
loop = asyncio.get_event_loop()
|
| 224 |
+
return await loop.run_in_executor(_executor, _encode_sync, texts)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def _normalize(embs: np.ndarray) -> np.ndarray:
|
| 228 |
+
norms = np.linalg.norm(embs, axis=1, keepdims=True)
|
| 229 |
+
return embs / np.maximum(norms, 1e-9)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
# βββββββββββββββββββββββββββ LanceDB helpers ββββββββββββββββββββββββββββββββββ
|
| 233 |
+
def _db() -> lancedb.DBConnection:
|
| 234 |
+
return lancedb.connect(str(STORE_DIR))
|
| 235 |
+
|
| 236 |
|
| 237 |
+
def _table_exists(doc_id: str) -> bool:
|
| 238 |
+
return doc_id in _db().table_names()
|
| 239 |
|
| 240 |
+
|
| 241 |
+
def _open_table(doc_id: str):
|
| 242 |
+
"""Return table handle from LRU cache or open from disk."""
|
| 243 |
+
tbl = _table_cache.get(doc_id)
|
| 244 |
+
if tbl is None:
|
| 245 |
+
tbl = _db().open_table(doc_id)
|
| 246 |
+
_table_cache.set(doc_id, tbl)
|
| 247 |
+
return tbl
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
async def _build_table_streaming(doc_id: str, chunks: list[str]) -> None:
|
| 251 |
"""
|
| 252 |
+
Streaming index build β the heart of the memory optimisation.
|
| 253 |
+
|
| 254 |
+
Instead of: chunk_all β encode_all β build_index (full array in RAM)
|
| 255 |
+
We do: for each micro-batch β encode β write to LanceDB β free
|
| 256 |
+
|
| 257 |
+
Peak RAM = one micro-batch of embeddings (8 Γ 768 Γ 4 bytes β 24 KB).
|
| 258 |
+
LanceDB stores vectors as a memory-mapped Lance file on disk; only
|
| 259 |
+
the pages touched during a query are paged into RAM at search time.
|
| 260 |
"""
|
| 261 |
+
db = _db()
|
| 262 |
+
# Drop stale table if it exists
|
| 263 |
+
if doc_id in db.table_names():
|
| 264 |
+
db.drop_table(doc_id)
|
| 265 |
+
_table_cache.evict(doc_id)
|
| 266 |
+
|
| 267 |
+
tbl = None
|
| 268 |
+
for i in range(0, len(chunks), ENCODE_BATCH_SIZE):
|
| 269 |
+
batch = chunks[i : i + ENCODE_BATCH_SIZE]
|
| 270 |
+
embs = await _encode_async(batch)
|
| 271 |
+
embs = _normalize(embs)
|
| 272 |
+
|
| 273 |
+
records = [
|
| 274 |
+
{
|
| 275 |
+
"chunk_id": i + j,
|
| 276 |
+
"text": text,
|
| 277 |
+
"vector": vec.tolist(),
|
| 278 |
+
}
|
| 279 |
+
for j, (text, vec) in enumerate(zip(batch, embs))
|
| 280 |
+
]
|
| 281 |
+
|
| 282 |
+
if tbl is None:
|
| 283 |
+
tbl = db.create_table(doc_id, data=records,
|
| 284 |
+
schema=_SCHEMA, mode="overwrite")
|
| 285 |
+
else:
|
| 286 |
+
tbl.add(records)
|
| 287 |
|
| 288 |
+
del embs, records
|
| 289 |
+
gc.collect()
|
| 290 |
|
| 291 |
+
# Create ANN vector index for tables large enough to benefit
|
| 292 |
+
if tbl is not None and len(chunks) >= 256:
|
| 293 |
+
try:
|
| 294 |
+
tbl.create_index(
|
| 295 |
+
metric="dot", # vectors are pre-normalised
|
| 296 |
+
vector_column_name="vector",
|
| 297 |
+
num_partitions=max(1, min(256, len(chunks) // 40)),
|
| 298 |
+
num_sub_vectors=96,
|
| 299 |
+
)
|
| 300 |
+
except Exception as e:
|
| 301 |
+
print(f"Warning: ANN index creation skipped for '{doc_id}': {e}")
|
|
|
|
|
|
|
|
|
|
| 302 |
|
| 303 |
+
if tbl is not None:
|
| 304 |
+
_table_cache.set(doc_id, tbl)
|
| 305 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
|
| 307 |
+
def _search_table(doc_id: str, query: str, top_k: int) -> list[dict]:
|
| 308 |
+
"""
|
| 309 |
+
On-demand search. Opens the table handle (from LRU cache or disk),
|
| 310 |
+
runs a vector search, returns top_k results. Only the pages of the
|
| 311 |
+
Lance file containing the nearest vectors are paged into RAM.
|
| 312 |
+
"""
|
| 313 |
+
q_emb = _encode_sync([query])
|
| 314 |
+
q_emb = _normalize(q_emb)[0]
|
| 315 |
+
|
| 316 |
+
tbl = _open_table(doc_id)
|
| 317 |
+
results = (
|
| 318 |
+
tbl.search(q_emb.tolist(), vector_column_name="vector")
|
| 319 |
+
.metric("dot")
|
| 320 |
+
.limit(top_k)
|
| 321 |
+
.to_list()
|
| 322 |
+
)
|
| 323 |
|
| 324 |
+
return [
|
| 325 |
+
{
|
| 326 |
+
"rank": i + 1,
|
| 327 |
+
"score": round(float(r.get("_distance", r.get("score", 0.0))), 4),
|
| 328 |
+
"text": r["text"],
|
| 329 |
+
}
|
| 330 |
+
for i, r in enumerate(results)
|
| 331 |
+
]
|
| 332 |
|
| 333 |
|
| 334 |
+
# βββββββββββββββββββββββββββ Chunking βββββββββββββββββββββββββββββββββββββββββ
|
| 335 |
def detect_language(filename: str) -> str:
|
| 336 |
+
return LANGUAGE_MAP.get(os.path.splitext(filename)[-1].lower(), "text")
|
|
|
|
| 337 |
|
| 338 |
|
| 339 |
def chunk_text(text: str, chunk_size: int = 3, overlap: int = 1) -> list[str]:
|
|
|
|
| 340 |
sentences = re.split(r'(?<=[.!?])\s+', text.strip())
|
| 341 |
sentences = [s.strip() for s in sentences if s.strip()]
|
| 342 |
chunks, i = [], 0
|
|
|
|
| 347 |
|
| 348 |
|
| 349 |
def chunk_fallback(source: str, max_lines: int = 40, overlap: int = 5) -> list[str]:
|
|
|
|
| 350 |
lines = source.splitlines()
|
| 351 |
chunks = []
|
| 352 |
i = 0
|
|
|
|
| 357 |
|
| 358 |
|
| 359 |
def chunk_python(source: str, filepath: str = "") -> list[str]:
|
|
|
|
| 360 |
try:
|
| 361 |
tree = ast.parse(source)
|
| 362 |
lines = source.splitlines()
|
| 363 |
chunks = []
|
| 364 |
for node in ast.walk(tree):
|
| 365 |
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
|
| 366 |
+
snippet = "\n".join(lines[node.lineno - 1 : node.end_lineno])
|
|
|
|
|
|
|
| 367 |
prefix = f"# {filepath}\n" if filepath else ""
|
| 368 |
chunks.append(f"{prefix}{snippet}")
|
| 369 |
return chunks if chunks else chunk_fallback(source)
|
|
|
|
| 372 |
|
| 373 |
|
| 374 |
def chunk_generic(source: str, filepath: str = "") -> list[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
pattern = re.compile(
|
| 376 |
r'(?:^|\n)(?='
|
| 377 |
+
r'(?:export\s+)?(?:async\s+)?'
|
|
|
|
| 378 |
r'(?:function|class|const\s+\w+\s*=\s*(?:async\s+)?(?:\(|function)|'
|
| 379 |
r'(?:public|private|protected|static|\s)*(?:fn|func|def)\s+\w+)'
|
| 380 |
r')',
|
|
|
|
| 387 |
|
| 388 |
|
| 389 |
def chunk_code(source: str, filename: str = "") -> list[str]:
|
|
|
|
| 390 |
lang = detect_language(filename)
|
| 391 |
if lang == "python":
|
| 392 |
return chunk_python(source, filepath=filename)
|
|
|
|
| 396 |
return chunk_generic(source, filepath=filename)
|
| 397 |
|
| 398 |
|
| 399 |
+
# βββββββββββββββββββββββββββ Schemas ββββββββββββββββββββββββββββββββββββββββββ
|
| 400 |
class IndexResponse(BaseModel):
|
| 401 |
+
doc_id: str; chunks_indexed: int; message: str
|
|
|
|
|
|
|
|
|
|
| 402 |
|
| 403 |
class SearchRequest(BaseModel):
|
| 404 |
+
doc_id: str = Field(...); query: str = Field(...); top_k: int = Field(5, ge=1, le=20)
|
|
|
|
|
|
|
|
|
|
| 405 |
|
| 406 |
class SearchResult(BaseModel):
|
| 407 |
+
rank: int; score: float; text: str
|
|
|
|
|
|
|
|
|
|
| 408 |
|
| 409 |
class SearchResponse(BaseModel):
|
| 410 |
+
doc_id: str; query: str; results: list[SearchResult]
|
|
|
|
|
|
|
|
|
|
| 411 |
|
| 412 |
class EmbedRequest(BaseModel):
|
| 413 |
+
texts: list[str] = Field(...)
|
|
|
|
| 414 |
|
| 415 |
class EmbedResponse(BaseModel):
|
| 416 |
+
embeddings: list[list[float]]; dimensions: int
|
|
|
|
|
|
|
| 417 |
|
| 418 |
class FileEntry(BaseModel):
|
| 419 |
+
filename: str; content: str
|
|
|
|
|
|
|
| 420 |
|
| 421 |
class BatchIndexRequest(BaseModel):
|
| 422 |
+
doc_id: str; files: list[FileEntry]; replace: bool = True
|
|
|
|
|
|
|
|
|
|
| 423 |
|
| 424 |
class BatchIndexResponse(BaseModel):
|
| 425 |
+
doc_id: str; files_indexed: int; chunks_indexed: int
|
|
|
|
|
|
|
| 426 |
|
| 427 |
|
| 428 |
+
# βββββββββββββββββββββββββββ Routes βββββββββββββββββββββββββββββββββββββββββββ
|
| 429 |
@app.get("/", tags=["health"])
|
| 430 |
def root():
|
| 431 |
return {"status": "ok", "docs": "/docs"}
|
|
|
|
| 433 |
|
| 434 |
@app.get("/health", tags=["health"])
|
| 435 |
def health():
|
| 436 |
+
return {"status": "ok", "models_loaded": bool(models),
|
| 437 |
+
"backend": models["model"].backend if models else None}
|
| 438 |
|
| 439 |
|
| 440 |
@app.post("/index", response_model=IndexResponse, tags=["search"])
|
| 441 |
async def index_document(
|
| 442 |
+
file: Annotated[UploadFile, File(description="Source file to index")],
|
| 443 |
+
doc_id: Annotated[str, Form(description="Unique ID (defaults to filename)")] = "",
|
| 444 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
if not models:
|
| 446 |
+
raise HTTPException(503, "Model not loaded yet.")
|
| 447 |
|
| 448 |
+
content = await file.read()
|
| 449 |
if len(content) > MAX_UPLOAD_BYTES:
|
| 450 |
+
raise HTTPException(413,
|
| 451 |
+
f"File too large ({len(content)/1024/1024:.1f} MB). "
|
| 452 |
+
f"Max: {MAX_UPLOAD_BYTES//1024//1024} MB.")
|
| 453 |
+
|
| 454 |
+
source = content.decode("utf-8", errors="replace")
|
| 455 |
+
filename = file.filename or "unknown"
|
|
|
|
|
|
|
| 456 |
resolved_id = doc_id.strip() or os.path.splitext(filename)[0]
|
| 457 |
|
| 458 |
chunks = chunk_code(source, filename=filename)
|
| 459 |
if not chunks:
|
| 460 |
+
raise HTTPException(400, "Document produced no chunks.")
|
| 461 |
|
| 462 |
+
await _build_table_streaming(resolved_id, chunks)
|
| 463 |
+
gc.collect()
|
|
|
|
|
|
|
|
|
|
| 464 |
|
| 465 |
return IndexResponse(
|
| 466 |
doc_id=resolved_id,
|
|
|
|
| 471 |
|
| 472 |
@app.post("/index/batch", response_model=BatchIndexResponse, tags=["search"])
|
| 473 |
async def index_batch(req: BatchIndexRequest):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
if not models:
|
| 475 |
raise HTTPException(503, "Model not loaded yet.")
|
| 476 |
|
| 477 |
+
# Collect all chunks first (just strings β negligible RAM)
|
|
|
|
|
|
|
|
|
|
| 478 |
all_chunks: list[str] = []
|
| 479 |
for entry in req.files:
|
| 480 |
all_chunks.extend(chunk_code(entry.content, filename=entry.filename))
|
| 481 |
|
| 482 |
if not all_chunks:
|
| 483 |
raise HTTPException(400, "No chunks produced from provided files.")
|
|
|
|
|
|
|
| 484 |
if len(all_chunks) > MAX_CHUNKS:
|
| 485 |
+
raise HTTPException(413,
|
| 486 |
+
f"Too many chunks ({len(all_chunks):,}). Max: {MAX_CHUNKS:,}.")
|
|
|
|
|
|
|
|
|
|
| 487 |
|
| 488 |
+
# Streaming build β never holds full embeddings array
|
| 489 |
+
await _build_table_streaming(req.doc_id, all_chunks)
|
| 490 |
+
gc.collect()
|
|
|
|
| 491 |
|
| 492 |
return BatchIndexResponse(
|
| 493 |
doc_id=req.doc_id,
|
|
|
|
| 498 |
|
| 499 |
@app.post("/search", response_model=SearchResponse, tags=["search"])
|
| 500 |
async def search_document(req: SearchRequest):
|
| 501 |
+
if not _table_exists(req.doc_id):
|
|
|
|
| 502 |
raise HTTPException(404, f"doc_id '{req.doc_id}' not found. Call /index first.")
|
| 503 |
|
| 504 |
+
loop = asyncio.get_event_loop()
|
| 505 |
+
results = await loop.run_in_executor(
|
| 506 |
+
_executor, _search_table, req.doc_id, req.query, req.top_k
|
| 507 |
+
)
|
| 508 |
return SearchResponse(
|
| 509 |
doc_id=req.doc_id,
|
| 510 |
query=req.query,
|
|
|
|
| 514 |
|
| 515 |
@app.post("/embed", response_model=EmbedResponse, tags=["embeddings"])
|
| 516 |
async def embed_texts(req: EmbedRequest):
|
|
|
|
| 517 |
if not models:
|
| 518 |
raise HTTPException(503, "Model not loaded yet.")
|
| 519 |
if len(req.texts) > 64:
|
| 520 |
raise HTTPException(400, "Maximum 64 texts per request.")
|
| 521 |
|
| 522 |
+
embs = await _encode_async(req.texts)
|
| 523 |
+
return EmbedResponse(embeddings=embs.tolist(), dimensions=embs.shape[1])
|
|
|
|
|
|
|
|
|
|
| 524 |
|
| 525 |
|
| 526 |
@app.get("/documents", tags=["search"])
|
| 527 |
def list_documents():
|
| 528 |
+
db = _db()
|
| 529 |
+
docs = []
|
| 530 |
+
for name in db.table_names():
|
| 531 |
+
try:
|
| 532 |
+
tbl = db.open_table(name)
|
| 533 |
+
count = tbl.count_rows()
|
| 534 |
+
docs.append({"doc_id": name, "chunks": count})
|
| 535 |
+
except Exception:
|
| 536 |
+
docs.append({"doc_id": name, "chunks": -1})
|
| 537 |
+
return {"documents": docs}
|
| 538 |
|
| 539 |
|
| 540 |
@app.delete("/documents/{doc_id}", tags=["search"])
|
| 541 |
def delete_document(doc_id: str):
|
| 542 |
+
if not _table_exists(doc_id):
|
|
|
|
| 543 |
raise HTTPException(404, f"doc_id '{doc_id}' not found.")
|
| 544 |
+
_db().drop_table(doc_id)
|
| 545 |
+
_table_cache.evict(doc_id)
|
| 546 |
return {"deleted": doc_id}
|