Spaces:
Runtime error
Runtime error
| import concurrent.futures | |
| import opik | |
| from loguru import logger | |
| from qdrant_client.models import FieldCondition, Filter, MatchValue, Range, Distance | |
| from llm_engineering.application import utils | |
| from llm_engineering.application.preprocessing.dispatchers import EmbeddingDispatcher | |
| from llm_engineering.domain.embedded_chunks import ( | |
| EmbeddedArticleChunk, | |
| EmbeddedChunk, | |
| EmbeddedPostChunk, | |
| EmbeddedRepositoryChunk, | |
| ) | |
| from llm_engineering.domain.queries import EmbeddedQuery, Query | |
| from llm_engineering.domain.video_chunks import EmbeddedVideoChunk | |
| from .query_expansion import QueryExpansion | |
| from .reranking import Reranker | |
| from .self_query import SelfQuery | |
| from .multimodal_dispatcher import MultimodalEmbeddingDispatcher | |
| from typing import Union | |
| class ContextRetriever: | |
| def __init__(self, mock: bool = False) -> None: | |
| self._query_expander = QueryExpansion(mock=mock) | |
| self._metadata_extractor = SelfQuery(mock=mock) | |
| self._reranker = Reranker(mock=mock) | |
| def search(self, query: Union[str, Query], k: int = 3, expand_to_n_queries: int = 3) -> list: | |
| # Existing code | |
| query_model = Query.from_str(query) if isinstance(query, str) else query | |
| query_model = self._metadata_extractor.generate(query_model) | |
| n_generated_queries = self._query_expander.generate(query_model, expand_to_n=expand_to_n_queries) | |
| # Initialize n_k_documents with empty list | |
| n_k_documents = [] | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| if n_generated_queries: | |
| search_tasks = [executor.submit(self._search, _query_model, k) | |
| for _query_model in n_generated_queries] | |
| # Handle potential None results from tasks | |
| n_k_documents = [task.result() or [] for task in concurrent.futures.as_completed(search_tasks)] | |
| # Ensure we're always working with a list of lists | |
| n_k_documents = n_k_documents or [[]] | |
| # Safe flattening with None filtering | |
| n_k_documents = utils.misc.flatten([docs for docs in n_k_documents if docs is not None]) | |
| if n_k_documents: | |
| k_documents = self.rerank(query, chunks=n_k_documents, keep_top_k=k) | |
| else: | |
| k_documents = [] | |
| return k_documents | |
| def _search(self, query: Query, k: int = 3) -> list[EmbeddedChunk]: | |
| assert k >= 3, "k should be >= 3" | |
| def _search_data_category( | |
| data_category_odm: type[EmbeddedChunk], embedded_query: EmbeddedQuery | |
| ) -> list[EmbeddedChunk]: | |
| if embedded_query.author_id: | |
| query_filter = Filter( | |
| must=[ | |
| FieldCondition( | |
| key="author_id", | |
| match=MatchValue( | |
| value=str(embedded_query.author_id), | |
| ), | |
| ) | |
| ] | |
| ) | |
| else: | |
| query_filter = None | |
| return data_category_odm.search( | |
| query_vector=embedded_query.embedding, | |
| limit=k // 3, | |
| query_filter=query_filter, | |
| ) | |
| embedded_query: EmbeddedQuery = EmbeddingDispatcher.dispatch(query) | |
| post_chunks = _search_data_category(EmbeddedPostChunk, embedded_query) | |
| articles_chunks = _search_data_category(EmbeddedArticleChunk, embedded_query) | |
| repositories_chunks = _search_data_category(EmbeddedRepositoryChunk, embedded_query) | |
| retrieved_chunks = post_chunks + articles_chunks + repositories_chunks | |
| return retrieved_chunks | |
| def rerank(self, query: str | Query, chunks: list[EmbeddedChunk], keep_top_k: int) -> list[EmbeddedChunk]: | |
| if isinstance(query, str): | |
| query = Query.from_str(query) | |
| reranked_documents = self._reranker.generate(query=query, chunks=chunks, keep_top_k=keep_top_k) | |
| logger.info(f"{len(reranked_documents)} documents reranked successfully.") | |
| return reranked_documents | |
| class VideoContextRetriever(ContextRetriever): | |
| def _search(self, query: Query, k: int = 3) -> list[EmbeddedChunk]: | |
| def _search_video_category(self, embedded_query: EmbeddedQuery, k: int): | |
| return EmbeddedVideoChunk.search( | |
| query_vector=embedded_query.embedding, | |
| limit=k, | |
| query_filter=self._create_time_filter(embedded_query) | |
| ) | |