import hashlib import pickle from pathlib import Path import numpy as np import torch from rank_bm25 import BM25Okapi from sentence_transformers import SentenceTransformer import warnings warnings.filterwarnings('ignore') from get_documents import load_and_process_data from parse_documents import process_documents from lemmatizer import RussianLemmatizer def normalize_array(arr): min_val = np.min(arr) max_val = np.max(arr) return (arr - min_val) / (max_val - min_val) class Retrieval: """ Структура хранения данных: ============================ 1. ДАТАФРЕЙМ ПАРАГРАФОВ (self.paragraphs_df): ┌──────────────────────┬─────────────────────────────────┐ │ Колонка │ Описание │ ├──────────────────────┼─────────────────────────────────┤ │ paragraph_id │ Уникальный ID параграфа │ │ summary │ Название документа/раздела │ │ start_year │ Год начала периода │ │ end_year │ Год окончания периода │ │ text │ Текст │ │ document_id │ Ссылка на исходный документ │ └──────────────────────┴─────────────────────────────────┘ 2. ДАТАФРЕЙМ ЧАНКОВ (self.chunks_df): ┌──────────────────────┬─────────────────────────────────┐ │ Колонка │ Описание │ ├──────────────────────┼─────────────────────────────────┤ │ chunk_id │ Уникальный ID чанка │ │ paragraph_id │ Foreign key на параграф │ │ text │ Исходный текст чанка │ │ lemmatized_text │ Лемматизированный текст │ │ (embeddings) │ (будет добавлено в будущем) │ └──────────────────────┴─────────────────────────────────┘ 3. ОБЪЕДИНЁННЫЙ ДАТАФРЕЙМ (get_merged_data()): Комбинирует оба датафрейма через JOIN по paragraph_id. Содержит все колонки обоих датафреймов. Используется для поиска и фильтрации. """ def __init__(self, use_gpu: bool = False, use_cache: bool = True): print("Инициализация RAG системы...") self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu" self.use_cache = use_cache # Путь к кэшу self.cache_dir = Path('.cache') if self.use_cache: self.cache_dir.mkdir(exist_ok=True) # Инициализация лемматизатора для русского языка print(" Инициализация лемматизатора...") self.lemmatizer = RussianLemmatizer() # Загружаем и обрабатываем данные print("1. Загрузка данных из JSON...") self.documents, self.docs_names = load_and_process_data() # self.documents after this phase: list of {'text': str, 'date': str} print(f" Загружено {len(self.documents)} сообщений") # Парсим даты из документов и создаем датафреймы self.paragraphs_df, self.chunks_df = process_documents(self.documents) # Добавляем лемматизированный текст в датафрейм чанков с кэшем print("2. Лемматизация текстов (с кэшированием)...") self.chunks_df['lemmatized_text'] = self._lemmatize_with_cache(self.chunks_df['text']) # Инициализируем CrossEncoder # self.cross_encoder = CrossEncoder('DiTy/cross-encoder-russian-msmarco') self.embedder = SentenceTransformer('cointegrated/rubert-tiny2', cache_folder="/tmp") # TODO: кэшировать эмбеддинги! self.embeddings_of_summary = self.embedder.encode(self.paragraphs_df['summary'], convert_to_tensor=True) print("RAG система готова к использованию") # ============ Методы кэширования лемматизации ============ def _load_cache(self) -> dict: """ Загружает кэш лемматизации из файловой системы. Returns: dict: {text_hash -> lemmatized_tokens} """ cache_file = self.cache_dir / 'lemmatization_cache.pkl' if cache_file.exists(): try: with open(cache_file, 'rb') as f: cache = pickle.load(f) print(f" ✓ Кэш загружен ({len(cache)} записей)") return cache except Exception as e: print(f" ⚠ Ошибка при загрузке кэша: {e}") return {} return {} def _lemmatize_with_cache(self, texts: list[str]) -> list: """ Лемматизирует тексты с использованием кэша. Проверяет хэши текстов - если хэш совпадает с кэшированным, использует кэшированный результат. Иначе перелемматизирует. Args: texts: Список текстов для лемматизации Returns: list: Лемматизированные тексты """ if not self.use_cache: # Если кэш отключен, просто лемматизировать return [self.lemmatizer.tokenize_text(text) for text in texts] # Загружаем существующий кэш cache = self._load_cache() text_hashes = {} results = [] needs_save = False for text in texts: text_hash = hashlib.sha256(text.encode('utf-8')).hexdigest() text_hashes[text] = text_hash if text_hash in cache: # Используем кэшированный результат results.append(cache[text_hash]) else: # Лемматизируем и добавляем в кэш lemmatized = self.lemmatizer.tokenize_text(text) results.append(lemmatized) cache[text_hash] = lemmatized needs_save = True # Сохраняем кэш если были новые записи if needs_save: with open(self.cache_dir / 'lemmatization_cache.pkl', 'wb') as f: pickle.dump(cache, f) print(f" ✓ Кэш сохранён ({len(cache)} записей)") return results def semantic_search(self, query: str) -> torch.Tensor: # 1. Семантический поиск query_embedding = torch.tensor(self.embedder.encode_query(query)) semantic_scores = torch.nn.functional.cosine_similarity(self.embeddings_of_summary, query_embedding, eps=1e-8).cpu() return semantic_scores def bm25_search(self, query: str) -> np.ndarray: """BM25 поиск, используя лемматизированные чанки. Args: query: Текст запроса Returns: np.ndarray: Скоры для каждого абзаца (не предложения!) """ bm25 = BM25Okapi(self.chunks_df['lemmatized_text']) tokenized_query = self.lemmatizer.tokenize_text(query) sentences_scores = bm25.get_scores(tokenized_query) df = self.chunks_df['paragraph_id'].to_frame().copy() df['score'] = sentences_scores paragraph_scores = df.groupby('paragraph_id')['score'].max().reindex(self.paragraphs_df['paragraph_id']).fillna(0) return paragraph_scores def search(self, query: str) -> None: bm25_scores = self.bm25_search(query) semantic_scores = self.semantic_search(query).numpy() bm25_scores = normalize_array(bm25_scores) return semantic_scores + 1.0 * bm25_scores