Spaces:
Sleeping
Sleeping
| import sys | |
| from pathlib import Path | |
| import fire | |
| from omegaconf import OmegaConf | |
| from sqlalchemy import create_engine | |
| from sqlalchemy.orm import sessionmaker | |
| def initialize_bm25_search(db, config): | |
| """ | |
| Initialize BM25 search engine. | |
| Args: | |
| db: Database session. | |
| config: Configuration dictionary. | |
| Returns: | |
| BM25Search: Initialized BM25 search engine. | |
| """ | |
| from search.bm25_search import BM25Search | |
| from preprocessing.mystem_tokenizer import MystemTokenizer | |
| custom_tokenizer = MystemTokenizer() | |
| return BM25Search( | |
| db, | |
| config['index_folders']['bm25'], | |
| custom_tokenizer.tokenize | |
| ) | |
| def initialize_semantic_search(db, config): | |
| """ | |
| Initialize semantic search engine. | |
| Args: | |
| db: Database session. | |
| config: Configuration dictionary. | |
| Returns: | |
| SemanticSearch: Initialized semantic search engine. | |
| """ | |
| from search.semantic_search import SemanticSearch | |
| return SemanticSearch( | |
| db, | |
| model=config['semantic_search']['model'], | |
| embeddings_file=f"{config['index_folders']['semantic']}/embeddings.npy", | |
| prefix=config['semantic_search']['query_prefix']) | |
| def search_memes(query: str, search_type: str = 'bm25', num: int = 1): | |
| """ | |
| Search for memes using the specified search method. | |
| Args: | |
| query (str): The search query. | |
| search_type (str): The type of search to perform. Either 'bm25' or 'semantic'. Defaults to 'bm25s'. | |
| num (int): The number of results to return. Defaults to 1. | |
| Returns: | |
| None: Prints the results to the console. | |
| """ | |
| if not query: | |
| print("Error: Query is required.") | |
| return | |
| if search_type not in ['bm25', 'semantic']: | |
| print("Error: Invalid search type. Use 'bm25' or 'semantic'.") | |
| return | |
| if num < 1: | |
| print("Error: Number of results must be at least 1.") | |
| return | |
| # Load configuration | |
| config = OmegaConf.load('config.yaml') | |
| config = OmegaConf.to_container(config) | |
| # Initialize database session | |
| engine = create_engine(config['database']['url']) | |
| SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | |
| db = SessionLocal() | |
| try: | |
| # Initialize search engine | |
| if search_type == 'bm25': | |
| search_engine = initialize_bm25_search(db, config) | |
| elif search_type == 'semantic': | |
| search_engine = initialize_semantic_search(db, config) | |
| # Perform search | |
| results = search_engine.search(query, num) | |
| # Print results | |
| for result in results['results']: | |
| print(result['text']) | |
| print(f"\nSearch time: {results['search_time']:.4f} seconds") | |
| finally: | |
| db.close() | |
| if __name__ == "__main__": | |
| # Set up project root path | |
| project_root = Path(__file__).resolve().parents[1] | |
| sys.path.insert(0, str(project_root)) | |
| fire.Fire(search_memes) | |