File size: 2,792 Bytes
7125686
14fa872
 
7125686
14fa872
 
 
 
 
 
 
c044be1
 
7125686
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14fa872
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7125686
14fa872
 
 
 
 
 
 
7125686
14fa872
 
 
 
7125686
14fa872
7125686
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import logging

# Optional FAISS (keeps your original behavior)
try:
    import faiss
    _HAS_FAISS = True
except ImportError:
    logging.warning("FAISS not installed — retrieval will be disabled. Install faiss-cpu or faiss-gpu for full functionality.")
    _HAS_FAISS = False

from sentence_transformers import SentenceTransformer

# ---- Writable cache + stable repo id for Spaces ----
_ST_CACHE = os.getenv("SENTENCE_TRANSFORMERS_HOME", "/data/.cache/sentence-transformers")
_ST_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"  # canonical repo id

def _load_st_model():
    """
    Load SentenceTransformer using an explicit cache folder to avoid
    Hugging Face 'xet' transport / permission issues on Spaces.
    """
    # Ensure cache dir exists
    try:
        os.makedirs(_ST_CACHE, exist_ok=True)
    except Exception as e:
        logging.warning(f"Could not create cache directory {_ST_CACHE}: {e}")

    # Primary attempt
    try:
        return SentenceTransformer(_ST_MODEL_ID, cache_folder=_ST_CACHE)
    except Exception as e1:
        logging.warning(f"Primary load failed for '{_ST_MODEL_ID}' with cache '{_ST_CACHE}': {e1}")

        # Secondary attempt (allow trust_remote_code just in case)
        try:
            return SentenceTransformer(_ST_MODEL_ID, cache_folder=_ST_CACHE, trust_remote_code=True)
        except Exception as e2:
            logging.exception("Failed loading SentenceTransformer model on both attempts.")
            raise RuntimeError(
                f"Failed loading SentenceTransformer '{_ST_MODEL_ID}'.\n"
                f"First error: {e1}\nSecond error: {e2}\n"
                f"Check cache dir permissions at: {_ST_CACHE}\n"
                f"Tip: ensure app.py sets HF_HUB_ENABLE_XET=0 and uses writable caches under /data."
            )

# Load embedding model (works even if FAISS missing)
_model = _load_st_model()

_index = None
_docs = []

def init_retriever(docs=None):
    """
    Initialize FAISS index if FAISS is available.
    docs: list[str] to index
    """
    global _index, _docs
    if not _HAS_FAISS:
        _docs = docs or []
        return

    if docs:
        _docs = docs
        embeddings = _model.encode(docs, convert_to_numpy=True, normalize_embeddings=False)
        d = embeddings.shape[1]
        _index = faiss.IndexFlatL2(d)
        _index.add(embeddings)

def retrieve_context(query: str, k: int = 5):
    """
    Retrieve top-k docs matching query.
    Falls back to empty list if FAISS unavailable or not initialized.
    """
    if not _HAS_FAISS or _index is None or not _docs:
        return []

    q_emb = _model.encode([query], convert_to_numpy=True, normalize_embeddings=False)
    D, I = _index.search(q_emb, k)
    return [_docs[i] for i in I[0] if 0 <= i < len(_docs)]