from __future__ import annotations from collections.abc import Callable from qdrant_client import QdrantClient from qdrant_client.http.models import Distance, FieldCondition, Filter, MatchValue, PointStruct, VectorParams from app.config import ( EMBEDDING_API_URL, EMBEDDING_BATCH_SIZE, EMBEDDING_CACHE_ENABLED, EMBEDDING_CACHE_PATH, EMBEDDING_MODEL, QDRANT_API_KEY, QDRANT_COLLECTION, QDRANT_URL, ) from app.embedding_cache import EmbeddingCache, embedding_cache_key from app.embeddings import get_embedding_model from app.schemas import Chunk, RetrievedChunk def get_qdrant_client() -> QdrantClient: kwargs = {"url": QDRANT_URL} if QDRANT_API_KEY: kwargs["api_key"] = QDRANT_API_KEY return QdrantClient(**kwargs) def ensure_collection(client: QdrantClient, vector_size: int) -> None: collections = client.get_collections().collections if any(collection.name == QDRANT_COLLECTION for collection in collections): return client.create_collection( collection_name=QDRANT_COLLECTION, vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), ) def recreate_collection(client: QdrantClient, vector_size: int) -> None: collections = client.get_collections().collections if any(collection.name == QDRANT_COLLECTION for collection in collections): client.delete_collection(collection_name=QDRANT_COLLECTION) client.create_collection( collection_name=QDRANT_COLLECTION, vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), ) def chunk_payload(chunk: Chunk) -> dict: return { "text": chunk.text, "ticker": chunk.ticker, "scope": chunk.scope, "modality": chunk.modality, "source_path": chunk.source_path, "chunk_index": chunk.chunk_index, "structure_type": chunk.structure_type, "heading_path": chunk.heading_path, "token_count": chunk.token_count, "metadata": chunk.metadata, } def index_chunks( chunks: list[Chunk], batch_size: int | None = None, rebuild: bool = True, progress_callback: Callable[[dict], None] | None = None, ) -> int: batch_size = batch_size or EMBEDDING_BATCH_SIZE embedding_model = get_embedding_model() cache = EmbeddingCache(EMBEDDING_CACHE_PATH) if EMBEDDING_CACHE_ENABLED else None client = get_qdrant_client() try: if rebuild: recreate_collection(client, embedding_model.dim) else: ensure_collection(client, embedding_model.dim) indexed = 0 total_batches = (len(chunks) + batch_size - 1) // batch_size if chunks else 0 for start in range(0, len(chunks), batch_size): batch = chunks[start : start + batch_size] batch_number = (start // batch_size) + 1 vectors, cache_hits, cache_misses = embed_index_batch(batch, embedding_model, cache) if progress_callback: progress_callback( { "stage": "embedding", "batch_number": batch_number, "total_batches": total_batches, "batch_size": len(batch), "indexed_so_far": indexed, "total_chunks": len(chunks), "embedding_dim": embedding_model.dim, "cache_hits": cache_hits, "cache_misses": cache_misses, } ) points = [ PointStruct(id=chunk.id, vector=vector, payload=chunk_payload(chunk)) for chunk, vector in zip(batch, vectors) ] client.upsert(collection_name=QDRANT_COLLECTION, points=points) indexed += len(points) if progress_callback: progress_callback( { "stage": "upsert", "batch_number": batch_number, "total_batches": total_batches, "batch_size": len(points), "indexed_so_far": indexed, "total_chunks": len(chunks), "embedding_dim": embedding_model.dim, "cache_hits": cache_hits, "cache_misses": cache_misses, } ) return indexed finally: if cache: cache.close() def embed_index_batch( batch: list[Chunk], embedding_model, cache: EmbeddingCache | None, ) -> tuple[list[list[float]], int, int]: if not cache: return embedding_model.encode([chunk.text for chunk in batch]), 0, len(batch) keys = [ embedding_cache_key( chunk.text, provider=embedding_model.provider, model=EMBEDDING_MODEL, dim=embedding_model.dim, api_url=EMBEDDING_API_URL, ) for chunk in batch ] cached = cache.get_many(keys) missing_indexes = [index for index, key in enumerate(keys) if key not in cached] if missing_indexes: missing_vectors = embedding_model.encode([batch[index].text for index in missing_indexes]) cache.set_many( { keys[index]: vector for index, vector in zip(missing_indexes, missing_vectors) } ) for index, vector in zip(missing_indexes, missing_vectors): cached[keys[index]] = vector return [cached[key] for key in keys], len(batch) - len(missing_indexes), len(missing_indexes) def search_points( client: QdrantClient, query_vector: list[float], query_filter: Filter | None, limit: int, ): if hasattr(client, "search"): return client.search( collection_name=QDRANT_COLLECTION, query_vector=query_vector, query_filter=query_filter, limit=limit, with_payload=True, ) response = client.query_points( collection_name=QDRANT_COLLECTION, query=query_vector, query_filter=query_filter, limit=limit, with_payload=True, ) return response.points def retrieve(query: str, top_k: int, ticker: str | None = None) -> list[RetrievedChunk]: embedding_model = get_embedding_model() client = get_qdrant_client() ensure_collection(client, embedding_model.dim) query_filter = None if ticker: query_filter = Filter( must=[FieldCondition(key="ticker", match=MatchValue(value=ticker.upper()))] ) hits = search_points( client=client, query_vector=embedding_model.encode([query])[0], query_filter=query_filter, limit=top_k, ) retrieved: list[RetrievedChunk] = [] for hit in hits: payload = hit.payload or {} source_path = str(payload.get("source_path", "")) ticker_value = str(payload.get("ticker", "")) scope = str(payload.get("scope") or ticker_value or "") if ticker_value.upper() == "MARKET" or "world_market" in source_path or "/market/" in source_path.replace("\\", "/"): ticker_value = "" scope = "market" retrieved.append( RetrievedChunk( id=str(hit.id), text=str(payload.get("text", "")), score=float(hit.score), ticker=ticker_value, modality=str(payload.get("modality", "")), source_path=source_path, structure_type=str(payload.get("structure_type", "")), heading_path=list(payload.get("heading_path") or []), metadata=dict(payload.get("metadata") or {}), scope=scope, ) ) return retrieved