File size: 2,963 Bytes
7e1f5f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import sys
from pathlib import Path

import fire
from omegaconf import OmegaConf
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker


def initialize_bm25_search(db, config):
    """
    Initialize BM25 search engine.

    Args:
        db: Database session.
        config: Configuration dictionary.

    Returns:
        BM25Search: Initialized BM25 search engine.
    """
    from search.bm25_search import BM25Search
    from preprocessing.mystem_tokenizer import MystemTokenizer

    custom_tokenizer = MystemTokenizer()
    return BM25Search(
        db,
        config['index_folders']['bm25'],
        custom_tokenizer.tokenize
    )


def initialize_semantic_search(db, config):
    """
    Initialize semantic search engine.

    Args:
        db: Database session.
        config: Configuration dictionary.

    Returns:
        SemanticSearch: Initialized semantic search engine.
    """
    from search.semantic_search import SemanticSearch
    return SemanticSearch(
        db,
        model=config['semantic_search']['model'],
        embeddings_file=f"{config['index_folders']['semantic']}/embeddings.npy",
        prefix=config['semantic_search']['query_prefix'])


def search_memes(query: str, search_type: str = 'bm25', num: int = 1):
    """
    Search for memes using the specified search method.

    Args:
        query (str): The search query.
        search_type (str): The type of search to perform. Either 'bm25' or 'semantic'. Defaults to 'bm25s'.
        num (int): The number of results to return. Defaults to 1.

    Returns:
        None: Prints the results to the console.
    """
    if not query:
        print("Error: Query is required.")
        return
    if search_type not in ['bm25', 'semantic']:
        print("Error: Invalid search type. Use 'bm25' or 'semantic'.")
        return
    if num < 1:
        print("Error: Number of results must be at least 1.")
        return

    # Load configuration
    config = OmegaConf.load('config.yaml')
    config = OmegaConf.to_container(config)

    # Initialize database session
    engine = create_engine(config['database']['url'])
    SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
    db = SessionLocal()

    try:
        # Initialize search engine
        if search_type == 'bm25':
            search_engine = initialize_bm25_search(db, config)
        elif search_type == 'semantic':
            search_engine = initialize_semantic_search(db, config)

        # Perform search
        results = search_engine.search(query, num)

        # Print results
        for result in results['results']:
            print(result['text'])
        print(f"\nSearch time: {results['search_time']:.4f} seconds")
    finally:
        db.close()


if __name__ == "__main__":
    # Set up project root path
    project_root = Path(__file__).resolve().parents[1]
    sys.path.insert(0, str(project_root))
    fire.Fire(search_memes)