Spaces:
Build error
Build error
| import asyncio | |
| from rank_bm25 import BM25Okapi | |
| # import nltk | |
| import string | |
| from typing import List, Set, Optional | |
| # from nltk.corpus import stopwords | |
| # from nltk.stem import WordNetLemmatizer | |
| import os | |
| # Commented out this function that downloads NLTK resources. | |
| # def download_nltk_resources(): | |
| # """ | |
| # Downloads required NLTK resources synchronously. | |
| # """ | |
| # resources = ['punkt', 'stopwords', 'wordnet', 'omw-1.4'] | |
| # nltk_data_path = "/tmp/nltk_data" | |
| # os.makedirs(nltk_data_path, exist_ok=True) | |
| # nltk.data.path.append(nltk_data_path) | |
| # for resource in resources: | |
| # try: | |
| # nltk.download(resource, download_dir=nltk_data_path, quiet=True) | |
| # except Exception as e: | |
| # print(f"Error downloading {resource}: {str(e)}") | |
| class BM25_search: | |
| nltk_resources_downloaded = False | |
| def __init__(self, remove_stopwords: bool = True, perform_lemmatization: bool = False): | |
| """ | |
| Initializes the BM25search. | |
| """ | |
| # Commented out NLTK resource initialization | |
| # if not BM25_search.nltk_resources_downloaded: | |
| # download_nltk_resources() | |
| # BM25_search.nltk_resources_downloaded = True | |
| self.documents: List[str] = [] | |
| self.doc_ids: List[str] = [] | |
| self.tokenized_docs: List[List[str]] = [] | |
| self.bm25: Optional[BM25Okapi] = None | |
| self.remove_stopwords = remove_stopwords | |
| self.perform_lemmatization = perform_lemmatization | |
| # Commented out NLTK-specific tools | |
| # self.stop_words: Set[str] = set(stopwords.words('english')) if remove_stopwords else set() | |
| # self.lemmatizer = WordNetLemmatizer() if perform_lemmatization else None | |
| def preprocess(self, text: str) -> List[str]: | |
| """ | |
| Preprocesses the input text by lowercasing and removing punctuation. | |
| NLTK-related tokenization, stopword removal, and lemmatization are commented out. | |
| """ | |
| text = text.lower().translate(str.maketrans('', '', string.punctuation)) | |
| # tokens = nltk.word_tokenize(text) # Commented out NLTK tokenization | |
| tokens = text.split() # Basic tokenization as a fallback | |
| # if self.remove_stopwords: | |
| # tokens = [token for token in tokens if token not in self.stop_words] | |
| # if self.perform_lemmatization and self.lemmatizer: | |
| # tokens = [self.lemmatizer.lemmatize(token) for token in tokens] | |
| return tokens | |
| def add_document(self, doc_id: str, new_doc: str) -> None: | |
| """ | |
| Adds a new document to the corpus and updates the BM25 index. | |
| """ | |
| processed_tokens = self.preprocess(new_doc) | |
| self.documents.append(new_doc) | |
| self.doc_ids.append(doc_id) | |
| self.tokenized_docs.append(processed_tokens) | |
| self.update_bm25() | |
| print(f"Added document ID: {doc_id}") | |
| async def remove_document(self, index: int) -> None: | |
| """ | |
| Removes a document from the corpus based on its index and updates the BM25 index. | |
| """ | |
| if 0 <= index < len(self.documents): | |
| removed_doc_id = self.doc_ids[index] | |
| del self.documents[index] | |
| del self.doc_ids[index] | |
| del self.tokenized_docs[index] | |
| self.update_bm25() | |
| print(f"Removed document ID: {removed_doc_id}") | |
| else: | |
| print(f"Index {index} is out of bounds.") | |
| def update_bm25(self) -> None: | |
| """ | |
| Updates the BM25 index based on the current tokenized documents. | |
| """ | |
| if self.tokenized_docs: | |
| self.bm25 = BM25Okapi(self.tokenized_docs) | |
| print("BM25 index has been initialized.") | |
| else: | |
| print("No documents to initialize BM25.") | |
| def get_scores(self, query: str) -> List[float]: | |
| """ | |
| Computes BM25 scores for all documents based on the given query. | |
| """ | |
| processed_query = self.preprocess(query) | |
| print(f"Tokenized Query: {processed_query}") | |
| if self.bm25: | |
| return self.bm25.get_scores(processed_query) | |
| else: | |
| print("BM25 is not initialized.") | |
| return [] | |
| def get_top_n_docs(self, query: str, n: int = 5) -> List[str]: | |
| """ | |
| Returns the top N documents for a given query. | |
| """ | |
| processed_query = self.preprocess(query) | |
| if self.bm25: | |
| return self.bm25.get_top_n(processed_query, self.documents, n) | |
| else: | |
| print("BM25 is not initialized.") | |
| return [] | |
| def clear_documents(self) -> None: | |
| """ | |
| Clears all documents from the BM25 index. | |
| """ | |
| self.documents = [] | |
| self.doc_ids = [] | |
| self.tokenized_docs = [] | |
| self.bm25 = None | |
| print("BM25 documents cleared and index reset.") | |
| def get_document(self, doc_id: str) -> str: | |
| """ | |
| Retrieves a document by its document ID. | |
| """ | |
| try: | |
| index = self.doc_ids.index(doc_id) | |
| return self.documents[index] | |
| except ValueError: | |
| print(f"Document ID {doc_id} not found.") | |
| return "" | |
| async def initialize_bm25_search(remove_stopwords: bool = True, perform_lemmatization: bool = False) -> BM25_search: | |
| """ | |
| Initializes the BM25search. | |
| """ | |
| # Removed NLTK resource download from async context | |
| return BM25_search(remove_stopwords, perform_lemmatization) | |