File size: 9,520 Bytes
09bc630
 
2a48bd3
5759868
09bc630
5759868
 
873ada4
5759868
 
 
8109cc7
 
2aaeb1b
5759868
 
873ada4
 
 
 
5759868
0ffefbc
5759868
873ada4
09bc630
 
 
 
 
b6d731b
09bc630
 
 
 
 
b6d731b
09bc630
 
 
 
 
b6d731b
09bc630
 
 
 
 
 
 
 
 
 
 
 
 
 
873ada4
5759868
 
09bc630
 
 
 
 
 
5759868
2aaeb1b
 
d061e47
2a48bd3
5759868
 
 
 
 
 
09bc630
4dd2f1d
09bc630
 
 
4dd2f1d
09bc630
 
873ada4
96830f5
0aa6d2c
4dd2f1d
873ada4
8109cc7
5759868
09bc630
 
4dd2f1d
09bc630
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4dd2f1d
09bc630
 
 
 
 
 
 
 
 
 
 
 
 
 
6a7ab41
4dd2f1d
 
09bc630
 
 
873ada4
5759868
873ada4
 
5759868
 
873ada4
09bc630
 
 
 
 
 
873ada4
09bc630
6a7ab41
2aaeb1b
873ada4
 
 
 
 
0ffefbc
3302068
873ada4
 
 
6a7ab41
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import hashlib
import pickle
from pathlib import Path

import numpy as np
import torch
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
import warnings
warnings.filterwarnings('ignore')

from get_documents import load_and_process_data
from parse_documents import process_documents
from lemmatizer import RussianLemmatizer


def normalize_array(arr):
    min_val = np.min(arr)
    max_val = np.max(arr)
    return (arr - min_val) / (max_val - min_val)


class Retrieval:
    """ 
    Структура хранения данных:
    ============================
    
    1. ДАТАФРЕЙМ ПАРАГРАФОВ (self.paragraphs_df):
       ┌──────────────────────┬─────────────────────────────────┐
       │ Колонка              │ Описание                        │
       ├──────────────────────┼─────────────────────────────────┤
       │ paragraph_id         │ Уникальный ID параграфа         │
       │ summary              │ Название документа/раздела      │
       │ start_year           │ Год начала периода              │
       │ end_year             │ Год окончания периода           │
       │ text                 │ Текст                           │
       │ document_id          │ Ссылка на исходный документ     │
       └──────────────────────┴─────────────────────────────────┘
    
    2. ДАТАФРЕЙМ ЧАНКОВ (self.chunks_df):
       ┌──────────────────────┬─────────────────────────────────┐
       │ Колонка              │ Описание                        │
       ├──────────────────────┼─────────────────────────────────┤
       │ chunk_id             │ Уникальный ID чанка             │
       │ paragraph_id         │ Foreign key на параграф         │
       │ text                 │ Исходный текст чанка            │
       │ lemmatized_text      │ Лемматизированный текст         │
       │ (embeddings)         │ (будет добавлено в будущем)     │
       └──────────────────────┴─────────────────────────────────┘
    
    3. ОБЪЕДИНЁННЫЙ ДАТАФРЕЙМ (get_merged_data()):
       Комбинирует оба датафрейма через JOIN по paragraph_id.
       Содержит все колонки обоих датафреймов.
       Используется для поиска и фильтрации.
    """
    
    def __init__(self, use_gpu: bool = False, use_cache: bool = True):
        print("Инициализация RAG системы...")
        self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
        self.use_cache = use_cache
        
        # Путь к кэшу
        self.cache_dir = Path('.cache')
        if self.use_cache:
            self.cache_dir.mkdir(exist_ok=True)
        
        # Инициализация лемматизатора для русского языка
        print("   Инициализация лемматизатора...")
        self.lemmatizer = RussianLemmatizer()
        
        # Загружаем и обрабатываем данные
        print("1. Загрузка данных из JSON...")
        self.documents, self.docs_names = load_and_process_data()
        # self.documents after this phase: list of {'text': str, 'date': str}
        print(f"   Загружено {len(self.documents)} сообщений")
        
        # Парсим даты из документов и создаем датафреймы
        self.paragraphs_df, self.chunks_df = process_documents(self.documents)
        
        # Добавляем лемматизированный текст в датафрейм чанков с кэшем
        print("2. Лемматизация текстов (с кэшированием)...")
        self.chunks_df['lemmatized_text'] = self._lemmatize_with_cache(self.chunks_df['text'])
        
        # Инициализируем CrossEncoder
        # self.cross_encoder = CrossEncoder('DiTy/cross-encoder-russian-msmarco')
        self.embedder = SentenceTransformer('cointegrated/rubert-tiny2', cache_folder="/tmp")
        # TODO: кэшировать эмбеддинги!
        self.embeddings_of_summary = self.embedder.encode(self.paragraphs_df['summary'], convert_to_tensor=True)

        print("RAG система готова к использованию")


    # ============ Методы кэширования лемматизации ============

    def _load_cache(self) -> dict:
        """
        Загружает кэш лемматизации из файловой системы.
        
        Returns:
            dict: {text_hash -> lemmatized_tokens}
        """
        cache_file = self.cache_dir / 'lemmatization_cache.pkl'
        
        if cache_file.exists():
            try:
                with open(cache_file, 'rb') as f:
                    cache = pickle.load(f)
                print(f"   ✓ Кэш загружен ({len(cache)} записей)")
                return cache
            except Exception as e:
                print(f"   ⚠ Ошибка при загрузке кэша: {e}")
                return {}
        return {}
    
    def _lemmatize_with_cache(self, texts: list[str]) -> list:
        """
        Лемматизирует тексты с использованием кэша.
        Проверяет хэши текстов - если хэш совпадает с кэшированным,
        использует кэшированный результат. Иначе перелемматизирует.
        
        Args:
            texts: Список текстов для лемматизации
            
        Returns:
            list: Лемматизированные тексты
        """
        if not self.use_cache:
            # Если кэш отключен, просто лемматизировать
            return [self.lemmatizer.tokenize_text(text) for text in texts]
        
        # Загружаем существующий кэш
        cache = self._load_cache()
        text_hashes = {}
        results = []
        needs_save = False
        
        for text in texts:
            text_hash = hashlib.sha256(text.encode('utf-8')).hexdigest()
            text_hashes[text] = text_hash
            
            if text_hash in cache:
                # Используем кэшированный результат
                results.append(cache[text_hash])
            else:
                # Лемматизируем и добавляем в кэш
                lemmatized = self.lemmatizer.tokenize_text(text)
                results.append(lemmatized)
                cache[text_hash] = lemmatized
                needs_save = True
        
        # Сохраняем кэш если были новые записи
        if needs_save:
            with open(self.cache_dir / 'lemmatization_cache.pkl', 'wb') as f:
                pickle.dump(cache, f)
            print(f"   ✓ Кэш сохранён ({len(cache)} записей)")
        
        return results
    
    def semantic_search(self, query: str) -> torch.Tensor:    
        # 1. Семантический поиск
        query_embedding = torch.tensor(self.embedder.encode_query(query))
        semantic_scores = torch.nn.functional.cosine_similarity(self.embeddings_of_summary, query_embedding, eps=1e-8).cpu()
        return semantic_scores
    
    def bm25_search(self, query: str) -> np.ndarray:
        """BM25 поиск, используя лемматизированные чанки.
        
        Args:
            query: Текст запроса
            
        Returns:
            np.ndarray: Скоры для каждого абзаца (не предложения!)
        """
        bm25 = BM25Okapi(self.chunks_df['lemmatized_text'])
        tokenized_query = self.lemmatizer.tokenize_text(query)
        sentences_scores = bm25.get_scores(tokenized_query)
        df = self.chunks_df['paragraph_id'].to_frame().copy()
        df['score'] = sentences_scores
        paragraph_scores = df.groupby('paragraph_id')['score'].max().reindex(self.paragraphs_df['paragraph_id']).fillna(0)
        return paragraph_scores
    
    def search(self, query: str) -> None:
        bm25_scores = self.bm25_search(query)        
        semantic_scores = self.semantic_search(query).numpy()
        bm25_scores = normalize_array(bm25_scores)
        return semantic_scores + 1.0 * bm25_scores