multimodal-rag / src /retrieval /vector_db.py
itachi
Initial deployment
a809248
"""
Vector Database Module.
Supports PostgreSQL+pgvector and FAISS for vector storage and retrieval.
"""
import json
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from ..utils import get_logger, get_config, LoggerMixin
logger = get_logger(__name__)
config = get_config()
@dataclass
class Document:
"""Document with text and metadata."""
id: str
text: str
embedding: Optional[np.ndarray] = None
metadata: Dict = field(default_factory=dict)
def to_dict(self) -> Dict:
return {
"id": self.id,
"text": self.text,
"metadata": self.metadata
}
@dataclass
class SearchResult:
"""Search result with score."""
document: Document
score: float
rank: int = 0
def to_dict(self) -> Dict:
return {
"id": self.document.id,
"text": self.document.text,
"score": self.score,
"rank": self.rank,
"metadata": self.document.metadata
}
class VectorStore(ABC, LoggerMixin):
"""Abstract base class for vector stores."""
@abstractmethod
def add_documents(
self,
documents: List[Document],
embeddings: Optional[np.ndarray] = None
) -> List[str]:
"""Add documents to the store."""
pass
@abstractmethod
def search(
self,
query_embedding: np.ndarray,
top_k: int = 10
) -> List[SearchResult]:
"""Search for similar documents."""
pass
@abstractmethod
def delete(self, document_ids: List[str]) -> int:
"""Delete documents by ID."""
pass
@abstractmethod
def get_document(self, document_id: str) -> Optional[Document]:
"""Get document by ID."""
pass
@property
@abstractmethod
def count(self) -> int:
"""Return number of documents in store."""
pass
class PostgresVectorStore(VectorStore):
"""
PostgreSQL + pgvector vector store.
Features:
- ACID compliance
- SQL filtering
- IVFFlat/HNSW indexing
- Full-text search support
"""
def __init__(
self,
connection_string: Optional[str] = None,
table_name: str = "document_embeddings",
embedding_dim: int = None,
index_type: str = None
):
"""
Initialize PostgreSQL vector store.
Args:
connection_string: PostgreSQL connection string
table_name: Name of the embeddings table
embedding_dim: Dimension of embeddings
index_type: Index type ("ivfflat" or "hnsw")
"""
self.table_name = table_name
self.embedding_dim = embedding_dim or config.embedding.embedding_dim
self.index_type = index_type or config.database.index_type
# Build connection string
if connection_string:
self.connection_string = connection_string
else:
db = config.database
self.connection_string = (
f"postgresql://{db.pg_user}:{db.pg_password}@"
f"{db.pg_host}:{db.pg_port}/{db.pg_database}"
)
self.conn = None
self._initialized = False
def _connect(self):
"""Connect to database."""
if self.conn is not None:
return
try:
import psycopg2
from pgvector.psycopg2 import register_vector
self.conn = psycopg2.connect(self.connection_string)
register_vector(self.conn)
self.logger.info("Connected to PostgreSQL")
except ImportError:
self.logger.error("psycopg2 or pgvector not installed")
raise
except Exception as e:
self.logger.error(f"Failed to connect: {e}")
raise
def initialize(self):
"""Create table and indexes."""
self._connect()
with self.conn.cursor() as cur:
# Enable pgvector extension
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
# Create table
cur.execute(f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
id TEXT PRIMARY KEY,
chunk_text TEXT NOT NULL,
embedding vector({self.embedding_dim}),
metadata JSONB DEFAULT '{{}}',
full_text_search tsvector GENERATED ALWAYS AS (
to_tsvector('english', chunk_text)
) STORED,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Create vector index
index_name = f"{self.table_name}_embedding_idx"
if self.index_type == "ivfflat":
cur.execute(f"""
CREATE INDEX IF NOT EXISTS {index_name}
ON {self.table_name}
USING ivfflat (embedding vector_cosine_ops)
WITH (lists = {config.database.num_lists})
""")
elif self.index_type == "hnsw":
cur.execute(f"""
CREATE INDEX IF NOT EXISTS {index_name}
ON {self.table_name}
USING hnsw (embedding vector_cosine_ops)
""")
# Create GIN index for full-text search
cur.execute(f"""
CREATE INDEX IF NOT EXISTS {self.table_name}_fts_idx
ON {self.table_name}
USING gin(full_text_search)
""")
self.conn.commit()
self._initialized = True
self.logger.info(f"Initialized table {self.table_name}")
def add_documents(
self,
documents: List[Document],
embeddings: Optional[np.ndarray] = None
) -> List[str]:
"""Add documents with embeddings."""
self._connect()
if not self._initialized:
self.initialize()
ids = []
with self.conn.cursor() as cur:
for i, doc in enumerate(documents):
embedding = embeddings[i] if embeddings is not None else doc.embedding
if embedding is None:
self.logger.warning(f"No embedding for document {doc.id}")
continue
cur.execute(f"""
INSERT INTO {self.table_name} (id, chunk_text, embedding, metadata)
VALUES (%s, %s, %s, %s)
ON CONFLICT (id) DO UPDATE SET
chunk_text = EXCLUDED.chunk_text,
embedding = EXCLUDED.embedding,
metadata = EXCLUDED.metadata
""", (
doc.id,
doc.text,
embedding.tolist(),
json.dumps(doc.metadata)
))
ids.append(doc.id)
self.conn.commit()
self.logger.info(f"Added {len(ids)} documents")
return ids
def search(
self,
query_embedding: np.ndarray,
top_k: int = 10,
filter_metadata: Optional[Dict] = None
) -> List[SearchResult]:
"""Search for similar documents."""
self._connect()
with self.conn.cursor() as cur:
# Build query
query = f"""
SELECT id, chunk_text, metadata,
1 - (embedding <=> %s) as similarity
FROM {self.table_name}
"""
params = [query_embedding.tolist()]
# Add metadata filter
if filter_metadata:
conditions = []
for key, value in filter_metadata.items():
conditions.append(f"metadata->>{key} = %s")
params.append(json.dumps(value))
query += " WHERE " + " AND ".join(conditions)
query += f" ORDER BY embedding <=> %s LIMIT {top_k}"
params.append(query_embedding.tolist())
cur.execute(query, params)
rows = cur.fetchall()
results = []
for rank, (id, text, metadata, score) in enumerate(rows):
doc = Document(
id=id,
text=text,
metadata=metadata if isinstance(metadata, dict) else json.loads(metadata)
)
results.append(SearchResult(
document=doc,
score=float(score),
rank=rank
))
return results
def full_text_search(
self,
query: str,
top_k: int = 10
) -> List[SearchResult]:
"""Perform full-text search using PostgreSQL FTS."""
self._connect()
with self.conn.cursor() as cur:
cur.execute(f"""
SELECT id, chunk_text, metadata,
ts_rank(full_text_search, plainto_tsquery('english', %s)) as score
FROM {self.table_name}
WHERE full_text_search @@ plainto_tsquery('english', %s)
ORDER BY score DESC
LIMIT {top_k}
""", (query, query))
rows = cur.fetchall()
results = []
for rank, (id, text, metadata, score) in enumerate(rows):
doc = Document(
id=id,
text=text,
metadata=metadata if isinstance(metadata, dict) else json.loads(metadata)
)
results.append(SearchResult(
document=doc,
score=float(score),
rank=rank
))
return results
def delete(self, document_ids: List[str]) -> int:
"""Delete documents by ID."""
self._connect()
with self.conn.cursor() as cur:
cur.execute(f"""
DELETE FROM {self.table_name}
WHERE id = ANY(%s)
""", (document_ids,))
deleted = cur.rowcount
self.conn.commit()
self.logger.info(f"Deleted {deleted} documents")
return deleted
def get_document(self, document_id: str) -> Optional[Document]:
"""Get document by ID."""
self._connect()
with self.conn.cursor() as cur:
cur.execute(f"""
SELECT id, chunk_text, metadata
FROM {self.table_name}
WHERE id = %s
""", (document_id,))
row = cur.fetchone()
if row:
return Document(
id=row[0],
text=row[1],
metadata=row[2] if isinstance(row[2], dict) else json.loads(row[2])
)
return None
@property
def count(self) -> int:
"""Return number of documents."""
self._connect()
with self.conn.cursor() as cur:
cur.execute(f"SELECT COUNT(*) FROM {self.table_name}")
return cur.fetchone()[0]
def close(self):
"""Close database connection."""
if self.conn:
self.conn.close()
self.conn = None
self.logger.info("Closed PostgreSQL connection")
class FAISSVectorStore(VectorStore):
"""
FAISS vector store for fast similarity search.
Features:
- Fast approximate nearest neighbor search
- GPU acceleration support
- Multiple index types (Flat, IVF, HNSW)
"""
def __init__(
self,
embedding_dim: int = None,
index_type: str = None,
nlist: int = None,
nprobe: int = None
):
"""
Initialize FAISS vector store.
Args:
embedding_dim: Dimension of embeddings
index_type: Index type ("flat", "ivf", "hnsw")
nlist: Number of clusters for IVF
nprobe: Number of clusters to search
"""
self.embedding_dim = embedding_dim or config.embedding.embedding_dim
self.index_type = index_type or config.database.faiss_index_type
self.nlist = nlist or config.database.faiss_nlist
self.nprobe = nprobe or config.database.faiss_nprobe
self.index = None
self.documents: Dict[int, Document] = {}
self.id_to_idx: Dict[str, int] = {}
self.idx_to_id: Dict[int, str] = {}
self.current_idx = 0
self._init_index()
def _init_index(self):
"""Initialize FAISS index."""
try:
import faiss
self.faiss = faiss
except ImportError:
self.logger.error("faiss not installed")
raise ImportError("Install faiss: pip install faiss-cpu")
if self.index_type == "flat":
self.index = faiss.IndexFlatIP(self.embedding_dim) # Inner product
elif self.index_type == "ivf":
quantizer = faiss.IndexFlatIP(self.embedding_dim)
self.index = faiss.IndexIVFFlat(
quantizer,
self.embedding_dim,
self.nlist,
faiss.METRIC_INNER_PRODUCT
)
self._needs_training = True
elif self.index_type == "hnsw":
self.index = faiss.IndexHNSWFlat(self.embedding_dim, 32)
self.index.hnsw.efConstruction = 200
else:
raise ValueError(f"Unknown index type: {self.index_type}")
self.logger.info(f"Initialized FAISS {self.index_type} index")
def add_documents(
self,
documents: List[Document],
embeddings: Optional[np.ndarray] = None
) -> List[str]:
"""Add documents with embeddings."""
if embeddings is None:
embeddings = np.vstack([doc.embedding for doc in documents])
# Normalize for cosine similarity
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
embeddings = embeddings.astype('float32')
# Train IVF index if needed
if hasattr(self, '_needs_training') and self._needs_training:
if embeddings.shape[0] >= self.nlist:
self.index.train(embeddings)
self._needs_training = False
else:
self.logger.warning(
f"Not enough vectors ({embeddings.shape[0]}) to train IVF index "
f"(need {self.nlist}). Using flat index."
)
self.index = self.faiss.IndexFlatIP(self.embedding_dim)
# Add to index
self.index.add(embeddings)
# Store document mapping
ids = []
for doc in documents:
self.documents[self.current_idx] = doc
self.id_to_idx[doc.id] = self.current_idx
self.idx_to_id[self.current_idx] = doc.id
ids.append(doc.id)
self.current_idx += 1
self.logger.info(f"Added {len(ids)} documents to FAISS")
return ids
def search(
self,
query_embedding: np.ndarray,
top_k: int = 10
) -> List[SearchResult]:
"""Search for similar documents."""
# Normalize query
query_embedding = query_embedding / np.linalg.norm(query_embedding)
query_embedding = query_embedding.astype('float32').reshape(1, -1)
# Set search parameters
if self.index_type == "ivf" and hasattr(self.index, 'nprobe'):
self.index.nprobe = self.nprobe
# Search
scores, indices = self.index.search(query_embedding, top_k)
results = []
for rank, (idx, score) in enumerate(zip(indices[0], scores[0])):
if idx == -1: # FAISS returns -1 for empty slots
continue
doc = self.documents.get(idx)
if doc:
results.append(SearchResult(
document=doc,
score=float(score),
rank=rank
))
return results
def delete(self, document_ids: List[str]) -> int:
"""Delete documents by ID (not well supported in FAISS)."""
deleted = 0
for doc_id in document_ids:
if doc_id in self.id_to_idx:
idx = self.id_to_idx[doc_id]
del self.documents[idx]
del self.id_to_idx[doc_id]
del self.idx_to_id[idx]
deleted += 1
self.logger.warning(
f"Marked {deleted} documents as deleted. "
"Note: FAISS doesn't support true deletion. Rebuild index for cleanup."
)
return deleted
def get_document(self, document_id: str) -> Optional[Document]:
"""Get document by ID."""
idx = self.id_to_idx.get(document_id)
if idx is not None:
return self.documents.get(idx)
return None
def get_all_documents(self) -> List[Document]:
"""Get all documents in the store."""
return list(self.documents.values())
@property
def count(self) -> int:
"""Return number of documents."""
return len(self.documents)
def save(self, path: Union[str, Path]):
"""Save index and documents."""
path = Path(path)
path.mkdir(parents=True, exist_ok=True)
# Save FAISS index
self.faiss.write_index(self.index, str(path / "index.faiss"))
# Save documents and mappings
import pickle
with open(path / "documents.pkl", 'wb') as f:
pickle.dump({
'documents': self.documents,
'id_to_idx': self.id_to_idx,
'idx_to_id': self.idx_to_id,
'current_idx': self.current_idx
}, f)
self.logger.info(f"Saved FAISS index to {path}")
def load(self, path: Union[str, Path]):
"""Load index and documents."""
path = Path(path)
# Load FAISS index
self.index = self.faiss.read_index(str(path / "index.faiss"))
# Load documents and mappings
import pickle
with open(path / "documents.pkl", 'rb') as f:
data = pickle.load(f)
self.documents = data['documents']
self.id_to_idx = data['id_to_idx']
self.idx_to_id = data['idx_to_id']
self.current_idx = data['current_idx']
self.logger.info(f"Loaded FAISS index from {path}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Vector DB Test")
parser.add_argument("--test", action="store_true", help="Run test mode")
parser.add_argument("--init", action="store_true", help="Initialize PostgreSQL")
args = parser.parse_args()
if args.test:
print("Vector Store Test (FAISS)\n" + "=" * 50)
# Create sample documents
np.random.seed(42)
docs = [
Document(id=f"doc_{i}", text=f"Sample document {i}",
embedding=np.random.randn(768))
for i in range(100)
]
# Initialize FAISS store
store = FAISSVectorStore(embedding_dim=768, index_type="flat")
store.add_documents(docs)
print(f"Documents in store: {store.count}")
# Search
query = np.random.randn(768)
results = store.search(query, top_k=5)
print(f"\nTop 5 results:")
for r in results:
print(f" {r.document.id}: score={r.score:.4f}")
if args.init:
print("Initializing PostgreSQL Vector Store...")
store = PostgresVectorStore()
store.initialize()
print("Done!")