antimoda1 commited on
Commit
873ada4
·
1 Parent(s): de448c9

move cross encoder to embedder

Browse files
Files changed (2) hide show
  1. retrieval.py +40 -119
  2. test_cross_encoder.py +2 -16
retrieval.py CHANGED
@@ -1,5 +1,4 @@
1
  import re
2
- import time
3
  import hashlib
4
  import pickle
5
  from pathlib import Path
@@ -8,24 +7,23 @@ import numpy as np
8
  import torch
9
  import pandas as pd
10
  from rank_bm25 import BM25Okapi
 
11
  import warnings
12
  warnings.filterwarnings('ignore')
13
 
14
  from _1_get_documents import load_and_process_data
15
- from _2_splitting import years_overlap, parse_metadata_from_document
16
  from lemmatizer import RussianLemmatizer
17
- # from _3_chunking import RussianEmbedder
18
 
19
- from sentence_transformers import CrossEncoder
20
 
21
- # Модель будет загружена автоматически
22
- # model = CrossEncoder('DiTy/cross-encoder-russian-msmarco', max_length=512)
 
 
23
 
24
 
25
  class Retrieval:
26
- """
27
- RAG (Retrieval-Augmented Generation) система на русском языке.
28
-
29
  Структура хранения данных:
30
  ============================
31
 
@@ -55,15 +53,9 @@ class Retrieval:
55
  Комбинирует оба датафрейма через JOIN по paragraph_id.
56
  Содержит все колонки обоих датафреймов.
57
  Используется для поиска и фильтрации.
58
-
59
- Ключевые преимущества:
60
- - Избегаем дублирования метаданных параграфов
61
- - Легко фильтровать по году, summary, документу
62
- - Оптимизировано для работы с 5000+ чанками
63
- - Простой merge для получения полной информации
64
  """
65
 
66
- def __init__(self, use_gpu: bool = False, load_json: bool = True, use_cache: bool = True):
67
  print("Инициализация RAG системы...")
68
  self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
69
  self.use_cache = use_cache
@@ -91,9 +83,10 @@ class Retrieval:
91
  self.chunks_df['lemmatized_text'] = self._lemmatize_with_cache(self.chunks_df['text'].tolist())
92
 
93
  # Инициализируем CrossEncoder
94
- print("3. Загрузка CrossEncoder модели...")
95
- self.cross_encoder = CrossEncoder('DiTy/cross-encoder-russian-msmarco')
96
-
 
97
  print("✅ RAG система готова к использованию")
98
 
99
  def _process_documents_with_dates(self):
@@ -331,46 +324,46 @@ class Retrieval:
331
  (merged['end_year'] >= year_range[0])
332
  ]
333
 
334
- def rerank_search(self, query: str) -> list[dict]:
335
- """Ранжирует все чанки используя CrossEncoder модель.
336
 
337
- Args:
338
- query: Текст запроса
339
 
340
- Returns:
341
- list: Отсортированный список результатов с scores
342
- """
343
- pairs = [[query, text] for text in self.chunks_df['text'].tolist()]
344
- scores = self.cross_encoder.predict(pairs)
345
- breakpoint()
346
-
347
- # Добавляем scores в датафрейм и сортируем
348
- results = self.chunks_df.copy()
349
- results['score'] = scores
350
- return results.sort_values('score', ascending=False).to_dict('records')
351
 
352
- def semantic_search(self, query: str) -> list:
 
353
  # 1. Семантический поиск
354
- query_embedding = self.embedder.encode_query(query)
355
- semantic_scores = torch.nn.functional.cosine_similarity(self.embeddings, query_embedding, eps=1e-8).cpu()
356
  return semantic_scores
357
 
358
- def bm25_search(self, query: str) -> list:
359
  """BM25 поиск, используя лемматизированные чанки.
360
 
361
  Args:
362
  query: Текст запроса
363
 
364
  Returns:
365
- list: Скоры для каждого чанка
366
  """
367
  bm25 = BM25Okapi(self.chunks_df['lemmatized_text'].tolist())
368
  tokenized_query = self.lemmatizer.tokenize_text(query)
369
- return bm25.get_scores(tokenized_query)
 
 
 
 
370
 
371
  # ============ Для тестирования cross-encoder ============
372
  def test_query_with_cross_encoder(self, query: str,
373
- target_summary: str):
374
  """ Тестирует запрос с cross-encoder и выводит результаты.
375
 
376
  Args:
@@ -379,81 +372,9 @@ class Retrieval:
379
  """
380
  print(f"{'='*90}")
381
  print(f" ✓ Target summary: '{target_summary}'\n")
382
-
383
- # Получаем объединённый датафрейм
384
- merged_df = self.get_merged_data()
385
-
386
- # ================================================================
387
- # 1. BM25 ПОИСК
388
- # ================================================================
389
- print(f" 📊 BM25 ЛЕКСИЧЕСКИЙ ПОИСК:")
390
-
391
- # Инициализируем BM25
392
- bm25_scores = self.bm25_search(query)
393
-
394
- # Добавляем scores в помощный датафрейм
395
- search_df = merged_df.copy()
396
- search_df['bm25_score'] = bm25_scores
397
-
398
- # Получаем топ-30 по BM25
399
- top_bm25 = search_df.nlargest(30, 'bm25_score')
400
-
401
- print(f" Топ-10 чанков, их summary-ы:")
402
-
403
- # Собираем уникальные summary из BM25 результатов
404
- bm25_summaries = top_bm25['summary'].unique()
405
- summary_scores_bm25 = dict(
406
- top_bm25.groupby('summary')['bm25_score'].first()
407
- )
408
-
409
- for rank, summary in enumerate(bm25_summaries[:10], 1):
410
- score = summary_scores_bm25[summary]
411
- print(f" {rank:2}. BM25={score:6.2f} [{summary[:50]:50}]")
412
-
413
- print(f" → Уникальных summary найдено: {len(bm25_summaries)}")
414
- print(f" → Целевой summary в результатах: {'✓ ДА' if target_summary in bm25_summaries else '✗ НЕТ'}")
415
-
416
- # ================================================================
417
- # 2. КРОСС-ЭНКОДЕР РАНЖИРОВАНИЕ
418
- # ================================================================
419
- print(f"\n 🏆 КРОСС-ЭНКОДЕР РАНЖИРОВАНИЕ:")
420
-
421
- # Собираем ВСЕ уникальные summary
422
- all_unique_summaries = merged_df['summary'].unique().tolist()
423
- assert target_summary in all_unique_summaries, breakpoint()
424
-
425
- cross_encoder_start = time.time()
426
-
427
- # Подготавливаем пары query-summary
428
- pairs = [[query, summary] for summary in all_unique_summaries]
429
-
430
- # Ранжируем через кросс-энкодер
431
- cross_scores = self.cross_encoder.predict(pairs)
432
- cross_encoder_time = time.time() - cross_encoder_start
433
-
434
- # Сортируем результаты
435
- ranked_indices = sorted(
436
- range(len(cross_scores)),
437
- key=lambda i: cross_scores[i],
438
- reverse=True
439
- )
440
-
441
- print(f" (время: {cross_encoder_time:.3f} сек)")
442
- print(f" Top-5 summary (из {len(all_unique_summaries)} всего):")
443
-
444
- cross_target_rank = None
445
-
446
- for rank, idx in enumerate(ranked_indices[:5], 1):
447
- summary = all_unique_summaries[idx]
448
- score = cross_scores[idx]
449
-
450
- is_target = summary == target_summary
451
- mark = "⭐ TARGET ⭐" if is_target else " " * 13
452
-
453
- print(f" {mark} {rank}. Cross={score:7.4f} [{summary[:50]:50}]")
454
-
455
- if is_target:
456
- cross_target_rank = rank
457
-
458
- if not cross_target_rank:
459
- print(f" ❌ Целевой summary НЕ в топ-5")
 
1
  import re
 
2
  import hashlib
3
  import pickle
4
  from pathlib import Path
 
7
  import torch
8
  import pandas as pd
9
  from rank_bm25 import BM25Okapi
10
+ from sentence_transformers import SentenceTransformer
11
  import warnings
12
  warnings.filterwarnings('ignore')
13
 
14
  from _1_get_documents import load_and_process_data
15
+ from _2_splitting import parse_metadata_from_document
16
  from lemmatizer import RussianLemmatizer
 
17
 
 
18
 
19
+ def normalize_array(arr):
20
+ min_val = np.min(arr)
21
+ max_val = np.max(arr)
22
+ return (arr - min_val) / (max_val - min_val)
23
 
24
 
25
  class Retrieval:
26
+ """
 
 
27
  Структура хранения данных:
28
  ============================
29
 
 
53
  Комбинирует оба датафрейма через JOIN по paragraph_id.
54
  Содержит все колонки обоих датафреймов.
55
  Используется для поиска и фильтрации.
 
 
 
 
 
 
56
  """
57
 
58
+ def __init__(self, use_gpu: bool = False, use_cache: bool = True):
59
  print("Инициализация RAG системы...")
60
  self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
61
  self.use_cache = use_cache
 
83
  self.chunks_df['lemmatized_text'] = self._lemmatize_with_cache(self.chunks_df['text'].tolist())
84
 
85
  # Инициализируем CrossEncoder
86
+ # self.cross_encoder = CrossEncoder('DiTy/cross-encoder-russian-msmarco')
87
+ self.embedder = SentenceTransformer('cointegrated/rubert-tiny2')
88
+ self.embeddings_of_summary = self.embedder.encode(self.paragraphs_df['summary'].tolist(), convert_to_tensor=True)
89
+
90
  print("✅ RAG система готова к использованию")
91
 
92
  def _process_documents_with_dates(self):
 
324
  (merged['end_year'] >= year_range[0])
325
  ]
326
 
327
+ # def rerank_search(self, query: str) -> list[dict]:
328
+ # """Ранжирует все чанки используя CrossEncoder модель.
329
 
330
+ # Args:
331
+ # query: Текст запроса
332
 
333
+ # Returns:
334
+ # list: Отсортированный список результатов с scores
335
+ # """
336
+ # pairs = [[query, text] for text in self.paragraphs_df['summary'].tolist()]
337
+ # scores = self.cross_encoder.predict(pairs)
338
+ # sorted_scores = dict(sorted(scores.items(), key=lambda item: item[0]))
 
 
 
 
 
339
 
340
+
341
+ def semantic_search(self, query: str) -> torch.Tensor:
342
  # 1. Семантический поиск
343
+ query_embedding = torch.tensor(self.embedder.encode_query(query))
344
+ semantic_scores = torch.nn.functional.cosine_similarity(self.embeddings_of_summary, query_embedding, eps=1e-8).cpu()
345
  return semantic_scores
346
 
347
+ def bm25_search(self, query: str) -> np.ndarray:
348
  """BM25 поиск, используя лемматизированные чанки.
349
 
350
  Args:
351
  query: Текст запроса
352
 
353
  Returns:
354
+ np.ndarray: Скоры для каждого абзаца (не предложения!)
355
  """
356
  bm25 = BM25Okapi(self.chunks_df['lemmatized_text'].tolist())
357
  tokenized_query = self.lemmatizer.tokenize_text(query)
358
+ sentences_scores = bm25.get_scores(tokenized_query)
359
+ df = self.chunks_df['paragraph_id'].to_frame().copy()
360
+ df['score'] = sentences_scores
361
+ paragraph_scores = df.groupby('paragraph_id')['score'].max().reindex(self.paragraphs_df['paragraph_id']).fillna(0)
362
+ return paragraph_scores
363
 
364
  # ============ Для тестирования cross-encoder ============
365
  def test_query_with_cross_encoder(self, query: str,
366
+ target_summary: str, weight_bm25: float = 0.5, weight_semantic: float = 0.5) -> None:
367
  """ Тестирует запрос с cross-encoder и выводит результаты.
368
 
369
  Args:
 
372
  """
373
  print(f"{'='*90}")
374
  print(f" ✓ Target summary: '{target_summary}'\n")
375
+
376
+ bm25_scores = self.bm25_search(query)
377
+ semantic_scores = self.semantic_search(query).numpy()
378
+ bm25_scores = normalize_array(bm25_scores)
379
+ semantic_scores = normalize_array(semantic_scores)
380
+ return weight_semantic * semantic_scores + weight_bm25 * bm25_scores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_cross_encoder.py CHANGED
@@ -36,29 +36,15 @@ test_cases = [
36
 
37
 
38
  def test_cross_encoder_vs_bm25():
39
- """Тестирует кросс-энкодер vs BM25 на всех документах."""
40
- print("=" * 90)
41
- print("СРАВНЕНИЕ: КРОСС-ЭНКОДЕР vs BM25 ЛЕММАТИЗИРОВАННЫЙ ПОИСК")
42
- print("=" * 90)
43
-
44
  # Создаем объект Retrieval (загружает корпус автоматически)
45
  retrieval = Retrieval(use_gpu=False)
46
 
47
- # Тестируем каждый тестовый случай
48
- print("=" * 90)
49
- print("ТЕСТИРОВАНИЕ ОТДЕЛЬНЫХ ЗАПРОСОВ")
50
- print("=" * 90)
51
-
52
- for test_num, test_case in enumerate(test_cases, 1):
53
  retrieval.test_query_with_cross_encoder(
54
  query=test_case.query,
55
  target_summary=test_case.good_answer,
56
- test_num=test_num
57
  )
58
-
59
- print("\n" + "=" * 90)
60
- print(f"✅ Тестирование завершено")
61
- print("=" * 90)
62
 
63
 
64
  if __name__ == "__main__":
 
36
 
37
 
38
  def test_cross_encoder_vs_bm25():
39
+ """Тестирует кросс-энкодер vs BM25 на всех документах."""
 
 
 
 
40
  # Создаем объект Retrieval (загружает корпус автоматически)
41
  retrieval = Retrieval(use_gpu=False)
42
 
43
+ for test_case in test_cases:
 
 
 
 
 
44
  retrieval.test_query_with_cross_encoder(
45
  query=test_case.query,
46
  target_summary=test_case.good_answer,
 
47
  )
 
 
 
 
48
 
49
 
50
  if __name__ == "__main__":