File size: 3,911 Bytes
777ea0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e48654b
 
 
777ea0e
e48654b
777ea0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""Retrieval over the GDScript corpus.

Loads the prebuilt FAISS index (cosine / IndexIDMap2, faiss_id == chunk id) and
chunks.jsonl, embeds the query with the same jina code model used to build the
index, and returns the top-k chunk records. Runs on CPU (query embedding is one
text at a time, fast).
"""
from __future__ import annotations

import json
import os
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path

import faiss
import numpy as np

DATA_DIR = Path(os.environ.get("GDRAG_SPACE_DATA", Path(__file__).parent / "data"))
FAISS_PATH = DATA_DIR / "embeddings.faiss"
CHUNKS_PATH = DATA_DIR / "chunks.jsonl"
EMBED_MODEL = "jinaai/jina-embeddings-v2-base-code"


@dataclass
class Hit:
    score: float
    text: str
    repo: str
    origin_url: str
    file_path: str
    kind: str


# ---------------------------------------------------------------------------
# Lazy singletons (loaded once per process)
# ---------------------------------------------------------------------------
@lru_cache(maxsize=1)
def _index() -> faiss.Index:
    return faiss.read_index(str(FAISS_PATH))


@lru_cache(maxsize=1)
def _chunks() -> dict[int, dict]:
    by_id: dict[int, dict] = {}
    with open(CHUNKS_PATH, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            try:
                r = json.loads(line)
            except json.JSONDecodeError:
                continue
            by_id[r["id"]] = r
    return by_id


@lru_cache(maxsize=1)
def _embedder():
    # transformers ~=4.45 (pinned) loads jina's remote code without shims.
    # device="cpu" is REQUIRED on ZeroGPU: query embedding runs in retrieve(),
    # outside the @spaces.GPU block, so CUDA isn't really allocated there — left
    # on auto it lands on a phantom cuda device and returns zero vectors.
    from sentence_transformers import SentenceTransformer
    return SentenceTransformer(EMBED_MODEL, trust_remote_code=True, device="cpu")


def _embed_query(query: str) -> np.ndarray:
    vec = _embedder().encode([query], normalize_embeddings=True,
                             show_progress_bar=False)
    return np.asarray(vec, dtype=np.float32)


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def index_available() -> bool:
    return FAISS_PATH.exists() and CHUNKS_PATH.exists()


def retrieve(query: str, k: int = 6) -> list[Hit]:
    """Return the top-k GDScript chunks most relevant to the query.

    Returns [] if the index hasn't been built/uploaded yet, so the Space still
    runs (answers without retrieval) until the Colab build pushes the index.
    """
    if not query.strip() or not index_available():
        return []
    qv = _embed_query(query)
    scores, ids = _index().search(qv, k)
    chunks = _chunks()
    hits: list[Hit] = []
    for score, cid in zip(scores[0], ids[0]):
        if cid < 0:
            continue
        rec = chunks.get(int(cid))
        if not rec:
            continue
        hits.append(Hit(
            score=float(score),
            text=rec.get("text", ""),
            repo=rec.get("repo", ""),
            origin_url=rec.get("origin_url", ""),
            file_path=rec.get("file_path", ""),
            kind=rec.get("kind", ""),
        ))
    return hits


def warmup() -> None:
    """Preload index, chunks and embedder (call at Space startup)."""
    if index_available():
        _index(); _chunks(); _embedder()


if __name__ == "__main__":
    import sys
    q = " ".join(sys.argv[1:]) or "how do I use @export and signals in GDScript"
    print(f"Query: {q}\n")
    for i, h in enumerate(retrieve(q, k=6), 1):
        print(f"[{i}] score={h.score:.3f}  {h.repo}  {h.file_path}")
        print("    " + h.text[:160].replace("\n", " ") + "...\n")