Spaces:
Sleeping
Sleeping
| import sys | |
| from pathlib import Path | |
| import gradio as gr | |
| from omegaconf import OmegaConf | |
| from sqlalchemy import create_engine | |
| from sqlalchemy.orm import sessionmaker | |
| def initialize_search_engines(db, config): | |
| """ | |
| Initialize both BM25 and Semantic search engines. | |
| Args: | |
| db: Database session. | |
| config: Configuration dictionary. | |
| Returns: | |
| tuple: Initialized BM25Search and SemanticSearch engines. | |
| """ | |
| from search.bm25_search import BM25Search | |
| from search.semantic_search import SemanticSearch | |
| from preprocessing.mystem_tokenizer import MystemTokenizer | |
| custom_tokenizer = MystemTokenizer() | |
| bm25_search = BM25Search( | |
| db, | |
| config['index_folders']['bm25'], | |
| custom_tokenizer.tokenize | |
| ) | |
| semantic_search = SemanticSearch( | |
| db, | |
| model=config['semantic_search']['model'], | |
| embeddings_file=f"{config['index_folders']['semantic']}/embeddings.npy", | |
| prefix=config['semantic_search']['query_prefix']) | |
| return bm25_search, semantic_search | |
| def search_memes(query: str, search_type: str, num_results: int): | |
| """ | |
| Search for memes using the specified search method. | |
| Args: | |
| query (str): The search query. | |
| search_type (str): The type of search to perform. Either 'BM25' or 'Семантический'. | |
| num_results (int): The number of results to return. | |
| Returns: | |
| tuple: A tuple containing the search results and search time. | |
| """ | |
| if search_type == "BM25": | |
| results = bm25_search.search(query, num_results) | |
| else: | |
| results = semantic_search.search(query, num_results) | |
| output = [] | |
| for result in results['results']: | |
| output.append((result['image_url'], result['text'])) | |
| return output, f"Время поиска: {results['search_time']:.4f} секунд" | |
| def main(): | |
| global bm25_search, semantic_search | |
| # Load configuration | |
| config = OmegaConf.load('config.yaml') | |
| config = OmegaConf.to_container(config) | |
| # Initialize database session | |
| engine = create_engine(config['database']['url']) | |
| SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | |
| db = SessionLocal() | |
| # Initialize search engines | |
| bm25_search, semantic_search = initialize_search_engines(db, config) | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Поиск мемов") | |
| gr.Markdown( | |
| "Добро пожаловать в приложение для поиска мемов! Введите запрос, выберите тип поиска и количество результатов." | |
| ) | |
| with gr.Row(): | |
| query = gr.Textbox(label="Запрос") | |
| search_type = gr.Radio( | |
| ["BM25", "Семантический"], | |
| label="Тип поиска", | |
| value="BM25" | |
| ) | |
| num_results = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| step=1, | |
| value=1, | |
| label="Количество результатов" | |
| ) | |
| search_button = gr.Button("Найти") | |
| output_gallery = gr.Gallery( | |
| label="Результаты", | |
| show_label=False, | |
| columns=3, | |
| height=400 | |
| ) | |
| output_time = gr.Markdown() | |
| search_button.click( | |
| fn=search_memes, | |
| inputs=[query, search_type, num_results], | |
| outputs=[output_gallery, output_time] | |
| ) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| # Set up project root path | |
| project_root = Path(__file__).resolve().parents[1] | |
| sys.path.insert(0, str(project_root)) | |
| main() | |