Spaces:
Configuration error
Configuration error
| import string | |
| from typing import List, Optional, Tuple | |
| from langchain.chains import LLMChain | |
| from langchain.chains.base import Chain | |
| from loguru import logger | |
| from app.chroma import ChromaDenseVectorDB | |
| from app.config.models.configs import ( | |
| ResponseModel, | |
| Config, SemanticSearchConfig, | |
| ) | |
| from app.ranking import BCEReranker, rerank | |
| from app.splade import SpladeSparseVectorDB | |
| class LLMBundle: | |
| def __init__( | |
| self, | |
| chain: Chain, | |
| dense_db: ChromaDenseVectorDB, | |
| reranker: BCEReranker, | |
| sparse_db: SpladeSparseVectorDB, | |
| chunk_sizes: List[int], | |
| hyde_chain: Optional[LLMChain] = None | |
| ) -> None: | |
| self.chain = chain | |
| self.dense_db = dense_db | |
| self.reranker = reranker | |
| self.sparse_db = sparse_db | |
| self.chunk_sizes = chunk_sizes | |
| self.hyde_chain = hyde_chain | |
| def get_relevant_documents( | |
| self, | |
| original_query: str, | |
| query: str, | |
| config: SemanticSearchConfig, | |
| label: str, | |
| ) -> Tuple[List[str], float]: | |
| most_relevant_docs = [] | |
| docs = [] | |
| current_reranker_score, reranker_score = -1e5, -1e5 | |
| for chunk_size in self.chunk_sizes: | |
| all_relevant_docs = [] | |
| all_relevant_doc_ids = set() | |
| logger.debug("Evaluating query: {}", query) | |
| if config.query_prefix: | |
| logger.info(f"Adding query prefix for retrieval: {config.query_prefix}") | |
| query = config.query_prefix + query | |
| sparse_search_docs_ids, sparse_scores = self.sparse_db.query( | |
| search=query, n=config.max_k, label=label, chunk_size=chunk_size | |
| ) | |
| logger.info(f"Stage 1: Got {len(sparse_search_docs_ids)} documents.") | |
| filter = ( | |
| {"chunk_size": chunk_size} | |
| if len(self.chunk_sizes) > 1 | |
| else dict() | |
| ) | |
| if label: | |
| filter.update({"label": label}) | |
| if ( | |
| not filter | |
| ): | |
| filter = None | |
| logger.info(f"Dense embeddings filter: {filter}") | |
| res = self.dense_db.similarity_search_with_relevance_scores( | |
| query, filter=filter | |
| ) | |
| dense_search_doc_ids = [r[0].metadata["document_id"] for r in res] | |
| all_doc_ids = ( | |
| set(sparse_search_docs_ids).union(set(dense_search_doc_ids)) | |
| ).difference(all_relevant_doc_ids) | |
| if all_doc_ids: | |
| relevant_docs = self.dense_db.get_documents_by_id( | |
| document_ids=list(all_doc_ids) | |
| ) | |
| all_relevant_docs += relevant_docs | |
| # Re-rank embeddings | |
| reranker_score, relevant_docs = rerank( | |
| rerank_model=self.reranker, | |
| query=original_query, | |
| docs=all_relevant_docs, | |
| ) | |
| if reranker_score > current_reranker_score: | |
| docs = relevant_docs | |
| current_reranker_score = reranker_score | |
| len_ = 0 | |
| for doc in docs: | |
| doc_length = len(doc.page_content) | |
| if len_ + doc_length < config.max_char_size: | |
| most_relevant_docs.append(doc) | |
| len_ += doc_length | |
| return most_relevant_docs, current_reranker_score | |
| def get_and_parse_response( | |
| self, | |
| query: str, | |
| config: Config, | |
| label: str = "", | |
| ) -> ResponseModel: | |
| original_query = query | |
| # Add HyDE queries | |
| hyde_response = self.hyde_chain.run(query) | |
| query += hyde_response | |
| logger.info(f"query: {query}") | |
| semantic_search_config = config.semantic_search | |
| most_relevant_docs, score = self.get_relevant_documents( | |
| original_query, query, semantic_search_config, label | |
| ) | |
| res = self.chain( | |
| {"input_documents": most_relevant_docs, "question": original_query}, | |
| ) | |
| out = ResponseModel( | |
| response=res["output_text"], | |
| question=query, | |
| average_score=score, | |
| hyde_response="", | |
| ) | |
| for doc in res["input_documents"]: | |
| out.semantic_search.append(doc.page_content) | |
| return out | |
| class PartialFormatter(string.Formatter): | |
| def __init__(self, missing="~~", bad_fmt="!!"): | |
| self.missing, self.bad_fmt = missing, bad_fmt | |
| def get_field(self, field_name, args, kwargs): | |
| try: | |
| val = super(PartialFormatter, self).get_field(field_name, args, kwargs) | |
| except (KeyError, AttributeError): | |
| val = None, field_name | |
| return val | |
| def format_field(self, value, spec): | |
| if value is None: | |
| return self.missing | |
| try: | |
| return super(PartialFormatter, self).format_field(value, spec) | |
| except ValueError: | |
| if self.bad_fmt is not None: | |
| return self.bad_fmt | |
| else: | |
| raise | |