navid72m commited on
Commit
43efcb9
·
verified ·
1 Parent(s): 192ac5e

Upload 9 files

Browse files
Files changed (9) hide show
  1. app.py +91 -0
  2. config.py +121 -0
  3. document-processor.py +239 -0
  4. embedding-model.py +286 -0
  5. package-init.py +23 -0
  6. rag-engine.py +357 -0
  7. readme.md +105 -0
  8. streamlit-app.py +229 -0
  9. vector-db.py +545 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main application entry point.
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ from fastapi import FastAPI
8
+ import uvicorn
9
+ from dotenv import load_dotenv
10
+
11
+ # Load environment variables
12
+ load_dotenv()
13
+
14
+ # Configure logging
15
+ from config import get_logging_config
16
+ import logging.config
17
+ logging.config.dictConfig(get_logging_config())
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Create FastAPI app
21
+ app = FastAPI(
22
+ title="RAG API System",
23
+ description="API for RAG-based question answering",
24
+ version="1.0.0"
25
+ )
26
+
27
+ # Initialize components
28
+ def initialize_components():
29
+ """Initialize all system components."""
30
+ logger.info("Initializing system components")
31
+
32
+ # Create embedding model
33
+ from embedding.model import create_embedding_model
34
+ embedding_model = create_embedding_model()
35
+ logger.info(f"Embedding model initialized with dimension {embedding_model.dimension}")
36
+
37
+ # Create vector database
38
+ from storage.vector_db import create_vector_database
39
+ vector_db = create_vector_database(dimension=embedding_model.dimension)
40
+ logger.info("Vector database initialized")
41
+
42
+ # Create RAG engine
43
+ from rag.engine import create_rag_engine
44
+ rag_engine = create_rag_engine(
45
+ embedder=embedding_model,
46
+ vector_db=vector_db
47
+ )
48
+ logger.info("RAG engine initialized")
49
+
50
+ return rag_engine
51
+
52
+ # Register API routes
53
+ def register_api_routes(app, rag_engine):
54
+ """Register API routes."""
55
+ from api.routes import RAGAPIRouter
56
+ router = RAGAPIRouter(app, rag_engine)
57
+ logger.info("API routes registered")
58
+
59
+ # Add health check route
60
+ @app.get("/", tags=["Root"])
61
+ async def root():
62
+ """Root endpoint returning basic system information."""
63
+ return {
64
+ "name": "RAG API System",
65
+ "version": "1.0.0",
66
+ "status": "running"
67
+ }
68
+
69
+ # Main entry point
70
+ def main():
71
+ """Main application entry point."""
72
+ logger.info("Starting RAG API system")
73
+
74
+ # Initialize components
75
+ rag_engine = initialize_components()
76
+
77
+ # Register API routes
78
+ register_api_routes(app, rag_engine)
79
+
80
+ # Run server if executed directly
81
+ if __name__ == "__main__":
82
+ host = os.getenv("API_HOST", "0.0.0.0")
83
+ port = int(os.getenv("API_PORT", "8000"))
84
+
85
+ logger.info(f"Starting server on http://{host}:{port}")
86
+ uvicorn.run(app, host=host, port=port)
87
+
88
+ return app
89
+
90
+ # Create and run application
91
+ app = main()
config.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Centralized configuration for the RAG system.
3
+ """
4
+
5
+ import os
6
+ from typing import Optional, Dict, Any
7
+ from dotenv import load_dotenv
8
+
9
+ # Load environment variables
10
+ load_dotenv()
11
+
12
+ # Embedding model settings
13
+ EMBEDDING_MODEL_NAME = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
14
+ EMBEDDING_DIMENSION = int(os.getenv("EMBEDDING_DIMENSION", "384"))
15
+ USE_GPU = os.getenv("USE_GPU", "True").lower() in ("true", "1", "t")
16
+
17
+ # Document processing settings
18
+ CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", "1000"))
19
+ CHUNK_OVERLAP = int(os.getenv("CHUNK_OVERLAP", "200"))
20
+ MAX_LENGTH = int(os.getenv("MAX_LENGTH", "512"))
21
+
22
+ # Vector database settings
23
+ VECTOR_DB_TYPE = os.getenv("VECTOR_DB_TYPE", "faiss") # Options: "faiss", "milvus", etc.
24
+ FAISS_INDEX_TYPE = os.getenv("FAISS_INDEX_TYPE", "Flat") # Options: "Flat", "IVF", "HNSW"
25
+ MONGODB_URI = os.getenv("MONGODB_URI", "mongodb://localhost:27017/")
26
+ DB_NAME = os.getenv("DB_NAME", "rag_db")
27
+ COLLECTION_NAME = os.getenv("COLLECTION_NAME", "documents")
28
+
29
+ # Retrieval settings
30
+ TOP_K = int(os.getenv("TOP_K", "5"))
31
+ SEARCH_TYPE = os.getenv("SEARCH_TYPE", "hybrid") # Options: "semantic", "keyword", "hybrid"
32
+ SEMANTIC_SEARCH_WEIGHT = float(os.getenv("SEMANTIC_SEARCH_WEIGHT", "0.7"))
33
+ KEYWORD_SEARCH_WEIGHT = float(os.getenv("KEYWORD_SEARCH_WEIGHT", "0.3"))
34
+
35
+ # LLM settings
36
+ LLM_MODEL_NAME = os.getenv("LLM_MODEL", "gpt-3.5-turbo")
37
+ LLM_API_KEY = os.getenv("OPENAI_API_KEY")
38
+ LLM_TEMPERATURE = float(os.getenv("LLM_TEMPERATURE", "0.2"))
39
+ LLM_MAX_TOKENS = int(os.getenv("LLM_MAX_TOKENS", "512"))
40
+
41
+ # Local LLM settings (optional)
42
+ LOCAL_LLM_MODEL_NAME = os.getenv("LOCAL_LLM_MODEL", "google/flan-t5-base")
43
+ USE_LOCAL_LLM = os.getenv("USE_LOCAL_LLM", "False").lower() in ("true", "1", "t")
44
+
45
+ # API settings
46
+ API_HOST = os.getenv("API_HOST", "0.0.0.0")
47
+ API_PORT = int(os.getenv("API_PORT", "8000"))
48
+
49
+ # Logging settings
50
+ LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
51
+ LOG_FORMAT = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
52
+
53
+ # Default prompt template
54
+ DEFAULT_PROMPT_TEMPLATE = """
55
+ Answer the following question based ONLY on the provided context.
56
+ If you cannot answer the question based on the context, say "I don't have enough information to answer this question."
57
+
58
+ Context:
59
+ {context}
60
+
61
+ Question: {query}
62
+
63
+ Answer:
64
+ """
65
+
66
+
67
+ def get_logging_config() -> Dict[str, Any]:
68
+ """Get logging configuration dictionary."""
69
+ return {
70
+ "version": 1,
71
+ "disable_existing_loggers": False,
72
+ "formatters": {
73
+ "standard": {
74
+ "format": LOG_FORMAT
75
+ },
76
+ },
77
+ "handlers": {
78
+ "console": {
79
+ "class": "logging.StreamHandler",
80
+ "level": LOG_LEVEL,
81
+ "formatter": "standard",
82
+ "stream": "ext://sys.stdout"
83
+ },
84
+ },
85
+ "loggers": {
86
+ "": {
87
+ "handlers": ["console"],
88
+ "level": LOG_LEVEL,
89
+ "propagate": True
90
+ }
91
+ }
92
+ }
93
+
94
+
95
+ def get_model_config(model_name: Optional[str] = None) -> Dict[str, Any]:
96
+ """Get model-specific configuration."""
97
+ # Default to the configured model if none specified
98
+ if model_name is None:
99
+ model_name = EMBEDDING_MODEL_NAME
100
+
101
+ # Common configurations for popular models
102
+ config_map = {
103
+ "sentence-transformers/all-MiniLM-L6-v2": {
104
+ "dimension": 384,
105
+ "max_length": 512,
106
+ "normalize": True,
107
+ },
108
+ "sentence-transformers/all-mpnet-base-v2": {
109
+ "dimension": 768,
110
+ "max_length": 512,
111
+ "normalize": True,
112
+ },
113
+ # Add more models as needed
114
+ }
115
+
116
+ # Return specific config if available, otherwise return default values
117
+ return config_map.get(model_name, {
118
+ "dimension": EMBEDDING_DIMENSION,
119
+ "max_length": MAX_LENGTH,
120
+ "normalize": True,
121
+ })
document-processor.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Document processing utilities for text extraction and chunking.
3
+ """
4
+
5
+ import os
6
+ import logging
7
+ from typing import List, Dict, Any, Optional, Tuple, Union
8
+ import uuid
9
+
10
+ # Configure logging
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class DocumentProcessor:
15
+ """
16
+ Class to handle document processing, chunking, and text extraction.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ chunk_size: int = 1000,
22
+ chunk_overlap: int = 200
23
+ ):
24
+ """
25
+ Initialize the document processor.
26
+
27
+ Args:
28
+ chunk_size: Maximum size of text chunks in characters
29
+ chunk_overlap: Overlap between chunks in characters
30
+ """
31
+ self.chunk_size = chunk_size
32
+ self.chunk_overlap = chunk_overlap
33
+
34
+ def process_file(
35
+ self,
36
+ file_path: str,
37
+ metadata: Optional[Dict[str, Any]] = None
38
+ ) -> Tuple[List[str], List[Dict[str, Any]]]:
39
+ """
40
+ Process a document file: extract text and chunk it.
41
+
42
+ Args:
43
+ file_path: Path to the document file
44
+ metadata: Optional metadata about the document
45
+
46
+ Returns:
47
+ Tuple of (list of text chunks, list of metadata dictionaries)
48
+ """
49
+ if not os.path.exists(file_path):
50
+ raise FileNotFoundError(f"Document not found: {file_path}")
51
+
52
+ # Extract text from document
53
+ logger.info(f"Processing document: {file_path}")
54
+ text = self._extract_text(file_path)
55
+
56
+ if not text:
57
+ logger.warning(f"No text could be extracted from {file_path}")
58
+ return [], []
59
+
60
+ # Create base metadata if not provided
61
+ base_metadata = {"source": os.path.basename(file_path)}
62
+ if metadata:
63
+ base_metadata.update(metadata)
64
+
65
+ # Chunk the document
66
+ chunks = self._chunk_text(text, self.chunk_size, self.chunk_overlap)
67
+ logger.info(f"Created {len(chunks)} chunks from document")
68
+
69
+ # Create chunk-specific metadata
70
+ chunk_metadata = []
71
+ for i, _ in enumerate(chunks):
72
+ metadata_item = {
73
+ **base_metadata,
74
+ "chunk_id": i,
75
+ "total_chunks": len(chunks),
76
+ "document_id": str(uuid.uuid4()) # Unique ID for tracking
77
+ }
78
+ chunk_metadata.append(metadata_item)
79
+
80
+ return chunks, chunk_metadata
81
+
82
+ def _extract_text(self, file_path: str) -> str:
83
+ """
84
+ Extract text from a document file based on its extension.
85
+
86
+ Args:
87
+ file_path: Path to the document file
88
+
89
+ Returns:
90
+ Extracted text
91
+ """
92
+ _, ext = os.path.splitext(file_path)
93
+ ext = ext.lower()
94
+
95
+ if ext == '.pdf':
96
+ return self._extract_text_from_pdf(file_path)
97
+ elif ext == '.txt':
98
+ return self._extract_text_from_txt(file_path)
99
+ elif ext == '.md':
100
+ return self._extract_text_from_txt(file_path)
101
+ elif ext == '.docx':
102
+ return self._extract_text_from_docx(file_path)
103
+ else:
104
+ raise ValueError(f"Unsupported file format: {ext}")
105
+
106
+ def _extract_text_from_pdf(self, file_path: str) -> str:
107
+ """
108
+ Extract text from a PDF file.
109
+
110
+ Args:
111
+ file_path: Path to the PDF file
112
+
113
+ Returns:
114
+ Extracted text
115
+ """
116
+ try:
117
+ import PyPDF2
118
+ except ImportError:
119
+ raise ImportError(
120
+ "PyPDF2 is not installed. "
121
+ "Please install it with `pip install PyPDF2`."
122
+ )
123
+
124
+ text = ""
125
+ try:
126
+ with open(file_path, "rb") as f:
127
+ pdf_reader = PyPDF2.PdfReader(f)
128
+ num_pages = len(pdf_reader.pages)
129
+ logger.info(f"PDF has {num_pages} pages")
130
+
131
+ for page in pdf_reader.pages:
132
+ page_text = page.extract_text()
133
+ if page_text:
134
+ text += page_text + "\n\n"
135
+ except Exception as e:
136
+ logger.error(f"Error reading PDF file {file_path}: {e}")
137
+
138
+ logger.info(f"Extracted {len(text)} characters from PDF")
139
+ return text
140
+
141
+ def _extract_text_from_txt(self, file_path: str) -> str:
142
+ """
143
+ Extract text from a plain text file.
144
+
145
+ Args:
146
+ file_path: Path to the text file
147
+
148
+ Returns:
149
+ Extracted text
150
+ """
151
+ try:
152
+ with open(file_path, 'r', encoding='utf-8') as f:
153
+ text = f.read()
154
+
155
+ logger.info(f"Extracted {len(text)} characters from text file")
156
+ return text
157
+ except Exception as e:
158
+ logger.error(f"Error reading text file {file_path}: {e}")
159
+ return ""
160
+
161
+ def _extract_text_from_docx(self, file_path: str) -> str:
162
+ """
163
+ Extract text from a DOCX file.
164
+
165
+ Args:
166
+ file_path: Path to the DOCX file
167
+
168
+ Returns:
169
+ Extracted text
170
+ """
171
+ try:
172
+ import docx
173
+ except ImportError:
174
+ raise ImportError(
175
+ "python-docx is not installed. "
176
+ "Please install it with `pip install python-docx`."
177
+ )
178
+
179
+ try:
180
+ doc = docx.Document(file_path)
181
+ text = "\n\n".join([paragraph.text for paragraph in doc.paragraphs if paragraph.text])
182
+
183
+ logger.info(f"Extracted {len(text)} characters from DOCX")
184
+ return text
185
+ except Exception as e:
186
+ logger.error(f"Error reading DOCX file {file_path}: {e}")
187
+ return ""
188
+
189
+ @staticmethod
190
+ def _chunk_text(
191
+ text: str,
192
+ chunk_size: int = 1000,
193
+ overlap: int = 200
194
+ ) -> List[str]:
195
+ """
196
+ Split text into overlapping chunks.
197
+
198
+ Args:
199
+ text: The text to chunk
200
+ chunk_size: Maximum chunk size in characters
201
+ overlap: Overlap between chunks in characters
202
+
203
+ Returns:
204
+ List of text chunks
205
+ """
206
+ if not text or not text.strip():
207
+ return []
208
+
209
+ chunks = []
210
+ start = 0
211
+ text_len = len(text)
212
+
213
+ while start < text_len:
214
+ # Define the initial chunk end
215
+ end = min(start + chunk_size, text_len)
216
+
217
+ # Try to find a natural break point if not at the end of text
218
+ if end < text_len:
219
+ # Look for paragraph break
220
+ next_para = text.find('\n\n', end - overlap, end + 100)
221
+ if next_para != -1:
222
+ end = next_para + 2
223
+ else:
224
+ # Look for sentence break
225
+ for punct in ['. ', '! ', '? ', '.\n', '!\n', '?\n']:
226
+ next_sent = text.find(punct, end - overlap, end + 100)
227
+ if next_sent != -1:
228
+ end = next_sent + len(punct)
229
+ break
230
+
231
+ # Extract the chunk
232
+ chunk = text[start:end].strip()
233
+ if chunk: # Only add non-empty chunks
234
+ chunks.append(chunk)
235
+
236
+ # Move to next chunk with overlap
237
+ start = max(end - overlap, start + 1)
238
+
239
+ return chunks
embedding-model.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified embedding model implementation supporting multiple backends.
3
+ """
4
+
5
+ from typing import List, Union, Optional, Dict, Any
6
+ import logging
7
+ import numpy as np
8
+ import torch
9
+ from abc import ABC, abstractmethod
10
+
11
+ # Configure logging
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class EmbeddingModel(ABC):
16
+ """Abstract base class for embedding models."""
17
+
18
+ @abstractmethod
19
+ def embed(self, texts: Union[str, List[str]], batch_size: int = 32) -> np.ndarray:
20
+ """
21
+ Convert text(s) to embedding vector(s).
22
+
23
+ Args:
24
+ texts: Input text(s) to embed
25
+ batch_size: Batch size for processing
26
+
27
+ Returns:
28
+ Embedding vector(s) as numpy array
29
+ """
30
+ pass
31
+
32
+ @property
33
+ @abstractmethod
34
+ def dimension(self) -> int:
35
+ """Get the dimension of the embedding vectors."""
36
+ pass
37
+
38
+
39
+ class SentenceTransformerEmbedding(EmbeddingModel):
40
+ """Embedding model using sentence-transformers library."""
41
+
42
+ def __init__(
43
+ self,
44
+ model_name: str = "all-MiniLM-L6-v2",
45
+ device: Optional[str] = None,
46
+ normalize: bool = True,
47
+ **kwargs
48
+ ):
49
+ """
50
+ Initialize the sentence transformer embedding model.
51
+
52
+ Args:
53
+ model_name: Sentence transformer model name or path
54
+ device: Device to run model on ('cpu', 'cuda', 'cuda:0', etc.)
55
+ normalize: Whether to L2-normalize embeddings
56
+ **kwargs: Additional arguments for the model
57
+ """
58
+ try:
59
+ from sentence_transformers import SentenceTransformer
60
+ except ImportError:
61
+ raise ImportError(
62
+ "sentence-transformers is not installed. "
63
+ "Please install it with `pip install sentence-transformers`."
64
+ )
65
+
66
+ self.model_name = model_name
67
+ self.normalize = normalize
68
+
69
+ # Determine device
70
+ if device is None:
71
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
72
+ else:
73
+ self.device = device
74
+
75
+ logger.info(f"Loading SentenceTransformer model: {model_name} on {self.device}")
76
+ try:
77
+ self.model = SentenceTransformer(model_name, device=self.device)
78
+ self._dimension = self.model.get_sentence_embedding_dimension()
79
+ logger.info(f"Model loaded successfully. Embedding dimension: {self._dimension}")
80
+ except Exception as e:
81
+ logger.error(f"Failed to load model: {e}")
82
+ raise
83
+
84
+ def embed(self, texts: Union[str, List[str]], batch_size: int = 32) -> np.ndarray:
85
+ """
86
+ Convert text(s) to embedding vector(s).
87
+
88
+ Args:
89
+ texts: Input text(s) to embed
90
+ batch_size: Batch size for processing
91
+
92
+ Returns:
93
+ Embedding vector(s) as numpy array
94
+ """
95
+ # Handle single text input
96
+ if isinstance(texts, str):
97
+ texts = [texts]
98
+
99
+ # Validate input
100
+ if not texts:
101
+ logger.warning("Empty texts provided for embedding")
102
+ return np.array([])
103
+
104
+ try:
105
+ # Generate embeddings
106
+ embeddings = self.model.encode(
107
+ texts,
108
+ batch_size=batch_size,
109
+ show_progress_bar=False,
110
+ convert_to_numpy=True
111
+ )
112
+
113
+ # Normalize if requested
114
+ if self.normalize:
115
+ embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
116
+
117
+ return embeddings
118
+ except Exception as e:
119
+ logger.error(f"Error during embedding generation: {e}")
120
+ raise
121
+
122
+ @property
123
+ def dimension(self) -> int:
124
+ """Get the dimension of the embedding vectors."""
125
+ return self._dimension
126
+
127
+
128
+ class HuggingFaceEmbedding(EmbeddingModel):
129
+ """Embedding model using HuggingFace transformers directly."""
130
+
131
+ def __init__(
132
+ self,
133
+ model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
134
+ device: Optional[str] = None,
135
+ normalize: bool = True,
136
+ max_length: int = 512,
137
+ **kwargs
138
+ ):
139
+ """
140
+ Initialize the HuggingFace embedding model.
141
+
142
+ Args:
143
+ model_name: HuggingFace model name or path
144
+ device: Device to run model on ('cpu', 'cuda', 'cuda:0', etc.)
145
+ normalize: Whether to L2-normalize embeddings
146
+ max_length: Maximum token length for inputs
147
+ **kwargs: Additional arguments for the model
148
+ """
149
+ try:
150
+ from transformers import AutoTokenizer, AutoModel
151
+ except ImportError:
152
+ raise ImportError(
153
+ "transformers is not installed. "
154
+ "Please install it with `pip install transformers`."
155
+ )
156
+
157
+ self.model_name = model_name
158
+ self.normalize = normalize
159
+ self.max_length = max_length
160
+
161
+ # Determine device
162
+ if device is None:
163
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
164
+ else:
165
+ self.device = device
166
+
167
+ logger.info(f"Loading HuggingFace model: {model_name} on {self.device}")
168
+ try:
169
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
170
+ self.model = AutoModel.from_pretrained(model_name)
171
+ self.model.to(self.device)
172
+ self.model.eval()
173
+
174
+ # Get embedding dimension from model config
175
+ self._dimension = self.model.config.hidden_size
176
+ logger.info(f"Model loaded successfully. Embedding dimension: {self._dimension}")
177
+ except Exception as e:
178
+ logger.error(f"Failed to load model: {e}")
179
+ raise
180
+
181
+ def _mean_pooling(self, model_output, attention_mask):
182
+ """Perform mean pooling on token embeddings."""
183
+ token_embeddings = model_output.last_hidden_state
184
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
185
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
186
+
187
+ def embed(self, texts: Union[str, List[str]], batch_size: int = 32) -> np.ndarray:
188
+ """
189
+ Convert text(s) to embedding vector(s).
190
+
191
+ Args:
192
+ texts: Input text(s) to embed
193
+ batch_size: Batch size for processing
194
+
195
+ Returns:
196
+ Embedding vector(s) as numpy array
197
+ """
198
+ # Handle single text input
199
+ if isinstance(texts, str):
200
+ texts = [texts]
201
+
202
+ # Validate input
203
+ if not texts:
204
+ logger.warning("Empty texts provided for embedding")
205
+ return np.array([])
206
+
207
+ try:
208
+ all_embeddings = []
209
+
210
+ # Process in batches
211
+ for i in range(0, len(texts), batch_size):
212
+ batch_texts = texts[i:i+batch_size]
213
+
214
+ # Tokenize and move to device
215
+ inputs = self.tokenizer(
216
+ batch_texts,
217
+ padding=True,
218
+ truncation=True,
219
+ max_length=self.max_length,
220
+ return_tensors="pt"
221
+ )
222
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
223
+
224
+ # Generate embeddings
225
+ with torch.no_grad():
226
+ outputs = self.model(**inputs)
227
+ embeddings = self._mean_pooling(outputs, inputs["attention_mask"])
228
+
229
+ # Normalize if requested
230
+ if self.normalize:
231
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
232
+
233
+ # Move to CPU and convert to numpy
234
+ embeddings = embeddings.cpu().numpy()
235
+ all_embeddings.append(embeddings)
236
+
237
+ # Concatenate all batches
238
+ return np.vstack(all_embeddings) if all_embeddings else np.array([])
239
+
240
+ except Exception as e:
241
+ logger.error(f"Error during embedding generation: {e}")
242
+ raise
243
+
244
+ @property
245
+ def dimension(self) -> int:
246
+ """Get the dimension of the embedding vectors."""
247
+ return self._dimension
248
+
249
+
250
+ # Factory function to create embedding models
251
+ def create_embedding_model(
252
+ backend: str = "sentence-transformers",
253
+ model_name: Optional[str] = None,
254
+ **kwargs
255
+ ) -> EmbeddingModel:
256
+ """
257
+ Factory function to create an embedding model.
258
+
259
+ Args:
260
+ backend: Backend to use ('sentence-transformers' or 'huggingface')
261
+ model_name: Model name or path
262
+ **kwargs: Additional arguments for the model
263
+
264
+ Returns:
265
+ An EmbeddingModel instance
266
+ """
267
+ from ..config import EMBEDDING_MODEL_NAME, get_model_config
268
+
269
+ # Use config model if not specified
270
+ if model_name is None:
271
+ model_name = EMBEDDING_MODEL_NAME
272
+
273
+ # Get model-specific config
274
+ model_config = get_model_config(model_name)
275
+
276
+ # Override with provided kwargs
277
+ for k, v in kwargs.items():
278
+ model_config[k] = v
279
+
280
+ # Create the model
281
+ if backend.lower() == "sentence-transformers":
282
+ return SentenceTransformerEmbedding(model_name=model_name, **model_config)
283
+ elif backend.lower() in ["huggingface", "hf", "transformers"]:
284
+ return HuggingFaceEmbedding(model_name=model_name, **model_config)
285
+ else:
286
+ raise ValueError(f"Unsupported backend: {backend}")
package-init.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # embedding/__init__.py
2
+ from .model import create_embedding_model, EmbeddingModel, SentenceTransformerEmbedding, HuggingFaceEmbedding
3
+
4
+ # storage/__init__.py
5
+ from .vector_db import Document, VectorDatabase, FaissVectorDatabase, create_vector_database
6
+
7
+ # document/__init__.py
8
+ from .processor import DocumentProcessor
9
+
10
+ # retrieval/__init__.py
11
+ # Import relevant classes if needed
12
+
13
+ # rag/__init__.py
14
+ from .engine import RAGEngine, create_rag_engine
15
+
16
+ # api/__init__.py
17
+ from .routes import RAGAPIRouter
18
+
19
+ # ui/__init__.py
20
+ # No exports needed
21
+
22
+ # utils/__init__.py
23
+ # Import utility functions if needed
rag-engine.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main RAG (Retrieval-Augmented Generation) engine implementation.
3
+ """
4
+
5
+ import os
6
+ import logging
7
+ from typing import List, Dict, Any, Optional, Tuple, Union
8
+ import numpy as np
9
+
10
+ # Configure logging
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class RAGEngine:
15
+ """Retrieval-Augmented Generation (RAG) engine for question answering."""
16
+
17
+ def __init__(
18
+ self,
19
+ embedder,
20
+ vector_db,
21
+ llm=None,
22
+ top_k: int = 5,
23
+ search_type: str = "hybrid",
24
+ prompt_template: Optional[str] = None
25
+ ):
26
+ """
27
+ Initialize the RAG engine.
28
+
29
+ Args:
30
+ embedder: Embedding model
31
+ vector_db: Vector database for document storage and retrieval
32
+ llm: Language model for text generation (optional)
33
+ top_k: Number of documents to retrieve
34
+ search_type: Type of search ('semantic', 'keyword', 'hybrid')
35
+ prompt_template: Optional custom prompt template
36
+ """
37
+ self.embedder = embedder
38
+ self.vector_db = vector_db
39
+ self.llm = llm
40
+ self.top_k = top_k
41
+ self.search_type = search_type
42
+
43
+ # Set default prompt template if none provided
44
+ if prompt_template is None:
45
+ from ..config import DEFAULT_PROMPT_TEMPLATE
46
+ self.prompt_template = DEFAULT_PROMPT_TEMPLATE
47
+ else:
48
+ self.prompt_template = prompt_template
49
+
50
+ def add_documents(
51
+ self,
52
+ texts: List[str],
53
+ metadata: Optional[List[Dict[str, Any]]] = None,
54
+ batch_size: int = 32
55
+ ) -> List[str]:
56
+ """
57
+ Add documents to the database.
58
+
59
+ Args:
60
+ texts: List of text chunks
61
+ metadata: Optional list of metadata dictionaries for each text
62
+ batch_size: Batch size for embedding generation
63
+
64
+ Returns:
65
+ List of document IDs
66
+ """
67
+ from ..storage.vector_db import Document
68
+
69
+ # Handle metadata
70
+ if metadata is None:
71
+ metadata = [{} for _ in texts]
72
+ elif len(metadata) != len(texts):
73
+ raise ValueError(f"Length mismatch: got {len(texts)} texts but {len(metadata)} metadata entries")
74
+
75
+ # Generate embeddings in batches
76
+ doc_ids = []
77
+
78
+ for i in range(0, len(texts), batch_size):
79
+ batch_texts = texts[i:i+batch_size]
80
+ batch_metadata = metadata[i:i+batch_size]
81
+
82
+ # Generate embeddings
83
+ logger.info(f"Generating embeddings for batch {i//batch_size + 1}/{(len(texts)-1)//batch_size + 1}")
84
+ batch_embeddings = self.embedder.embed(batch_texts)
85
+
86
+ # Create document objects
87
+ documents = []
88
+ for text, meta, embedding in zip(batch_texts, batch_metadata, batch_embeddings):
89
+ doc = Document(text=text, metadata=meta, embedding=embedding)
90
+ documents.append(doc)
91
+
92
+ # Add to database
93
+ batch_ids = self.vector_db.add_documents(documents)
94
+ doc_ids.extend(batch_ids)
95
+
96
+ logger.info(f"Added {len(doc_ids)} documents to database")
97
+ return doc_ids
98
+
99
+ def search(
100
+ self,
101
+ query: str,
102
+ top_k: Optional[int] = None,
103
+ search_type: Optional[str] = None,
104
+ filter_dict: Optional[Dict[str, Any]] = None
105
+ ) -> List[Dict[str, Any]]:
106
+ """
107
+ Search for relevant documents.
108
+
109
+ Args:
110
+ query: Query string
111
+ top_k: Number of results to return (defaults to self.top_k)
112
+ search_type: Type of search (defaults to self.search_type)
113
+ filter_dict: Dictionary of metadata filters
114
+
115
+ Returns:
116
+ List of document dictionaries
117
+ """
118
+ if top_k is None:
119
+ top_k = self.top_k
120
+
121
+ if search_type is None:
122
+ search_type = self.search_type
123
+
124
+ # Create filter function if filter_dict is provided
125
+ filter_func = None
126
+ if filter_dict:
127
+ def filter_func(doc):
128
+ for key, value in filter_dict.items():
129
+ # Handle nested keys (e.g., "metadata.source")
130
+ if "." in key:
131
+ parts = key.split(".")
132
+ current = doc.metadata
133
+ for part in parts[:-1]:
134
+ if part not in current:
135
+ return False
136
+ current = current[part]
137
+ if parts[-1] not in current or current[parts[-1]] != value:
138
+ return False
139
+ elif key not in doc.metadata or doc.metadata[key] != value:
140
+ return False
141
+ return True
142
+
143
+ # Generate query embedding
144
+ query_embedding = self.embedder.embed(query)
145
+
146
+ # Perform search
147
+ results = self.vector_db.search(query_embedding, top_k, filter_func)
148
+
149
+ # Convert results to dictionaries
150
+ return [
151
+ {
152
+ "id": doc.id,
153
+ "text": doc.text,
154
+ "metadata": doc.metadata,
155
+ "score": score
156
+ }
157
+ for doc, score in results
158
+ ]
159
+
160
+ def generate_response(
161
+ self,
162
+ query: str,
163
+ top_k: Optional[int] = None,
164
+ search_type: Optional[str] = None,
165
+ filter_dict: Optional[Dict[str, Any]] = None,
166
+ max_tokens: int = 512
167
+ ) -> Dict[str, Any]:
168
+ """
169
+ Generate a response to a query using RAG.
170
+
171
+ Args:
172
+ query: Query string
173
+ top_k: Number of documents to retrieve
174
+ search_type: Type of search
175
+ filter_dict: Optional filter for document retrieval
176
+ max_tokens: Maximum number of tokens in the response
177
+
178
+ Returns:
179
+ Dictionary with query, response, and retrieved documents
180
+ """
181
+ # Retrieve relevant documents
182
+ retrieved_docs = self.search(query, top_k, search_type, filter_dict)
183
+
184
+ # If no documents were found, return a default message
185
+ if not retrieved_docs:
186
+ return {
187
+ "query": query,
188
+ "response": "I couldn't find any relevant information to answer your question.",
189
+ "retrieved_documents": [],
190
+ "search_type": search_type or self.search_type
191
+ }
192
+
193
+ # Format context from retrieved documents
194
+ context = self._format_context(retrieved_docs)
195
+
196
+ # Format prompt with context and query
197
+ prompt = self.prompt_template.format(context=context, query=query)
198
+
199
+ # Generate response using LLM
200
+ if self.llm is None:
201
+ logger.warning("No LLM provided, returning only retrieved documents")
202
+ response = "No language model available to generate a response. Here's what I found in the documents."
203
+ else:
204
+ response = self._generate_llm_response(prompt, max_tokens)
205
+
206
+ # Return the results
207
+ return {
208
+ "query": query,
209
+ "response": response,
210
+ "retrieved_documents": retrieved_docs,
211
+ "search_type": search_type or self.search_type
212
+ }
213
+
214
+ def _format_context(self, documents: List[Dict[str, Any]]) -> str:
215
+ """
216
+ Format retrieved documents into context for the prompt.
217
+
218
+ Args:
219
+ documents: List of retrieved documents
220
+
221
+ Returns:
222
+ Formatted context string
223
+ """
224
+ context_parts = []
225
+
226
+ for i, doc in enumerate(documents):
227
+ # Extract relevant fields
228
+ text = doc["text"]
229
+ metadata = doc["metadata"]
230
+ source = metadata.get("source", "Unknown")
231
+
232
+ # Format the document
233
+ doc_text = f"Document {i+1}: [Source: {source}]\n{text}\n"
234
+ context_parts.append(doc_text)
235
+
236
+ return "\n".join(context_parts)
237
+
238
+ def _generate_llm_response(self, prompt: str, max_tokens: int) -> str:
239
+ """
240
+ Generate a response using the LLM.
241
+
242
+ Args:
243
+ prompt: The formatted prompt
244
+ max_tokens: Maximum number of tokens in the response
245
+
246
+ Returns:
247
+ Generated response
248
+ """
249
+ if hasattr(self.llm, "generate_openai_response"):
250
+ # OpenAI-compatible LLM
251
+ return self.llm.generate_openai_response(prompt, max_tokens)
252
+ elif hasattr(self.llm, "generate_huggingface_response"):
253
+ # HuggingFace-compatible LLM
254
+ return self.llm.generate_huggingface_response(prompt, max_tokens)
255
+ else:
256
+ # Default implementation
257
+ try:
258
+ return self.llm.generate_response(prompt, max_tokens)
259
+ except Exception as e:
260
+ logger.error(f"Error generating response: {e}")
261
+ return "I encountered an error while generating a response."
262
+
263
+ def update_prompt_template(self, new_template: str) -> None:
264
+ """
265
+ Update the prompt template.
266
+
267
+ Args:
268
+ new_template: New prompt template
269
+ """
270
+ self.prompt_template = new_template
271
+ logger.info("Updated prompt template")
272
+
273
+ def count_documents(self) -> int:
274
+ """
275
+ Get the number of documents in the database.
276
+
277
+ Returns:
278
+ Number of documents
279
+ """
280
+ return self.vector_db.count_documents()
281
+
282
+ def clear_documents(self) -> None:
283
+ """Clear all documents from the database."""
284
+ self.vector_db.clear()
285
+ logger.info("Cleared all documents from database")
286
+
287
+
288
+ # Factory function to create the RAG engine
289
+ def create_rag_engine(
290
+ embedder=None,
291
+ vector_db=None,
292
+ llm=None,
293
+ config=None
294
+ ) -> RAGEngine:
295
+ """
296
+ Factory function to create a RAG engine.
297
+
298
+ Args:
299
+ embedder: Embedding model (if None, created based on config)
300
+ vector_db: Vector database (if None, created based on config)
301
+ llm: Language model (if None, created based on config)
302
+ config: Configuration module or dictionary
303
+
304
+ Returns:
305
+ Configured RAGEngine instance
306
+ """
307
+ # Load configuration if provided
308
+ if config is None:
309
+ from ..config import (
310
+ TOP_K,
311
+ SEARCH_TYPE,
312
+ DEFAULT_PROMPT_TEMPLATE
313
+ )
314
+ else:
315
+ TOP_K = config.get("TOP_K", 5)
316
+ SEARCH_TYPE = config.get("SEARCH_TYPE", "hybrid")
317
+ DEFAULT_PROMPT_TEMPLATE = config.get(
318
+ "DEFAULT_PROMPT_TEMPLATE",
319
+ """
320
+ Answer the following question based ONLY on the provided context.
321
+
322
+ Context:
323
+ {context}
324
+
325
+ Question: {query}
326
+
327
+ Answer:
328
+ """
329
+ )
330
+
331
+ # Create embedding model if not provided
332
+ if embedder is None:
333
+ from ..embedding.model import create_embedding_model
334
+ embedder = create_embedding_model()
335
+
336
+ # Create vector database if not provided
337
+ if vector_db is None:
338
+ from ..storage.vector_db import create_vector_database
339
+ vector_db = create_vector_database(dimension=embedder.dimension)
340
+
341
+ # Create language model if not provided and requested
342
+ if llm is None:
343
+ try:
344
+ from ..llm.model import create_llm
345
+ llm = create_llm()
346
+ except (ImportError, ModuleNotFoundError):
347
+ logger.warning("LLM module not found, proceeding without an LLM")
348
+
349
+ # Create and return the RAG engine
350
+ return RAGEngine(
351
+ embedder=embedder,
352
+ vector_db=vector_db,
353
+ llm=llm,
354
+ top_k=TOP_K,
355
+ search_type=SEARCH_TYPE,
356
+ prompt_template=DEFAULT_PROMPT_TEMPLATE
357
+ )
readme.md ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RAG System
2
+
3
+ A modular Retrieval-Augmented Generation (RAG) system for document-based question answering.
4
+
5
+ ## Features
6
+
7
+ - **Document Processing**: Extract and chunk text from PDF, DOCX, and TXT files
8
+ - **Semantic Search**: Embed and search documents based on meaning, not just keywords
9
+ - **Flexible Architecture**: Support for multiple embedding models and vector databases
10
+ - **REST API**: API for integrating with other applications
11
+ - **Web UI**: User-friendly Streamlit interface for document upload and querying
12
+
13
+ ## Architecture
14
+
15
+ The system consists of the following components:
16
+
17
+ - **Embedding Model**: Converts text to vector embeddings
18
+ - **Vector Database**: Stores and searches document embeddings
19
+ - **Document Processor**: Extracts and chunks text from documents
20
+ - **RAG Engine**: Combines retrieval and generation for question answering
21
+ - **API**: Exposes functionality through a RESTful API
22
+ - **UI**: Provides a user interface for interacting with the system
23
+
24
+ ## Installation
25
+
26
+ ### Prerequisites
27
+
28
+ - Python 3.8+
29
+ - pip
30
+
31
+ ### Setup
32
+
33
+ 1. Clone the repository:
34
+ ```bash
35
+ git clone https://github.com/yourusername/rag-system.git
36
+ cd rag-system
37
+ ```
38
+
39
+ 2. Install dependencies:
40
+ ```bash
41
+ pip install -r requirements.txt
42
+ ```
43
+
44
+ 3. Set up environment variables (optional):
45
+ ```bash
46
+ cp .env.example .env
47
+ # Edit .env with your settings
48
+ ```
49
+
50
+ ## Usage
51
+
52
+ ### API Server
53
+
54
+ Run the API server:
55
+
56
+ ```bash
57
+ python app.py
58
+ ```
59
+
60
+ The API will be available at http://localhost:8000
61
+
62
+ ### Streamlit UI
63
+
64
+ Run the Streamlit UI:
65
+
66
+ ```bash
67
+ streamlit run ui/app.py
68
+ ```
69
+
70
+ The UI will be available at http://localhost:8501
71
+
72
+ ## API Endpoints
73
+
74
+ - `POST /documents`: Add documents
75
+ - `POST /upload`: Upload and process document files
76
+ - `POST /query`: Query the RAG system
77
+ - `GET /search`: Search for documents
78
+ - `DELETE /documents`: Clear all documents
79
+ - `GET /health`: Check system health
80
+
81
+ ## Configuration
82
+
83
+ The system can be configured through environment variables or the `config.py` file:
84
+
85
+ - `EMBEDDING_MODEL_NAME`: Name of the embedding model
86
+ - `VECTOR_DB_TYPE`: Type of vector database to use
87
+ - `CHUNK_SIZE`: Size of document chunks
88
+ - `CHUNK_OVERLAP`: Overlap between chunks
89
+ - `TOP_K`: Number of documents to retrieve
90
+ - `SEARCH_TYPE`: Type of search (semantic, keyword, hybrid)
91
+ - `LLM_MODEL_NAME`: Name of the language model for generation
92
+ - `LLM_API_KEY`: API key for the language model
93
+
94
+ ## Extending
95
+
96
+ The modular architecture makes it easy to extend the system:
97
+
98
+ - Add new embedding models in `embedding/model.py`
99
+ - Add new vector databases in `storage/vector_db.py`
100
+ - Add support for new document types in `document/processor.py`
101
+ - Add new LLM integrations in `llm/model.py`
102
+
103
+ ## License
104
+
105
+ [MIT License](LICENSE)
streamlit-app.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Streamlit UI for the RAG system.
3
+ """
4
+
5
+ import os
6
+ import streamlit as st
7
+ import tempfile
8
+ import logging
9
+ from dotenv import load_dotenv
10
+
11
+ # Load environment variables
12
+ load_dotenv()
13
+
14
+ # Configure logging
15
+ from config import get_logging_config
16
+ import logging.config
17
+ logging.config.dictConfig(get_logging_config())
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Set page config
21
+ st.set_page_config(
22
+ page_title="RAG Document QA System",
23
+ page_icon="📚",
24
+ layout="wide",
25
+ initial_sidebar_state="expanded"
26
+ )
27
+
28
+ # Initialize session state
29
+ if "document_count" not in st.session_state:
30
+ st.session_state.document_count = 0
31
+ if "initialized" not in st.session_state:
32
+ st.session_state.initialized = False
33
+
34
+ # Initialize RAG engine
35
+ @st.cache_resource
36
+ def initialize_rag_engine():
37
+ """Initialize RAG engine."""
38
+ from embedding.model import create_embedding_model
39
+ from storage.vector_db import create_vector_database
40
+ from rag.engine import create_rag_engine
41
+
42
+ # Create components
43
+ embedding_model = create_embedding_model()
44
+ vector_db = create_vector_database(dimension=embedding_model.dimension)
45
+ rag_engine = create_rag_engine(
46
+ embedder=embedding_model,
47
+ vector_db=vector_db
48
+ )
49
+
50
+ st.session_state.initialized = True
51
+ return rag_engine
52
+
53
+ # Initialize document processor
54
+ @st.cache_resource
55
+ def initialize_document_processor():
56
+ """Initialize document processor."""
57
+ from document.processor import DocumentProcessor
58
+ return DocumentProcessor()
59
+
60
+ # Main application
61
+ def main():
62
+ """Main Streamlit application."""
63
+ # Initialize components
64
+ rag_engine = initialize_rag_engine()
65
+ doc_processor = initialize_document_processor()
66
+
67
+ # Update document count
68
+ st.session_state.document_count = rag_engine.count_documents()
69
+
70
+ # Sidebar
71
+ st.sidebar.title("📚 RAG Document QA")
72
+
73
+ # Document upload
74
+ st.sidebar.header("Upload Documents")
75
+ uploaded_file = st.sidebar.file_uploader(
76
+ "Choose a document file (PDF, TXT, DOCX)",
77
+ type=["pdf", "txt", "md", "docx"]
78
+ )
79
+
80
+ # Upload settings
81
+ st.sidebar.subheader("Document Settings")
82
+ chunk_size = st.sidebar.slider(
83
+ "Chunk Size",
84
+ min_value=100,
85
+ max_value=2000,
86
+ value=1000,
87
+ step=100,
88
+ help="Size of text chunks in characters"
89
+ )
90
+ chunk_overlap = st.sidebar.slider(
91
+ "Chunk Overlap",
92
+ min_value=0,
93
+ max_value=500,
94
+ value=200,
95
+ step=50,
96
+ help="Overlap between chunks in characters"
97
+ )
98
+
99
+ # Search settings
100
+ st.sidebar.header("Search Settings")
101
+ top_k = st.sidebar.slider(
102
+ "Results to Return",
103
+ min_value=1,
104
+ max_value=10,
105
+ value=3,
106
+ help="Number of document chunks to retrieve"
107
+ )
108
+ search_type = st.sidebar.selectbox(
109
+ "Search Type",
110
+ options=["hybrid", "semantic", "keyword"],
111
+ index=0,
112
+ help="Type of search to perform"
113
+ )
114
+
115
+ # Document info
116
+ st.sidebar.header("Document Store")
117
+ st.sidebar.metric("Documents Stored", st.session_state.document_count)
118
+
119
+ if st.sidebar.button("Clear All Documents"):
120
+ rag_engine.clear_documents()
121
+ st.session_state.document_count = 0
122
+ st.sidebar.success("Document store cleared!")
123
+ st.experimental_rerun()
124
+
125
+ # Process uploaded file
126
+ if uploaded_file is not None:
127
+ with st.sidebar.expander("Upload Status", expanded=True):
128
+ with st.spinner('Processing document...'):
129
+ # Save to temporary file
130
+ with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
131
+ tmp_file.write(uploaded_file.getvalue())
132
+ tmp_file_path = tmp_file.name
133
+
134
+ try:
135
+ # Process document
136
+ doc_processor.chunk_size = chunk_size
137
+ doc_processor.chunk_overlap = chunk_overlap
138
+
139
+ chunks, chunk_metadata = doc_processor.process_file(
140
+ tmp_file_path,
141
+ metadata={"filename": uploaded_file.name, "source": "UI upload"}
142
+ )
143
+
144
+ if not chunks:
145
+ st.sidebar.error("No text could be extracted from the document.")
146
+ else:
147
+ # Add chunks to RAG engine
148
+ doc_ids = rag_engine.add_documents(chunks, chunk_metadata)
149
+
150
+ # Update document count
151
+ st.session_state.document_count = rag_engine.count_documents()
152
+
153
+ st.sidebar.success(f"Added {len(chunks)} document chunks!")
154
+ except Exception as e:
155
+ st.sidebar.error(f"Error processing document: {str(e)}")
156
+ finally:
157
+ # Clean up temporary file
158
+ os.unlink(tmp_file_path)
159
+
160
+ # Main content
161
+ st.title("📚 Document Query System")
162
+
163
+ if st.session_state.document_count == 0:
164
+ st.info("👈 Please upload documents using the sidebar to get started.")
165
+
166
+ # Sample documents
167
+ st.subheader("Sample Text")
168
+ sample_text = st.text_area(
169
+ "Or try adding some sample text directly:",
170
+ height=200
171
+ )
172
+
173
+ if sample_text and st.button("Add Sample Text"):
174
+ with st.spinner('Processing text...'):
175
+ # Chunk the text
176
+ chunks = doc_processor._chunk_text(sample_text, chunk_size, chunk_overlap)
177
+
178
+ # Create metadata
179
+ chunk_metadata = [
180
+ {"source": "Sample text", "chunk_id": i, "total_chunks": len(chunks)}
181
+ for i in range(len(chunks))
182
+ ]
183
+
184
+ # Add to RAG engine
185
+ doc_ids = rag_engine.add_documents(chunks, chunk_metadata)
186
+
187
+ # Update document count
188
+ st.session_state.document_count = rag_engine.count_documents()
189
+
190
+ st.success(f"Added {len(chunks)} text chunks!")
191
+ st.experimental_rerun()
192
+ else:
193
+ # Question answering
194
+ st.subheader("Ask a Question")
195
+ question = st.text_input("Enter your question:")
196
+
197
+ if question:
198
+ with st.spinner('Searching for answer...'):
199
+ try:
200
+ # Generate response
201
+ result = rag_engine.generate_response(
202
+ query=question,
203
+ top_k=top_k,
204
+ search_type=search_type
205
+ )
206
+
207
+ # Display response
208
+ st.markdown("### Answer")
209
+ st.write(result["response"])
210
+
211
+ # Display sources
212
+ st.markdown("### Sources")
213
+ for i, doc in enumerate(result["retrieved_documents"]):
214
+ with st.expander(f"Source {i+1} (Score: {doc['score']:.2f})"):
215
+ st.markdown(f"**Source:** {doc['metadata'].get('source', 'Unknown')}")
216
+ st.text(doc["text"])
217
+ except Exception as e:
218
+ st.error(f"Error generating response: {str(e)}")
219
+
220
+ # About section
221
+ st.sidebar.markdown("---")
222
+ st.sidebar.info(
223
+ "This application allows you to upload documents and ask questions about their content. "
224
+ "The system uses embedding models for semantic search and retrieval."
225
+ )
226
+
227
+ # Run the application
228
+ if __name__ == "__main__":
229
+ main()
vector-db.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Vector database implementation for document storage and retrieval.
3
+ """
4
+
5
+ from typing import List, Dict, Any, Optional, Union, Tuple, Callable
6
+ import logging
7
+ import os
8
+ import json
9
+ import uuid
10
+ import numpy as np
11
+ from dataclasses import dataclass, field, asdict
12
+
13
+ # Configure logging
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ @dataclass
18
+ class Document:
19
+ """Class to represent a document or text chunk with metadata and embeddings."""
20
+ text: str
21
+ metadata: Dict[str, Any] = field(default_factory=dict)
22
+ embedding: Optional[np.ndarray] = None
23
+ id: str = field(default_factory=lambda: str(uuid.uuid4()))
24
+
25
+ def to_dict(self) -> Dict[str, Any]:
26
+ """Convert to dictionary for serialization."""
27
+ result = asdict(self)
28
+ # Convert numpy array to list for JSON serialization
29
+ if self.embedding is not None:
30
+ result['embedding'] = self.embedding.tolist()
31
+ return result
32
+
33
+ @classmethod
34
+ def from_dict(cls, data: Dict[str, Any]) -> 'Document':
35
+ """Create Document from dictionary."""
36
+ if 'embedding' in data and data['embedding'] is not None:
37
+ data['embedding'] = np.array(data['embedding'], dtype=np.float32)
38
+ return cls(**data)
39
+
40
+
41
+ class VectorDatabase:
42
+ """Base class for vector databases."""
43
+
44
+ def __init__(self, dimension: int = 384):
45
+ """
46
+ Initialize the vector database.
47
+
48
+ Args:
49
+ dimension: Dimension of the embedding vectors
50
+ """
51
+ self.dimension = dimension
52
+
53
+ def add_document(self, document: Document) -> str:
54
+ """
55
+ Add a document to the database.
56
+
57
+ Args:
58
+ document: Document to add
59
+
60
+ Returns:
61
+ Document ID
62
+ """
63
+ raise NotImplementedError("Subclasses must implement add_document")
64
+
65
+ def add_documents(self, documents: List[Document]) -> List[str]:
66
+ """
67
+ Add multiple documents to the database.
68
+
69
+ Args:
70
+ documents: List of documents to add
71
+
72
+ Returns:
73
+ List of document IDs
74
+ """
75
+ return [self.add_document(doc) for doc in documents]
76
+
77
+ def search(
78
+ self,
79
+ query_embedding: np.ndarray,
80
+ top_k: int = 5,
81
+ filter_func: Optional[Callable[[Document], bool]] = None
82
+ ) -> List[Tuple[Document, float]]:
83
+ """
84
+ Search for similar documents.
85
+
86
+ Args:
87
+ query_embedding: Query embedding vector
88
+ top_k: Number of results to return
89
+ filter_func: Optional function to filter results
90
+
91
+ Returns:
92
+ List of (document, score) tuples
93
+ """
94
+ raise NotImplementedError("Subclasses must implement search")
95
+
96
+ def delete_document(self, doc_id: str) -> bool:
97
+ """
98
+ Delete a document from the database.
99
+
100
+ Note: FAISS doesn't support direct deletion, so we handle this
101
+ by rebuilding the index when needed.
102
+
103
+ Args:
104
+ doc_id: Document ID to delete
105
+
106
+ Returns:
107
+ True if document was found and deleted
108
+ """
109
+ if doc_id not in self.documents:
110
+ return False
111
+
112
+ # Remove from documents dictionary
113
+ del self.documents[doc_id]
114
+
115
+ # If document was in index, mark for rebuild
116
+ if doc_id in self.id_to_index:
117
+ # Remove from mappings
118
+ del self.id_to_index[doc_id]
119
+ # We'll rebuild the index on the next query
120
+
121
+ return True
122
+
123
+ def _rebuild_index(self):
124
+ """Rebuild the FAISS index from scratch."""
125
+ # Re-initialize the index
126
+ self._initialize_index()
127
+ self.id_to_index = {}
128
+ self.index_to_id = {}
129
+
130
+ # Collect all documents with embeddings
131
+ docs_with_embeddings = [doc for doc in self.documents.values() if doc.embedding is not None]
132
+
133
+ if not docs_with_embeddings:
134
+ logger.warning("No documents with embeddings to rebuild index")
135
+ return
136
+
137
+ # Extract embeddings
138
+ embeddings = np.array([doc.embedding for doc in docs_with_embeddings], dtype=np.float32)
139
+
140
+ # Train if needed
141
+ if self.needs_training and len(docs_with_embeddings) >= 100:
142
+ logger.info("Training FAISS index during rebuild")
143
+ train_data = embeddings[:min(1000, len(embeddings))]
144
+ self.index.train(train_data)
145
+
146
+ # Add to index if trained or doesn't need training
147
+ if not self.needs_training or not self.needs_training or self.index.is_trained:
148
+ self.index.add(embeddings)
149
+
150
+ # Update mappings
151
+ for i, doc in enumerate(docs_with_embeddings):
152
+ self.id_to_index[doc.id] = i
153
+ self.index_to_id[i] = doc.id
154
+
155
+ def get_document(self, doc_id: str) -> Optional[Document]:
156
+ """
157
+ Get a document by ID.
158
+
159
+ Args:
160
+ doc_id: Document ID to get
161
+
162
+ Returns:
163
+ Document if found, None otherwise
164
+ """
165
+ return self.documents.get(doc_id)
166
+
167
+ def count_documents(self) -> int:
168
+ """
169
+ Get the number of documents in the database.
170
+
171
+ Returns:
172
+ Number of documents
173
+ """
174
+ return len(self.documents)
175
+
176
+ def clear(self) -> None:
177
+ """Clear all documents from the database."""
178
+ self.documents = {}
179
+ self.id_to_index = {}
180
+ self.index_to_id = {}
181
+ self._initialize_index()
182
+
183
+ def save(self, directory: str) -> None:
184
+ """
185
+ Save the database to disk.
186
+
187
+ Args:
188
+ directory: Directory to save to
189
+ """
190
+ import faiss
191
+
192
+ os.makedirs(directory, exist_ok=True)
193
+
194
+ # Save documents
195
+ documents_data = {doc_id: doc.to_dict() for doc_id, doc in self.documents.items()}
196
+ with open(os.path.join(directory, "documents.json"), "w") as f:
197
+ json.dump(documents_data, f)
198
+
199
+ # Save mappings
200
+ mappings = {
201
+ "id_to_index": self.id_to_index,
202
+ "index_to_id": {str(k): v for k, v in self.index_to_id.items()} # Convert int keys to strings for JSON
203
+ }
204
+ with open(os.path.join(directory, "mappings.json"), "w") as f:
205
+ json.dump(mappings, f)
206
+
207
+ # Save index
208
+ faiss.write_index(self.index, os.path.join(directory, "faiss_index.bin"))
209
+
210
+ # Save metadata
211
+ metadata = {
212
+ "dimension": self.dimension,
213
+ "index_type": self.index_type,
214
+ "document_count": len(self.documents)
215
+ }
216
+ with open(os.path.join(directory, "metadata.json"), "w") as f:
217
+ json.dump(metadata, f)
218
+
219
+ @classmethod
220
+ def load(cls, directory: str) -> 'FaissVectorDatabase':
221
+ """
222
+ Load a database from disk.
223
+
224
+ Args:
225
+ directory: Directory to load from
226
+
227
+ Returns:
228
+ Loaded FaissVectorDatabase
229
+ """
230
+ import faiss
231
+
232
+ # Load metadata
233
+ with open(os.path.join(directory, "metadata.json"), "r") as f:
234
+ metadata = json.load(f)
235
+
236
+ # Create instance
237
+ db = cls(dimension=metadata["dimension"], index_type=metadata["index_type"])
238
+
239
+ # Load documents
240
+ with open(os.path.join(directory, "documents.json"), "r") as f:
241
+ documents_data = json.load(f)
242
+
243
+ db.documents = {doc_id: Document.from_dict(doc_data) for doc_id, doc_data in documents_data.items()}
244
+
245
+ # Load mappings
246
+ with open(os.path.join(directory, "mappings.json"), "r") as f:
247
+ mappings = json.load(f)
248
+
249
+ db.id_to_index = mappings["id_to_index"]
250
+ db.index_to_id = {int(k): v for k, v in mappings["index_to_id"].items()} # Convert string keys back to int
251
+
252
+ # Load index
253
+ db.index = faiss.read_index(os.path.join(directory, "faiss_index.bin"))
254
+
255
+ return db
256
+
257
+
258
+ # Factory function to create vector databases
259
+ def create_vector_database(
260
+ db_type: str = "faiss",
261
+ dimension: int = 384,
262
+ **kwargs
263
+ ) -> VectorDatabase:
264
+ """
265
+ Factory function to create a vector database.
266
+
267
+ Args:
268
+ db_type: Database type ('faiss')
269
+ dimension: Dimension of the embedding vectors
270
+ **kwargs: Additional arguments for the database
271
+
272
+ Returns:
273
+ A VectorDatabase instance
274
+ """
275
+ if db_type.lower() == "faiss":
276
+ return FaissVectorDatabase(dimension=dimension, **kwargs)
277
+ else:
278
+ raise ValueError(f"Unsupported database type: {db_type}")
279
+ Args:
280
+ doc_id: Document ID to delete
281
+
282
+ Returns:
283
+ True if document was deleted, False otherwise
284
+ """
285
+ raise NotImplementedError("Subclasses must implement delete_document")
286
+
287
+ def get_document(self, doc_id: str) -> Optional[Document]:
288
+ """
289
+ Get a document by ID.
290
+
291
+ Args:
292
+ doc_id: Document ID to get
293
+
294
+ Returns:
295
+ Document if found, None otherwise
296
+ """
297
+ raise NotImplementedError("Subclasses must implement get_document")
298
+
299
+ def count_documents(self) -> int:
300
+ """
301
+ Get the number of documents in the database.
302
+
303
+ Returns:
304
+ Number of documents
305
+ """
306
+ raise NotImplementedError("Subclasses must implement count_documents")
307
+
308
+ def clear(self) -> None:
309
+ """Clear all documents from the database."""
310
+ raise NotImplementedError("Subclasses must implement clear")
311
+
312
+ def save(self, directory: str) -> None:
313
+ """
314
+ Save the database to disk.
315
+
316
+ Args:
317
+ directory: Directory to save to
318
+ """
319
+ raise NotImplementedError("Subclasses must implement save")
320
+
321
+ @classmethod
322
+ def load(cls, directory: str) -> 'VectorDatabase':
323
+ """
324
+ Load a database from disk.
325
+
326
+ Args:
327
+ directory: Directory to load from
328
+
329
+ Returns:
330
+ Loaded database
331
+ """
332
+ raise NotImplementedError("Subclasses must implement load")
333
+
334
+
335
+ class FaissVectorDatabase(VectorDatabase):
336
+ """Vector database implementation using FAISS."""
337
+
338
+ def __init__(self, dimension: int = 384, index_type: str = "Flat"):
339
+ """
340
+ Initialize the FAISS vector database.
341
+
342
+ Args:
343
+ dimension: Dimension of the embedding vectors
344
+ index_type: FAISS index type (e.g., "Flat", "IVF", "HNSW")
345
+ """
346
+ super().__init__(dimension)
347
+ self.index_type = index_type
348
+ self.documents: Dict[str, Document] = {}
349
+ self.id_to_index: Dict[str, int] = {}
350
+ self.index_to_id: Dict[int, str] = {}
351
+
352
+ # Initialize FAISS index
353
+ self._initialize_index()
354
+
355
+ def _initialize_index(self):
356
+ """Initialize FAISS index based on the specified type."""
357
+ try:
358
+ import faiss
359
+ except ImportError:
360
+ raise ImportError(
361
+ "faiss-cpu is not installed. "
362
+ "Please install it with `pip install faiss-cpu` or `pip install faiss-gpu`."
363
+ )
364
+
365
+ if self.index_type == "Flat":
366
+ self.index = faiss.IndexFlatL2(self.dimension)
367
+ elif self.index_type == "IVF":
368
+ # IVF requires training, so we'll use a placeholder
369
+ # This would need to be trained on actual data
370
+ quantizer = faiss.IndexFlatL2(self.dimension)
371
+ n_cells = 100 # Number of centroids
372
+ self.index = faiss.IndexIVFFlat(quantizer, self.dimension, n_cells)
373
+ self.index.nprobe = 10 # Number of cells to probe at search time
374
+ elif self.index_type == "HNSW":
375
+ self.index = faiss.IndexHNSWFlat(self.dimension, 32) # 32 neighbors per node
376
+ else:
377
+ logger.warning(f"Unknown index type {self.index_type}, falling back to Flat")
378
+ self.index = faiss.IndexFlatL2(self.dimension)
379
+
380
+ # Mark if index needs training
381
+ self.needs_training = self.index_type in ["IVF"]
382
+
383
+ def add_document(self, document: Document) -> str:
384
+ """
385
+ Add a document to the database.
386
+
387
+ Args:
388
+ document: Document to add
389
+
390
+ Returns:
391
+ Document ID
392
+ """
393
+ # If no embedding is provided, log warning
394
+ if document.embedding is None:
395
+ logger.warning(f"Document {document.id} has no embedding - skipping indexing")
396
+ self.documents[document.id] = document
397
+ return document.id
398
+
399
+ # Ensure embedding is in the right format
400
+ embedding = np.array([document.embedding], dtype=np.float32)
401
+
402
+ # Train index if needed and we have enough data
403
+ if self.needs_training and len(self.documents) >= 100 and not self.index.is_trained:
404
+ logger.info("Training FAISS index")
405
+ # Collect 1000 embeddings for training
406
+ train_data = np.vstack([doc.embedding for doc in list(self.documents.values())[:1000]])
407
+ self.index.train(train_data)
408
+
409
+ # Add to FAISS index if it's trained or doesn't need training
410
+ if not self.needs_training or self.index.is_trained:
411
+ idx = len(self.id_to_index)
412
+ self.index.add(embedding)
413
+
414
+ # Update mapping dictionaries
415
+ self.id_to_index[document.id] = idx
416
+ self.index_to_id[idx] = document.id
417
+
418
+ # Store document
419
+ self.documents[document.id] = document
420
+
421
+ return document.id
422
+
423
+ def add_documents(self, documents: List[Document]) -> List[str]:
424
+ """
425
+ Add multiple documents to the database.
426
+
427
+ Args:
428
+ documents: List of Document objects
429
+
430
+ Returns:
431
+ List of document IDs
432
+ """
433
+ doc_ids = []
434
+
435
+ # First, collect all valid documents with embeddings
436
+ valid_docs = []
437
+ valid_embeddings = []
438
+
439
+ for doc in documents:
440
+ if doc.embedding is not None:
441
+ valid_docs.append(doc)
442
+ valid_embeddings.append(doc.embedding)
443
+
444
+ if not valid_docs:
445
+ logger.warning("No valid documents with embeddings to add")
446
+ return []
447
+
448
+ # Train index if needed and we have enough data
449
+ if self.needs_training and not self.index.is_trained:
450
+ if len(valid_embeddings) >= 100 or (len(self.documents) + len(valid_docs)) >= 100:
451
+ logger.info("Training FAISS index")
452
+ # Use available embeddings for training
453
+ train_data = np.vstack([
454
+ *[doc.embedding for doc in list(self.documents.values()) if doc.embedding is not None],
455
+ *valid_embeddings
456
+ ])
457
+ train_data = train_data[:min(1000, len(train_data))] # Limit to 1000 samples
458
+ self.index.train(train_data)
459
+
460
+ # Add embeddings to FAISS index if it's trained or doesn't need training
461
+ if not self.needs_training or self.index.is_trained:
462
+ embeddings_array = np.array(valid_embeddings, dtype=np.float32)
463
+ start_idx = len(self.id_to_index)
464
+ self.index.add(embeddings_array)
465
+
466
+ # Update mappings
467
+ for i, doc in enumerate(valid_docs):
468
+ idx = start_idx + i
469
+ self.id_to_index[doc.id] = idx
470
+ self.index_to_id[idx] = doc.id
471
+
472
+ # Store all documents (with or without embeddings)
473
+ for doc in documents:
474
+ self.documents[doc.id] = doc
475
+ doc_ids.append(doc.id)
476
+
477
+ return doc_ids
478
+
479
+ def search(
480
+ self,
481
+ query_embedding: np.ndarray,
482
+ top_k: int = 5,
483
+ filter_func: Optional[Callable[[Document], bool]] = None
484
+ ) -> List[Tuple[Document, float]]:
485
+ """
486
+ Search for similar documents.
487
+
488
+ Args:
489
+ query_embedding: Query embedding vector
490
+ top_k: Number of results to return
491
+ filter_func: Optional function to filter results
492
+
493
+ Returns:
494
+ List of (document, score) tuples
495
+ """
496
+ if not self.documents or not self.id_to_index:
497
+ logger.warning("Cannot search: database is empty")
498
+ return []
499
+
500
+ # Ensure index is trained if needed
501
+ if self.needs_training and not self.index.is_trained:
502
+ logger.warning("Cannot search: index not trained")
503
+ return []
504
+
505
+ # Convert to correct format if needed
506
+ if len(query_embedding.shape) == 1:
507
+ query_embedding = np.array([query_embedding], dtype=np.float32)
508
+
509
+ # Check if we need to rebuild the index
510
+ if len(self.id_to_index) != self.index.ntotal:
511
+ logger.info("Rebuilding index before search")
512
+ self._rebuild_index()
513
+
514
+ # Adjust top_k based on available items
515
+ effective_top_k = min(top_k, self.index.ntotal)
516
+ if effective_top_k < top_k:
517
+ logger.warning(f"Requested top_k={top_k} but only {effective_top_k} items in index")
518
+
519
+ # Perform search
520
+ distances, indices = self.index.search(query_embedding, effective_top_k)
521
+
522
+ # Retrieve documents
523
+ results = []
524
+ for i, idx in enumerate(indices[0]):
525
+ if idx != -1: # FAISS uses -1 for padding when there aren't enough results
526
+ doc_id = self.index_to_id.get(idx)
527
+ if doc_id and doc_id in self.documents:
528
+ doc = self.documents[doc_id]
529
+
530
+ # Apply filter if provided
531
+ if filter_func is None or filter_func(doc):
532
+ # Convert L2 distance to similarity score (1 / (1 + distance))
533
+ score = 1.0 / (1.0 + distances[0][i])
534
+ results.append((doc, score))
535
+
536
+ # Sort by score in descending order
537
+ results.sort(key=lambda x: x[1], reverse=True)
538
+
539
+ return results
540
+
541
+ def delete_document(self, doc_id: str) -> bool:
542
+ """
543
+ Delete a document from the database.
544
+
545
+