import os import logging # Optional FAISS (keeps your original behavior) try: import faiss _HAS_FAISS = True except ImportError: logging.warning("FAISS not installed — retrieval will be disabled. Install faiss-cpu or faiss-gpu for full functionality.") _HAS_FAISS = False from sentence_transformers import SentenceTransformer # ---- Writable cache + stable repo id for Spaces ---- _ST_CACHE = os.getenv("SENTENCE_TRANSFORMERS_HOME", "/data/.cache/sentence-transformers") _ST_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2" # canonical repo id def _load_st_model(): """ Load SentenceTransformer using an explicit cache folder to avoid Hugging Face 'xet' transport / permission issues on Spaces. """ # Ensure cache dir exists try: os.makedirs(_ST_CACHE, exist_ok=True) except Exception as e: logging.warning(f"Could not create cache directory {_ST_CACHE}: {e}") # Primary attempt try: return SentenceTransformer(_ST_MODEL_ID, cache_folder=_ST_CACHE) except Exception as e1: logging.warning(f"Primary load failed for '{_ST_MODEL_ID}' with cache '{_ST_CACHE}': {e1}") # Secondary attempt (allow trust_remote_code just in case) try: return SentenceTransformer(_ST_MODEL_ID, cache_folder=_ST_CACHE, trust_remote_code=True) except Exception as e2: logging.exception("Failed loading SentenceTransformer model on both attempts.") raise RuntimeError( f"Failed loading SentenceTransformer '{_ST_MODEL_ID}'.\n" f"First error: {e1}\nSecond error: {e2}\n" f"Check cache dir permissions at: {_ST_CACHE}\n" f"Tip: ensure app.py sets HF_HUB_ENABLE_XET=0 and uses writable caches under /data." ) # Load embedding model (works even if FAISS missing) _model = _load_st_model() _index = None _docs = [] def init_retriever(docs=None): """ Initialize FAISS index if FAISS is available. docs: list[str] to index """ global _index, _docs if not _HAS_FAISS: _docs = docs or [] return if docs: _docs = docs embeddings = _model.encode(docs, convert_to_numpy=True, normalize_embeddings=False) d = embeddings.shape[1] _index = faiss.IndexFlatL2(d) _index.add(embeddings) def retrieve_context(query: str, k: int = 5): """ Retrieve top-k docs matching query. Falls back to empty list if FAISS unavailable or not initialized. """ if not _HAS_FAISS or _index is None or not _docs: return [] q_emb = _model.encode([query], convert_to_numpy=True, normalize_embeddings=False) D, I = _index.search(q_emb, k) return [_docs[i] for i in I[0] if 0 <= i < len(_docs)]