File size: 6,956 Bytes
c61a3d0
1a7ac30
ac08d2a
a09fb65
b91214e
71a256d
c61a3d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91630d1
e2d584e
8bc7a37
ac08d2a
9221408
 
ac08d2a
a09fb65
ac08d2a
9221408
 
 
ac08d2a
 
 
 
 
 
 
 
 
 
1a7ac30
 
9221408
ac08d2a
 
 
 
 
 
 
 
 
 
1a7ac30
 
9221408
ac08d2a
b04d90e
 
 
 
1a7ac30
91630d1
ac08d2a
b04d90e
 
 
 
1a7ac30
ac08d2a
1a7ac30
ac08d2a
 
 
91630d1
ac08d2a
b04d90e
ac08d2a
9221408
ac08d2a
 
91630d1
ac08d2a
 
b384849
 
91630d1
 
 
 
ac08d2a
91630d1
b384849
 
 
 
91630d1
b384849
 
91630d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac08d2a
91630d1
a09fb65
91630d1
 
ac08d2a
 
 
 
 
 
 
1a7ac30
 
 
91630d1
 
 
ac08d2a
 
 
1a7ac30
ac08d2a
91630d1
 
aabf335
 
91630d1
1a7ac30
b384849
 
aabf335
91630d1
aabf335
1a7ac30
b91214e
ac08d2a
b91214e
b384849
 
91630d1
ac08d2a
a09fb65
1a7ac30
a09fb65
ac08d2a
9221408
ac08d2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
# rag_mini.py 
from __future__ import annotations
import os, math
from pathlib import Path
from typing import List, Tuple

def _first_writable(paths):
    for p in paths:
        if not p: 
            continue
        try:
            base = Path(p)
            base.mkdir(parents=True, exist_ok=True)
            test = base / ".writetest"
            test.write_text("ok")
            test.unlink(missing_ok=True)
            return base.resolve()
        except Exception:
            continue
    # last resort
    return Path("/tmp").resolve()

# Prefer env(DATA_ROOT), then /data (Spaces persistent), else /tmp
DATA_ROOT = _first_writable([os.getenv("DATA_ROOT"), "/data", "/tmp"])
ROOT_DIR  = Path(__file__).parent.resolve()
MM_ROOT   = DATA_ROOT / "MaterialMind"

DEFAULT_TOPK = 5


# ---- where the index lives ----
INDEX_DS         = os.getenv("INDEX_DS", "").strip()
INDEX_DIR_ENV    = os.getenv("INDEX_DIR", "").strip()
INDEX_COLLECTION = os.getenv("INDEX_COLLECTION", "").strip()  # e.g., "materialmind"

# ---- embedding settings (match local!) ----
# Use BGE-small (384-d) everywhere to avoid mismatch unless you *know* you indexed with OpenAI.
EMB_PROVIDER  = os.getenv("EMB_PROVIDER", "hf").strip().lower()  # "hf" or "openai"
EMB_MODEL     = os.getenv("EMB_MODEL", "BAAI/bge-small-en-v1.5").strip()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")  # only used if EMB_PROVIDER=openai

# backends
_EMB_FAST = None
_EMB_ST   = None
_EMB_OAI  = None

def _l2norm(vec: List[float]) -> List[float]:
    s = math.sqrt(sum(x*x for x in vec)) or 1.0
    return [x/s for x in vec]

def _init_embedder():
    """Initialize one embedding backend."""
    global _EMB_FAST, _EMB_ST, _EMB_OAI
    if EMB_PROVIDER in ("openai","oai"):
        try:
            from openai import OpenAI
            _EMB_OAI = OpenAI(api_key=OPENAI_API_KEY)
            print(f"[EMB] OpenAI embeddings ready: {EMB_MODEL}", flush=True)
            return
        except Exception as e:
            print("[EMB] OpenAI embeddings unavailable:", e, flush=True)
    # HF path (FastEmbed → SentenceTransformers fallback)
    try:
        from fastembed import TextEmbedding
        _EMB_FAST = TextEmbedding(model_name=EMB_MODEL)
        print(f"[EMB] FastEmbed ready: {EMB_MODEL}", flush=True)
        return
    except Exception as e1:
        print("[EMB] FastEmbed unavailable:", e1, flush=True)
    try:
        from sentence_transformers import SentenceTransformer
        _EMB_ST = SentenceTransformer(EMB_MODEL)
        print(f"[EMB] SentenceTransformers ready: {EMB_MODEL}", flush=True)
        return
    except Exception as e2:
        print("[EMB] SentenceTransformers unavailable:", e2, flush=True)
        print("[EMB] ERROR: No embedding backend available. Install 'fastembed' or 'sentence-transformers'.", flush=True)

def _embed(texts: List[str]) -> List[List[float]]:
    _init_embedder()
    if _EMB_OAI is not None:
        r = _EMB_OAI.embeddings.create(model=EMB_MODEL, input=texts)
        return [_l2norm(d.embedding) for d in r.data]
    if _EMB_FAST is not None:
        return [_l2norm(v) for v in _EMB_FAST.embed(texts)]
    if _EMB_ST is not None:
        arr = _EMB_ST.encode(texts, normalize_embeddings=True)
        return [_l2norm(v.tolist()) for v in arr]
    # last resort: zeros (prevents crashes; yields 0 hits)
    return [[0.0]*384 for _ in texts]

# ---- index discovery ----
def _has_catalog(dirpath: Path) -> bool:
    for f in ["chroma.sqlite3","chroma.sqlite","chroma-collections.parquet",
              "index_metadata.pickle","data_level0.bin"]:
        if (dirpath/f).exists():
            return True
    return False

def _locate_local_index() -> Path:
    if INDEX_DIR_ENV:
        return (ROOT_DIR / INDEX_DIR_ENV).resolve()
    base = (MM_ROOT / "index" / "chroma_v3").resolve()
    if _has_catalog(base):
        return base
    hits = list(base.rglob("chroma.sqlite3"))
    if hits:
        return hits[0].parent
    return base

def ensure_ready():
    local = _locate_local_index()
    local.mkdir(parents=True, exist_ok=True)
    if INDEX_DS:
        try:
            from huggingface_hub import snapshot_download
            print("[RAG] downloading index dataset:", INDEX_DS, flush=True)
            snapshot_download(repo_id=INDEX_DS, repo_type="dataset",
                              local_dir=str(MM_ROOT), local_dir_use_symlinks=False)
        except Exception as e:
            print("[RAG] dataset download failed:", e, flush=True)
    local = _locate_local_index()
    if not _has_catalog(local):
        print(f"[RAG] WARNING: No Chroma catalog found in: {local}", flush=True)
        print("      Set INDEX_DIR to the nested folder containing chroma.sqlite3", flush=True)
    else:
        print(f"[RAG] Index OK at {local}", flush=True)

# ---- Chroma access ----
def _get_collection():
    import chromadb
    local = _locate_local_index()
    client = chromadb.PersistentClient(path=str(local))
    if INDEX_COLLECTION:
        try:
            return client.get_collection(INDEX_COLLECTION)
        except Exception:
            return client.get_or_create_collection(
                name=INDEX_COLLECTION, metadata={"hnsw:space": "cosine"}
            )
    try:
        cols = client.list_collections()
        if cols:
            return client.get_collection(cols[0].name)
    except Exception:
        pass
    return client.get_or_create_collection(
        name="materialmind", metadata={"hnsw:space": "cosine"}
    )

def search(query: str, k: int = DEFAULT_TOPK) -> List[Tuple[str, str]]:
    local = _locate_local_index()
    if not _has_catalog(local):
        return []
    try:
        col = _get_collection()
        qvec = _embed([query])[0]
        res = col.query(query_embeddings=[qvec], n_results=int(k),
                        include=["documents","metadatas"])
    except Exception as e:
        print("[RAG] query failed:", e, flush=True)
        return []
    docs  = (res.get("documents") or [[]])[0]
    metas = (res.get("metadatas") or [[]])[0]
    hits = []
    for d, m in zip(docs, metas):
        if not d:
            continue
        src = (m or {}).get("source") or (m or {}).get("path") or "unknown"
        page = (m or {}).get("page")
        cite = f"{src}" + (f":p.{page}" if page else "")
        hits.append((d, cite))
    return hits

# ---- tiny debugger (for /debug/rag) ----
def rag_debug_info():
    import chromadb
    local = _locate_local_index()
    client = chromadb.PersistentClient(path=str(local))
    info = {"index_path": str(local), "collections": [], "emb": {
        "provider": EMB_PROVIDER, "model": EMB_MODEL
    }}
    try:
        for c in client.list_collections():
            try:
                cnt = c.count()
            except Exception:
                cnt = -1
            info["collections"].append({"name": c.name, "count": cnt})
    except Exception as e:
        info["collections"].append({"error": str(e)})
    return info