antimoda1 commited on
Commit ·
873ada4
1
Parent(s): de448c9
move cross encoder to embedder
Browse files- retrieval.py +40 -119
- test_cross_encoder.py +2 -16
retrieval.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
import re
|
| 2 |
-
import time
|
| 3 |
import hashlib
|
| 4 |
import pickle
|
| 5 |
from pathlib import Path
|
|
@@ -8,24 +7,23 @@ import numpy as np
|
|
| 8 |
import torch
|
| 9 |
import pandas as pd
|
| 10 |
from rank_bm25 import BM25Okapi
|
|
|
|
| 11 |
import warnings
|
| 12 |
warnings.filterwarnings('ignore')
|
| 13 |
|
| 14 |
from _1_get_documents import load_and_process_data
|
| 15 |
-
from _2_splitting import
|
| 16 |
from lemmatizer import RussianLemmatizer
|
| 17 |
-
# from _3_chunking import RussianEmbedder
|
| 18 |
|
| 19 |
-
from sentence_transformers import CrossEncoder
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
class Retrieval:
|
| 26 |
-
"""
|
| 27 |
-
RAG (Retrieval-Augmented Generation) система на русском языке.
|
| 28 |
-
|
| 29 |
Структура хранения данных:
|
| 30 |
============================
|
| 31 |
|
|
@@ -55,15 +53,9 @@ class Retrieval:
|
|
| 55 |
Комбинирует оба датафрейма через JOIN по paragraph_id.
|
| 56 |
Содержит все колонки обоих датафреймов.
|
| 57 |
Используется для поиска и фильтрации.
|
| 58 |
-
|
| 59 |
-
Ключевые преимущества:
|
| 60 |
-
- Избегаем дублирования метаданных параграфов
|
| 61 |
-
- Легко фильтровать по году, summary, документу
|
| 62 |
-
- Оптимизировано для работы с 5000+ чанками
|
| 63 |
-
- Простой merge для получения полной информации
|
| 64 |
"""
|
| 65 |
|
| 66 |
-
def __init__(self, use_gpu: bool = False,
|
| 67 |
print("Инициализация RAG системы...")
|
| 68 |
self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
|
| 69 |
self.use_cache = use_cache
|
|
@@ -91,9 +83,10 @@ class Retrieval:
|
|
| 91 |
self.chunks_df['lemmatized_text'] = self._lemmatize_with_cache(self.chunks_df['text'].tolist())
|
| 92 |
|
| 93 |
# Инициализируем CrossEncoder
|
| 94 |
-
|
| 95 |
-
self.
|
| 96 |
-
|
|
|
|
| 97 |
print("✅ RAG система готова к использованию")
|
| 98 |
|
| 99 |
def _process_documents_with_dates(self):
|
|
@@ -331,46 +324,46 @@ class Retrieval:
|
|
| 331 |
(merged['end_year'] >= year_range[0])
|
| 332 |
]
|
| 333 |
|
| 334 |
-
def rerank_search(self, query: str) -> list[dict]:
|
| 335 |
-
|
| 336 |
|
| 337 |
-
|
| 338 |
-
|
| 339 |
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
# Добавляем scores в датафрейм и сортируем
|
| 348 |
-
results = self.chunks_df.copy()
|
| 349 |
-
results['score'] = scores
|
| 350 |
-
return results.sort_values('score', ascending=False).to_dict('records')
|
| 351 |
|
| 352 |
-
|
|
|
|
| 353 |
# 1. Семантический поиск
|
| 354 |
-
query_embedding = self.embedder.encode_query(query)
|
| 355 |
-
semantic_scores = torch.nn.functional.cosine_similarity(self.
|
| 356 |
return semantic_scores
|
| 357 |
|
| 358 |
-
def bm25_search(self, query: str) ->
|
| 359 |
"""BM25 поиск, используя лемматизированные чанки.
|
| 360 |
|
| 361 |
Args:
|
| 362 |
query: Текст запроса
|
| 363 |
|
| 364 |
Returns:
|
| 365 |
-
|
| 366 |
"""
|
| 367 |
bm25 = BM25Okapi(self.chunks_df['lemmatized_text'].tolist())
|
| 368 |
tokenized_query = self.lemmatizer.tokenize_text(query)
|
| 369 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
|
| 371 |
# ============ Для тестирования cross-encoder ============
|
| 372 |
def test_query_with_cross_encoder(self, query: str,
|
| 373 |
-
target_summary: str):
|
| 374 |
""" Тестирует запрос с cross-encoder и выводит результаты.
|
| 375 |
|
| 376 |
Args:
|
|
@@ -379,81 +372,9 @@ class Retrieval:
|
|
| 379 |
"""
|
| 380 |
print(f"{'='*90}")
|
| 381 |
print(f" ✓ Target summary: '{target_summary}'\n")
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
# ================================================================
|
| 389 |
-
print(f" 📊 BM25 ЛЕКСИЧЕСКИЙ ПОИСК:")
|
| 390 |
-
|
| 391 |
-
# Инициализируем BM25
|
| 392 |
-
bm25_scores = self.bm25_search(query)
|
| 393 |
-
|
| 394 |
-
# Добавляем scores в помощный датафрейм
|
| 395 |
-
search_df = merged_df.copy()
|
| 396 |
-
search_df['bm25_score'] = bm25_scores
|
| 397 |
-
|
| 398 |
-
# Получаем топ-30 по BM25
|
| 399 |
-
top_bm25 = search_df.nlargest(30, 'bm25_score')
|
| 400 |
-
|
| 401 |
-
print(f" Топ-10 чанков, их summary-ы:")
|
| 402 |
-
|
| 403 |
-
# Собираем уникальные summary из BM25 результатов
|
| 404 |
-
bm25_summaries = top_bm25['summary'].unique()
|
| 405 |
-
summary_scores_bm25 = dict(
|
| 406 |
-
top_bm25.groupby('summary')['bm25_score'].first()
|
| 407 |
-
)
|
| 408 |
-
|
| 409 |
-
for rank, summary in enumerate(bm25_summaries[:10], 1):
|
| 410 |
-
score = summary_scores_bm25[summary]
|
| 411 |
-
print(f" {rank:2}. BM25={score:6.2f} [{summary[:50]:50}]")
|
| 412 |
-
|
| 413 |
-
print(f" → Уникальных summary найдено: {len(bm25_summaries)}")
|
| 414 |
-
print(f" → Целевой summary в результатах: {'✓ ДА' if target_summary in bm25_summaries else '✗ НЕТ'}")
|
| 415 |
-
|
| 416 |
-
# ================================================================
|
| 417 |
-
# 2. КРОСС-ЭНКОДЕР РАНЖИРОВАНИЕ
|
| 418 |
-
# ================================================================
|
| 419 |
-
print(f"\n 🏆 КРОСС-ЭНКОДЕР РАНЖИРОВАНИЕ:")
|
| 420 |
-
|
| 421 |
-
# Собираем ВСЕ уникальные summary
|
| 422 |
-
all_unique_summaries = merged_df['summary'].unique().tolist()
|
| 423 |
-
assert target_summary in all_unique_summaries, breakpoint()
|
| 424 |
-
|
| 425 |
-
cross_encoder_start = time.time()
|
| 426 |
-
|
| 427 |
-
# Подготавливаем пары query-summary
|
| 428 |
-
pairs = [[query, summary] for summary in all_unique_summaries]
|
| 429 |
-
|
| 430 |
-
# Ранжируем через кросс-энкодер
|
| 431 |
-
cross_scores = self.cross_encoder.predict(pairs)
|
| 432 |
-
cross_encoder_time = time.time() - cross_encoder_start
|
| 433 |
-
|
| 434 |
-
# Сортируем результаты
|
| 435 |
-
ranked_indices = sorted(
|
| 436 |
-
range(len(cross_scores)),
|
| 437 |
-
key=lambda i: cross_scores[i],
|
| 438 |
-
reverse=True
|
| 439 |
-
)
|
| 440 |
-
|
| 441 |
-
print(f" (время: {cross_encoder_time:.3f} сек)")
|
| 442 |
-
print(f" Top-5 summary (из {len(all_unique_summaries)} всего):")
|
| 443 |
-
|
| 444 |
-
cross_target_rank = None
|
| 445 |
-
|
| 446 |
-
for rank, idx in enumerate(ranked_indices[:5], 1):
|
| 447 |
-
summary = all_unique_summaries[idx]
|
| 448 |
-
score = cross_scores[idx]
|
| 449 |
-
|
| 450 |
-
is_target = summary == target_summary
|
| 451 |
-
mark = "⭐ TARGET ⭐" if is_target else " " * 13
|
| 452 |
-
|
| 453 |
-
print(f" {mark} {rank}. Cross={score:7.4f} [{summary[:50]:50}]")
|
| 454 |
-
|
| 455 |
-
if is_target:
|
| 456 |
-
cross_target_rank = rank
|
| 457 |
-
|
| 458 |
-
if not cross_target_rank:
|
| 459 |
-
print(f" ❌ Целевой summary НЕ в топ-5")
|
|
|
|
| 1 |
import re
|
|
|
|
| 2 |
import hashlib
|
| 3 |
import pickle
|
| 4 |
from pathlib import Path
|
|
|
|
| 7 |
import torch
|
| 8 |
import pandas as pd
|
| 9 |
from rank_bm25 import BM25Okapi
|
| 10 |
+
from sentence_transformers import SentenceTransformer
|
| 11 |
import warnings
|
| 12 |
warnings.filterwarnings('ignore')
|
| 13 |
|
| 14 |
from _1_get_documents import load_and_process_data
|
| 15 |
+
from _2_splitting import parse_metadata_from_document
|
| 16 |
from lemmatizer import RussianLemmatizer
|
|
|
|
| 17 |
|
|
|
|
| 18 |
|
| 19 |
+
def normalize_array(arr):
|
| 20 |
+
min_val = np.min(arr)
|
| 21 |
+
max_val = np.max(arr)
|
| 22 |
+
return (arr - min_val) / (max_val - min_val)
|
| 23 |
|
| 24 |
|
| 25 |
class Retrieval:
|
| 26 |
+
"""
|
|
|
|
|
|
|
| 27 |
Структура хранения данных:
|
| 28 |
============================
|
| 29 |
|
|
|
|
| 53 |
Комбинирует оба датафрейма через JOIN по paragraph_id.
|
| 54 |
Содержит все колонки обоих датафреймов.
|
| 55 |
Используется для поиска и фильтрации.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
"""
|
| 57 |
|
| 58 |
+
def __init__(self, use_gpu: bool = False, use_cache: bool = True):
|
| 59 |
print("Инициализация RAG системы...")
|
| 60 |
self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
|
| 61 |
self.use_cache = use_cache
|
|
|
|
| 83 |
self.chunks_df['lemmatized_text'] = self._lemmatize_with_cache(self.chunks_df['text'].tolist())
|
| 84 |
|
| 85 |
# Инициализируем CrossEncoder
|
| 86 |
+
# self.cross_encoder = CrossEncoder('DiTy/cross-encoder-russian-msmarco')
|
| 87 |
+
self.embedder = SentenceTransformer('cointegrated/rubert-tiny2')
|
| 88 |
+
self.embeddings_of_summary = self.embedder.encode(self.paragraphs_df['summary'].tolist(), convert_to_tensor=True)
|
| 89 |
+
|
| 90 |
print("✅ RAG система готова к использованию")
|
| 91 |
|
| 92 |
def _process_documents_with_dates(self):
|
|
|
|
| 324 |
(merged['end_year'] >= year_range[0])
|
| 325 |
]
|
| 326 |
|
| 327 |
+
# def rerank_search(self, query: str) -> list[dict]:
|
| 328 |
+
# """Ранжирует все чанки используя CrossEncoder модель.
|
| 329 |
|
| 330 |
+
# Args:
|
| 331 |
+
# query: Текст запроса
|
| 332 |
|
| 333 |
+
# Returns:
|
| 334 |
+
# list: Отсортированный список результатов с scores
|
| 335 |
+
# """
|
| 336 |
+
# pairs = [[query, text] for text in self.paragraphs_df['summary'].tolist()]
|
| 337 |
+
# scores = self.cross_encoder.predict(pairs)
|
| 338 |
+
# sorted_scores = dict(sorted(scores.items(), key=lambda item: item[0]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
|
| 340 |
+
|
| 341 |
+
def semantic_search(self, query: str) -> torch.Tensor:
|
| 342 |
# 1. Семантический поиск
|
| 343 |
+
query_embedding = torch.tensor(self.embedder.encode_query(query))
|
| 344 |
+
semantic_scores = torch.nn.functional.cosine_similarity(self.embeddings_of_summary, query_embedding, eps=1e-8).cpu()
|
| 345 |
return semantic_scores
|
| 346 |
|
| 347 |
+
def bm25_search(self, query: str) -> np.ndarray:
|
| 348 |
"""BM25 поиск, используя лемматизированные чанки.
|
| 349 |
|
| 350 |
Args:
|
| 351 |
query: Текст запроса
|
| 352 |
|
| 353 |
Returns:
|
| 354 |
+
np.ndarray: Скоры для каждого абзаца (не предложения!)
|
| 355 |
"""
|
| 356 |
bm25 = BM25Okapi(self.chunks_df['lemmatized_text'].tolist())
|
| 357 |
tokenized_query = self.lemmatizer.tokenize_text(query)
|
| 358 |
+
sentences_scores = bm25.get_scores(tokenized_query)
|
| 359 |
+
df = self.chunks_df['paragraph_id'].to_frame().copy()
|
| 360 |
+
df['score'] = sentences_scores
|
| 361 |
+
paragraph_scores = df.groupby('paragraph_id')['score'].max().reindex(self.paragraphs_df['paragraph_id']).fillna(0)
|
| 362 |
+
return paragraph_scores
|
| 363 |
|
| 364 |
# ============ Для тестирования cross-encoder ============
|
| 365 |
def test_query_with_cross_encoder(self, query: str,
|
| 366 |
+
target_summary: str, weight_bm25: float = 0.5, weight_semantic: float = 0.5) -> None:
|
| 367 |
""" Тестирует запрос с cross-encoder и выводит результаты.
|
| 368 |
|
| 369 |
Args:
|
|
|
|
| 372 |
"""
|
| 373 |
print(f"{'='*90}")
|
| 374 |
print(f" ✓ Target summary: '{target_summary}'\n")
|
| 375 |
+
|
| 376 |
+
bm25_scores = self.bm25_search(query)
|
| 377 |
+
semantic_scores = self.semantic_search(query).numpy()
|
| 378 |
+
bm25_scores = normalize_array(bm25_scores)
|
| 379 |
+
semantic_scores = normalize_array(semantic_scores)
|
| 380 |
+
return weight_semantic * semantic_scores + weight_bm25 * bm25_scores
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_cross_encoder.py
CHANGED
|
@@ -36,29 +36,15 @@ test_cases = [
|
|
| 36 |
|
| 37 |
|
| 38 |
def test_cross_encoder_vs_bm25():
|
| 39 |
-
"""Тестирует кросс-энкодер vs BM25 на всех документах."""
|
| 40 |
-
print("=" * 90)
|
| 41 |
-
print("СРАВНЕНИЕ: КРОСС-ЭНКОДЕР vs BM25 ЛЕММАТИЗИРОВАННЫЙ ПОИСК")
|
| 42 |
-
print("=" * 90)
|
| 43 |
-
|
| 44 |
# Создаем объект Retrieval (загружает корпус автоматически)
|
| 45 |
retrieval = Retrieval(use_gpu=False)
|
| 46 |
|
| 47 |
-
|
| 48 |
-
print("=" * 90)
|
| 49 |
-
print("ТЕСТИРОВАНИЕ ОТДЕЛЬНЫХ ЗАПРОСОВ")
|
| 50 |
-
print("=" * 90)
|
| 51 |
-
|
| 52 |
-
for test_num, test_case in enumerate(test_cases, 1):
|
| 53 |
retrieval.test_query_with_cross_encoder(
|
| 54 |
query=test_case.query,
|
| 55 |
target_summary=test_case.good_answer,
|
| 56 |
-
test_num=test_num
|
| 57 |
)
|
| 58 |
-
|
| 59 |
-
print("\n" + "=" * 90)
|
| 60 |
-
print(f"✅ Тестирование завершено")
|
| 61 |
-
print("=" * 90)
|
| 62 |
|
| 63 |
|
| 64 |
if __name__ == "__main__":
|
|
|
|
| 36 |
|
| 37 |
|
| 38 |
def test_cross_encoder_vs_bm25():
|
| 39 |
+
"""Тестирует кросс-энкодер vs BM25 на всех документах."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
# Создаем объект Retrieval (загружает корпус автоматически)
|
| 41 |
retrieval = Retrieval(use_gpu=False)
|
| 42 |
|
| 43 |
+
for test_case in test_cases:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
retrieval.test_query_with_cross_encoder(
|
| 45 |
query=test_case.query,
|
| 46 |
target_summary=test_case.good_answer,
|
|
|
|
| 47 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
|
| 50 |
if __name__ == "__main__":
|