final_project2 / src /vector_store.py
dnj0's picture
Simplify
b802cc4
raw
history blame
8.13 kB
"""
Векторное хранилище и Эмбеддер"
"""
import os
import json
from typing import List, Dict
import chromadb
from sentence_transformers import SentenceTransformer
import numpy as np
from config import CHROMA_DB_PATH, EMBEDDING_MODEL, EMBEDDING_DIM
class CLIPEmbedder:
"""Эмбеддер"""
def __init__(self, model_name: str = EMBEDDING_MODEL):
print(f"Embedding model: {model_name}")
self.model = SentenceTransformer(model_name)
print(f"Model loaded successfully")
def embed(self, text: str) -> List[float]:
"""Эмбеддинг для текста"""
try:
embedding = self.model.encode(text, convert_to_numpy=False)
return embedding.tolist() if hasattr(embedding, 'tolist') else embedding
except Exception as e:
print(f"Error embedding text: {e}")
return [0.0] * EMBEDDING_DIM
def embed_batch(self, texts: List[str]) -> List[List[float]]:
"""Эмбеддинг для текста"""
try:
embeddings = self.model.encode(texts, convert_to_numpy=False)
return [e.tolist() if hasattr(e, 'tolist') else e for e in embeddings]
except Exception as e:
print(f"Error embedding batch: {e}")
return [[0.0] * EMBEDDING_DIM] * len(texts)
class VectorStore:
"""Векторное хранилище"""
def __init__(self):
self.persist_directory = CHROMA_DB_PATH
self.embedder = CLIPEmbedder()
print(f"\nInitializing ChromaDB: {self.persist_directory}")
try:
self.client = chromadb.PersistentClient(
path=self.persist_directory
)
print(f"ChromaDB initialized")
except Exception as e:
print(f"Error initializing ChromaDB: {e}")
self.client = chromadb.PersistentClient(
path=self.persist_directory
)
try:
self.collection = self.client.get_or_create_colletion(
name="multimodal_rag",
metadata={"hnsw:space": "cosine"}
)
count = self.collection.count()
print(f"Collection loaded: {count} items in store")
except Exception as e:
print(f"Error with collection: {e}")
self.collection = self.client.get_or_create_collection(
name="multimodal_rag"
)
def add_documents(self, documents: List[Dict], doc_id: str):
"""Добавление документов в векторное хранилище"""
texts = []
metadatas = []
ids = []
print(f"\nAdding document: {doc_id}")
if 'text' in documents and documents['text']:
chunks = self._chunk_text(documents['text'], chunk_size=1000, overlap=200)
for idx, chunk in enumerate(chunks):
texts.append(chunk)
metadatas.append({
'doc_id': doc_id,
'type': 'text',
'chunk_idx': str(idx)
})
ids.append(f"{doc_id}_text_{idx}")
print(f" Text: {len(chunks)} chunks")
if 'images' in documents:
image_count = 0
for idx, image_data in enumerate(documents['images']):
if image_data.get('ocr_text'):
texts.append(f"Image {idx}: {image_data['ocr_text']}")
metadatas.append({
'doc_id': doc_id,
'type': 'image',
'image_idx': str(idx),
'image_path': image_data.get('path', '')
})
ids.append(f"{doc_id}_image_{idx}")
image_count += 1
if image_count > 0:
print(f" Images: {image_count} with OCR text")
if 'tables' in documents:
table_count = 0
for idx, table_data in enumerate(documents['tables']):
if table_data.get('content'):
texts.append(f"Table {idx}: {table_data.get('content', '')}")
metadatas.append({
'doc_id': doc_id,
'type': 'table',
'table_idx': str(idx)
})
ids.append(f"{doc_id}_table_{idx}")
table_count += 1
if table_count > 0:
print(f" Tables: {table_count}")
if texts:
print(f" 🔄 Generating {len(texts)} embeddings...")
embeddings = self.embedder.embed_batch(texts)
try:
self.collection.add(
ids=ids,
documents=texts,
embeddings=embeddings,
metadatas=metadatas
)
print(f"Successfully added {len(texts)} items to vector store")
except Exception as e:
print(f"Error adding to collection: {e}")
def search(self, query: str, n_results: int = 5) -> List[Dict]:
"""Поиск в векторном хранилище"""
try:
query_embedding = self.embedder.embed(query)
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=n_results
)
formatted_results = []
if results['documents']:
for i, doc in enumerate(results['documents'][0]):
metadata = results['metadatas'][0][i] if results['metadatas'] else {}
distance = results['distances'][0][i] if results['distances'] else 0
formatted_results.append({
'content': doc,
'metadata': metadata,
'distance': distance,
'type': metadata.get('type', 'unknown')
})
return formatted_results
except Exception as e:
print(f"Error searching vector store: {e}")
return []
def _chunk_text(self, text: str, chunk_size: int = 1000, overlap: int = 200) -> List[str]:
"""Сплит текста"""
chunks = []
start = 0
while start < len(text):
end = start + chunk_size
chunks.append(text[start:end])
start = end - overlap
return chunks
def get_collection_info(self) -> Dict:
"""Получение информации о коллекции в вектороном хранилище"""
try:
count = self.collection.count()
return {
'name': 'multimodal_rag',
'count': count,
'status': 'active',
'persist_path': self.persist_directory
}
except Exception as e:
print(f"Error getting collection info: {e}")
return {'status': 'error', 'message': str(e)}
def delete_by_doc_id(self, doc_id: str):
"""Удаление документа из векторного хранилища"""
try:
results = self.collection.get(where={'doc_id': doc_id})
if results['ids']:
self.collection.delete(ids=results['ids'])
print(f"Deleted {len(results['ids'])} documents for {doc_id}")
except Exception as e:
print(f"Error deleting documents: {e}")
def clear_all(self):
"""Очистка хранилища"""
try:
self.client.delete_collection(name="multimodal_rag")
self.collection = self.client.get_or_create_collection(
name="multimodal_rag",
metadata={"hnsw:space": "cosine"}
)
print("Collection cleared")
except Exception as e:
print(f"Error clearing collection: {e}")