sinal-de-alerta / src /labdaps /retrieval /vector_store.py
fabianonbfilho's picture
Upload src/labdaps/retrieval/vector_store.py with huggingface_hub
c75fcb3 verified
Raw
History Blame Contribute Delete
2.32 kB
from pathlib import Path
import chromadb
from src.labdaps.config import CHROMA_DIR, COLLECTION_NAME
from src.labdaps.ingestion.chunker import Chunk
from src.labdaps.ingestion.embedder import Embedder
def _get_client() -> chromadb.PersistentClient:
CHROMA_DIR.mkdir(parents=True, exist_ok=True)
return chromadb.PersistentClient(path=str(CHROMA_DIR))
def build_index(chunks: list[Chunk], embedder: Embedder, rebuild: bool = False) -> None:
client = _get_client()
if rebuild:
try:
client.delete_collection(COLLECTION_NAME)
print(f"[INFO] Coleção '{COLLECTION_NAME}' removida.")
except Exception:
pass
collection = client.get_or_create_collection(
name=COLLECTION_NAME,
metadata={"hnsw:space": "cosine"},
)
if collection.count() > 0 and not rebuild:
print(f"[INFO] Coleção já contém {collection.count()} chunks. Use --rebuild para re-indexar.")
return
texts = [c.text for c in chunks]
embeddings = embedder.embed_passages(texts)
ids = [f"chunk_{c.chunk_index}" for c in chunks]
metadatas = [
{"source_file": c.source_file, "page_number": c.page_number, "chunk_index": c.chunk_index}
for c in chunks
]
batch_size = 500
for i in range(0, len(chunks), batch_size):
collection.upsert(
ids=ids[i:i+batch_size],
embeddings=embeddings[i:i+batch_size],
documents=texts[i:i+batch_size],
metadatas=metadatas[i:i+batch_size],
)
print(f"[INFO] Indexados chunks {i} a {min(i+batch_size, len(chunks))}/{len(chunks)}")
print(f"[INFO] Indexação concluída. Total: {collection.count()} chunks")
def query_store(query_embedding: list[float], top_k: int):
client = _get_client()
collection = client.get_collection(COLLECTION_NAME)
results = collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
include=["documents", "metadatas", "distances"],
)
return results["documents"][0], results["metadatas"][0], results["distances"][0]
def collection_count() -> int:
try:
client = _get_client()
collection = client.get_collection(COLLECTION_NAME)
return collection.count()
except Exception:
return 0