Azizahalq commited on
Commit
ac08d2a
·
verified ·
1 Parent(s): 4e72a0f

Update rag_mini.py

Browse files
Files changed (1) hide show
  1. rag_mini.py +82 -22
rag_mini.py CHANGED
@@ -1,55 +1,85 @@
1
  from __future__ import annotations
2
- import os, sys
3
  from pathlib import Path
4
  from typing import List, Tuple
5
 
 
6
  ROOT_DIR = Path(__file__).parent.resolve()
7
  MM_ROOT = ROOT_DIR / "MaterialMind"
8
  DEFAULT_TOPK = 5
9
 
10
- INDEX_DS = os.getenv("INDEX_DS", "").strip()
11
- INDEX_DIR_ENV = os.getenv("INDEX_DIR", "").strip()
 
 
12
 
13
- _EMB_FAST=None; _EMB_ST=None
14
- EMB_MODEL = "BAAI/bge-small-en-v1.5"
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def _init_embedder():
17
- global _EMB_FAST, _EMB_ST
18
- if _EMB_FAST or _EMB_ST:
19
- return
 
 
 
 
 
 
 
 
20
  try:
21
  from fastembed import TextEmbedding
22
- _EMB_FAST = TextEmbedding(model_name=EMB_MODEL)
23
- print("[EMB] FastEmbed ready:", EMB_MODEL, flush=True)
24
  return
25
  except Exception as e1:
26
  print("[EMB] FastEmbed unavailable:", e1, flush=True)
27
  try:
28
  from sentence_transformers import SentenceTransformer
29
  _EMB_ST = SentenceTransformer(EMB_MODEL)
30
- print("[EMB] SentenceTransformers ready:", EMB_MODEL, flush=True)
31
  return
32
  except Exception as e2:
33
  print("[EMB] SentenceTransformers unavailable:", e2, flush=True)
34
  print("[EMB] ERROR: No embedding backend available. Install 'fastembed' or 'sentence-transformers'.", flush=True)
35
 
36
- def _embed(texts:List[str])->List[List[float]]:
37
  _init_embedder()
 
 
 
38
  if _EMB_FAST is not None:
39
- return [v for v in _EMB_FAST.embed(texts)]
40
  if _EMB_ST is not None:
41
- return _EMB_ST.encode(texts, normalize_embeddings=True).tolist()
42
- # Fallback: no embeddings – return zeros to avoid crashing
43
- return [[0.0]*384 for _ in texts] # length doesn’t matter; Chroma ignores if we don’t query
 
 
 
44
 
45
- def _has_catalog(dirpath:Path)->bool:
 
46
  for f in ["chroma.sqlite3","chroma.sqlite","chroma-collections.parquet",
47
  "index_metadata.pickle","data_level0.bin"]:
48
  if (dirpath/f).exists():
49
  return True
50
  return False
51
 
52
- def _locate_local_index()->Path:
53
  if INDEX_DIR_ENV:
54
  return (ROOT_DIR / INDEX_DIR_ENV).resolve()
55
  base = (MM_ROOT / "index" / "chroma_v3").resolve()
@@ -78,19 +108,30 @@ def ensure_ready():
78
  else:
79
  print(f"[RAG] Index OK at {local}", flush=True)
80
 
 
81
  def _get_collection():
82
  import chromadb
83
  local = _locate_local_index()
84
  client = chromadb.PersistentClient(path=str(local))
 
 
 
 
 
 
 
 
85
  try:
86
  cols = client.list_collections()
87
  if cols:
88
  return client.get_collection(cols[0].name)
89
  except Exception:
90
  pass
91
- return client.get_or_create_collection(name="materialmind")
 
 
92
 
93
- def search(query:str, k:int=DEFAULT_TOPK)->List[Tuple[str,str]]:
94
  local = _locate_local_index()
95
  if not _has_catalog(local):
96
  return []
@@ -104,12 +145,31 @@ def search(query:str, k:int=DEFAULT_TOPK)->List[Tuple[str,str]]:
104
  return []
105
  docs = (res.get("documents") or [[]])[0]
106
  metas = (res.get("metadatas") or [[]])[0]
107
- hits=[]
108
  for d, m in zip(docs, metas):
109
  if not d:
110
  continue
111
  src = (m or {}).get("source") or (m or {}).get("path") or "unknown"
112
- page= (m or {}).get("page")
113
  cite = f"{src}" + (f":p.{page}" if page else "")
114
  hits.append((d, cite))
115
  return hits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
+ import os, math
3
  from pathlib import Path
4
  from typing import List, Tuple
5
 
6
+ # ---- paths / constants ----
7
  ROOT_DIR = Path(__file__).parent.resolve()
8
  MM_ROOT = ROOT_DIR / "MaterialMind"
9
  DEFAULT_TOPK = 5
10
 
11
+ # ---- where the index lives ----
12
+ INDEX_DS = os.getenv("INDEX_DS", "").strip()
13
+ INDEX_DIR_ENV = os.getenv("INDEX_DIR", "").strip()
14
+ INDEX_COLLECTION = os.getenv("INDEX_COLLECTION", "").strip() # e.g., "materialmind"
15
 
16
+ # ---- embedding settings (match local!) ----
17
+ # Use BGE-small (384-d) everywhere to avoid mismatch
18
+ EMB_PROVIDER = os.getenv("EMB_PROVIDER", "hf").strip().lower() # "hf" or "openai"
19
+ EMB_MODEL = os.getenv("EMB_MODEL", "BAAI/bge-small-en-v1.5").strip()
20
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # only used if EMB_PROVIDER=openai
21
+
22
+ # backends
23
+ _EMB_FAST = None
24
+ _EMB_ST = None
25
+ _EMB_OAI = None
26
+
27
+ def _l2norm(vec: List[float]) -> List[float]:
28
+ s = math.sqrt(sum(x*x for x in vec)) or 1.0
29
+ return [x/s for x in vec]
30
 
31
  def _init_embedder():
32
+ """Initialize exactly one embedding backend based on EMB_PROVIDER."""
33
+ global _EMB_FAST, _EMB_ST, _EMB_OAI
34
+ if EMB_PROVIDER in ("openai","oai"):
35
+ try:
36
+ from openai import OpenAI
37
+ _EMB_OAI = OpenAI(api_key=OPENAI_API_KEY)
38
+ print(f"[EMB] OpenAI embeddings ready: {EMB_MODEL}", flush=True)
39
+ return
40
+ except Exception as e:
41
+ print("[EMB] OpenAI embeddings unavailable:", e, flush=True)
42
+ # HF path (FastEmbed → SentenceTransformers fallback)
43
  try:
44
  from fastembed import TextEmbedding
45
+ _EMB_FAST = TextEmbedding(model_name=EMB_MODEL) # we’ll L2-normalize ourselves
46
+ print(f"[EMB] FastEmbed ready: {EMB_MODEL}", flush=True)
47
  return
48
  except Exception as e1:
49
  print("[EMB] FastEmbed unavailable:", e1, flush=True)
50
  try:
51
  from sentence_transformers import SentenceTransformer
52
  _EMB_ST = SentenceTransformer(EMB_MODEL)
53
+ print(f"[EMB] SentenceTransformers ready: {EMB_MODEL}", flush=True)
54
  return
55
  except Exception as e2:
56
  print("[EMB] SentenceTransformers unavailable:", e2, flush=True)
57
  print("[EMB] ERROR: No embedding backend available. Install 'fastembed' or 'sentence-transformers'.", flush=True)
58
 
59
+ def _embed(texts: List[str]) -> List[List[float]]:
60
  _init_embedder()
61
+ if _EMB_OAI is not None:
62
+ r = _EMB_OAI.embeddings.create(model=EMB_MODEL, input=texts)
63
+ return [_l2norm(d.embedding) for d in r.data]
64
  if _EMB_FAST is not None:
65
+ return [_l2norm(v) for v in _EMB_FAST.embed(texts)]
66
  if _EMB_ST is not None:
67
+ # ST can normalize internally, but we also L2-normalize for safety
68
+ from numpy import array
69
+ arr = _EMB_ST.encode(texts, normalize_embeddings=True)
70
+ return [_l2norm(list(v)) for v in array(arr).tolist()]
71
+ # last resort: zeros (prevents crashes; yields 0 hits)
72
+ return [[0.0]*384 for _ in texts]
73
 
74
+ # ---- index discovery ----
75
+ def _has_catalog(dirpath: Path) -> bool:
76
  for f in ["chroma.sqlite3","chroma.sqlite","chroma-collections.parquet",
77
  "index_metadata.pickle","data_level0.bin"]:
78
  if (dirpath/f).exists():
79
  return True
80
  return False
81
 
82
+ def _locate_local_index() -> Path:
83
  if INDEX_DIR_ENV:
84
  return (ROOT_DIR / INDEX_DIR_ENV).resolve()
85
  base = (MM_ROOT / "index" / "chroma_v3").resolve()
 
108
  else:
109
  print(f"[RAG] Index OK at {local}", flush=True)
110
 
111
+ # ---- Chroma access ----
112
  def _get_collection():
113
  import chromadb
114
  local = _locate_local_index()
115
  client = chromadb.PersistentClient(path=str(local))
116
+ if INDEX_COLLECTION:
117
+ try:
118
+ return client.get_collection(INDEX_COLLECTION)
119
+ except Exception:
120
+ # create with cosine metric to match unit-normalized embeddings
121
+ return client.get_or_create_collection(
122
+ name=INDEX_COLLECTION, metadata={"hnsw:space": "cosine"}
123
+ )
124
  try:
125
  cols = client.list_collections()
126
  if cols:
127
  return client.get_collection(cols[0].name)
128
  except Exception:
129
  pass
130
+ return client.get_or_create_collection(
131
+ name="materialmind", metadata={"hnsw:space": "cosine"}
132
+ )
133
 
134
+ def search(query: str, k: int = DEFAULT_TOPK) -> List[Tuple[str, str]]:
135
  local = _locate_local_index()
136
  if not _has_catalog(local):
137
  return []
 
145
  return []
146
  docs = (res.get("documents") or [[]])[0]
147
  metas = (res.get("metadatas") or [[]])[0]
148
+ hits = []
149
  for d, m in zip(docs, metas):
150
  if not d:
151
  continue
152
  src = (m or {}).get("source") or (m or {}).get("path") or "unknown"
153
+ page = (m or {}).get("page")
154
  cite = f"{src}" + (f":p.{page}" if page else "")
155
  hits.append((d, cite))
156
  return hits
157
+
158
+ # ---- tiny debugger (optional) ----
159
+ def rag_debug_info():
160
+ import chromadb
161
+ local = _locate_local_index()
162
+ client = chromadb.PersistentClient(path=str(local))
163
+ info = {"index_path": str(local), "collections": [], "emb": {
164
+ "provider": EMB_PROVIDER, "model": EMB_MODEL
165
+ }}
166
+ try:
167
+ for c in client.list_collections():
168
+ try:
169
+ cnt = c.count()
170
+ except Exception:
171
+ cnt = -1
172
+ info["collections"].append({"name": c.name, "count": cnt})
173
+ except Exception as e:
174
+ info["collections"].append({"error": str(e)})
175
+ return info