Spaces:
Build error
Build error
| import os | |
| import pickle | |
| import time | |
| from typing import List, Dict, Any, Callable | |
| import numpy as np | |
| from sqlalchemy.orm import Session | |
| from rank_bm25 import BM25Okapi | |
| from src.db import crud | |
| class BM25Search: | |
| def __init__( | |
| self, | |
| db: Session, | |
| index_folder: str, | |
| tokenizer: Callable[[str], List[str]] | |
| ): | |
| """ | |
| Initialize the BM25Search. | |
| Args: | |
| db (Session): The database session. | |
| index_folder (str): The folder containing the BM25 index. | |
| tokenizer (Callable[[str], List[str]]): A function to tokenize the text. | |
| """ | |
| self.db = db | |
| self.tokenizer = tokenizer | |
| self.bm25 = self._load_index(index_folder) | |
| def _load_index(self, index_folder: str) -> BM25Okapi: | |
| """ | |
| Load the BM25 index from a file. | |
| Args: | |
| index_folder (str): The folder containing the BM25 index. | |
| Returns: | |
| BM25Okapi: The loaded BM25 index. | |
| """ | |
| with open(os.path.join(index_folder, 'bm25_index.pkl'), 'rb') as f: | |
| return pickle.load(f) | |
| def search(self, query: str, n: int = 3) -> Dict[str, Any]: | |
| """ | |
| Perform a search using BM25. | |
| Args: | |
| query (str): The search query. | |
| n (int, optional): The number of results to return. Defaults to 3. | |
| Returns: | |
| Dict[str, Any]: A dictionary containing search results and search time. | |
| """ | |
| start_time = time.time() | |
| # Tokenize the query | |
| query_tokens = self.tokenizer(query) | |
| # Retrieve scores for all documents | |
| scores = self.bm25.get_scores(query_tokens) | |
| # Get top n document indices | |
| top_n_indices = np.argsort(scores)[-n:][::-1] | |
| top_n_scores = scores[top_n_indices] | |
| # Adjust indices to match database IDs (assuming IDs start from 1) | |
| db_ids = top_n_indices + 1 | |
| # Retrieve memes from the database | |
| memes = crud.get_memes_by_ids(self.db, db_ids.tolist()) | |
| # Format the results | |
| results = [ | |
| { | |
| "id": meme.id, | |
| "public_id": meme.public_id, | |
| "text": meme.text, | |
| "image_url": meme.image_url, | |
| "score": top_n_scores[db_ids.tolist().index(meme.id)] | |
| } | |
| for meme in memes | |
| ] | |
| return { | |
| "results": results, | |
| "search_time": time.time() - start_time | |
| } | |