File size: 5,280 Bytes
18ab7fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5eb084
 
18ab7fd
a5eb084
 
 
 
 
 
 
 
 
 
 
 
18ab7fd
a5eb084
 
18ab7fd
 
 
 
 
47202a9
 
 
18ab7fd
 
 
 
 
 
47202a9
 
 
 
 
 
a58577d
 
 
 
 
 
 
47202a9
 
 
 
18ab7fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#!/usr/bin/env python
"""Query the local FAISS vectorstore and return top-k chunks using HF Inference API.
Usage: python scripts/query_qa.py --query "what is the personal income tax threshold" --top_k 5
"""
import argparse
import json
import os
import pickle
from pathlib import Path

from dotenv import load_dotenv, dotenv_values
load_dotenv()

import faiss
import numpy as np
import requests


def embed_text_hf(texts, model_id="sentence-transformers/all-mpnet-base-v2", api_token=None, timeout=15):
    """Call HF Inference API to embed texts with timeout."""
    if api_token is None:
        raise Exception("HF_TOKEN not found. Please set HF_TOKEN in your .env or environment variables.")
    
    # Use the new router endpoint (api-inference is deprecated)
    api_url = f"https://router.huggingface.co/hf-inference/models/{model_id}"
    headers = {"Authorization": f"Bearer {api_token}"}
    
    payload = {"inputs": texts, "options": {"wait_for_model": True}}
    response = requests.post(api_url, json=payload, headers=headers, timeout=timeout)
    
    if response.status_code != 200:
        if response.status_code == 401:
            raise Exception("HF API error 401: Unauthorized. Check your HF_TOKEN and model access permissions.")
        if response.status_code == 503:
            raise Exception("HF API: Model is loading, please retry in a moment.")
        raise Exception(f"HF API error {response.status_code}: {response.text}")
    
    embeddings = response.json()
    if isinstance(embeddings, dict) and "error" in embeddings:
        raise Exception(f"HF API error: {embeddings['error']}")
    
    # HF returns list of embeddings, need to handle the format
    # For feature-extraction, it returns token-level embeddings, we need to mean pool
    emb_array = np.array(embeddings, dtype=np.float32)
    if len(emb_array.shape) == 3:
        # Mean pooling over tokens
        emb_array = emb_array.mean(axis=1)
    
    return emb_array


def load_vectorstore(persist_dir="vectorstore"):
    import logging
    logger = logging.getLogger(__name__)
    persist_dir = Path(persist_dir)
    
    index_path = persist_dir / "faiss_index.bin"
    meta_path = persist_dir / "metadata.pkl"
    
    logger.info(f"Loading vectorstore from {persist_dir}")
    logger.info(f"Index exists: {index_path.exists()}, size: {index_path.stat().st_size if index_path.exists() else 0}")
    logger.info(f"Metadata exists: {meta_path.exists()}, size: {meta_path.stat().st_size if meta_path.exists() else 0}")
    
    index = faiss.read_index(str(index_path))
    logger.info(f"FAISS index loaded, ntotal: {index.ntotal}")
    
    with open(meta_path, "rb") as f:
        docs = pickle.load(f)
    logger.info(f"Metadata loaded, {len(docs)} documents")
    
    return index, docs


def query(index, docs, q, model_id="sentence-transformers/all-mpnet-base-v2", top_k=5, api_token=None):
    """Query the vectorstore using local sentence-transformers model."""
    import logging
    logger = logging.getLogger(__name__)
    
    # Use local model - HF API doesn't support direct embeddings for sentence-transformers
    from sentence_transformers import SentenceTransformer
    
    # Cache model in module-level variable
    global _st_model
    if '_st_model' not in globals() or _st_model is None:
        logger.info(f"Loading SentenceTransformer model: {model_id}")
        try:
            # Set HF token for model download if available
            hf_token = api_token or os.getenv("HF_TOKEN")
            if hf_token:
                os.environ["HF_TOKEN"] = hf_token
            
            # Use cache folder if set (Docker builds pre-download here)
            cache_folder = os.getenv("SENTENCE_TRANSFORMERS_HOME")
            if cache_folder:
                _st_model = SentenceTransformer(model_id, cache_folder=cache_folder)
            else:
                _st_model = SentenceTransformer(model_id)
            logger.info("Model loaded successfully")
        except Exception as e:
            logger.error(f"Failed to load model: {e}")
            raise
    
    emb = _st_model.encode([q], show_progress_bar=False, convert_to_numpy=True)
    emb = np.array(emb, dtype=np.float32)
    emb = emb / (np.linalg.norm(emb, axis=1, keepdims=True) + 1e-12)
    
    D, I = index.search(emb, top_k)
    results = []
    for score, idx in zip(D[0], I[0]):
        if idx < 0:
            continue
        meta = docs[idx]
        results.append({"score": float(score), "text": meta["text"], "source": meta["source"], "page": meta["page"], "chunk_id": meta["chunk_id"]})
    return results

# Model cache
_st_model = None


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--query", required=True)
    parser.add_argument("--persist_dir", default="vectorstore")
    parser.add_argument("--model", default="nvidia/llama-embed-nemotron-8b")
    parser.add_argument("--top_k", type=int, default=5)
    parser.add_argument("--hf_token", default=None, help="Hugging Face token (overrides HF_TOKEN env/.env)")
    args = parser.parse_args()

    index, docs = load_vectorstore(args.persist_dir)
    res = query(index, docs, args.query, args.model, args.top_k, api_token=args.hf_token)
    print(json.dumps(res, indent=2))