Mozzicstar commited on
Commit ·
47202a9
1
Parent(s): d2f4c95
Add logging and HF_TOKEN for model loading
Browse files- scripts/query_qa.py +14 -1
scripts/query_qa.py
CHANGED
|
@@ -59,13 +59,26 @@ def load_vectorstore(persist_dir="vectorstore"):
|
|
| 59 |
|
| 60 |
def query(index, docs, q, model_id="sentence-transformers/all-mpnet-base-v2", top_k=5, api_token=None):
|
| 61 |
"""Query the vectorstore using local sentence-transformers model."""
|
|
|
|
|
|
|
|
|
|
| 62 |
# Use local model - HF API doesn't support direct embeddings for sentence-transformers
|
| 63 |
from sentence_transformers import SentenceTransformer
|
| 64 |
|
| 65 |
# Cache model in module-level variable
|
| 66 |
global _st_model
|
| 67 |
if '_st_model' not in globals() or _st_model is None:
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
emb = _st_model.encode([q], show_progress_bar=False, convert_to_numpy=True)
|
| 71 |
emb = np.array(emb, dtype=np.float32)
|
|
|
|
| 59 |
|
| 60 |
def query(index, docs, q, model_id="sentence-transformers/all-mpnet-base-v2", top_k=5, api_token=None):
|
| 61 |
"""Query the vectorstore using local sentence-transformers model."""
|
| 62 |
+
import logging
|
| 63 |
+
logger = logging.getLogger(__name__)
|
| 64 |
+
|
| 65 |
# Use local model - HF API doesn't support direct embeddings for sentence-transformers
|
| 66 |
from sentence_transformers import SentenceTransformer
|
| 67 |
|
| 68 |
# Cache model in module-level variable
|
| 69 |
global _st_model
|
| 70 |
if '_st_model' not in globals() or _st_model is None:
|
| 71 |
+
logger.info(f"Loading SentenceTransformer model: {model_id}")
|
| 72 |
+
try:
|
| 73 |
+
# Set HF token for model download if available
|
| 74 |
+
hf_token = api_token or os.getenv("HF_TOKEN")
|
| 75 |
+
if hf_token:
|
| 76 |
+
os.environ["HF_TOKEN"] = hf_token
|
| 77 |
+
_st_model = SentenceTransformer(model_id)
|
| 78 |
+
logger.info("Model loaded successfully")
|
| 79 |
+
except Exception as e:
|
| 80 |
+
logger.error(f"Failed to load model: {e}")
|
| 81 |
+
raise
|
| 82 |
|
| 83 |
emb = _st_model.encode([q], show_progress_bar=False, convert_to_numpy=True)
|
| 84 |
emb = np.array(emb, dtype=np.float32)
|