vivekchakraverty's picture
Fix ZeroGPU retrieval: pin jina query embedder to CPU
e48654b verified
"""Retrieval over the GDScript corpus.
Loads the prebuilt FAISS index (cosine / IndexIDMap2, faiss_id == chunk id) and
chunks.jsonl, embeds the query with the same jina code model used to build the
index, and returns the top-k chunk records. Runs on CPU (query embedding is one
text at a time, fast).
"""
from __future__ import annotations
import json
import os
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
import faiss
import numpy as np
DATA_DIR = Path(os.environ.get("GDRAG_SPACE_DATA", Path(__file__).parent / "data"))
FAISS_PATH = DATA_DIR / "embeddings.faiss"
CHUNKS_PATH = DATA_DIR / "chunks.jsonl"
EMBED_MODEL = "jinaai/jina-embeddings-v2-base-code"
@dataclass
class Hit:
score: float
text: str
repo: str
origin_url: str
file_path: str
kind: str
# ---------------------------------------------------------------------------
# Lazy singletons (loaded once per process)
# ---------------------------------------------------------------------------
@lru_cache(maxsize=1)
def _index() -> faiss.Index:
return faiss.read_index(str(FAISS_PATH))
@lru_cache(maxsize=1)
def _chunks() -> dict[int, dict]:
by_id: dict[int, dict] = {}
with open(CHUNKS_PATH, "r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
try:
r = json.loads(line)
except json.JSONDecodeError:
continue
by_id[r["id"]] = r
return by_id
@lru_cache(maxsize=1)
def _embedder():
# transformers ~=4.45 (pinned) loads jina's remote code without shims.
# device="cpu" is REQUIRED on ZeroGPU: query embedding runs in retrieve(),
# outside the @spaces.GPU block, so CUDA isn't really allocated there — left
# on auto it lands on a phantom cuda device and returns zero vectors.
from sentence_transformers import SentenceTransformer
return SentenceTransformer(EMBED_MODEL, trust_remote_code=True, device="cpu")
def _embed_query(query: str) -> np.ndarray:
vec = _embedder().encode([query], normalize_embeddings=True,
show_progress_bar=False)
return np.asarray(vec, dtype=np.float32)
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def index_available() -> bool:
return FAISS_PATH.exists() and CHUNKS_PATH.exists()
def retrieve(query: str, k: int = 6) -> list[Hit]:
"""Return the top-k GDScript chunks most relevant to the query.
Returns [] if the index hasn't been built/uploaded yet, so the Space still
runs (answers without retrieval) until the Colab build pushes the index.
"""
if not query.strip() or not index_available():
return []
qv = _embed_query(query)
scores, ids = _index().search(qv, k)
chunks = _chunks()
hits: list[Hit] = []
for score, cid in zip(scores[0], ids[0]):
if cid < 0:
continue
rec = chunks.get(int(cid))
if not rec:
continue
hits.append(Hit(
score=float(score),
text=rec.get("text", ""),
repo=rec.get("repo", ""),
origin_url=rec.get("origin_url", ""),
file_path=rec.get("file_path", ""),
kind=rec.get("kind", ""),
))
return hits
def warmup() -> None:
"""Preload index, chunks and embedder (call at Space startup)."""
if index_available():
_index(); _chunks(); _embedder()
if __name__ == "__main__":
import sys
q = " ".join(sys.argv[1:]) or "how do I use @export and signals in GDScript"
print(f"Query: {q}\n")
for i, h in enumerate(retrieve(q, k=6), 1):
print(f"[{i}] score={h.score:.3f} {h.repo} {h.file_path}")
print(" " + h.text[:160].replace("\n", " ") + "...\n")