Rajan Sharma commited on
Commit
1dae236
·
verified ·
1 Parent(s): b9b3e60

Update retriever.py

Browse files
Files changed (1) hide show
  1. retriever.py +15 -11
retriever.py CHANGED
@@ -1,7 +1,8 @@
1
  import os
2
  import logging
 
3
 
4
- # Optional FAISS (keeps your original behavior)
5
  try:
6
  import faiss
7
  _HAS_FAISS = True
@@ -12,14 +13,11 @@ except ImportError:
12
  from sentence_transformers import SentenceTransformer
13
 
14
  # ---- Writable cache + stable repo id for Spaces ----
15
- _ST_CACHE = os.getenv("SENTENCE_TRANSFORMERS_HOME", "/data/.cache/sentence-transformers")
 
16
  _ST_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2" # canonical repo id
17
 
18
  def _load_st_model():
19
- """
20
- Load SentenceTransformer using an explicit cache folder to avoid
21
- Hugging Face 'xet' transport / permission issues on Spaces.
22
- """
23
  # Ensure cache dir exists
24
  try:
25
  os.makedirs(_ST_CACHE, exist_ok=True)
@@ -37,12 +35,12 @@ def _load_st_model():
37
  return SentenceTransformer(_ST_MODEL_ID, cache_folder=_ST_CACHE, trust_remote_code=True)
38
  except Exception as e2:
39
  logging.exception("Failed loading SentenceTransformer model on both attempts.")
40
- raise RuntimeError(
41
- f"Failed loading SentenceTransformer '{_ST_MODEL_ID}'.\n"
42
- f"First error: {e1}\nSecond error: {e2}\n"
43
- f"Check cache dir permissions at: {_ST_CACHE}\n"
44
- f"Tip: ensure app.py sets HF_HUB_ENABLE_XET=0 and uses writable caches under /data."
45
  )
 
46
 
47
  # Load embedding model (works even if FAISS missing)
48
  _model = _load_st_model()
@@ -56,12 +54,16 @@ def init_retriever(docs=None):
56
  docs: list[str] to index
57
  """
58
  global _index, _docs
 
 
 
59
  if not _HAS_FAISS:
60
  _docs = docs or []
61
  return
62
 
63
  if docs:
64
  _docs = docs
 
65
  embeddings = _model.encode(docs, convert_to_numpy=True, normalize_embeddings=False)
66
  d = embeddings.shape[1]
67
  _index = faiss.IndexFlatL2(d)
@@ -72,6 +74,8 @@ def retrieve_context(query: str, k: int = 5):
72
  Retrieve top-k docs matching query.
73
  Falls back to empty list if FAISS unavailable or not initialized.
74
  """
 
 
75
  if not _HAS_FAISS or _index is None or not _docs:
76
  return []
77
 
 
1
  import os
2
  import logging
3
+ from pathlib import Path
4
 
5
+ # Optional FAISS (keeps original behavior)
6
  try:
7
  import faiss
8
  _HAS_FAISS = True
 
13
  from sentence_transformers import SentenceTransformer
14
 
15
  # ---- Writable cache + stable repo id for Spaces ----
16
+ _HOME = Path.home()
17
+ _ST_CACHE = os.getenv("SENTENCE_TRANSFORMERS_HOME", str(_HOME / ".cache" / "sentence-transformers"))
18
  _ST_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2" # canonical repo id
19
 
20
  def _load_st_model():
 
 
 
 
21
  # Ensure cache dir exists
22
  try:
23
  os.makedirs(_ST_CACHE, exist_ok=True)
 
35
  return SentenceTransformer(_ST_MODEL_ID, cache_folder=_ST_CACHE, trust_remote_code=True)
36
  except Exception as e2:
37
  logging.exception("Failed loading SentenceTransformer model on both attempts.")
38
+ # Soft-fail: disable retrieval rather than crashing the whole app
39
+ logging.error(
40
+ "Disabling retrieval due to model load failure. "
41
+ f"Check permissions for {_ST_CACHE} and HF_* env vars."
 
42
  )
43
+ return None
44
 
45
  # Load embedding model (works even if FAISS missing)
46
  _model = _load_st_model()
 
54
  docs: list[str] to index
55
  """
56
  global _index, _docs
57
+ if _model is None:
58
+ _docs = docs or []
59
+ return
60
  if not _HAS_FAISS:
61
  _docs = docs or []
62
  return
63
 
64
  if docs:
65
  _docs = docs
66
+ import numpy as np
67
  embeddings = _model.encode(docs, convert_to_numpy=True, normalize_embeddings=False)
68
  d = embeddings.shape[1]
69
  _index = faiss.IndexFlatL2(d)
 
74
  Retrieve top-k docs matching query.
75
  Falls back to empty list if FAISS unavailable or not initialized.
76
  """
77
+ if _model is None:
78
+ return []
79
  if not _HAS_FAISS or _index is None or not _docs:
80
  return []
81