RAG2 / retrieval.py
antimoda1
add TODO
0aa6d2c
import hashlib
import pickle
from pathlib import Path
import numpy as np
import torch
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
import warnings
warnings.filterwarnings('ignore')
from get_documents import load_and_process_data
from parse_documents import process_documents
from lemmatizer import RussianLemmatizer
def normalize_array(arr):
min_val = np.min(arr)
max_val = np.max(arr)
return (arr - min_val) / (max_val - min_val)
class Retrieval:
"""
Структура хранения данных:
============================
1. ДАТАФРЕЙМ ПАРАГРАФОВ (self.paragraphs_df):
┌──────────────────────┬─────────────────────────────────┐
│ Колонка │ Описание │
├──────────────────────┼─────────────────────────────────┤
│ paragraph_id │ Уникальный ID параграфа │
│ summary │ Название документа/раздела │
│ start_year │ Год начала периода │
│ end_year │ Год окончания периода │
│ text │ Текст │
│ document_id │ Ссылка на исходный документ │
└──────────────────────┴─────────────────────────────────┘
2. ДАТАФРЕЙМ ЧАНКОВ (self.chunks_df):
┌──────────────────────┬─────────────────────────────────┐
│ Колонка │ Описание │
├──────────────────────┼─────────────────────────────────┤
│ chunk_id │ Уникальный ID чанка │
│ paragraph_id │ Foreign key на параграф │
│ text │ Исходный текст чанка │
│ lemmatized_text │ Лемматизированный текст │
│ (embeddings) │ (будет добавлено в будущем) │
└──────────────────────┴─────────────────────────────────┘
3. ОБЪЕДИНЁННЫЙ ДАТАФРЕЙМ (get_merged_data()):
Комбинирует оба датафрейма через JOIN по paragraph_id.
Содержит все колонки обоих датафреймов.
Используется для поиска и фильтрации.
"""
def __init__(self, use_gpu: bool = False, use_cache: bool = True):
print("Инициализация RAG системы...")
self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
self.use_cache = use_cache
# Путь к кэшу
self.cache_dir = Path('.cache')
if self.use_cache:
self.cache_dir.mkdir(exist_ok=True)
# Инициализация лемматизатора для русского языка
print(" Инициализация лемматизатора...")
self.lemmatizer = RussianLemmatizer()
# Загружаем и обрабатываем данные
print("1. Загрузка данных из JSON...")
self.documents, self.docs_names = load_and_process_data()
# self.documents after this phase: list of {'text': str, 'date': str}
print(f" Загружено {len(self.documents)} сообщений")
# Парсим даты из документов и создаем датафреймы
self.paragraphs_df, self.chunks_df = process_documents(self.documents)
# Добавляем лемматизированный текст в датафрейм чанков с кэшем
print("2. Лемматизация текстов (с кэшированием)...")
self.chunks_df['lemmatized_text'] = self._lemmatize_with_cache(self.chunks_df['text'])
# Инициализируем CrossEncoder
# self.cross_encoder = CrossEncoder('DiTy/cross-encoder-russian-msmarco')
self.embedder = SentenceTransformer('cointegrated/rubert-tiny2', cache_folder="/tmp")
# TODO: кэшировать эмбеддинги!
self.embeddings_of_summary = self.embedder.encode(self.paragraphs_df['summary'], convert_to_tensor=True)
print("RAG система готова к использованию")
# ============ Методы кэширования лемматизации ============
def _load_cache(self) -> dict:
"""
Загружает кэш лемматизации из файловой системы.
Returns:
dict: {text_hash -> lemmatized_tokens}
"""
cache_file = self.cache_dir / 'lemmatization_cache.pkl'
if cache_file.exists():
try:
with open(cache_file, 'rb') as f:
cache = pickle.load(f)
print(f" ✓ Кэш загружен ({len(cache)} записей)")
return cache
except Exception as e:
print(f" ⚠ Ошибка при загрузке кэша: {e}")
return {}
return {}
def _lemmatize_with_cache(self, texts: list[str]) -> list:
"""
Лемматизирует тексты с использованием кэша.
Проверяет хэши текстов - если хэш совпадает с кэшированным,
использует кэшированный результат. Иначе перелемматизирует.
Args:
texts: Список текстов для лемматизации
Returns:
list: Лемматизированные тексты
"""
if not self.use_cache:
# Если кэш отключен, просто лемматизировать
return [self.lemmatizer.tokenize_text(text) for text in texts]
# Загружаем существующий кэш
cache = self._load_cache()
text_hashes = {}
results = []
needs_save = False
for text in texts:
text_hash = hashlib.sha256(text.encode('utf-8')).hexdigest()
text_hashes[text] = text_hash
if text_hash in cache:
# Используем кэшированный результат
results.append(cache[text_hash])
else:
# Лемматизируем и добавляем в кэш
lemmatized = self.lemmatizer.tokenize_text(text)
results.append(lemmatized)
cache[text_hash] = lemmatized
needs_save = True
# Сохраняем кэш если были новые записи
if needs_save:
with open(self.cache_dir / 'lemmatization_cache.pkl', 'wb') as f:
pickle.dump(cache, f)
print(f" ✓ Кэш сохранён ({len(cache)} записей)")
return results
def semantic_search(self, query: str) -> torch.Tensor:
# 1. Семантический поиск
query_embedding = torch.tensor(self.embedder.encode_query(query))
semantic_scores = torch.nn.functional.cosine_similarity(self.embeddings_of_summary, query_embedding, eps=1e-8).cpu()
return semantic_scores
def bm25_search(self, query: str) -> np.ndarray:
"""BM25 поиск, используя лемматизированные чанки.
Args:
query: Текст запроса
Returns:
np.ndarray: Скоры для каждого абзаца (не предложения!)
"""
bm25 = BM25Okapi(self.chunks_df['lemmatized_text'])
tokenized_query = self.lemmatizer.tokenize_text(query)
sentences_scores = bm25.get_scores(tokenized_query)
df = self.chunks_df['paragraph_id'].to_frame().copy()
df['score'] = sentences_scores
paragraph_scores = df.groupby('paragraph_id')['score'].max().reindex(self.paragraphs_df['paragraph_id']).fillna(0)
return paragraph_scores
def search(self, query: str) -> None:
bm25_scores = self.bm25_search(query)
semantic_scores = self.semantic_search(query).numpy()
bm25_scores = normalize_array(bm25_scores)
return semantic_scores + 1.0 * bm25_scores