Spaces:
Sleeping
Sleeping
Update rag_mini.py
Browse files- rag_mini.py +18 -11
rag_mini.py
CHANGED
|
@@ -18,15 +18,17 @@ EMB_MODEL = "BAAI/bge-small-en-v1.5"
|
|
| 18 |
|
| 19 |
def _init_embedder():
|
| 20 |
global _EMB_FAST, _EMB_ST
|
| 21 |
-
if _EMB_FAST or _EMB_ST:
|
|
|
|
| 22 |
try:
|
| 23 |
from fastembed import TextEmbedding
|
| 24 |
_EMB_FAST = TextEmbedding(model_name=EMB_MODEL)
|
| 25 |
print("[EMB] FastEmbed ready:", EMB_MODEL, flush=True)
|
| 26 |
except Exception as e:
|
| 27 |
-
print("[EMB] FastEmbed unavailable ->
|
| 28 |
from sentence_transformers import SentenceTransformer
|
| 29 |
_EMB_ST = SentenceTransformer(EMB_MODEL)
|
|
|
|
| 30 |
|
| 31 |
def _embed(texts:List[str])->List[List[float]]:
|
| 32 |
_init_embedder()
|
|
@@ -35,7 +37,8 @@ def _embed(texts:List[str])->List[List[float]]:
|
|
| 35 |
return _EMB_ST.encode(texts, normalize_embeddings=True).tolist()
|
| 36 |
|
| 37 |
def _has_catalog(dirpath:Path)->bool:
|
| 38 |
-
for f in ["chroma.sqlite3","chroma.sqlite","chroma-collections.parquet",
|
|
|
|
| 39 |
if (dirpath/f).exists():
|
| 40 |
return True
|
| 41 |
return False
|
|
@@ -43,14 +46,16 @@ def _has_catalog(dirpath:Path)->bool:
|
|
| 43 |
def _locate_local_index()->Path:
|
| 44 |
# If user specified a precise directory, use it
|
| 45 |
if INDEX_DIR_ENV:
|
| 46 |
-
return ROOT_DIR / INDEX_DIR_ENV
|
| 47 |
-
# default path where we’ll snapshot_download
|
| 48 |
-
base = MM_ROOT / "index" / "chroma_v3"
|
| 49 |
# try direct
|
| 50 |
-
if _has_catalog(base):
|
|
|
|
| 51 |
# try nested uuid
|
| 52 |
hits = list(base.rglob("chroma.sqlite3"))
|
| 53 |
-
if hits:
|
|
|
|
| 54 |
return base
|
| 55 |
|
| 56 |
def ensure_ready():
|
|
@@ -67,7 +72,7 @@ def ensure_ready():
|
|
| 67 |
local_dir=str(MM_ROOT), local_dir_use_symlinks=False)
|
| 68 |
except Exception as e:
|
| 69 |
print("[RAG] dataset download failed:", e, flush=True)
|
| 70 |
-
#
|
| 71 |
local = _locate_local_index()
|
| 72 |
if not _has_catalog(local):
|
| 73 |
print(f"[RAG] WARNING: No Chroma catalog found in: {local}", flush=True)
|
|
@@ -95,7 +100,8 @@ def search(query:str, k:int=DEFAULT_TOPK)->List[Tuple[str,str]]:
|
|
| 95 |
try:
|
| 96 |
col = _get_collection()
|
| 97 |
qvec = _embed([query])[0]
|
| 98 |
-
res = col.query(query_embeddings=[qvec], n_results=int(k),
|
|
|
|
| 99 |
except Exception as e:
|
| 100 |
print("[RAG] query failed:", e, flush=True)
|
| 101 |
return []
|
|
@@ -103,7 +109,8 @@ def search(query:str, k:int=DEFAULT_TOPK)->List[Tuple[str,str]]:
|
|
| 103 |
metas = (res.get("metadatas") or [[]])[0]
|
| 104 |
hits=[]
|
| 105 |
for d, m in zip(docs, metas):
|
| 106 |
-
if not d:
|
|
|
|
| 107 |
src = (m or {}).get("source") or (m or {}).get("path") or "unknown"
|
| 108 |
page= (m or {}).get("page")
|
| 109 |
cite = f"{src}" + (f":p.{page}" if page else "")
|
|
|
|
| 18 |
|
| 19 |
def _init_embedder():
|
| 20 |
global _EMB_FAST, _EMB_ST
|
| 21 |
+
if _EMB_FAST or _EMB_ST:
|
| 22 |
+
return
|
| 23 |
try:
|
| 24 |
from fastembed import TextEmbedding
|
| 25 |
_EMB_FAST = TextEmbedding(model_name=EMB_MODEL)
|
| 26 |
print("[EMB] FastEmbed ready:", EMB_MODEL, flush=True)
|
| 27 |
except Exception as e:
|
| 28 |
+
print("[EMB] FastEmbed unavailable -> SentenceTransformers:", e, flush=True)
|
| 29 |
from sentence_transformers import SentenceTransformer
|
| 30 |
_EMB_ST = SentenceTransformer(EMB_MODEL)
|
| 31 |
+
print("[EMB] ST ready:", EMB_MODEL, flush=True)
|
| 32 |
|
| 33 |
def _embed(texts:List[str])->List[List[float]]:
|
| 34 |
_init_embedder()
|
|
|
|
| 37 |
return _EMB_ST.encode(texts, normalize_embeddings=True).tolist()
|
| 38 |
|
| 39 |
def _has_catalog(dirpath:Path)->bool:
|
| 40 |
+
for f in ["chroma.sqlite3","chroma.sqlite","chroma-collections.parquet",
|
| 41 |
+
"index_metadata.pickle","data_level0.bin"]:
|
| 42 |
if (dirpath/f).exists():
|
| 43 |
return True
|
| 44 |
return False
|
|
|
|
| 46 |
def _locate_local_index()->Path:
|
| 47 |
# If user specified a precise directory, use it
|
| 48 |
if INDEX_DIR_ENV:
|
| 49 |
+
return (ROOT_DIR / INDEX_DIR_ENV).resolve()
|
| 50 |
+
# default base path where we’ll snapshot_download
|
| 51 |
+
base = (MM_ROOT / "index" / "chroma_v3").resolve()
|
| 52 |
# try direct
|
| 53 |
+
if _has_catalog(base):
|
| 54 |
+
return base
|
| 55 |
# try nested uuid
|
| 56 |
hits = list(base.rglob("chroma.sqlite3"))
|
| 57 |
+
if hits:
|
| 58 |
+
return hits[0].parent
|
| 59 |
return base
|
| 60 |
|
| 61 |
def ensure_ready():
|
|
|
|
| 72 |
local_dir=str(MM_ROOT), local_dir_use_symlinks=False)
|
| 73 |
except Exception as e:
|
| 74 |
print("[RAG] dataset download failed:", e, flush=True)
|
| 75 |
+
# re-locate (after download)
|
| 76 |
local = _locate_local_index()
|
| 77 |
if not _has_catalog(local):
|
| 78 |
print(f"[RAG] WARNING: No Chroma catalog found in: {local}", flush=True)
|
|
|
|
| 100 |
try:
|
| 101 |
col = _get_collection()
|
| 102 |
qvec = _embed([query])[0]
|
| 103 |
+
res = col.query(query_embeddings=[qvec], n_results=int(k),
|
| 104 |
+
include=["documents","metadatas"])
|
| 105 |
except Exception as e:
|
| 106 |
print("[RAG] query failed:", e, flush=True)
|
| 107 |
return []
|
|
|
|
| 109 |
metas = (res.get("metadatas") or [[]])[0]
|
| 110 |
hits=[]
|
| 111 |
for d, m in zip(docs, metas):
|
| 112 |
+
if not d:
|
| 113 |
+
continue
|
| 114 |
src = (m or {}).get("source") or (m or {}).get("path") or "unknown"
|
| 115 |
page= (m or {}).get("page")
|
| 116 |
cite = f"{src}" + (f":p.{page}" if page else "")
|