Spaces:
Build error
Build error
| # rag_pipeline.py | |
| import numpy as np | |
| import pickle | |
| import os | |
| import logging | |
| import asyncio | |
| from app.search.bm25_search import BM25_search | |
| from app.search.faiss_search import FAISS_search | |
| from app.search.hybrid_search import Hybrid_search | |
| from app.utils.token_counter import TokenCounter | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| from keybert import KeyBERT | |
| import asyncio | |
| def extract_keywords_async(doc, threshold=0.4, top_n = 5): | |
| kw_model = KeyBERT() | |
| keywords = kw_model.extract_keywords(doc, threshold=threshold, top_n=top_n) | |
| keywords = [key for key, _ in keywords] | |
| return keywords | |
| # rag.py | |
| class RAGSystem: | |
| def __init__(self, embedding_model): | |
| self.token_counter = TokenCounter() | |
| self.documents = [] | |
| self.doc_ids = [] | |
| self.results = [] | |
| self.meta_data = [] | |
| self.embedding_model = embedding_model | |
| self.bm25_wrapper = BM25_search() | |
| self.faiss_wrapper = FAISS_search(embedding_model) | |
| self.hybrid_search = Hybrid_search(self.bm25_wrapper, self.faiss_wrapper) | |
| def add_document(self, doc_id, text, meta_data=None): | |
| self.token_counter.add_document(doc_id, text) | |
| self.doc_ids.append(doc_id) | |
| self.documents.append(text) | |
| self.meta_data.append(meta_data) | |
| self.bm25_wrapper.add_document(doc_id, text) | |
| self.faiss_wrapper.add_document(doc_id, text) | |
| def delete_document(self, doc_id): | |
| try: | |
| index = self.doc_ids.index(doc_id) | |
| del self.doc_ids[index] | |
| del self.documents[index] | |
| self.bm25_wrapper.remove_document(index) | |
| self.faiss_wrapper.remove_document(index) | |
| self.token_counter.remove_document(doc_id) | |
| except ValueError: | |
| logging.warning(f"Document ID {doc_id} not found.") | |
| async def adv_query(self, query_text, keywords, top_k=15, prefixes=None): | |
| results = await self.hybrid_search.advanced_search( | |
| query_text, | |
| keywords=keywords, | |
| top_n=top_k, | |
| threshold=0.43, | |
| prefixes=prefixes | |
| ) | |
| retrieved_docs = [] | |
| if results: | |
| seen_docs = set() | |
| for doc_id, score in results: | |
| if doc_id not in seen_docs: | |
| # Check if the doc_id exists in self.doc_ids | |
| if doc_id not in self.doc_ids: | |
| logger.error(f"doc_id {doc_id} not found in self.doc_ids") | |
| seen_docs.add(doc_id) | |
| # Fetch the index of the document | |
| try: | |
| index = self.doc_ids.index(doc_id) | |
| except ValueError as e: | |
| logger.error(f"Error finding index for doc_id {doc_id}: {e}") | |
| continue | |
| # Validate index range | |
| if index >= len(self.documents) or index >= len(self.meta_data): | |
| logger.error(f"Index {index} out of range for documents or metadata") | |
| continue | |
| doc = self.documents[index] | |
| meta_data = self.meta_data[index] | |
| # Extract the file name and page number | |
| # file_name = meta_data['source'].split('/')[-1] # Extracts 'POJK 31 - 2018.pdf' | |
| # page_number = meta_data.get('page', 'unknown') | |
| # url = meta_data['source'] | |
| # file_name = meta_data.get('source', 'unknown_source').split('/')[-1] # Safe extraction | |
| # page_number = meta_data.get('page', 'unknown') # Default to 'unknown' if 'page' is missing | |
| url = meta_data.get('source', 'unknown_url') # Default URL fallback | |
| # logger.info(f"file_name: {file_name}, page_number: {page_number}, url: {url}") | |
| # Format as a single string | |
| # content_string = f"'{file_name}', 'page': {page_number}" | |
| # doc_name = f"{file_name}" | |
| self.results.append(doc) | |
| retrieved_docs.append({"url":url, "text": doc}) | |
| return retrieved_docs | |
| else: | |
| return [{"url": "None.", "text": None}] | |
| def get_total_tokens(self): | |
| return self.token_counter.get_total_tokens() | |
| def get_context(self): | |
| context = "\n".join(self.results) | |
| return context | |
| def save_state(self, path): | |
| # Save doc_ids, documents, and token counter state | |
| with open(f"{path}_state.pkl", 'wb') as f: | |
| pickle.dump({ | |
| "doc_ids": self.doc_ids, | |
| "documents": self.documents, | |
| "meta_data": self.meta_data, | |
| "token_counts": self.token_counter.doc_tokens | |
| }, f) | |
| def load_state(self, path): | |
| if os.path.exists(f"{path}_state.pkl"): | |
| with open(f"{path}_state.pkl", 'rb') as f: | |
| state_data = pickle.load(f) | |
| self.doc_ids = state_data["doc_ids"] | |
| self.documents = state_data["documents"] | |
| self.meta_data = state_data["meta_data"] | |
| self.token_counter.doc_tokens = state_data["token_counts"] | |
| # Clear and rebuild BM25 and FAISS | |
| self.bm25_wrapper.clear_documents() | |
| self.faiss_wrapper.clear_documents() | |
| for doc_id, document in zip(self.doc_ids, self.documents): | |
| self.bm25_wrapper.add_document(doc_id, document) | |
| self.faiss_wrapper.add_document(doc_id, document) | |
| self.token_counter.total_tokens = sum(self.token_counter.doc_tokens.values()) | |
| logging.info("System state loaded successfully with documents and indices rebuilt.") | |
| else: | |
| logging.info("No previous state found, initializing fresh state.") | |
| self.doc_ids = [] | |
| self.documents = [] | |
| self.meta_data = [] # Reset meta_data | |
| self.token_counter = TokenCounter() | |
| self.bm25_wrapper = BM25_search() | |
| self.faiss_wrapper = FAISS_search(self.embedding_model) | |
| self.hybrid_search = Hybrid_search(self.bm25_wrapper, self.faiss_wrapper) |