Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| from pathlib import Path | |
| # Optional FAISS (keeps original behavior) | |
| try: | |
| import faiss | |
| _HAS_FAISS = True | |
| except ImportError: | |
| logging.warning("FAISS not installed — retrieval will be disabled. Install faiss-cpu or faiss-gpu for full functionality.") | |
| _HAS_FAISS = False | |
| from sentence_transformers import SentenceTransformer | |
| # ---- Writable cache + stable repo id for Spaces ---- | |
| _HOME = Path.home() | |
| _ST_CACHE = os.getenv("SENTENCE_TRANSFORMERS_HOME", str(_HOME / ".cache" / "sentence-transformers")) | |
| _ST_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2" # canonical repo id | |
| def _load_st_model(): | |
| # Ensure cache dir exists | |
| try: | |
| os.makedirs(_ST_CACHE, exist_ok=True) | |
| except Exception as e: | |
| logging.warning(f"Could not create cache directory {_ST_CACHE}: {e}") | |
| # Primary attempt | |
| try: | |
| return SentenceTransformer(_ST_MODEL_ID, cache_folder=_ST_CACHE) | |
| except Exception as e1: | |
| logging.warning(f"Primary load failed for '{_ST_MODEL_ID}' with cache '{_ST_CACHE}': {e1}") | |
| # Secondary attempt (allow trust_remote_code just in case) | |
| try: | |
| return SentenceTransformer(_ST_MODEL_ID, cache_folder=_ST_CACHE, trust_remote_code=True) | |
| except Exception as e2: | |
| logging.exception("Failed loading SentenceTransformer model on both attempts.") | |
| # Soft-fail: disable retrieval rather than crashing the whole app | |
| logging.error( | |
| "Disabling retrieval due to model load failure. " | |
| f"Check permissions for {_ST_CACHE} and HF_* env vars." | |
| ) | |
| return None | |
| # Load embedding model (works even if FAISS missing) | |
| _model = _load_st_model() | |
| _index = None | |
| _docs = [] | |
| def init_retriever(docs=None): | |
| """ | |
| Initialize FAISS index if FAISS is available. | |
| docs: list[str] to index | |
| """ | |
| global _index, _docs | |
| if _model is None: | |
| _docs = docs or [] | |
| return | |
| if not _HAS_FAISS: | |
| _docs = docs or [] | |
| return | |
| if docs: | |
| _docs = docs | |
| import numpy as np | |
| embeddings = _model.encode(docs, convert_to_numpy=True, normalize_embeddings=False) | |
| d = embeddings.shape[1] | |
| _index = faiss.IndexFlatL2(d) | |
| _index.add(embeddings) | |
| def retrieve_context(query: str, k: int = 5): | |
| """ | |
| Retrieve top-k docs matching query. | |
| Falls back to empty list if FAISS unavailable or not initialized. | |
| """ | |
| if _model is None: | |
| return [] | |
| if not _HAS_FAISS or _index is None or not _docs: | |
| return [] | |
| q_emb = _model.encode([query], convert_to_numpy=True, normalize_embeddings=False) | |
| D, I = _index.search(q_emb, k) | |
| return [_docs[i] for i in I[0] if 0 <= i < len(_docs)] | |