kamp0010 commited on
Commit
42abbab
Β·
verified Β·
1 Parent(s): a872f7a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +319 -332
main.py CHANGED
@@ -1,302 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import ast
3
  import re
4
- import pickle
 
5
  import pathlib
6
  import asyncio
 
7
  from concurrent.futures import ThreadPoolExecutor
8
  from contextlib import asynccontextmanager
 
9
  from typing import Annotated
10
 
11
  os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
12
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
13
- os.environ["HF_HUB_VERBOSITY"] = "error"
 
 
14
 
15
- import torch
16
  import numpy as np
17
- import faiss
18
-
19
- # ── Compatibility patches ──────────────────────────────────────────────────────
20
- # jina-bert-v2 (trust_remote_code) was written against transformers 4.x.
21
- # Transformers 5.x removed / broke three things the model relies on.
22
- # All patches are no-ops when the symbol already exists.
23
- #
24
- # 1. find_pruneable_heads_and_indices β€” removed from pytorch_utils
25
- # 2. PretrainedConfig.is_decoder etc β€” no longer set as instance defaults
26
- # 3. PreTrainedModel.get_head_mask β€” removed from modeling_utils in T5
27
-
28
- # ── patch 1: pytorch_utils ────────────────────────────────────────────────────
29
- import transformers.pytorch_utils as _pt_utils
30
- if not hasattr(_pt_utils, "find_pruneable_heads_and_indices"):
31
- def _find_pruneable_heads_and_indices(
32
- heads, n_heads: int, head_size: int, already_pruned_heads
33
- ):
34
- mask = torch.ones(n_heads, head_size)
35
- heads = set(heads) - already_pruned_heads
36
- for head in heads:
37
- head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
38
- mask[head] = 0
39
- mask = mask.view(-1).contiguous().eq(1)
40
- index = torch.arange(len(mask))[mask].long()
41
- return heads, index
42
- _pt_utils.find_pruneable_heads_and_indices = _find_pruneable_heads_and_indices
43
-
44
- # ── patch 2: PretrainedConfig legacy defaults ─────────────────────────────────
45
- import transformers.configuration_utils as _cfg_utils
46
- _PC = _cfg_utils.PretrainedConfig
47
- if not hasattr(_PC, "_jina_compat_patched"):
48
- _LEGACY_CFG_DEFAULTS = {
49
- "is_decoder": False,
50
- "add_cross_attention": False,
51
- "cross_attention_hidden_size": None,
52
- "use_cache": True,
53
- }
54
- def _pc_getattr(self, key: str):
55
- if key in _LEGACY_CFG_DEFAULTS:
56
- return _LEGACY_CFG_DEFAULTS[key]
57
- raise AttributeError(
58
- f"'{type(self).__name__}' object has no attribute '{key}'"
59
- )
60
- _PC.__getattr__ = _pc_getattr
61
- _PC._jina_compat_patched = True
62
-
63
- # ── patch 3: PreTrainedModel.get_head_mask ────────────────────────────────────
64
- import transformers.modeling_utils as _mod_utils
65
- _PTM = _mod_utils.PreTrainedModel
66
- if not hasattr(_PTM, "get_head_mask"):
67
- def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
68
- if head_mask.dim() == 1:
69
- head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
70
- head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
71
- elif head_mask.dim() == 2:
72
- head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
73
- assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
74
- head_mask = head_mask.to(dtype=self.dtype)
75
- return head_mask
76
-
77
- def _get_head_mask(self, head_mask, num_hidden_layers, is_attention_chunked=False):
78
- if head_mask is not None:
79
- head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
80
- if is_attention_chunked:
81
- head_mask = head_mask.unsqueeze(-1)
82
- else:
83
- head_mask = [None] * num_hidden_layers
84
- return head_mask
85
-
86
- if not hasattr(_PTM, "_convert_head_mask_to_5d"):
87
- _PTM._convert_head_mask_to_5d = _convert_head_mask_to_5d
88
- _PTM.get_head_mask = _get_head_mask
89
- # ──────────────────────────────────────────────────────────────────────────────
90
-
91
  from fastapi import FastAPI, HTTPException, UploadFile, File, Form
92
  from fastapi.middleware.cors import CORSMiddleware
93
  from pydantic import BaseModel, Field
94
  from sentence_transformers import SentenceTransformer
95
 
96
 
97
- # ─────────────────────────── Constants ───────────────────────────────────────
98
- DIM = 768 # jina-embeddings-v2-base-code output dimension
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
 
100
  def _resolve_store_dir() -> pathlib.Path:
101
- """
102
- Try /data/indexes (HF Spaces persistent volume).
103
- Fall back to ~/.cache/code-search/indexes if /data is not writable
104
- (local dev, or volume not yet mounted with correct permissions).
105
- """
106
- primary = pathlib.Path("/data/indexes")
107
  try:
108
  primary.mkdir(parents=True, exist_ok=True)
109
  probe = primary / ".write_probe"
110
- probe.touch()
111
- probe.unlink()
112
  return primary
113
  except OSError:
114
- fallback = pathlib.Path.home() / ".cache" / "code-search" / "indexes"
115
  fallback.mkdir(parents=True, exist_ok=True)
116
- print(f"Warning: /data/indexes not writable β€” using fallback: {fallback}")
117
  return fallback
118
 
119
  STORE_DIR = _resolve_store_dir()
120
 
121
- LANGUAGE_MAP = {
122
- ".py": "python",
123
- ".js": "javascript",
124
- ".ts": "typescript",
125
- ".tsx": "typescript",
126
- ".jsx": "javascript",
127
- ".go": "go",
128
- ".rs": "rust",
129
- ".java": "java",
130
- ".cpp": "cpp",
131
- ".c": "c",
132
- ".cs": "csharp",
133
- ".rb": "ruby",
134
- ".php": "php",
135
- ".md": "markdown",
136
- ".txt": "text",
137
- }
138
-
139
 
140
- # ─────────────────────────── Global state ────────────────────────────────────
141
- models: dict = {}
142
- # store[doc_id] = {"chunks": list[str], "index": faiss.Index}
143
- store: dict[str, dict] = {}
144
- _executor = ThreadPoolExecutor(max_workers=2)
145
-
146
-
147
- # ─────────────────────────── Lifespan ────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  @asynccontextmanager
149
  async def lifespan(app: FastAPI):
150
- print("Loading jina-embeddings-v2-base-code…")
 
 
 
 
 
151
  model = SentenceTransformer(
152
- "jinaai/jina-embeddings-v2-base-code", trust_remote_code=True
 
 
 
 
 
 
 
 
 
153
  )
154
  model.max_seq_length = 8192
155
- # Cast to float16 β€” cuts model VRAM/RAM from ~550 MB to ~275 MB.
156
- # SentenceTransformer wraps a nn.Module; half() applies recursively.
157
- model.half()
158
- model.eval()
159
  models["model"] = model
160
- print("Model ready.")
161
-
162
- # Restore persisted indexes from /data
163
- if STORE_DIR.exists():
164
- for faiss_path in STORE_DIR.glob("*.faiss"):
165
- doc_id = faiss_path.stem
166
- meta_path = STORE_DIR / f"{doc_id}.meta.pkl"
167
- if not meta_path.exists():
168
- continue
169
- try:
170
- index = faiss.read_index(str(faiss_path))
171
- with open(meta_path, "rb") as f:
172
- meta = pickle.load(f)
173
- store[doc_id] = {"chunks": meta["chunks"], "index": index}
174
- print(f"Restored index: {doc_id} ({index.ntotal} vectors)")
175
- except Exception as e:
176
- print(f"Warning: could not restore {doc_id}: {e}")
177
-
178
  yield
179
  models.clear()
180
 
181
 
182
- # ─────────────────────────── App ─────────────────────────────────────────────
183
- MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_MB", "50")) * 1024 * 1024 # default 50 MB
184
-
185
  app = FastAPI(
186
  title="Code Search API",
187
- description=(
188
- "Upload source files and search them semantically using "
189
- "jinaai/jina-embeddings-v2-base-code + FAISS ANN search."
190
- ),
191
- version="2.0.0",
192
  lifespan=lifespan,
193
  )
194
-
195
  app.add_middleware(
196
  CORSMiddleware,
197
- allow_origins=["*"],
198
- allow_methods=["*"],
199
- allow_headers=["*"],
200
  )
201
 
202
 
203
- # ─────────────────────────── Embedding helpers ────────────────────────────────
204
- ENCODE_BATCH_SIZE = int(os.getenv("ENCODE_BATCH_SIZE", "8")) # lower = less RAM peak
205
-
206
- def encode(texts: list[str]) -> np.ndarray:
207
  """
208
- Synchronous encode with micro-batching + explicit GC.
209
- - float16 model weights halve static RAM.
210
- - Small batch size (8) keeps activation RAM low during forward pass.
211
- - gc.collect() + torch.cuda.empty_cache() after each batch release
212
- intermediate tensors immediately instead of waiting for GC.
213
  """
214
- import gc
215
- all_embeddings = []
216
  for i in range(0, len(texts), ENCODE_BATCH_SIZE):
217
  batch = texts[i : i + ENCODE_BATCH_SIZE]
218
- with torch.no_grad():
219
- embs = models["model"].encode(
220
- batch,
221
- show_progress_bar=False,
222
- convert_to_numpy=True,
223
- normalize_embeddings=False, # we normalise in FAISS
224
- )
225
- all_embeddings.append(np.array(embs, dtype=np.float32))
226
- # Free activations between batches
227
  gc.collect()
228
- if torch.cuda.is_available():
229
- torch.cuda.empty_cache()
230
- return np.vstack(all_embeddings)
231
 
232
 
233
- async def encode_async(texts: list[str]) -> np.ndarray:
234
- """Non-blocking wrapper β€” frees the event loop during model inference."""
235
  loop = asyncio.get_event_loop()
236
- return await loop.run_in_executor(_executor, encode, texts)
 
 
 
 
 
 
 
 
 
 
 
237
 
 
 
238
 
239
- # ─────────────────────────── FAISS helpers ───────────────────────────────────
240
- def build_faiss_index(embeddings: np.ndarray) -> faiss.Index:
 
 
 
 
 
 
 
 
 
241
  """
242
- Use HNSW for datasets up to ~500k vectors:
243
- - ~2x less RAM than IndexFlatIP (stores graph links, not raw vectors twice)
244
- - O(log n) search vs O(n) flat scan
245
- - M=32 is a good balance of speed/recall; raise to 64 for higher recall
246
- Fall back to IndexFlatIP for tiny datasets where HNSW overhead isn't worth it.
 
 
 
247
  """
248
- import gc
249
- faiss.normalize_L2(embeddings) # in-place β€” no copy
250
- n = len(embeddings)
251
- if n >= 100:
252
- index = faiss.IndexHNSWFlat(DIM, 32, faiss.METRIC_INNER_PRODUCT)
253
- index.hnsw.efConstruction = 200
254
- index.hnsw.efSearch = 64
255
- else:
256
- index = faiss.IndexFlatIP(DIM)
257
- index.add(embeddings)
258
- del embeddings
259
- gc.collect()
260
- return index
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
 
 
262
 
263
- def search_index(query: str, doc_id: str, top_k: int) -> list[dict]:
264
- q = encode([query])
265
- faiss.normalize_L2(q)
266
- scores, indices = store[doc_id]["index"].search(q, top_k)
267
- chunks = store[doc_id]["chunks"]
268
- return [
269
- {
270
- "rank": i + 1,
271
- "score": round(float(scores[0][i]), 4),
272
- "text": chunks[indices[0][i]],
273
- }
274
- for i in range(len(indices[0]))
275
- if indices[0][i] >= 0 # FAISS returns -1 for empty slots
276
- ]
277
 
 
 
278
 
279
- # ─────────────────────────── Persistence helpers ─────────────────────────────
280
- def persist_index(doc_id: str, chunks: list[str], index: faiss.Index) -> None:
281
- STORE_DIR.mkdir(parents=True, exist_ok=True)
282
- faiss.write_index(index, str(STORE_DIR / f"{doc_id}.faiss"))
283
- with open(STORE_DIR / f"{doc_id}.meta.pkl", "wb") as f:
284
- pickle.dump({"chunks": chunks, "doc_id": doc_id}, f)
285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
- def delete_persisted(doc_id: str) -> None:
288
- (STORE_DIR / f"{doc_id}.faiss").unlink(missing_ok=True)
289
- (STORE_DIR / f"{doc_id}.meta.pkl").unlink(missing_ok=True)
 
 
 
 
 
290
 
291
 
292
- # ─────────────────────────── Chunking helpers ────────────────────────────────
293
  def detect_language(filename: str) -> str:
294
- ext = os.path.splitext(filename)[-1].lower()
295
- return LANGUAGE_MAP.get(ext, "text")
296
 
297
 
298
  def chunk_text(text: str, chunk_size: int = 3, overlap: int = 1) -> list[str]:
299
- """Sentence-window chunker for prose / markdown."""
300
  sentences = re.split(r'(?<=[.!?])\s+', text.strip())
301
  sentences = [s.strip() for s in sentences if s.strip()]
302
  chunks, i = [], 0
@@ -307,7 +347,6 @@ def chunk_text(text: str, chunk_size: int = 3, overlap: int = 1) -> list[str]:
307
 
308
 
309
  def chunk_fallback(source: str, max_lines: int = 40, overlap: int = 5) -> list[str]:
310
- """Fixed line-window chunking with overlap β€” last resort."""
311
  lines = source.splitlines()
312
  chunks = []
313
  i = 0
@@ -318,16 +357,13 @@ def chunk_fallback(source: str, max_lines: int = 40, overlap: int = 5) -> list[s
318
 
319
 
320
  def chunk_python(source: str, filepath: str = "") -> list[str]:
321
- """AST-based chunker β€” extracts functions and classes."""
322
  try:
323
  tree = ast.parse(source)
324
  lines = source.splitlines()
325
  chunks = []
326
  for node in ast.walk(tree):
327
  if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
328
- start = node.lineno - 1
329
- end = node.end_lineno
330
- snippet = "\n".join(lines[start:end])
331
  prefix = f"# {filepath}\n" if filepath else ""
332
  chunks.append(f"{prefix}{snippet}")
333
  return chunks if chunks else chunk_fallback(source)
@@ -336,14 +372,9 @@ def chunk_python(source: str, filepath: str = "") -> list[str]:
336
 
337
 
338
  def chunk_generic(source: str, filepath: str = "") -> list[str]:
339
- """
340
- Regex chunker for JS, TS, Go, Rust, Java, C++, etc.
341
- Splits on function / class declaration boundaries.
342
- """
343
  pattern = re.compile(
344
  r'(?:^|\n)(?='
345
- r'(?:export\s+)?'
346
- r'(?:async\s+)?'
347
  r'(?:function|class|const\s+\w+\s*=\s*(?:async\s+)?(?:\(|function)|'
348
  r'(?:public|private|protected|static|\s)*(?:fn|func|def)\s+\w+)'
349
  r')',
@@ -356,7 +387,6 @@ def chunk_generic(source: str, filepath: str = "") -> list[str]:
356
 
357
 
358
  def chunk_code(source: str, filename: str = "") -> list[str]:
359
- """Master dispatcher β€” routes to the best chunker for the file type."""
360
  lang = detect_language(filename)
361
  if lang == "python":
362
  return chunk_python(source, filepath=filename)
@@ -366,58 +396,36 @@ def chunk_code(source: str, filename: str = "") -> list[str]:
366
  return chunk_generic(source, filepath=filename)
367
 
368
 
369
- # ─────────────────────────── Schemas ─────────────────────────────────────────
370
  class IndexResponse(BaseModel):
371
- doc_id: str
372
- chunks_indexed: int
373
- message: str
374
-
375
 
376
  class SearchRequest(BaseModel):
377
- doc_id: str = Field(..., description="ID returned by /index")
378
- query: str = Field(..., description="Natural language or code query")
379
- top_k: int = Field(5, ge=1, le=20)
380
-
381
 
382
  class SearchResult(BaseModel):
383
- rank: int
384
- score: float
385
- text: str
386
-
387
 
388
  class SearchResponse(BaseModel):
389
- doc_id: str
390
- query: str
391
- results: list[SearchResult]
392
-
393
 
394
  class EmbedRequest(BaseModel):
395
- texts: list[str] = Field(..., description="List of strings to embed")
396
-
397
 
398
  class EmbedResponse(BaseModel):
399
- embeddings: list[list[float]]
400
- dimensions: int
401
-
402
 
403
  class FileEntry(BaseModel):
404
- filename: str
405
- content: str # raw file content as string
406
-
407
 
408
  class BatchIndexRequest(BaseModel):
409
- doc_id: str # one doc_id for the whole project / repo
410
- files: list[FileEntry]
411
- replace: bool = True # if True, replaces existing index for this doc_id
412
-
413
 
414
  class BatchIndexResponse(BaseModel):
415
- doc_id: str
416
- files_indexed: int
417
- chunks_indexed: int
418
 
419
 
420
- # ─────────────────────────── Routes ──────────────────────────────────────────
421
  @app.get("/", tags=["health"])
422
  def root():
423
  return {"status": "ok", "docs": "/docs"}
@@ -425,42 +433,34 @@ def root():
425
 
426
  @app.get("/health", tags=["health"])
427
  def health():
428
- return {"status": "ok", "models_loaded": bool(models)}
 
429
 
430
 
431
  @app.post("/index", response_model=IndexResponse, tags=["search"])
432
  async def index_document(
433
- file: Annotated[UploadFile, File(description="Source file to index")],
434
- doc_id: Annotated[str, Form(description="Unique ID (defaults to filename)")] = "",
435
  ):
436
- """
437
- Upload a source file and embed it with code-aware chunking.
438
- Returns the doc_id to use in /search.
439
- """
440
  if not models:
441
- raise HTTPException(503, "Model not loaded yet β€” please retry in a few seconds.")
442
 
443
- content = await file.read()
444
  if len(content) > MAX_UPLOAD_BYTES:
445
- raise HTTPException(
446
- 413,
447
- f"File too large ({len(content) / 1024 / 1024:.1f} MB). "
448
- f"Max allowed: {MAX_UPLOAD_BYTES // 1024 // 1024} MB. "
449
- "Use /index/batch for large codebases.",
450
- )
451
- source = content.decode("utf-8", errors="replace")
452
- filename = file.filename or "unknown"
453
  resolved_id = doc_id.strip() or os.path.splitext(filename)[0]
454
 
455
  chunks = chunk_code(source, filename=filename)
456
  if not chunks:
457
- raise HTTPException(400, "Document produced no chunks. Check the file contents.")
458
 
459
- embeddings = await encode_async(chunks)
460
- index = build_faiss_index(embeddings.astype("float32"))
461
- store[resolved_id] = {"chunks": chunks, "index": index}
462
- persist_index(resolved_id, chunks, index)
463
- import gc; gc.collect() # free encoding intermediates before responding
464
 
465
  return IndexResponse(
466
  doc_id=resolved_id,
@@ -471,36 +471,23 @@ async def index_document(
471
 
472
  @app.post("/index/batch", response_model=BatchIndexResponse, tags=["search"])
473
  async def index_batch(req: BatchIndexRequest):
474
- """
475
- Index an entire codebase in one HTTP call.
476
- Ideal for IDE integrations β€” send all files, get one searchable doc_id back.
477
- """
478
  if not models:
479
  raise HTTPException(503, "Model not loaded yet.")
480
 
481
- if req.replace and req.doc_id in store:
482
- del store[req.doc_id]
483
- delete_persisted(req.doc_id)
484
-
485
  all_chunks: list[str] = []
486
  for entry in req.files:
487
  all_chunks.extend(chunk_code(entry.content, filename=entry.filename))
488
 
489
  if not all_chunks:
490
  raise HTTPException(400, "No chunks produced from provided files.")
491
-
492
- MAX_CHUNKS = int(os.getenv("MAX_CHUNKS", "10000")) # ~3 GB RAM at 10k chunks; raise carefully
493
  if len(all_chunks) > MAX_CHUNKS:
494
- raise HTTPException(
495
- 413,
496
- f"Too many chunks ({len(all_chunks):,}). Max: {MAX_CHUNKS:,}. "
497
- "Split your project into smaller doc_id groups.",
498
- )
499
 
500
- embeddings = await encode_async(all_chunks)
501
- index = build_faiss_index(embeddings.astype("float32"))
502
- store[req.doc_id] = {"chunks": all_chunks, "index": index}
503
- persist_index(req.doc_id, all_chunks, index)
504
 
505
  return BatchIndexResponse(
506
  doc_id=req.doc_id,
@@ -511,11 +498,13 @@ async def index_batch(req: BatchIndexRequest):
511
 
512
  @app.post("/search", response_model=SearchResponse, tags=["search"])
513
  async def search_document(req: SearchRequest):
514
- """Search a previously indexed document or codebase by doc_id."""
515
- if req.doc_id not in store:
516
  raise HTTPException(404, f"doc_id '{req.doc_id}' not found. Call /index first.")
517
 
518
- results = search_index(req.query, req.doc_id, req.top_k)
 
 
 
519
  return SearchResponse(
520
  doc_id=req.doc_id,
521
  query=req.query,
@@ -525,35 +514,33 @@ async def search_document(req: SearchRequest):
525
 
526
  @app.post("/embed", response_model=EmbedResponse, tags=["embeddings"])
527
  async def embed_texts(req: EmbedRequest):
528
- """Embed arbitrary texts. Returns raw float embeddings."""
529
  if not models:
530
  raise HTTPException(503, "Model not loaded yet.")
531
  if len(req.texts) > 64:
532
  raise HTTPException(400, "Maximum 64 texts per request.")
533
 
534
- embs = await encode_async(req.texts)
535
- return EmbedResponse(
536
- embeddings=embs.tolist(),
537
- dimensions=embs.shape[1],
538
- )
539
 
540
 
541
  @app.get("/documents", tags=["search"])
542
  def list_documents():
543
- """List all currently indexed document IDs."""
544
- return {
545
- "documents": [
546
- {"doc_id": k, "chunks": len(v["chunks"])}
547
- for k, v in store.items()
548
- ]
549
- }
 
 
 
550
 
551
 
552
  @app.delete("/documents/{doc_id}", tags=["search"])
553
  def delete_document(doc_id: str):
554
- """Remove a document from the in-memory index and from disk."""
555
- if doc_id not in store:
556
  raise HTTPException(404, f"doc_id '{doc_id}' not found.")
557
- del store[doc_id]
558
- delete_persisted(doc_id)
559
  return {"deleted": doc_id}
 
1
+ """
2
+ Code Search API β€” v3.0
3
+ ────────────────────────────────────────────────────────────────────────────
4
+ Key architecture changes from v2:
5
+
6
+ β€’ Model : ONNX fp16 via sentence-transformers backend="onnx"
7
+ β†’ ONNX Runtime replaces PyTorch for every forward pass.
8
+ β†’ Pre-built onnx/model_fp16.onnx from the HF repo is used
9
+ directly β€” no export step, no trust_remote_code issues.
10
+ β†’ All three transformers-compatibility patches removed.
11
+
12
+ β€’ Storage : LanceDB (disk-backed, columnar, mmap)
13
+ β†’ Vectors live on disk, not in Python RAM.
14
+ β†’ Chunks stored alongside vectors in the same table β€”
15
+ no separate pickle files.
16
+ β†’ FAISS removed entirely.
17
+
18
+ β€’ Indexing: Streaming pipeline
19
+ β†’ Chunks are produced, encoded in micro-batches, and written
20
+ to LanceDB immediately. The full embeddings array is never
21
+ held in RAM.
22
+
23
+ β€’ Retrieval: On-demand table loading + LRU cache
24
+ β†’ Tables are opened from disk per request.
25
+ β†’ An LRU cache (default: 5 tables, TTL: 10 min) keeps
26
+ recently used handles warm without pinning everything.
27
+
28
+ β€’ RAM budget (approximate, CPU-only HF Space):
29
+ Model weights (fp16 ONNX) ~275 MB
30
+ Encoding peak (batch=8) ~100 MB transient
31
+ LanceDB per query ~10-50 MB transient
32
+ Python overhead ~150 MB
33
+ ─────────────────────────────────────
34
+ Total steady-state ~425 MB (vs ~16 GB before)
35
+ """
36
+
37
  import os
38
  import ast
39
  import re
40
+ import gc
41
+ import time
42
  import pathlib
43
  import asyncio
44
+ from collections import OrderedDict
45
  from concurrent.futures import ThreadPoolExecutor
46
  from contextlib import asynccontextmanager
47
+ from threading import Lock
48
  from typing import Annotated
49
 
50
  os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
51
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
52
+ os.environ["HF_HUB_VERBOSITY"] = "error"
53
+ # Tell ONNX Runtime to use a modest thread count so it doesn't spike RSS
54
+ os.environ.setdefault("OMP_NUM_THREADS", "2")
55
 
 
56
  import numpy as np
57
+ import lancedb
58
+ import pyarrow as pa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  from fastapi import FastAPI, HTTPException, UploadFile, File, Form
60
  from fastapi.middleware.cors import CORSMiddleware
61
  from pydantic import BaseModel, Field
62
  from sentence_transformers import SentenceTransformer
63
 
64
 
65
+ # ─────────────────────────── Constants ────────────────────────────────────────
66
+ DIM = 768
67
+ ENCODE_BATCH_SIZE = int(os.getenv("ENCODE_BATCH_SIZE", "8"))
68
+ MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_MB", "50")) * 1024 * 1024
69
+ MAX_CHUNKS = int(os.getenv("MAX_CHUNKS", "10000"))
70
+ LRU_MAXSIZE = int(os.getenv("LRU_TABLE_CACHE", "5"))
71
+ LRU_TTL = int(os.getenv("LRU_TTL_SECONDS", "600")) # 10 min
72
+
73
+ LANGUAGE_MAP = {
74
+ ".py": "python", ".js": "javascript", ".ts": "typescript",
75
+ ".tsx": "typescript", ".jsx": "javascript", ".go": "go",
76
+ ".rs": "rust", ".java": "java", ".cpp": "cpp",
77
+ ".c": "c", ".cs": "csharp", ".rb": "ruby",
78
+ ".php": "php", ".md": "markdown", ".txt": "text",
79
+ }
80
+
81
+ # LanceDB schema β€” one row per chunk
82
+ _SCHEMA = pa.schema([
83
+ pa.field("chunk_id", pa.int32()),
84
+ pa.field("text", pa.large_utf8()),
85
+ pa.field("vector", pa.list_(pa.float32(), DIM)),
86
+ ])
87
+
88
 
89
+ # ─────────────────────────── Storage directory ────────────────────────────────
90
  def _resolve_store_dir() -> pathlib.Path:
91
+ primary = pathlib.Path("/data/lancedb")
 
 
 
 
 
92
  try:
93
  primary.mkdir(parents=True, exist_ok=True)
94
  probe = primary / ".write_probe"
95
+ probe.touch(); probe.unlink()
 
96
  return primary
97
  except OSError:
98
+ fallback = pathlib.Path.home() / ".cache" / "code-search" / "lancedb"
99
  fallback.mkdir(parents=True, exist_ok=True)
100
+ print(f"Warning: /data/lancedb not writable β€” using fallback: {fallback}")
101
  return fallback
102
 
103
  STORE_DIR = _resolve_store_dir()
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+ # ─────────────────────────── LRU table-handle cache ───────────────────────────
107
+ class _LRUTableCache:
108
+ """
109
+ Keeps up to `maxsize` LanceDB table handles open in memory.
110
+ Entries expire after `ttl` seconds of inactivity.
111
+ Opening a LanceDB table is cheap (no vectors loaded into RAM), so
112
+ this is primarily about limiting open file-descriptor churn.
113
+ """
114
+ def __init__(self, maxsize: int = 5, ttl: int = 600):
115
+ self._cache: OrderedDict = OrderedDict()
116
+ self._maxsize = maxsize
117
+ self._ttl = ttl
118
+ self._lock = Lock()
119
+
120
+ def get(self, key: str):
121
+ with self._lock:
122
+ entry = self._cache.get(key)
123
+ if entry is None:
124
+ return None
125
+ ts, tbl = entry
126
+ if time.monotonic() - ts > self._ttl:
127
+ del self._cache[key]
128
+ return None
129
+ self._cache.move_to_end(key)
130
+ self._cache[key] = (time.monotonic(), tbl)
131
+ return tbl
132
+
133
+ def set(self, key: str, tbl) -> None:
134
+ with self._lock:
135
+ if key in self._cache:
136
+ self._cache.move_to_end(key)
137
+ self._cache[key] = (time.monotonic(), tbl)
138
+ while len(self._cache) > self._maxsize:
139
+ self._cache.popitem(last=False)
140
+
141
+ def evict(self, key: str) -> None:
142
+ with self._lock:
143
+ self._cache.pop(key, None)
144
+
145
+ def keys(self):
146
+ with self._lock:
147
+ now = time.monotonic()
148
+ return [k for k, (ts, _) in self._cache.items()
149
+ if now - ts <= self._ttl]
150
+
151
+ _table_cache = _LRUTableCache(maxsize=LRU_MAXSIZE, ttl=LRU_TTL)
152
+
153
+
154
+ # ─────────────────────────── Global state ─────────────────────────────────────
155
+ models: dict = {}
156
+ _executor = ThreadPoolExecutor(max_workers=2)
157
+
158
+
159
+ # ─────────────────────────── Lifespan ─────────────────────────────────────────
160
  @asynccontextmanager
161
  async def lifespan(app: FastAPI):
162
+ print("Loading jina-embeddings-v2-base-code (ONNX fp16)…")
163
+ # backend="onnx" tells sentence-transformers to use ONNX Runtime instead
164
+ # of PyTorch for the forward pass. file_name points to the pre-built
165
+ # fp16 ONNX graph that ships with the model on HuggingFace Hub.
166
+ # This completely bypasses the custom trust_remote_code PyTorch modeling
167
+ # code β€” no compat patches needed, no PyTorch GPU/RAM usage for inference.
168
  model = SentenceTransformer(
169
+ "jinaai/jina-embeddings-v2-base-code",
170
+ backend="onnx",
171
+ model_kwargs={
172
+ "file_name": "onnx/model_fp16.onnx",
173
+ "provider": "CPUExecutionProvider",
174
+ "provider_options": [{
175
+ "intra_op_num_threads": int(os.getenv("OMP_NUM_THREADS", "2")),
176
+ }],
177
+ },
178
+ trust_remote_code=True,
179
  )
180
  model.max_seq_length = 8192
 
 
 
 
181
  models["model"] = model
182
+ print(f"Model ready [backend={model.backend}]")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  yield
184
  models.clear()
185
 
186
 
187
+ # ─────────────────────────── App ───────────���──────────────────────────────────
 
 
188
  app = FastAPI(
189
  title="Code Search API",
190
+ description="Semantic code search β€” jina-embeddings-v2-base-code ONNX fp16 + LanceDB",
191
+ version="3.0.0",
 
 
 
192
  lifespan=lifespan,
193
  )
 
194
  app.add_middleware(
195
  CORSMiddleware,
196
+ allow_origins=["*"], allow_methods=["*"], allow_headers=["*"],
 
 
197
  )
198
 
199
 
200
+ # ─────────────────────────── Encoding ─────────────────────────────────────────
201
+ def _encode_sync(texts: list[str]) -> np.ndarray:
 
 
202
  """
203
+ Synchronous encode via ONNX Runtime.
204
+ Processes ENCODE_BATCH_SIZE texts at a time; GC between batches.
205
+ Returns float32 array of shape (len(texts), DIM).
206
+ Note: no torch.no_grad() needed β€” ONNX Runtime has no autograd.
 
207
  """
208
+ parts = []
 
209
  for i in range(0, len(texts), ENCODE_BATCH_SIZE):
210
  batch = texts[i : i + ENCODE_BATCH_SIZE]
211
+ embs = models["model"].encode(
212
+ batch,
213
+ show_progress_bar=False,
214
+ convert_to_numpy=True,
215
+ normalize_embeddings=False,
216
+ )
217
+ parts.append(np.asarray(embs, dtype=np.float32))
 
 
218
  gc.collect()
219
+ return np.vstack(parts)
 
 
220
 
221
 
222
+ async def _encode_async(texts: list[str]) -> np.ndarray:
 
223
  loop = asyncio.get_event_loop()
224
+ return await loop.run_in_executor(_executor, _encode_sync, texts)
225
+
226
+
227
+ def _normalize(embs: np.ndarray) -> np.ndarray:
228
+ norms = np.linalg.norm(embs, axis=1, keepdims=True)
229
+ return embs / np.maximum(norms, 1e-9)
230
+
231
+
232
+ # ─────────────────────────── LanceDB helpers ──────────────────────────────────
233
+ def _db() -> lancedb.DBConnection:
234
+ return lancedb.connect(str(STORE_DIR))
235
+
236
 
237
+ def _table_exists(doc_id: str) -> bool:
238
+ return doc_id in _db().table_names()
239
 
240
+
241
+ def _open_table(doc_id: str):
242
+ """Return table handle from LRU cache or open from disk."""
243
+ tbl = _table_cache.get(doc_id)
244
+ if tbl is None:
245
+ tbl = _db().open_table(doc_id)
246
+ _table_cache.set(doc_id, tbl)
247
+ return tbl
248
+
249
+
250
+ async def _build_table_streaming(doc_id: str, chunks: list[str]) -> None:
251
  """
252
+ Streaming index build β€” the heart of the memory optimisation.
253
+
254
+ Instead of: chunk_all β†’ encode_all β†’ build_index (full array in RAM)
255
+ We do: for each micro-batch β†’ encode β†’ write to LanceDB β†’ free
256
+
257
+ Peak RAM = one micro-batch of embeddings (8 Γ— 768 Γ— 4 bytes β‰ˆ 24 KB).
258
+ LanceDB stores vectors as a memory-mapped Lance file on disk; only
259
+ the pages touched during a query are paged into RAM at search time.
260
  """
261
+ db = _db()
262
+ # Drop stale table if it exists
263
+ if doc_id in db.table_names():
264
+ db.drop_table(doc_id)
265
+ _table_cache.evict(doc_id)
266
+
267
+ tbl = None
268
+ for i in range(0, len(chunks), ENCODE_BATCH_SIZE):
269
+ batch = chunks[i : i + ENCODE_BATCH_SIZE]
270
+ embs = await _encode_async(batch)
271
+ embs = _normalize(embs)
272
+
273
+ records = [
274
+ {
275
+ "chunk_id": i + j,
276
+ "text": text,
277
+ "vector": vec.tolist(),
278
+ }
279
+ for j, (text, vec) in enumerate(zip(batch, embs))
280
+ ]
281
+
282
+ if tbl is None:
283
+ tbl = db.create_table(doc_id, data=records,
284
+ schema=_SCHEMA, mode="overwrite")
285
+ else:
286
+ tbl.add(records)
287
 
288
+ del embs, records
289
+ gc.collect()
290
 
291
+ # Create ANN vector index for tables large enough to benefit
292
+ if tbl is not None and len(chunks) >= 256:
293
+ try:
294
+ tbl.create_index(
295
+ metric="dot", # vectors are pre-normalised
296
+ vector_column_name="vector",
297
+ num_partitions=max(1, min(256, len(chunks) // 40)),
298
+ num_sub_vectors=96,
299
+ )
300
+ except Exception as e:
301
+ print(f"Warning: ANN index creation skipped for '{doc_id}': {e}")
 
 
 
302
 
303
+ if tbl is not None:
304
+ _table_cache.set(doc_id, tbl)
305
 
 
 
 
 
 
 
306
 
307
+ def _search_table(doc_id: str, query: str, top_k: int) -> list[dict]:
308
+ """
309
+ On-demand search. Opens the table handle (from LRU cache or disk),
310
+ runs a vector search, returns top_k results. Only the pages of the
311
+ Lance file containing the nearest vectors are paged into RAM.
312
+ """
313
+ q_emb = _encode_sync([query])
314
+ q_emb = _normalize(q_emb)[0]
315
+
316
+ tbl = _open_table(doc_id)
317
+ results = (
318
+ tbl.search(q_emb.tolist(), vector_column_name="vector")
319
+ .metric("dot")
320
+ .limit(top_k)
321
+ .to_list()
322
+ )
323
 
324
+ return [
325
+ {
326
+ "rank": i + 1,
327
+ "score": round(float(r.get("_distance", r.get("score", 0.0))), 4),
328
+ "text": r["text"],
329
+ }
330
+ for i, r in enumerate(results)
331
+ ]
332
 
333
 
334
+ # ─────────────────────────── Chunking ─────────────────────────────────────────
335
  def detect_language(filename: str) -> str:
336
+ return LANGUAGE_MAP.get(os.path.splitext(filename)[-1].lower(), "text")
 
337
 
338
 
339
  def chunk_text(text: str, chunk_size: int = 3, overlap: int = 1) -> list[str]:
 
340
  sentences = re.split(r'(?<=[.!?])\s+', text.strip())
341
  sentences = [s.strip() for s in sentences if s.strip()]
342
  chunks, i = [], 0
 
347
 
348
 
349
  def chunk_fallback(source: str, max_lines: int = 40, overlap: int = 5) -> list[str]:
 
350
  lines = source.splitlines()
351
  chunks = []
352
  i = 0
 
357
 
358
 
359
  def chunk_python(source: str, filepath: str = "") -> list[str]:
 
360
  try:
361
  tree = ast.parse(source)
362
  lines = source.splitlines()
363
  chunks = []
364
  for node in ast.walk(tree):
365
  if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
366
+ snippet = "\n".join(lines[node.lineno - 1 : node.end_lineno])
 
 
367
  prefix = f"# {filepath}\n" if filepath else ""
368
  chunks.append(f"{prefix}{snippet}")
369
  return chunks if chunks else chunk_fallback(source)
 
372
 
373
 
374
  def chunk_generic(source: str, filepath: str = "") -> list[str]:
 
 
 
 
375
  pattern = re.compile(
376
  r'(?:^|\n)(?='
377
+ r'(?:export\s+)?(?:async\s+)?'
 
378
  r'(?:function|class|const\s+\w+\s*=\s*(?:async\s+)?(?:\(|function)|'
379
  r'(?:public|private|protected|static|\s)*(?:fn|func|def)\s+\w+)'
380
  r')',
 
387
 
388
 
389
  def chunk_code(source: str, filename: str = "") -> list[str]:
 
390
  lang = detect_language(filename)
391
  if lang == "python":
392
  return chunk_python(source, filepath=filename)
 
396
  return chunk_generic(source, filepath=filename)
397
 
398
 
399
+ # ─────────────────────────── Schemas ──────────────────────────────────────────
400
  class IndexResponse(BaseModel):
401
+ doc_id: str; chunks_indexed: int; message: str
 
 
 
402
 
403
  class SearchRequest(BaseModel):
404
+ doc_id: str = Field(...); query: str = Field(...); top_k: int = Field(5, ge=1, le=20)
 
 
 
405
 
406
  class SearchResult(BaseModel):
407
+ rank: int; score: float; text: str
 
 
 
408
 
409
  class SearchResponse(BaseModel):
410
+ doc_id: str; query: str; results: list[SearchResult]
 
 
 
411
 
412
  class EmbedRequest(BaseModel):
413
+ texts: list[str] = Field(...)
 
414
 
415
  class EmbedResponse(BaseModel):
416
+ embeddings: list[list[float]]; dimensions: int
 
 
417
 
418
  class FileEntry(BaseModel):
419
+ filename: str; content: str
 
 
420
 
421
  class BatchIndexRequest(BaseModel):
422
+ doc_id: str; files: list[FileEntry]; replace: bool = True
 
 
 
423
 
424
  class BatchIndexResponse(BaseModel):
425
+ doc_id: str; files_indexed: int; chunks_indexed: int
 
 
426
 
427
 
428
+ # ─────────────────────────── Routes ───────────────────────────────────────────
429
  @app.get("/", tags=["health"])
430
  def root():
431
  return {"status": "ok", "docs": "/docs"}
 
433
 
434
  @app.get("/health", tags=["health"])
435
  def health():
436
+ return {"status": "ok", "models_loaded": bool(models),
437
+ "backend": models["model"].backend if models else None}
438
 
439
 
440
  @app.post("/index", response_model=IndexResponse, tags=["search"])
441
  async def index_document(
442
+ file: Annotated[UploadFile, File(description="Source file to index")],
443
+ doc_id: Annotated[str, Form(description="Unique ID (defaults to filename)")] = "",
444
  ):
 
 
 
 
445
  if not models:
446
+ raise HTTPException(503, "Model not loaded yet.")
447
 
448
+ content = await file.read()
449
  if len(content) > MAX_UPLOAD_BYTES:
450
+ raise HTTPException(413,
451
+ f"File too large ({len(content)/1024/1024:.1f} MB). "
452
+ f"Max: {MAX_UPLOAD_BYTES//1024//1024} MB.")
453
+
454
+ source = content.decode("utf-8", errors="replace")
455
+ filename = file.filename or "unknown"
 
 
456
  resolved_id = doc_id.strip() or os.path.splitext(filename)[0]
457
 
458
  chunks = chunk_code(source, filename=filename)
459
  if not chunks:
460
+ raise HTTPException(400, "Document produced no chunks.")
461
 
462
+ await _build_table_streaming(resolved_id, chunks)
463
+ gc.collect()
 
 
 
464
 
465
  return IndexResponse(
466
  doc_id=resolved_id,
 
471
 
472
  @app.post("/index/batch", response_model=BatchIndexResponse, tags=["search"])
473
  async def index_batch(req: BatchIndexRequest):
 
 
 
 
474
  if not models:
475
  raise HTTPException(503, "Model not loaded yet.")
476
 
477
+ # Collect all chunks first (just strings β€” negligible RAM)
 
 
 
478
  all_chunks: list[str] = []
479
  for entry in req.files:
480
  all_chunks.extend(chunk_code(entry.content, filename=entry.filename))
481
 
482
  if not all_chunks:
483
  raise HTTPException(400, "No chunks produced from provided files.")
 
 
484
  if len(all_chunks) > MAX_CHUNKS:
485
+ raise HTTPException(413,
486
+ f"Too many chunks ({len(all_chunks):,}). Max: {MAX_CHUNKS:,}.")
 
 
 
487
 
488
+ # Streaming build β€” never holds full embeddings array
489
+ await _build_table_streaming(req.doc_id, all_chunks)
490
+ gc.collect()
 
491
 
492
  return BatchIndexResponse(
493
  doc_id=req.doc_id,
 
498
 
499
  @app.post("/search", response_model=SearchResponse, tags=["search"])
500
  async def search_document(req: SearchRequest):
501
+ if not _table_exists(req.doc_id):
 
502
  raise HTTPException(404, f"doc_id '{req.doc_id}' not found. Call /index first.")
503
 
504
+ loop = asyncio.get_event_loop()
505
+ results = await loop.run_in_executor(
506
+ _executor, _search_table, req.doc_id, req.query, req.top_k
507
+ )
508
  return SearchResponse(
509
  doc_id=req.doc_id,
510
  query=req.query,
 
514
 
515
  @app.post("/embed", response_model=EmbedResponse, tags=["embeddings"])
516
  async def embed_texts(req: EmbedRequest):
 
517
  if not models:
518
  raise HTTPException(503, "Model not loaded yet.")
519
  if len(req.texts) > 64:
520
  raise HTTPException(400, "Maximum 64 texts per request.")
521
 
522
+ embs = await _encode_async(req.texts)
523
+ return EmbedResponse(embeddings=embs.tolist(), dimensions=embs.shape[1])
 
 
 
524
 
525
 
526
  @app.get("/documents", tags=["search"])
527
  def list_documents():
528
+ db = _db()
529
+ docs = []
530
+ for name in db.table_names():
531
+ try:
532
+ tbl = db.open_table(name)
533
+ count = tbl.count_rows()
534
+ docs.append({"doc_id": name, "chunks": count})
535
+ except Exception:
536
+ docs.append({"doc_id": name, "chunks": -1})
537
+ return {"documents": docs}
538
 
539
 
540
  @app.delete("/documents/{doc_id}", tags=["search"])
541
  def delete_document(doc_id: str):
542
+ if not _table_exists(doc_id):
 
543
  raise HTTPException(404, f"doc_id '{doc_id}' not found.")
544
+ _db().drop_table(doc_id)
545
+ _table_cache.evict(doc_id)
546
  return {"deleted": doc_id}