import sys from pathlib import Path import fire from omegaconf import OmegaConf from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker def initialize_bm25_search(db, config): """ Initialize BM25 search engine. Args: db: Database session. config: Configuration dictionary. Returns: BM25Search: Initialized BM25 search engine. """ from search.bm25_search import BM25Search from preprocessing.mystem_tokenizer import MystemTokenizer custom_tokenizer = MystemTokenizer() return BM25Search( db, config['index_folders']['bm25'], custom_tokenizer.tokenize ) def initialize_semantic_search(db, config): """ Initialize semantic search engine. Args: db: Database session. config: Configuration dictionary. Returns: SemanticSearch: Initialized semantic search engine. """ from search.semantic_search import SemanticSearch return SemanticSearch( db, model=config['semantic_search']['model'], embeddings_file=f"{config['index_folders']['semantic']}/embeddings.npy", prefix=config['semantic_search']['query_prefix']) def search_memes(query: str, search_type: str = 'bm25', num: int = 1): """ Search for memes using the specified search method. Args: query (str): The search query. search_type (str): The type of search to perform. Either 'bm25' or 'semantic'. Defaults to 'bm25s'. num (int): The number of results to return. Defaults to 1. Returns: None: Prints the results to the console. """ if not query: print("Error: Query is required.") return if search_type not in ['bm25', 'semantic']: print("Error: Invalid search type. Use 'bm25' or 'semantic'.") return if num < 1: print("Error: Number of results must be at least 1.") return # Load configuration config = OmegaConf.load('config.yaml') config = OmegaConf.to_container(config) # Initialize database session engine = create_engine(config['database']['url']) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) db = SessionLocal() try: # Initialize search engine if search_type == 'bm25': search_engine = initialize_bm25_search(db, config) elif search_type == 'semantic': search_engine = initialize_semantic_search(db, config) # Perform search results = search_engine.search(query, num) # Print results for result in results['results']: print(result['text']) print(f"\nSearch time: {results['search_time']:.4f} seconds") finally: db.close() if __name__ == "__main__": # Set up project root path project_root = Path(__file__).resolve().parents[1] sys.path.insert(0, str(project_root)) fire.Fire(search_memes)