| | import hashlib |
| | import pickle |
| | from pathlib import Path |
| |
|
| | import numpy as np |
| | import torch |
| | from rank_bm25 import BM25Okapi |
| | from sentence_transformers import SentenceTransformer |
| | import warnings |
| | warnings.filterwarnings('ignore') |
| |
|
| | from get_documents import load_and_process_data |
| | from parse_documents import process_documents |
| | from lemmatizer import RussianLemmatizer |
| |
|
| |
|
| | def normalize_array(arr): |
| | min_val = np.min(arr) |
| | max_val = np.max(arr) |
| | return (arr - min_val) / (max_val - min_val) |
| |
|
| |
|
| | class Retrieval: |
| | """ |
| | Структура хранения данных: |
| | ============================ |
| | |
| | 1. ДАТАФРЕЙМ ПАРАГРАФОВ (self.paragraphs_df): |
| | ┌──────────────────────┬─────────────────────────────────┐ |
| | │ Колонка │ Описание │ |
| | ├──────────────────────┼─────────────────────────────────┤ |
| | │ paragraph_id │ Уникальный ID параграфа │ |
| | │ summary │ Название документа/раздела │ |
| | │ start_year │ Год начала периода │ |
| | │ end_year │ Год окончания периода │ |
| | │ text │ Текст │ |
| | │ document_id │ Ссылка на исходный документ │ |
| | └──────────────────────┴─────────────────────────────────┘ |
| | |
| | 2. ДАТАФРЕЙМ ЧАНКОВ (self.chunks_df): |
| | ┌──────────────────────┬─────────────────────────────────┐ |
| | │ Колонка │ Описание │ |
| | ├──────────────────────┼─────────────────────────────────┤ |
| | │ chunk_id │ Уникальный ID чанка │ |
| | │ paragraph_id │ Foreign key на параграф │ |
| | │ text │ Исходный текст чанка │ |
| | │ lemmatized_text │ Лемматизированный текст │ |
| | │ (embeddings) │ (будет добавлено в будущем) │ |
| | └──────────────────────┴─────────────────────────────────┘ |
| | |
| | 3. ОБЪЕДИНЁННЫЙ ДАТАФРЕЙМ (get_merged_data()): |
| | Комбинирует оба датафрейма через JOIN по paragraph_id. |
| | Содержит все колонки обоих датафреймов. |
| | Используется для поиска и фильтрации. |
| | """ |
| | |
| | def __init__(self, use_gpu: bool = False, use_cache: bool = True): |
| | print("Инициализация RAG системы...") |
| | self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu" |
| | self.use_cache = use_cache |
| | |
| | |
| | self.cache_dir = Path('.cache') |
| | if self.use_cache: |
| | self.cache_dir.mkdir(exist_ok=True) |
| | |
| | |
| | print(" Инициализация лемматизатора...") |
| | self.lemmatizer = RussianLemmatizer() |
| | |
| | |
| | print("1. Загрузка данных из JSON...") |
| | self.documents, self.docs_names = load_and_process_data() |
| | |
| | print(f" Загружено {len(self.documents)} сообщений") |
| | |
| | |
| | self.paragraphs_df, self.chunks_df = process_documents(self.documents) |
| | |
| | |
| | print("2. Лемматизация текстов (с кэшированием)...") |
| | self.chunks_df['lemmatized_text'] = self._lemmatize_with_cache(self.chunks_df['text']) |
| | |
| | |
| | |
| | self.embedder = SentenceTransformer('cointegrated/rubert-tiny2', cache_folder="/tmp") |
| | |
| | self.embeddings_of_summary = self.embedder.encode(self.paragraphs_df['summary'], convert_to_tensor=True) |
| |
|
| | print("RAG система готова к использованию") |
| |
|
| |
|
| | |
| |
|
| | def _load_cache(self) -> dict: |
| | """ |
| | Загружает кэш лемматизации из файловой системы. |
| | |
| | Returns: |
| | dict: {text_hash -> lemmatized_tokens} |
| | """ |
| | cache_file = self.cache_dir / 'lemmatization_cache.pkl' |
| | |
| | if cache_file.exists(): |
| | try: |
| | with open(cache_file, 'rb') as f: |
| | cache = pickle.load(f) |
| | print(f" ✓ Кэш загружен ({len(cache)} записей)") |
| | return cache |
| | except Exception as e: |
| | print(f" ⚠ Ошибка при загрузке кэша: {e}") |
| | return {} |
| | return {} |
| | |
| | def _lemmatize_with_cache(self, texts: list[str]) -> list: |
| | """ |
| | Лемматизирует тексты с использованием кэша. |
| | Проверяет хэши текстов - если хэш совпадает с кэшированным, |
| | использует кэшированный результат. Иначе перелемматизирует. |
| | |
| | Args: |
| | texts: Список текстов для лемматизации |
| | |
| | Returns: |
| | list: Лемматизированные тексты |
| | """ |
| | if not self.use_cache: |
| | |
| | return [self.lemmatizer.tokenize_text(text) for text in texts] |
| | |
| | |
| | cache = self._load_cache() |
| | text_hashes = {} |
| | results = [] |
| | needs_save = False |
| | |
| | for text in texts: |
| | text_hash = hashlib.sha256(text.encode('utf-8')).hexdigest() |
| | text_hashes[text] = text_hash |
| | |
| | if text_hash in cache: |
| | |
| | results.append(cache[text_hash]) |
| | else: |
| | |
| | lemmatized = self.lemmatizer.tokenize_text(text) |
| | results.append(lemmatized) |
| | cache[text_hash] = lemmatized |
| | needs_save = True |
| | |
| | |
| | if needs_save: |
| | with open(self.cache_dir / 'lemmatization_cache.pkl', 'wb') as f: |
| | pickle.dump(cache, f) |
| | print(f" ✓ Кэш сохранён ({len(cache)} записей)") |
| | |
| | return results |
| | |
| | def semantic_search(self, query: str) -> torch.Tensor: |
| | |
| | query_embedding = torch.tensor(self.embedder.encode_query(query)) |
| | semantic_scores = torch.nn.functional.cosine_similarity(self.embeddings_of_summary, query_embedding, eps=1e-8).cpu() |
| | return semantic_scores |
| | |
| | def bm25_search(self, query: str) -> np.ndarray: |
| | """BM25 поиск, используя лемматизированные чанки. |
| | |
| | Args: |
| | query: Текст запроса |
| | |
| | Returns: |
| | np.ndarray: Скоры для каждого абзаца (не предложения!) |
| | """ |
| | bm25 = BM25Okapi(self.chunks_df['lemmatized_text']) |
| | tokenized_query = self.lemmatizer.tokenize_text(query) |
| | sentences_scores = bm25.get_scores(tokenized_query) |
| | df = self.chunks_df['paragraph_id'].to_frame().copy() |
| | df['score'] = sentences_scores |
| | paragraph_scores = df.groupby('paragraph_id')['score'].max().reindex(self.paragraphs_df['paragraph_id']).fillna(0) |
| | return paragraph_scores |
| | |
| | def search(self, query: str) -> None: |
| | bm25_scores = self.bm25_search(query) |
| | semantic_scores = self.semantic_search(query).numpy() |
| | bm25_scores = normalize_array(bm25_scores) |
| | return semantic_scores + 1.0 * bm25_scores |