chatvns / app /vector_store.py
liamxdev's picture
Upload folder using huggingface_hub
34b531b verified
Raw
History Blame Contribute Delete
8.12 kB
from __future__ import annotations
from collections.abc import Callable
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, FieldCondition, Filter, MatchValue, PointStruct, VectorParams
from app.config import (
EMBEDDING_API_URL,
EMBEDDING_BATCH_SIZE,
EMBEDDING_CACHE_ENABLED,
EMBEDDING_CACHE_PATH,
EMBEDDING_MODEL,
QDRANT_API_KEY,
QDRANT_COLLECTION,
QDRANT_URL,
)
from app.embedding_cache import EmbeddingCache, embedding_cache_key
from app.embeddings import get_embedding_model
from app.schemas import Chunk, RetrievedChunk
def get_qdrant_client() -> QdrantClient:
kwargs = {"url": QDRANT_URL}
if QDRANT_API_KEY:
kwargs["api_key"] = QDRANT_API_KEY
return QdrantClient(**kwargs)
def ensure_collection(client: QdrantClient, vector_size: int) -> None:
collections = client.get_collections().collections
if any(collection.name == QDRANT_COLLECTION for collection in collections):
return
client.create_collection(
collection_name=QDRANT_COLLECTION,
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
)
def recreate_collection(client: QdrantClient, vector_size: int) -> None:
collections = client.get_collections().collections
if any(collection.name == QDRANT_COLLECTION for collection in collections):
client.delete_collection(collection_name=QDRANT_COLLECTION)
client.create_collection(
collection_name=QDRANT_COLLECTION,
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
)
def chunk_payload(chunk: Chunk) -> dict:
return {
"text": chunk.text,
"ticker": chunk.ticker,
"scope": chunk.scope,
"modality": chunk.modality,
"source_path": chunk.source_path,
"chunk_index": chunk.chunk_index,
"structure_type": chunk.structure_type,
"heading_path": chunk.heading_path,
"token_count": chunk.token_count,
"metadata": chunk.metadata,
}
def index_chunks(
chunks: list[Chunk],
batch_size: int | None = None,
rebuild: bool = True,
progress_callback: Callable[[dict], None] | None = None,
) -> int:
batch_size = batch_size or EMBEDDING_BATCH_SIZE
embedding_model = get_embedding_model()
cache = EmbeddingCache(EMBEDDING_CACHE_PATH) if EMBEDDING_CACHE_ENABLED else None
client = get_qdrant_client()
try:
if rebuild:
recreate_collection(client, embedding_model.dim)
else:
ensure_collection(client, embedding_model.dim)
indexed = 0
total_batches = (len(chunks) + batch_size - 1) // batch_size if chunks else 0
for start in range(0, len(chunks), batch_size):
batch = chunks[start : start + batch_size]
batch_number = (start // batch_size) + 1
vectors, cache_hits, cache_misses = embed_index_batch(batch, embedding_model, cache)
if progress_callback:
progress_callback(
{
"stage": "embedding",
"batch_number": batch_number,
"total_batches": total_batches,
"batch_size": len(batch),
"indexed_so_far": indexed,
"total_chunks": len(chunks),
"embedding_dim": embedding_model.dim,
"cache_hits": cache_hits,
"cache_misses": cache_misses,
}
)
points = [
PointStruct(id=chunk.id, vector=vector, payload=chunk_payload(chunk))
for chunk, vector in zip(batch, vectors)
]
client.upsert(collection_name=QDRANT_COLLECTION, points=points)
indexed += len(points)
if progress_callback:
progress_callback(
{
"stage": "upsert",
"batch_number": batch_number,
"total_batches": total_batches,
"batch_size": len(points),
"indexed_so_far": indexed,
"total_chunks": len(chunks),
"embedding_dim": embedding_model.dim,
"cache_hits": cache_hits,
"cache_misses": cache_misses,
}
)
return indexed
finally:
if cache:
cache.close()
def embed_index_batch(
batch: list[Chunk],
embedding_model,
cache: EmbeddingCache | None,
) -> tuple[list[list[float]], int, int]:
if not cache:
return embedding_model.encode([chunk.text for chunk in batch]), 0, len(batch)
keys = [
embedding_cache_key(
chunk.text,
provider=embedding_model.provider,
model=EMBEDDING_MODEL,
dim=embedding_model.dim,
api_url=EMBEDDING_API_URL,
)
for chunk in batch
]
cached = cache.get_many(keys)
missing_indexes = [index for index, key in enumerate(keys) if key not in cached]
if missing_indexes:
missing_vectors = embedding_model.encode([batch[index].text for index in missing_indexes])
cache.set_many(
{
keys[index]: vector
for index, vector in zip(missing_indexes, missing_vectors)
}
)
for index, vector in zip(missing_indexes, missing_vectors):
cached[keys[index]] = vector
return [cached[key] for key in keys], len(batch) - len(missing_indexes), len(missing_indexes)
def search_points(
client: QdrantClient,
query_vector: list[float],
query_filter: Filter | None,
limit: int,
):
if hasattr(client, "search"):
return client.search(
collection_name=QDRANT_COLLECTION,
query_vector=query_vector,
query_filter=query_filter,
limit=limit,
with_payload=True,
)
response = client.query_points(
collection_name=QDRANT_COLLECTION,
query=query_vector,
query_filter=query_filter,
limit=limit,
with_payload=True,
)
return response.points
def retrieve(query: str, top_k: int, ticker: str | None = None) -> list[RetrievedChunk]:
embedding_model = get_embedding_model()
client = get_qdrant_client()
ensure_collection(client, embedding_model.dim)
query_filter = None
if ticker:
query_filter = Filter(
must=[FieldCondition(key="ticker", match=MatchValue(value=ticker.upper()))]
)
hits = search_points(
client=client,
query_vector=embedding_model.encode([query])[0],
query_filter=query_filter,
limit=top_k,
)
retrieved: list[RetrievedChunk] = []
for hit in hits:
payload = hit.payload or {}
source_path = str(payload.get("source_path", ""))
ticker_value = str(payload.get("ticker", ""))
scope = str(payload.get("scope") or ticker_value or "")
if ticker_value.upper() == "MARKET" or "world_market" in source_path or "/market/" in source_path.replace("\\", "/"):
ticker_value = ""
scope = "market"
retrieved.append(
RetrievedChunk(
id=str(hit.id),
text=str(payload.get("text", "")),
score=float(hit.score),
ticker=ticker_value,
modality=str(payload.get("modality", "")),
source_path=source_path,
structure_type=str(payload.get("structure_type", "")),
heading_path=list(payload.get("heading_path") or []),
metadata=dict(payload.get("metadata") or {}),
scope=scope,
)
)
return retrieved