File size: 4,164 Bytes
28e499f
 
34fbe97
c06718f
 
4f86dd4
2c3faf7
4f86dd4
 
202c49b
a876100
4f86dd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28e499f
a876100
40c75fc
c06718f
 
4f86dd4
28e499f
 
4f86dd4
 
 
28e499f
4f86dd4
28e499f
 
4f86dd4
28e499f
c06718f
40c75fc
 
 
28e499f
4f86dd4
28e499f
40c75fc
 
 
 
 
 
 
 
 
 
 
 
 
 
a876100
40c75fc
 
 
 
 
 
 
 
 
 
 
 
 
4f86dd4
c06718f
 
4f86dd4
28e499f
4f86dd4
28e499f
4f86dd4
a876100
 
 
4f86dd4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import json
import numpy as np
from sentence_transformers import SentenceTransformer
from typing import List, Dict
import logging
import time

logger = logging.getLogger(__name__)

class DocumentRetriever:
    def __init__(self, model_name='all-MiniLM-L6-v2', data_path='data/rupeia_document.json', cache_folder=None):
        """
        Initialize the DocumentRetriever with a SentenceTransformer model.

        Args:
            model_name (str): Name of the SentenceTransformer model (default: 'all-MiniLM-L6-v2').
            data_path (str): Path to the document JSON file (default: 'data/rupeia_document.json').
            cache_folder (str, optional): Directory to cache model files (default: None).
        """
        logger.info(f"Initializing DocumentRetriever with model: {model_name}, cache_folder: {cache_folder}")
        try:
            self.model = SentenceTransformer(model_name, cache_folder=cache_folder)
        except Exception as e:
            logger.error(f"Failed to load SentenceTransformer model: {str(e)}")
            raise
        self.data_path = data_path
        self.documents = self._load_documents()
        self.doc_embeddings = self._load_or_compute_embeddings()

    def _load_documents(self) -> List[Dict]:
        """Load documents from the JSON file."""
        try:
            with open(self.data_path, 'r') as f:
                documents = json.load(f)
                logger.info(f"Loaded {len(documents)} documents from {self.data_path}")
                return documents
        except FileNotFoundError:
            logger.warning(f"Data file not found at {self.data_path}, using empty documents")
            return []
        except json.JSONDecodeError:
            logger.warning(f"Invalid JSON in {self.data_path}, using empty documents")
            return []

    def _load_or_compute_embeddings(self) -> np.ndarray:
        """Load cached embeddings or compute new ones."""
        embedding_cache_path = 'data/doc_embeddings.npy'
        if not self.documents:
            logger.info("No documents to embed, returning empty embeddings")
            return np.array([])
        
        # Check for cached embeddings
        if os.path.exists(embedding_cache_path):
            try:
                embeddings = np.load(embedding_cache_path)
                if embeddings.shape[0] == len(self.documents):
                    logger.info(f"Loaded {embeddings.shape[0]} cached embeddings from {embedding_cache_path}")
                    return embeddings
                else:
                    logger.warning(f"Cached embeddings shape mismatch, recomputing...")
            except Exception as e:
                logger.warning(f"Failed to load cached embeddings: {str(e)}, recomputing...")
        
        # Compute new embeddings
        texts = [doc['content'] for doc in self.documents]
        logger.info(f"Computing embeddings for {len(texts)} documents...")
        start_time = time.time()
        embeddings = self.model.encode(texts, batch_size=32, show_progress_bar=True)
        logger.info(f"Embedding {len(texts)} documents took {time.time() - start_time:.2f} seconds")
        
        # Cache embeddings
        try:
            os.makedirs('data', exist_ok=True)
            np.save(embedding_cache_path, embeddings)
            logger.info(f"Saved embeddings to {embedding_cache_path}")
        except Exception as e:
            logger.warning(f"Failed to save embeddings: {str(e)}")
        
        return embeddings

    def retrieve(self, query: str, top_k: int = 3) -> List[Dict]:
        """Retrieve the top-k most relevant documents for a given query."""
        if not self.documents:
            logger.warning("No documents available for retrieval")
            return []
        logger.info(f"Retrieving top {top_k} documents for query: {query}")
        query_embedding = self.model.encode(query)
        scores = np.dot(self.doc_embeddings, query_embedding)
        top_indices = np.argsort(scores)[-top_k:][::-1]
        results = [self.documents[i] for i in top_indices]
        logger.info(f"Retrieved {len(results)} documents")
        return results