File size: 2,763 Bytes
7125686
14fa872
1dae236
14fa872
1dae236
14fa872
 
 
 
 
 
 
c044be1
 
7125686
1dae236
 
7125686
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1dae236
 
 
 
7125686
1dae236
7125686
 
 
14fa872
 
 
 
 
 
 
 
 
 
1dae236
 
 
14fa872
 
 
 
 
 
1dae236
7125686
14fa872
 
 
 
 
 
 
7125686
14fa872
1dae236
 
14fa872
 
 
7125686
14fa872
7125686
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import logging
from pathlib import Path

# Optional FAISS (keeps original behavior)
try:
    import faiss
    _HAS_FAISS = True
except ImportError:
    logging.warning("FAISS not installed — retrieval will be disabled. Install faiss-cpu or faiss-gpu for full functionality.")
    _HAS_FAISS = False

from sentence_transformers import SentenceTransformer

# ---- Writable cache + stable repo id for Spaces ----
_HOME = Path.home()
_ST_CACHE = os.getenv("SENTENCE_TRANSFORMERS_HOME", str(_HOME / ".cache" / "sentence-transformers"))
_ST_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"  # canonical repo id

def _load_st_model():
    # Ensure cache dir exists
    try:
        os.makedirs(_ST_CACHE, exist_ok=True)
    except Exception as e:
        logging.warning(f"Could not create cache directory {_ST_CACHE}: {e}")

    # Primary attempt
    try:
        return SentenceTransformer(_ST_MODEL_ID, cache_folder=_ST_CACHE)
    except Exception as e1:
        logging.warning(f"Primary load failed for '{_ST_MODEL_ID}' with cache '{_ST_CACHE}': {e1}")

        # Secondary attempt (allow trust_remote_code just in case)
        try:
            return SentenceTransformer(_ST_MODEL_ID, cache_folder=_ST_CACHE, trust_remote_code=True)
        except Exception as e2:
            logging.exception("Failed loading SentenceTransformer model on both attempts.")
            # Soft-fail: disable retrieval rather than crashing the whole app
            logging.error(
                "Disabling retrieval due to model load failure. "
                f"Check permissions for {_ST_CACHE} and HF_* env vars."
            )
            return None

# Load embedding model (works even if FAISS missing)
_model = _load_st_model()

_index = None
_docs = []

def init_retriever(docs=None):
    """
    Initialize FAISS index if FAISS is available.
    docs: list[str] to index
    """
    global _index, _docs
    if _model is None:
        _docs = docs or []
        return
    if not _HAS_FAISS:
        _docs = docs or []
        return

    if docs:
        _docs = docs
        import numpy as np
        embeddings = _model.encode(docs, convert_to_numpy=True, normalize_embeddings=False)
        d = embeddings.shape[1]
        _index = faiss.IndexFlatL2(d)
        _index.add(embeddings)

def retrieve_context(query: str, k: int = 5):
    """
    Retrieve top-k docs matching query.
    Falls back to empty list if FAISS unavailable or not initialized.
    """
    if _model is None:
        return []
    if not _HAS_FAISS or _index is None or not _docs:
        return []

    q_emb = _model.encode([query], convert_to_numpy=True, normalize_embeddings=False)
    D, I = _index.search(q_emb, k)
    return [_docs[i] for i in I[0] if 0 <= i < len(_docs)]