File size: 3,760 Bytes
7e1f5f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import sys
from pathlib import Path

import gradio as gr
from omegaconf import OmegaConf
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker


def initialize_search_engines(db, config):
    """
    Initialize both BM25 and Semantic search engines.

    Args:
        db: Database session.
        config: Configuration dictionary.

    Returns:
        tuple: Initialized BM25Search and SemanticSearch engines.
    """
    from search.bm25_search import BM25Search
    from search.semantic_search import SemanticSearch
    from preprocessing.mystem_tokenizer import MystemTokenizer

    custom_tokenizer = MystemTokenizer()
    bm25_search = BM25Search(
        db,
        config['index_folders']['bm25'],
        custom_tokenizer.tokenize
    )
    semantic_search = SemanticSearch(
        db,
        model=config['semantic_search']['model'],
        embeddings_file=f"{config['index_folders']['semantic']}/embeddings.npy",
        prefix=config['semantic_search']['query_prefix'])

    return bm25_search, semantic_search


def search_memes(query: str, search_type: str, num_results: int):
    """
    Search for memes using the specified search method.

    Args:
        query (str): The search query.
        search_type (str): The type of search to perform. Either 'BM25' or 'Семантический'.
        num_results (int): The number of results to return.

    Returns:
        tuple: A tuple containing the search results and search time.
    """
    if search_type == "BM25":
        results = bm25_search.search(query, num_results)
    else:
        results = semantic_search.search(query, num_results)

    output = []
    for result in results['results']:
        output.append((result['image_url'], result['text']))

    return output, f"Время поиска: {results['search_time']:.4f} секунд"


def main():
    global bm25_search, semantic_search

    # Load configuration
    config = OmegaConf.load('config.yaml')
    config = OmegaConf.to_container(config)

    # Initialize database session
    engine = create_engine(config['database']['url'])
    SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
    db = SessionLocal()

    # Initialize search engines
    bm25_search, semantic_search = initialize_search_engines(db, config)

    # Gradio interface
    with gr.Blocks() as demo:
        gr.Markdown("# Поиск мемов")
        gr.Markdown(
            "Добро пожаловать в приложение для поиска мемов! Введите запрос, выберите тип поиска и количество результатов."
        )

        with gr.Row():
            query = gr.Textbox(label="Запрос")
            search_type = gr.Radio(
                ["BM25", "Семантический"],
                label="Тип поиска",
                value="BM25"
            )
            num_results = gr.Slider(
                minimum=1,
                maximum=10,
                step=1,
                value=1,
                label="Количество результатов"
            )

        search_button = gr.Button("Найти")

        output_gallery = gr.Gallery(
            label="Результаты",
            show_label=False,
            columns=3,
            height=400
        )
        output_time = gr.Markdown()

        search_button.click(
            fn=search_memes,
            inputs=[query, search_type, num_results],
            outputs=[output_gallery, output_time]
        )

    demo.launch()


if __name__ == "__main__":
    # Set up project root path
    project_root = Path(__file__).resolve().parents[1]
    sys.path.insert(0, str(project_root))
    main()