textmeme_search / src /search /bm25_search.py
Futyn-Maker
Deploy the app
7e1f5f6
import os
import pickle
import time
from typing import List, Dict, Any, Callable
import numpy as np
from sqlalchemy.orm import Session
from rank_bm25 import BM25Okapi
from src.db import crud
class BM25Search:
def __init__(
self,
db: Session,
index_folder: str,
tokenizer: Callable[[str], List[str]]
):
"""
Initialize the BM25Search.
Args:
db (Session): The database session.
index_folder (str): The folder containing the BM25 index.
tokenizer (Callable[[str], List[str]]): A function to tokenize the text.
"""
self.db = db
self.tokenizer = tokenizer
self.bm25 = self._load_index(index_folder)
def _load_index(self, index_folder: str) -> BM25Okapi:
"""
Load the BM25 index from a file.
Args:
index_folder (str): The folder containing the BM25 index.
Returns:
BM25Okapi: The loaded BM25 index.
"""
with open(os.path.join(index_folder, 'bm25_index.pkl'), 'rb') as f:
return pickle.load(f)
def search(self, query: str, n: int = 3) -> Dict[str, Any]:
"""
Perform a search using BM25.
Args:
query (str): The search query.
n (int, optional): The number of results to return. Defaults to 3.
Returns:
Dict[str, Any]: A dictionary containing search results and search time.
"""
start_time = time.time()
# Tokenize the query
query_tokens = self.tokenizer(query)
# Retrieve scores for all documents
scores = self.bm25.get_scores(query_tokens)
# Get top n document indices
top_n_indices = np.argsort(scores)[-n:][::-1]
top_n_scores = scores[top_n_indices]
# Adjust indices to match database IDs (assuming IDs start from 1)
db_ids = top_n_indices + 1
# Retrieve memes from the database
memes = crud.get_memes_by_ids(self.db, db_ids.tolist())
# Format the results
results = [
{
"id": meme.id,
"public_id": meme.public_id,
"text": meme.text,
"image_url": meme.image_url,
"score": top_n_scores[db_ids.tolist().index(meme.id)]
}
for meme in memes
]
return {
"results": results,
"search_time": time.time() - start_time
}