Azizahalq commited on
Commit
b384849
·
verified ·
1 Parent(s): f78275d

Update rag_mini.py

Browse files
Files changed (1) hide show
  1. rag_mini.py +18 -11
rag_mini.py CHANGED
@@ -18,15 +18,17 @@ EMB_MODEL = "BAAI/bge-small-en-v1.5"
18
 
19
  def _init_embedder():
20
  global _EMB_FAST, _EMB_ST
21
- if _EMB_FAST or _EMB_ST: return
 
22
  try:
23
  from fastembed import TextEmbedding
24
  _EMB_FAST = TextEmbedding(model_name=EMB_MODEL)
25
  print("[EMB] FastEmbed ready:", EMB_MODEL, flush=True)
26
  except Exception as e:
27
- print("[EMB] FastEmbed unavailable -> ST:", e, flush=True)
28
  from sentence_transformers import SentenceTransformer
29
  _EMB_ST = SentenceTransformer(EMB_MODEL)
 
30
 
31
  def _embed(texts:List[str])->List[List[float]]:
32
  _init_embedder()
@@ -35,7 +37,8 @@ def _embed(texts:List[str])->List[List[float]]:
35
  return _EMB_ST.encode(texts, normalize_embeddings=True).tolist()
36
 
37
  def _has_catalog(dirpath:Path)->bool:
38
- for f in ["chroma.sqlite3","chroma.sqlite","chroma-collections.parquet","index_metadata.pickle","data_level0.bin"]:
 
39
  if (dirpath/f).exists():
40
  return True
41
  return False
@@ -43,14 +46,16 @@ def _has_catalog(dirpath:Path)->bool:
43
  def _locate_local_index()->Path:
44
  # If user specified a precise directory, use it
45
  if INDEX_DIR_ENV:
46
- return ROOT_DIR / INDEX_DIR_ENV
47
- # default path where we’ll snapshot_download
48
- base = MM_ROOT / "index" / "chroma_v3"
49
  # try direct
50
- if _has_catalog(base): return base
 
51
  # try nested uuid
52
  hits = list(base.rglob("chroma.sqlite3"))
53
- if hits: return hits[0].parent
 
54
  return base
55
 
56
  def ensure_ready():
@@ -67,7 +72,7 @@ def ensure_ready():
67
  local_dir=str(MM_ROOT), local_dir_use_symlinks=False)
68
  except Exception as e:
69
  print("[RAG] dataset download failed:", e, flush=True)
70
- # relocalize (after download)
71
  local = _locate_local_index()
72
  if not _has_catalog(local):
73
  print(f"[RAG] WARNING: No Chroma catalog found in: {local}", flush=True)
@@ -95,7 +100,8 @@ def search(query:str, k:int=DEFAULT_TOPK)->List[Tuple[str,str]]:
95
  try:
96
  col = _get_collection()
97
  qvec = _embed([query])[0]
98
- res = col.query(query_embeddings=[qvec], n_results=int(k), include=["documents","metadatas"])
 
99
  except Exception as e:
100
  print("[RAG] query failed:", e, flush=True)
101
  return []
@@ -103,7 +109,8 @@ def search(query:str, k:int=DEFAULT_TOPK)->List[Tuple[str,str]]:
103
  metas = (res.get("metadatas") or [[]])[0]
104
  hits=[]
105
  for d, m in zip(docs, metas):
106
- if not d: continue
 
107
  src = (m or {}).get("source") or (m or {}).get("path") or "unknown"
108
  page= (m or {}).get("page")
109
  cite = f"{src}" + (f":p.{page}" if page else "")
 
18
 
19
  def _init_embedder():
20
  global _EMB_FAST, _EMB_ST
21
+ if _EMB_FAST or _EMB_ST:
22
+ return
23
  try:
24
  from fastembed import TextEmbedding
25
  _EMB_FAST = TextEmbedding(model_name=EMB_MODEL)
26
  print("[EMB] FastEmbed ready:", EMB_MODEL, flush=True)
27
  except Exception as e:
28
+ print("[EMB] FastEmbed unavailable -> SentenceTransformers:", e, flush=True)
29
  from sentence_transformers import SentenceTransformer
30
  _EMB_ST = SentenceTransformer(EMB_MODEL)
31
+ print("[EMB] ST ready:", EMB_MODEL, flush=True)
32
 
33
  def _embed(texts:List[str])->List[List[float]]:
34
  _init_embedder()
 
37
  return _EMB_ST.encode(texts, normalize_embeddings=True).tolist()
38
 
39
  def _has_catalog(dirpath:Path)->bool:
40
+ for f in ["chroma.sqlite3","chroma.sqlite","chroma-collections.parquet",
41
+ "index_metadata.pickle","data_level0.bin"]:
42
  if (dirpath/f).exists():
43
  return True
44
  return False
 
46
  def _locate_local_index()->Path:
47
  # If user specified a precise directory, use it
48
  if INDEX_DIR_ENV:
49
+ return (ROOT_DIR / INDEX_DIR_ENV).resolve()
50
+ # default base path where we’ll snapshot_download
51
+ base = (MM_ROOT / "index" / "chroma_v3").resolve()
52
  # try direct
53
+ if _has_catalog(base):
54
+ return base
55
  # try nested uuid
56
  hits = list(base.rglob("chroma.sqlite3"))
57
+ if hits:
58
+ return hits[0].parent
59
  return base
60
 
61
  def ensure_ready():
 
72
  local_dir=str(MM_ROOT), local_dir_use_symlinks=False)
73
  except Exception as e:
74
  print("[RAG] dataset download failed:", e, flush=True)
75
+ # re-locate (after download)
76
  local = _locate_local_index()
77
  if not _has_catalog(local):
78
  print(f"[RAG] WARNING: No Chroma catalog found in: {local}", flush=True)
 
100
  try:
101
  col = _get_collection()
102
  qvec = _embed([query])[0]
103
+ res = col.query(query_embeddings=[qvec], n_results=int(k),
104
+ include=["documents","metadatas"])
105
  except Exception as e:
106
  print("[RAG] query failed:", e, flush=True)
107
  return []
 
109
  metas = (res.get("metadatas") or [[]])[0]
110
  hits=[]
111
  for d, m in zip(docs, metas):
112
+ if not d:
113
+ continue
114
  src = (m or {}).get("source") or (m or {}).get("path") or "unknown"
115
  page= (m or {}).get("page")
116
  cite = f"{src}" + (f":p.{page}" if page else "")