Spaces:
Sleeping
Sleeping
File size: 6,956 Bytes
c61a3d0 1a7ac30 ac08d2a a09fb65 b91214e 71a256d c61a3d0 91630d1 e2d584e 8bc7a37 ac08d2a 9221408 ac08d2a a09fb65 ac08d2a 9221408 ac08d2a 1a7ac30 9221408 ac08d2a 1a7ac30 9221408 ac08d2a b04d90e 1a7ac30 91630d1 ac08d2a b04d90e 1a7ac30 ac08d2a 1a7ac30 ac08d2a 91630d1 ac08d2a b04d90e ac08d2a 9221408 ac08d2a 91630d1 ac08d2a b384849 91630d1 ac08d2a 91630d1 b384849 91630d1 b384849 91630d1 ac08d2a 91630d1 a09fb65 91630d1 ac08d2a 1a7ac30 91630d1 ac08d2a 1a7ac30 ac08d2a 91630d1 aabf335 91630d1 1a7ac30 b384849 aabf335 91630d1 aabf335 1a7ac30 b91214e ac08d2a b91214e b384849 91630d1 ac08d2a a09fb65 1a7ac30 a09fb65 ac08d2a 9221408 ac08d2a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
# rag_mini.py
from __future__ import annotations
import os, math
from pathlib import Path
from typing import List, Tuple
def _first_writable(paths):
for p in paths:
if not p:
continue
try:
base = Path(p)
base.mkdir(parents=True, exist_ok=True)
test = base / ".writetest"
test.write_text("ok")
test.unlink(missing_ok=True)
return base.resolve()
except Exception:
continue
# last resort
return Path("/tmp").resolve()
# Prefer env(DATA_ROOT), then /data (Spaces persistent), else /tmp
DATA_ROOT = _first_writable([os.getenv("DATA_ROOT"), "/data", "/tmp"])
ROOT_DIR = Path(__file__).parent.resolve()
MM_ROOT = DATA_ROOT / "MaterialMind"
DEFAULT_TOPK = 5
# ---- where the index lives ----
INDEX_DS = os.getenv("INDEX_DS", "").strip()
INDEX_DIR_ENV = os.getenv("INDEX_DIR", "").strip()
INDEX_COLLECTION = os.getenv("INDEX_COLLECTION", "").strip() # e.g., "materialmind"
# ---- embedding settings (match local!) ----
# Use BGE-small (384-d) everywhere to avoid mismatch unless you *know* you indexed with OpenAI.
EMB_PROVIDER = os.getenv("EMB_PROVIDER", "hf").strip().lower() # "hf" or "openai"
EMB_MODEL = os.getenv("EMB_MODEL", "BAAI/bge-small-en-v1.5").strip()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # only used if EMB_PROVIDER=openai
# backends
_EMB_FAST = None
_EMB_ST = None
_EMB_OAI = None
def _l2norm(vec: List[float]) -> List[float]:
s = math.sqrt(sum(x*x for x in vec)) or 1.0
return [x/s for x in vec]
def _init_embedder():
"""Initialize one embedding backend."""
global _EMB_FAST, _EMB_ST, _EMB_OAI
if EMB_PROVIDER in ("openai","oai"):
try:
from openai import OpenAI
_EMB_OAI = OpenAI(api_key=OPENAI_API_KEY)
print(f"[EMB] OpenAI embeddings ready: {EMB_MODEL}", flush=True)
return
except Exception as e:
print("[EMB] OpenAI embeddings unavailable:", e, flush=True)
# HF path (FastEmbed → SentenceTransformers fallback)
try:
from fastembed import TextEmbedding
_EMB_FAST = TextEmbedding(model_name=EMB_MODEL)
print(f"[EMB] FastEmbed ready: {EMB_MODEL}", flush=True)
return
except Exception as e1:
print("[EMB] FastEmbed unavailable:", e1, flush=True)
try:
from sentence_transformers import SentenceTransformer
_EMB_ST = SentenceTransformer(EMB_MODEL)
print(f"[EMB] SentenceTransformers ready: {EMB_MODEL}", flush=True)
return
except Exception as e2:
print("[EMB] SentenceTransformers unavailable:", e2, flush=True)
print("[EMB] ERROR: No embedding backend available. Install 'fastembed' or 'sentence-transformers'.", flush=True)
def _embed(texts: List[str]) -> List[List[float]]:
_init_embedder()
if _EMB_OAI is not None:
r = _EMB_OAI.embeddings.create(model=EMB_MODEL, input=texts)
return [_l2norm(d.embedding) for d in r.data]
if _EMB_FAST is not None:
return [_l2norm(v) for v in _EMB_FAST.embed(texts)]
if _EMB_ST is not None:
arr = _EMB_ST.encode(texts, normalize_embeddings=True)
return [_l2norm(v.tolist()) for v in arr]
# last resort: zeros (prevents crashes; yields 0 hits)
return [[0.0]*384 for _ in texts]
# ---- index discovery ----
def _has_catalog(dirpath: Path) -> bool:
for f in ["chroma.sqlite3","chroma.sqlite","chroma-collections.parquet",
"index_metadata.pickle","data_level0.bin"]:
if (dirpath/f).exists():
return True
return False
def _locate_local_index() -> Path:
if INDEX_DIR_ENV:
return (ROOT_DIR / INDEX_DIR_ENV).resolve()
base = (MM_ROOT / "index" / "chroma_v3").resolve()
if _has_catalog(base):
return base
hits = list(base.rglob("chroma.sqlite3"))
if hits:
return hits[0].parent
return base
def ensure_ready():
local = _locate_local_index()
local.mkdir(parents=True, exist_ok=True)
if INDEX_DS:
try:
from huggingface_hub import snapshot_download
print("[RAG] downloading index dataset:", INDEX_DS, flush=True)
snapshot_download(repo_id=INDEX_DS, repo_type="dataset",
local_dir=str(MM_ROOT), local_dir_use_symlinks=False)
except Exception as e:
print("[RAG] dataset download failed:", e, flush=True)
local = _locate_local_index()
if not _has_catalog(local):
print(f"[RAG] WARNING: No Chroma catalog found in: {local}", flush=True)
print(" Set INDEX_DIR to the nested folder containing chroma.sqlite3", flush=True)
else:
print(f"[RAG] Index OK at {local}", flush=True)
# ---- Chroma access ----
def _get_collection():
import chromadb
local = _locate_local_index()
client = chromadb.PersistentClient(path=str(local))
if INDEX_COLLECTION:
try:
return client.get_collection(INDEX_COLLECTION)
except Exception:
return client.get_or_create_collection(
name=INDEX_COLLECTION, metadata={"hnsw:space": "cosine"}
)
try:
cols = client.list_collections()
if cols:
return client.get_collection(cols[0].name)
except Exception:
pass
return client.get_or_create_collection(
name="materialmind", metadata={"hnsw:space": "cosine"}
)
def search(query: str, k: int = DEFAULT_TOPK) -> List[Tuple[str, str]]:
local = _locate_local_index()
if not _has_catalog(local):
return []
try:
col = _get_collection()
qvec = _embed([query])[0]
res = col.query(query_embeddings=[qvec], n_results=int(k),
include=["documents","metadatas"])
except Exception as e:
print("[RAG] query failed:", e, flush=True)
return []
docs = (res.get("documents") or [[]])[0]
metas = (res.get("metadatas") or [[]])[0]
hits = []
for d, m in zip(docs, metas):
if not d:
continue
src = (m or {}).get("source") or (m or {}).get("path") or "unknown"
page = (m or {}).get("page")
cite = f"{src}" + (f":p.{page}" if page else "")
hits.append((d, cite))
return hits
# ---- tiny debugger (for /debug/rag) ----
def rag_debug_info():
import chromadb
local = _locate_local_index()
client = chromadb.PersistentClient(path=str(local))
info = {"index_path": str(local), "collections": [], "emb": {
"provider": EMB_PROVIDER, "model": EMB_MODEL
}}
try:
for c in client.list_collections():
try:
cnt = c.count()
except Exception:
cnt = -1
info["collections"].append({"name": c.name, "count": cnt})
except Exception as e:
info["collections"].append({"error": str(e)})
return info
|