File size: 4,505 Bytes
aca8ab4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""
ChromaDB vector store with persistent storage.
"""
import logging
from typing import List, Optional
from pathlib import Path
import chromadb
from chromadb.config import Settings

from utils.schemas import PaperChunk
from rag.embeddings import EmbeddingGenerator

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


class VectorStore:
    """ChromaDB vector store for paper chunks."""

    def __init__(
        self,
        persist_directory: str = "data/chroma_db",
        collection_name: str = "research_papers"
    ):
        """
        Initialize vector store.

        Args:
            persist_directory: Directory for persistent storage
            collection_name: Name of the collection
        """
        self.persist_directory = Path(persist_directory)
        self.persist_directory.mkdir(parents=True, exist_ok=True)
        self.collection_name = collection_name

        # Initialize ChromaDB client
        self.client = chromadb.PersistentClient(
            path=str(self.persist_directory),
            settings=Settings(
                anonymized_telemetry=False,
                allow_reset=True
            )
        )

        # Get or create collection
        self.collection = self.client.get_or_create_collection(
            name=self.collection_name,
            metadata={"description": "Research paper chunks for RAG"}
        )

        logger.info(f"Vector store initialized with {self.collection.count()} chunks")

    def add_chunks(
        self,
        chunks: List[PaperChunk],
        embeddings: List[List[float]]
    ):
        """
        Add chunks to vector store.

        Args:
            chunks: List of PaperChunk objects
            embeddings: List of embedding vectors
        """
        if not chunks or not embeddings:
            logger.warning("No chunks or embeddings provided")
            return

        if len(chunks) != len(embeddings):
            raise ValueError("Number of chunks and embeddings must match")

        # Prepare data for ChromaDB
        ids = [chunk.chunk_id for chunk in chunks]
        documents = [chunk.content for chunk in chunks]
        metadatas = [
            {
                "paper_id": chunk.paper_id,
                "section": chunk.section or "unknown",
                "page_number": chunk.page_number or 0,
                "arxiv_url": chunk.arxiv_url,
                "title": chunk.metadata.get("title", ""),
                "authors": ",".join(chunk.metadata.get("authors", [])),
                "chunk_index": chunk.metadata.get("chunk_index", 0)
            }
            for chunk in chunks
        ]

        # Check for existing chunks and filter
        existing_ids = set(self.collection.get(ids=ids)["ids"])
        new_indices = [i for i, chunk_id in enumerate(ids) if chunk_id not in existing_ids]

        if not new_indices:
            logger.info("All chunks already exist in vector store")
            return

        # Add only new chunks
        new_ids = [ids[i] for i in new_indices]
        new_documents = [documents[i] for i in new_indices]
        new_metadatas = [metadatas[i] for i in new_indices]
        new_embeddings = [embeddings[i] for i in new_indices]

        self.collection.add(
            ids=new_ids,
            documents=new_documents,
            embeddings=new_embeddings,
            metadatas=new_metadatas
        )

        logger.info(f"Added {len(new_ids)} new chunks to vector store")

    def search(
        self,
        query_embedding: List[float],
        top_k: int = 5,
        paper_ids: Optional[List[str]] = None
    ) -> dict:
        """
        Search for similar chunks.

        Args:
            query_embedding: Query embedding vector
            top_k: Number of results to return
            paper_ids: Optional filter by paper IDs

        Returns:
            Dictionary with search results
        """
        # Build where clause for filtering
        where = None
        if paper_ids:
            if len(paper_ids) == 1:
                where = {"paper_id": paper_ids[0]}
            else:
                where = {"paper_id": {"$in": paper_ids}}

        # Perform search
        results = self.collection.query(
            query_embeddings=[query_embedding],
            n_results=top_k,
            where=where
        )

        logger.info(f"Found {len(results['ids'][0])} results")
        return results