Rajan Sharma commited on
Commit
7125686
·
verified ·
1 Parent(s): d9257b7

Update retriever.py

Browse files
Files changed (1) hide show
  1. retriever.py +41 -6
retriever.py CHANGED
@@ -1,5 +1,7 @@
 
1
  import logging
2
 
 
3
  try:
4
  import faiss
5
  _HAS_FAISS = True
@@ -9,8 +11,41 @@ except ImportError:
9
 
10
  from sentence_transformers import SentenceTransformer
11
 
12
- # load embedding model (still works even if FAISS missing)
13
- _model = SentenceTransformer("all-MiniLM-L6-v2")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  _index = None
16
  _docs = []
@@ -27,7 +62,7 @@ def init_retriever(docs=None):
27
 
28
  if docs:
29
  _docs = docs
30
- embeddings = _model.encode(docs, convert_to_numpy=True)
31
  d = embeddings.shape[1]
32
  _index = faiss.IndexFlatL2(d)
33
  _index.add(embeddings)
@@ -35,11 +70,11 @@ def init_retriever(docs=None):
35
  def retrieve_context(query: str, k: int = 5):
36
  """
37
  Retrieve top-k docs matching query.
38
- Falls back to empty list if FAISS unavailable.
39
  """
40
  if not _HAS_FAISS or _index is None or not _docs:
41
  return []
42
 
43
- q_emb = _model.encode([query], convert_to_numpy=True)
44
  D, I = _index.search(q_emb, k)
45
- return [_docs[i] for i in I[0] if i < len(_docs)]
 
1
+ import os
2
  import logging
3
 
4
+ # Optional FAISS (keeps your original behavior)
5
  try:
6
  import faiss
7
  _HAS_FAISS = True
 
11
 
12
  from sentence_transformers import SentenceTransformer
13
 
14
+ # ---- Writable cache + stable repo id for Spaces ----
15
+ _ST_CACHE = os.getenv("SENTENCE_TRANSFORMERS_HOME", "/data/.cache/sentence-transformers")
16
+ _ST_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2" # canonical repo id
17
+
18
+ def _load_st_model():
19
+ """
20
+ Load SentenceTransformer using an explicit cache folder to avoid
21
+ Hugging Face 'xet' transport / permission issues on Spaces.
22
+ """
23
+ # Ensure cache dir exists
24
+ try:
25
+ os.makedirs(_ST_CACHE, exist_ok=True)
26
+ except Exception as e:
27
+ logging.warning(f"Could not create cache directory {_ST_CACHE}: {e}")
28
+
29
+ # Primary attempt
30
+ try:
31
+ return SentenceTransformer(_ST_MODEL_ID, cache_folder=_ST_CACHE)
32
+ except Exception as e1:
33
+ logging.warning(f"Primary load failed for '{_ST_MODEL_ID}' with cache '{_ST_CACHE}': {e1}")
34
+
35
+ # Secondary attempt (allow trust_remote_code just in case)
36
+ try:
37
+ return SentenceTransformer(_ST_MODEL_ID, cache_folder=_ST_CACHE, trust_remote_code=True)
38
+ except Exception as e2:
39
+ logging.exception("Failed loading SentenceTransformer model on both attempts.")
40
+ raise RuntimeError(
41
+ f"Failed loading SentenceTransformer '{_ST_MODEL_ID}'.\n"
42
+ f"First error: {e1}\nSecond error: {e2}\n"
43
+ f"Check cache dir permissions at: {_ST_CACHE}\n"
44
+ f"Tip: ensure app.py sets HF_HUB_ENABLE_XET=0 and uses writable caches under /data."
45
+ )
46
+
47
+ # Load embedding model (works even if FAISS missing)
48
+ _model = _load_st_model()
49
 
50
  _index = None
51
  _docs = []
 
62
 
63
  if docs:
64
  _docs = docs
65
+ embeddings = _model.encode(docs, convert_to_numpy=True, normalize_embeddings=False)
66
  d = embeddings.shape[1]
67
  _index = faiss.IndexFlatL2(d)
68
  _index.add(embeddings)
 
70
  def retrieve_context(query: str, k: int = 5):
71
  """
72
  Retrieve top-k docs matching query.
73
+ Falls back to empty list if FAISS unavailable or not initialized.
74
  """
75
  if not _HAS_FAISS or _index is None or not _docs:
76
  return []
77
 
78
+ q_emb = _model.encode([query], convert_to_numpy=True, normalize_embeddings=False)
79
  D, I = _index.search(q_emb, k)
80
+ return [_docs[i] for i in I[0] if 0 <= i < len(_docs)]