File size: 9,520 Bytes
09bc630 2a48bd3 5759868 09bc630 5759868 873ada4 5759868 8109cc7 2aaeb1b 5759868 873ada4 5759868 0ffefbc 5759868 873ada4 09bc630 b6d731b 09bc630 b6d731b 09bc630 b6d731b 09bc630 873ada4 5759868 09bc630 5759868 2aaeb1b d061e47 2a48bd3 5759868 09bc630 4dd2f1d 09bc630 4dd2f1d 09bc630 873ada4 96830f5 0aa6d2c 4dd2f1d 873ada4 8109cc7 5759868 09bc630 4dd2f1d 09bc630 4dd2f1d 09bc630 6a7ab41 4dd2f1d 09bc630 873ada4 5759868 873ada4 5759868 873ada4 09bc630 873ada4 09bc630 6a7ab41 2aaeb1b 873ada4 0ffefbc 3302068 873ada4 6a7ab41 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 | import hashlib
import pickle
from pathlib import Path
import numpy as np
import torch
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
import warnings
warnings.filterwarnings('ignore')
from get_documents import load_and_process_data
from parse_documents import process_documents
from lemmatizer import RussianLemmatizer
def normalize_array(arr):
min_val = np.min(arr)
max_val = np.max(arr)
return (arr - min_val) / (max_val - min_val)
class Retrieval:
"""
Структура хранения данных:
============================
1. ДАТАФРЕЙМ ПАРАГРАФОВ (self.paragraphs_df):
┌──────────────────────┬─────────────────────────────────┐
│ Колонка │ Описание │
├──────────────────────┼─────────────────────────────────┤
│ paragraph_id │ Уникальный ID параграфа │
│ summary │ Название документа/раздела │
│ start_year │ Год начала периода │
│ end_year │ Год окончания периода │
│ text │ Текст │
│ document_id │ Ссылка на исходный документ │
└──────────────────────┴─────────────────────────────────┘
2. ДАТАФРЕЙМ ЧАНКОВ (self.chunks_df):
┌──────────────────────┬─────────────────────────────────┐
│ Колонка │ Описание │
├──────────────────────┼─────────────────────────────────┤
│ chunk_id │ Уникальный ID чанка │
│ paragraph_id │ Foreign key на параграф │
│ text │ Исходный текст чанка │
│ lemmatized_text │ Лемматизированный текст │
│ (embeddings) │ (будет добавлено в будущем) │
└──────────────────────┴─────────────────────────────────┘
3. ОБЪЕДИНЁННЫЙ ДАТАФРЕЙМ (get_merged_data()):
Комбинирует оба датафрейма через JOIN по paragraph_id.
Содержит все колонки обоих датафреймов.
Используется для поиска и фильтрации.
"""
def __init__(self, use_gpu: bool = False, use_cache: bool = True):
print("Инициализация RAG системы...")
self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
self.use_cache = use_cache
# Путь к кэшу
self.cache_dir = Path('.cache')
if self.use_cache:
self.cache_dir.mkdir(exist_ok=True)
# Инициализация лемматизатора для русского языка
print(" Инициализация лемматизатора...")
self.lemmatizer = RussianLemmatizer()
# Загружаем и обрабатываем данные
print("1. Загрузка данных из JSON...")
self.documents, self.docs_names = load_and_process_data()
# self.documents after this phase: list of {'text': str, 'date': str}
print(f" Загружено {len(self.documents)} сообщений")
# Парсим даты из документов и создаем датафреймы
self.paragraphs_df, self.chunks_df = process_documents(self.documents)
# Добавляем лемматизированный текст в датафрейм чанков с кэшем
print("2. Лемматизация текстов (с кэшированием)...")
self.chunks_df['lemmatized_text'] = self._lemmatize_with_cache(self.chunks_df['text'])
# Инициализируем CrossEncoder
# self.cross_encoder = CrossEncoder('DiTy/cross-encoder-russian-msmarco')
self.embedder = SentenceTransformer('cointegrated/rubert-tiny2', cache_folder="/tmp")
# TODO: кэшировать эмбеддинги!
self.embeddings_of_summary = self.embedder.encode(self.paragraphs_df['summary'], convert_to_tensor=True)
print("RAG система готова к использованию")
# ============ Методы кэширования лемматизации ============
def _load_cache(self) -> dict:
"""
Загружает кэш лемматизации из файловой системы.
Returns:
dict: {text_hash -> lemmatized_tokens}
"""
cache_file = self.cache_dir / 'lemmatization_cache.pkl'
if cache_file.exists():
try:
with open(cache_file, 'rb') as f:
cache = pickle.load(f)
print(f" ✓ Кэш загружен ({len(cache)} записей)")
return cache
except Exception as e:
print(f" ⚠ Ошибка при загрузке кэша: {e}")
return {}
return {}
def _lemmatize_with_cache(self, texts: list[str]) -> list:
"""
Лемматизирует тексты с использованием кэша.
Проверяет хэши текстов - если хэш совпадает с кэшированным,
использует кэшированный результат. Иначе перелемматизирует.
Args:
texts: Список текстов для лемматизации
Returns:
list: Лемматизированные тексты
"""
if not self.use_cache:
# Если кэш отключен, просто лемматизировать
return [self.lemmatizer.tokenize_text(text) for text in texts]
# Загружаем существующий кэш
cache = self._load_cache()
text_hashes = {}
results = []
needs_save = False
for text in texts:
text_hash = hashlib.sha256(text.encode('utf-8')).hexdigest()
text_hashes[text] = text_hash
if text_hash in cache:
# Используем кэшированный результат
results.append(cache[text_hash])
else:
# Лемматизируем и добавляем в кэш
lemmatized = self.lemmatizer.tokenize_text(text)
results.append(lemmatized)
cache[text_hash] = lemmatized
needs_save = True
# Сохраняем кэш если были новые записи
if needs_save:
with open(self.cache_dir / 'lemmatization_cache.pkl', 'wb') as f:
pickle.dump(cache, f)
print(f" ✓ Кэш сохранён ({len(cache)} записей)")
return results
def semantic_search(self, query: str) -> torch.Tensor:
# 1. Семантический поиск
query_embedding = torch.tensor(self.embedder.encode_query(query))
semantic_scores = torch.nn.functional.cosine_similarity(self.embeddings_of_summary, query_embedding, eps=1e-8).cpu()
return semantic_scores
def bm25_search(self, query: str) -> np.ndarray:
"""BM25 поиск, используя лемматизированные чанки.
Args:
query: Текст запроса
Returns:
np.ndarray: Скоры для каждого абзаца (не предложения!)
"""
bm25 = BM25Okapi(self.chunks_df['lemmatized_text'])
tokenized_query = self.lemmatizer.tokenize_text(query)
sentences_scores = bm25.get_scores(tokenized_query)
df = self.chunks_df['paragraph_id'].to_frame().copy()
df['score'] = sentences_scores
paragraph_scores = df.groupby('paragraph_id')['score'].max().reindex(self.paragraphs_df['paragraph_id']).fillna(0)
return paragraph_scores
def search(self, query: str) -> None:
bm25_scores = self.bm25_search(query)
semantic_scores = self.semantic_search(query).numpy()
bm25_scores = normalize_array(bm25_scores)
return semantic_scores + 1.0 * bm25_scores |