File size: 5,867 Bytes
c7e0ef5
 
 
 
7a0766b
c7e0ef5
 
 
 
dfee524
c7e0ef5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c92162
c7e0ef5
8c92162
 
c7e0ef5
 
 
 
 
 
 
 
 
 
7a0766b
c7e0ef5
 
7a0766b
 
 
9541a9b
c7e0ef5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfee524
 
 
 
c7e0ef5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a0766b
 
c7e0ef5
 
 
 
7a0766b
c7e0ef5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c92162
 
 
 
c7e0ef5
 
 
 
 
8c92162
c7e0ef5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import re
from dataclasses import dataclass
from typing import List, Optional, Set

from dotenv import load_dotenv
import nltk
from nltk.corpus import stopwords


class RegexQueryPreprocessor:
    """Preprocesses search queries by removing common patterns and standardizing format."""

    @dataclass
    class QueryPattern:
        """Represents a query pattern with its regex and replacement."""

        pattern: str
        replacement: str = ""
        description: str = ""

    def __init__(self, remove_stopwords: bool = True) -> None:
        # Download required NLTK data if not already present
        try:
            nltk.data.find("corpora/stopwords")
        except LookupError:
            nltk.download("stopwords")

        self._remove_stopwords = remove_stopwords
        self._stopwords = set(stopwords.words("russian"))

        # Add custom Russian stopwords
        # fmt: off
        self._custom_stopwords = {
            "разные", "какие", "когда",
            "который", "которой", "которая", "которые", "был", "была", "были",
            "также", "именно", "либо", "или", "где", "как", "какой", "какая",
            "быть", "есть", "это", "эта", "эти", "для", "при", "про"
        }
        self._stopwords.update(self._custom_stopwords)
        # fmt: on

        # Define query patterns
        self._patterns = {
            "presentation_patterns": [
                self.QueryPattern(
                    r"^в какой презентации (?:был[аи]?|рассматривали?|говорили?|обсуждали?|показывали?|рассказывали?|перечисляли?) ",
                ),
                self.QueryPattern(
                    r"^в презентации (?:был[аио]?|рассматривал?|говорил?|обсуждал?|показывал?|сравнивал?)(?:и?|ась|ось|а) ",
                ),
                self.QueryPattern(
                    r"^презентаци(?:я|и) (?:про|с|в которой|где|со?) ",
                ),
            ],
            "slide_patterns": [
                self.QueryPattern(
                    r"(?:на )?слайд(?:е|ы)? (?:с|был[аи]?|про|где) ",
                ),
                # self.QueryPattern(
                #     r"слайд(?:ы)? с заголовк(?:ом|ами) ",
                # ),
            ],
            "question_patterns": [
                self.QueryPattern(
                    r"^где (?:был[аи]?|обсуждали?|говорили про) ",
                ),
                self.QueryPattern(
                    r"^о чем (?:рассказывал[аи]?|говорил[аи]?) ",
                ),
            ],
        }

        # Compile patterns
        self._compiled_patterns = {}
        for category, patterns in self._patterns.items():
            self._compiled_patterns[category] = [
                re.compile(p.pattern, re.IGNORECASE) for p in patterns
            ]

    @property
    def id(self):
        return self.__class__.__name__

    def remove_stopwords_from_text(self, text: str) -> str:
        """Remove stopwords while preserving protected terms."""
        tokens = text.split()
        filtered_tokens = [
            token for token in tokens if token.lower() not in self._stopwords
        ]
        return " ".join(filtered_tokens)

    def clean_query(self, query: str) -> str:
        """
        Remove common patterns, stopwords, and standardize the query.

        Args:
            query: Input search query

        Returns:
            Cleaned query with removed patterns and standardized format
        """
        # Convert to lowercase ? and remove punctuation
        # query = query.lower().strip()
        query = query.strip()
        query = re.sub(r"[?,!.]", "", query)

        # Apply all pattern categories
        for category, patterns in self._compiled_patterns.items():
            for pattern in patterns:
                query = pattern.sub("", query)

        # Remove extra spaces
        query = re.sub(r"\s+", " ", query).strip()

        # Remove stopwords if enabled
        if self._remove_stopwords:
            query = self.remove_stopwords_from_text(query)

        return query

    def __call__(self, query, *args, **kwargs):
        return self.clean_query(query, *args, **kwargs)


if __name__ == "__main__":
    from typing import List, Union

    import fire

    load_dotenv()

    class CLI:
        """Command line interface for QueryPreprocessor."""

        def __init__(self):
            self.preprocessor = RegexQueryPreprocessor()

        def clean(self, *queries: str, remove_stopwords: bool = True) -> None:
            """
            Clean queries and show original->cleaned pairs.

            Args:
                queries: Single query string or list of queries
                remove_stopwords: Whether to remove stopwords
            """
            self.preprocessor._remove_stopwords = remove_stopwords

            # Process each query
            print("Original -> Cleaned")
            print("-" * 50)
            for query in queries:
                cleaned = self.preprocessor.clean_query(query)
                print(f"{query} -> \033[94m{cleaned} \033[0m")

        def clean_gsheets(
            self,
            sheet_id: Optional[str] = None,
            gid: Optional[str] = None,
            remove_stopwords: bool = True,
        ):
            from src.config.spreadsheets import load_spreadsheet

            df = load_spreadsheet(sheet_id, gid)
            questions = df["question"]
            return self.clean(*questions, remove_stopwords=remove_stopwords)

    # Start CLI
    fire.Fire(CLI)