# rag_mini.py from __future__ import annotations import os, math from pathlib import Path from typing import List, Tuple def _first_writable(paths): for p in paths: if not p: continue try: base = Path(p) base.mkdir(parents=True, exist_ok=True) test = base / ".writetest" test.write_text("ok") test.unlink(missing_ok=True) return base.resolve() except Exception: continue # last resort return Path("/tmp").resolve() # Prefer env(DATA_ROOT), then /data (Spaces persistent), else /tmp DATA_ROOT = _first_writable([os.getenv("DATA_ROOT"), "/data", "/tmp"]) ROOT_DIR = Path(__file__).parent.resolve() MM_ROOT = DATA_ROOT / "MaterialMind" DEFAULT_TOPK = 5 # ---- where the index lives ---- INDEX_DS = os.getenv("INDEX_DS", "").strip() INDEX_DIR_ENV = os.getenv("INDEX_DIR", "").strip() INDEX_COLLECTION = os.getenv("INDEX_COLLECTION", "").strip() # e.g., "materialmind" # ---- embedding settings (match local!) ---- # Use BGE-small (384-d) everywhere to avoid mismatch unless you *know* you indexed with OpenAI. EMB_PROVIDER = os.getenv("EMB_PROVIDER", "hf").strip().lower() # "hf" or "openai" EMB_MODEL = os.getenv("EMB_MODEL", "BAAI/bge-small-en-v1.5").strip() OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # only used if EMB_PROVIDER=openai # backends _EMB_FAST = None _EMB_ST = None _EMB_OAI = None def _l2norm(vec: List[float]) -> List[float]: s = math.sqrt(sum(x*x for x in vec)) or 1.0 return [x/s for x in vec] def _init_embedder(): """Initialize one embedding backend.""" global _EMB_FAST, _EMB_ST, _EMB_OAI if EMB_PROVIDER in ("openai","oai"): try: from openai import OpenAI _EMB_OAI = OpenAI(api_key=OPENAI_API_KEY) print(f"[EMB] OpenAI embeddings ready: {EMB_MODEL}", flush=True) return except Exception as e: print("[EMB] OpenAI embeddings unavailable:", e, flush=True) # HF path (FastEmbed → SentenceTransformers fallback) try: from fastembed import TextEmbedding _EMB_FAST = TextEmbedding(model_name=EMB_MODEL) print(f"[EMB] FastEmbed ready: {EMB_MODEL}", flush=True) return except Exception as e1: print("[EMB] FastEmbed unavailable:", e1, flush=True) try: from sentence_transformers import SentenceTransformer _EMB_ST = SentenceTransformer(EMB_MODEL) print(f"[EMB] SentenceTransformers ready: {EMB_MODEL}", flush=True) return except Exception as e2: print("[EMB] SentenceTransformers unavailable:", e2, flush=True) print("[EMB] ERROR: No embedding backend available. Install 'fastembed' or 'sentence-transformers'.", flush=True) def _embed(texts: List[str]) -> List[List[float]]: _init_embedder() if _EMB_OAI is not None: r = _EMB_OAI.embeddings.create(model=EMB_MODEL, input=texts) return [_l2norm(d.embedding) for d in r.data] if _EMB_FAST is not None: return [_l2norm(v) for v in _EMB_FAST.embed(texts)] if _EMB_ST is not None: arr = _EMB_ST.encode(texts, normalize_embeddings=True) return [_l2norm(v.tolist()) for v in arr] # last resort: zeros (prevents crashes; yields 0 hits) return [[0.0]*384 for _ in texts] # ---- index discovery ---- def _has_catalog(dirpath: Path) -> bool: for f in ["chroma.sqlite3","chroma.sqlite","chroma-collections.parquet", "index_metadata.pickle","data_level0.bin"]: if (dirpath/f).exists(): return True return False def _locate_local_index() -> Path: if INDEX_DIR_ENV: return (ROOT_DIR / INDEX_DIR_ENV).resolve() base = (MM_ROOT / "index" / "chroma_v3").resolve() if _has_catalog(base): return base hits = list(base.rglob("chroma.sqlite3")) if hits: return hits[0].parent return base def ensure_ready(): local = _locate_local_index() local.mkdir(parents=True, exist_ok=True) if INDEX_DS: try: from huggingface_hub import snapshot_download print("[RAG] downloading index dataset:", INDEX_DS, flush=True) snapshot_download(repo_id=INDEX_DS, repo_type="dataset", local_dir=str(MM_ROOT), local_dir_use_symlinks=False) except Exception as e: print("[RAG] dataset download failed:", e, flush=True) local = _locate_local_index() if not _has_catalog(local): print(f"[RAG] WARNING: No Chroma catalog found in: {local}", flush=True) print(" Set INDEX_DIR to the nested folder containing chroma.sqlite3", flush=True) else: print(f"[RAG] Index OK at {local}", flush=True) # ---- Chroma access ---- def _get_collection(): import chromadb local = _locate_local_index() client = chromadb.PersistentClient(path=str(local)) if INDEX_COLLECTION: try: return client.get_collection(INDEX_COLLECTION) except Exception: return client.get_or_create_collection( name=INDEX_COLLECTION, metadata={"hnsw:space": "cosine"} ) try: cols = client.list_collections() if cols: return client.get_collection(cols[0].name) except Exception: pass return client.get_or_create_collection( name="materialmind", metadata={"hnsw:space": "cosine"} ) def search(query: str, k: int = DEFAULT_TOPK) -> List[Tuple[str, str]]: local = _locate_local_index() if not _has_catalog(local): return [] try: col = _get_collection() qvec = _embed([query])[0] res = col.query(query_embeddings=[qvec], n_results=int(k), include=["documents","metadatas"]) except Exception as e: print("[RAG] query failed:", e, flush=True) return [] docs = (res.get("documents") or [[]])[0] metas = (res.get("metadatas") or [[]])[0] hits = [] for d, m in zip(docs, metas): if not d: continue src = (m or {}).get("source") or (m or {}).get("path") or "unknown" page = (m or {}).get("page") cite = f"{src}" + (f":p.{page}" if page else "") hits.append((d, cite)) return hits # ---- tiny debugger (for /debug/rag) ---- def rag_debug_info(): import chromadb local = _locate_local_index() client = chromadb.PersistentClient(path=str(local)) info = {"index_path": str(local), "collections": [], "emb": { "provider": EMB_PROVIDER, "model": EMB_MODEL }} try: for c in client.list_collections(): try: cnt = c.count() except Exception: cnt = -1 info["collections"].append({"name": c.name, "count": cnt}) except Exception as e: info["collections"].append({"error": str(e)}) return info