antimoda1 commited on
Commit
09bc630
·
1 Parent(s): 6a33050

add cross-encoder

Browse files
_1_get_documents.py CHANGED
@@ -13,7 +13,7 @@ def process_file(file_path):
13
  with open(file_path, 'r', encoding='utf-8-sig') as f:
14
  text = f.read()
15
  assert text
16
- return text, str(file_path).split('.')[-1]
17
 
18
  def process_folder_recursive(folder_path):
19
  all_messages = []
 
13
  with open(file_path, 'r', encoding='utf-8-sig') as f:
14
  text = f.read()
15
  assert text
16
+ return text, str(file_path).split('.')[0]
17
 
18
  def process_folder_recursive(folder_path):
19
  all_messages = []
_2_splitting.py CHANGED
@@ -1,6 +1,3 @@
1
- from typing import List, Dict
2
- import re
3
-
4
  # Конфиги для парсинга дат
5
  YEARS = {
6
  'O': 1918,
@@ -53,64 +50,56 @@ def _parse_date_range(date_str: str) -> tuple[int, int]:
53
  parts = date_str.split('-')
54
  start = _parse_single_year(parts[0].strip())
55
  end = _parse_single_year(parts[1].strip())
56
- return (start, end) if start <= end else (end, start)
 
57
  else:
58
  # Один год
59
  year = _parse_single_year(date_str)
60
  return (year, year)
61
 
62
 
63
- def parse_year_metadata(text: str) -> list[tuple[str, tuple[int, int]]]:
64
- """
65
- Парсит markdown текст и возвращает список (чанк_текста, (год_начала, год_конца)).
 
 
 
66
 
67
- Формат разметки: ## 1962-2002 или ### 1962-2002 или просто 1962 или O или N
68
- Разметка распространяется на абзац ниже и все последующие до новой разметки.
 
 
 
 
69
 
70
  Args:
71
  text: Полный текст документа
72
 
73
  Returns:
74
- list: [(chunk_text, (start_year, end_year)), ...]
75
-
76
- Raises:
77
- ValueError: Если документ не начинается с разметки
78
  """
79
  lines = text.split('\n')
80
 
81
- # Проверяем, начинается ли документ с разметки (# или ## или ###)
82
- if not lines or not re.match(r'^#+\s*', lines[0].strip()):
83
- raise ValueError(f"Документ не начинается с разметки! Первая строка: {lines[0] if lines else 'ПУСТО'}")
84
-
85
  result = []
 
86
  current_year_range = None
87
- current_text = []
88
 
89
  for line in lines:
90
- # Проверяем, является ли строка разметкой (начинается с #, ##, ### и т.д.)
91
- match = re.match(r'^#+\s+(.+?)$', line.strip())
92
- if match:
93
- # Сохраняем предыдущий абзац если он есть
94
- if current_text and current_year_range:
95
- chunk = '\n'.join(current_text).strip()
96
- if chunk:
97
- result.append((chunk, current_year_range))
98
-
99
- # Парсим новую разметку
100
- date_str = match.group(1).strip()
101
- current_year_range = _parse_date_range(date_str)
102
-
103
- current_text = []
104
  else:
105
- # Если это не разметка, добавляем в текущий абзац
106
- if line.strip(): # Только непустые строки
107
- current_text.append(line)
108
-
109
- # Сохраняем последний абзац
110
- if current_text and current_year_range:
111
- chunk = '\n'.join(current_text).strip()
112
- if chunk:
113
- result.append((chunk, current_year_range))
114
 
115
  return result
116
 
@@ -132,68 +121,3 @@ def years_overlap(range1: tuple[int, int], range2: tuple[int, int]) -> bool:
132
  start1, end1 = range1
133
  start2, end2 = range2
134
  return start1 <= end2 and end1 >= start2
135
-
136
-
137
- class Splitter:
138
- """
139
- Класс для работы с русскоязычными эмбеддингами в RAG пайплайне.
140
- Поддерживает дообучение Word2Vec/FastText и использование RuBERT.
141
- """
142
-
143
- def __init__(self,
144
- chunk_size: int = 350,
145
- chunk_overlap: int = 70):
146
-
147
- self.chunk_size = chunk_size
148
- self.chunk_overlap = chunk_overlap
149
-
150
- # Инициализация компонентов
151
- self.chunks = []
152
- self.chunk_metadata = []
153
- self.documents = []
154
-
155
- def load_documents(self, documents: List[Dict]):
156
- """
157
- Загрузка документов и создание чанков.
158
-
159
- Args:
160
- documents: Список словарей с полем 'text'
161
- """
162
- self.documents = documents
163
- print(f"📄 Загрузка {len(documents)} документов...")
164
-
165
- chunks = []
166
- docs_metadata = [] # ID документов для каждого чанка
167
- paragraph_metadata = [] # ID абзацев для каждого чанка
168
- paragraph_id_counter = 0
169
-
170
- for doc_id, document in enumerate(documents):
171
- # Разбиваем документ на абзацы по \n\n
172
- paragraphs = document.split('\n')
173
-
174
- for paragraph in paragraphs:
175
- paragraph = paragraph.strip()
176
-
177
- if paragraph == '':
178
- continue
179
- sentences = re.split(r'(?<=[.!?])\s+', paragraph)
180
-
181
- # Если абзац слишком длинный, используем сплиттер для его разбиения
182
- if len(sentences) > 1:
183
- for chunk in sentences:
184
- if len(chunk.strip()) >= 30:
185
- chunks.append(chunk)
186
- docs_metadata.append(doc_id)
187
- paragraph_metadata.append(paragraph_id_counter)
188
- else:
189
- # Добавляем абзац как целый чанк
190
- chunks.append(paragraph)
191
- docs_metadata.append(doc_id)
192
- paragraph_metadata.append(paragraph_id_counter)
193
-
194
- paragraph_id_counter += 1
195
-
196
- print(f"✅ Создано {len(chunks)} чанков")
197
- print(f" Из {paragraph_id_counter} абзацев в {len(documents)} документах")
198
- return chunks, docs_metadata, paragraph_metadata
199
-
 
 
 
 
1
  # Конфиги для парсинга дат
2
  YEARS = {
3
  'O': 1918,
 
50
  parts = date_str.split('-')
51
  start = _parse_single_year(parts[0].strip())
52
  end = _parse_single_year(parts[1].strip())
53
+ assert start <= end, f"Год начала {start} должен быть меньше или равен году конца {end}"
54
+ return (start, end)
55
  else:
56
  # Один год
57
  year = _parse_single_year(date_str)
58
  return (year, year)
59
 
60
 
61
+ def parse_metadata_from_document(text: str) -> list[tuple[str, tuple[int, int], str]]:
62
+ """Парсит markdown текст и возвращает список (чанк_текста, (год_начала, год_конца), summary).
63
+
64
+ Формат разметки ОБЯЗАТЕЛЕН:
65
+ - ## Summary text - заголовок summary (двойной хэш + пробел)
66
+ - ### 1962-2002 - заголовок с годом (тройной хэш + пробел)
67
 
68
+ Правила:
69
+ - Каждый документ ДОЛЖЕН начинаться с "## {summary}"
70
+ - После summary ДОЛЖНЫ быть заголовки "### {годы}" с текстом
71
+ - ## распространяется на все абзацы ниже до следующего ## или конца файла
72
+ - ### распространяется на абзацы ниже до следующего ### или ##
73
+ - Текст БЕЗ предшествующего ### Не добавляется в результат
74
 
75
  Args:
76
  text: Полный текст документа
77
 
78
  Returns:
79
+ list: [(chunk_text, (start_year, end_year), summary), ...]
 
 
 
80
  """
81
  lines = text.split('\n')
82
 
 
 
 
 
83
  result = []
84
+ current_summary = None
85
  current_year_range = None
 
86
 
87
  for line in lines:
88
+ strip_line = line.strip()
89
+ if not strip_line:
90
+ continue
91
+
92
+ if strip_line.startswith('## '):
93
+ current_summary = strip_line[3:].strip() # Пропускаем "## "
94
+
95
+ # Проверяем, является ли строка "### " (год с пробелом после)
96
+ elif strip_line.startswith('### '):
97
+ current_year_range = _parse_date_range(strip_line[4:])
98
+
 
 
 
99
  else:
100
+ # Добавляем текст только если у нас есть год
101
+ assert current_year_range and current_summary, breakpoint()
102
+ result.append((line, current_year_range, current_summary))
 
 
 
 
 
 
103
 
104
  return result
105
 
 
121
  start1, end1 = range1
122
  start2, end2 = range2
123
  return start1 <= end2 and end1 >= start2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
retrieval.py CHANGED
@@ -1,13 +1,18 @@
1
  import re
 
 
 
2
  from pathlib import Path
3
 
 
4
  import torch
 
5
  from rank_bm25 import BM25Okapi
6
  import warnings
7
  warnings.filterwarnings('ignore')
8
 
9
  from _1_get_documents import load_and_process_data
10
- from _2_splitting import parse_year_metadata, years_overlap
11
  from lemmatizer import RussianLemmatizer
12
  # from _3_chunking import RussianEmbedder
13
 
@@ -18,9 +23,55 @@ from sentence_transformers import CrossEncoder
18
 
19
 
20
  class Retrieval:
21
- def __init__(self, use_gpu: bool = False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  print("Инициализация RAG системы...")
23
  self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
24
 
25
  # Инициализация лемматизатора для русского языка
26
  print(" Инициализация лемматизатора...")
@@ -32,85 +83,271 @@ class Retrieval:
32
  # self.documents after this phase: list of {'text': str, 'date': str}
33
  print(f" Загружено {len(self.documents)} сообщений")
34
 
35
- # Парсим даты из документов и создаем чанки
36
- self.chunks, self.docs_metadata, self.paragraph_metadata, self.chunk_dates = self._process_documents_with_dates()
37
- self.bm25 = self._prepare_bm25()
38
- # self.embedder = RussianEmbedder(self.chunks, model_type="rubert")
39
-
40
- # Создание индекса
41
- # self.embeddings = self.embedder.create_index()
 
 
 
 
 
42
 
43
- print("✓ RAG система готова к работе!")
44
-
45
  def _process_documents_with_dates(self):
46
  """
47
- Обрабатывает документы с парсингом дат и создает чанки.
48
 
49
  Returns:
50
- tuple: (chunks, docs_metadata, paragraph_metadata, chunk_dates)
51
- где chunk_dates - список ((start_year, end_year), ...) для каждого чанка
 
 
 
 
 
 
 
 
 
 
 
 
52
  """
53
- chunks = []
54
- docs_metadata = []
55
- paragraph_metadata = []
56
- chunk_dates = []
57
  paragraph_id_counter = 0
 
58
 
59
  for doc_id, document in enumerate(self.documents):
60
- try:
61
- # Парсим даты из документа
62
- dated_chunks = parse_year_metadata(document)
 
63
 
64
- for chunk_text, year_range in dated_chunks:
65
- # Разбиваем на предложения, как в Splitter
66
- paragraphs = chunk_text.split('\n')
 
 
 
 
 
 
 
 
67
 
68
- for paragraph in paragraphs:
69
- paragraph = paragraph.strip()
70
- if not paragraph:
71
- continue
72
-
73
- sentences = re.split(r'(?<=[.!?])\s+', paragraph)
74
-
75
- if len(sentences) > 1:
76
- for sent in sentences:
77
- if len(sent.strip()) >= 30:
78
- chunks.append(sent)
79
- docs_metadata.append(doc_id)
80
- paragraph_metadata.append(paragraph_id_counter)
81
- chunk_dates.append(year_range)
82
- else:
83
- chunks.append(paragraph)
84
- docs_metadata.append(doc_id)
85
- paragraph_metadata.append(paragraph_id_counter)
86
- chunk_dates.append(year_range)
87
-
88
- paragraph_id_counter += 1
89
-
90
- except ValueError as e:
91
- print(f" ⚠️ Ошибка при парсинге документа {doc_id} ({self.docs_names[doc_id]}): {e}")
92
- # Пропускаем документ если он не имеет правильной разметки
93
- continue
94
-
95
- print(f"✅ Создано {len(chunks)} чанков")
96
- print(f" Из {len(set(paragraph_metadata))} абзацев в {doc_id + 1} документах")
97
- return chunks, docs_metadata, paragraph_metadata, chunk_dates
98
-
99
- def _prepare_bm25(self):
100
- """Подготавливаем BM25 индекс для ключевого поиска"""
101
- # Токенизация для BM25
102
- tokenized_chunks = [self.lemmatizer.tokenize_text(chunk) for chunk in self.chunks]
103
- return BM25Okapi(tokenized_chunks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  def rerank_search(self, query: str) -> list[dict]:
 
 
 
 
 
 
 
106
  """
107
- [{'corpus_id': 0, 'score': 0.88126713},
108
- # {'corpus_id': 2, 'score': 0.001042091},
109
- # {'corpus_id': 3, 'score': 0.0010417715},
110
- # {'corpus_id': 1, 'score': 0.0010344835},
111
- # {'corpus_id': 4, 'score': 0.0010244923}]`"""
112
- reranker_model = CrossEncoder('DiTy/cross-encoder-russian-msmarco')
113
- return reranker_model.rank(query[0], self.chunks)
 
114
 
115
  def semantic_search(self, query: str) -> list:
116
  # 1. Семантический поиск
@@ -118,26 +355,105 @@ class Retrieval:
118
  semantic_scores = torch.nn.functional.cosine_similarity(self.embeddings, query_embedding, eps=1e-8).cpu()
119
  return semantic_scores
120
 
121
- def bm25_search(self, query: str) -> list:
122
- # 2. Ключевой поиск (BM25)
 
 
 
 
 
 
 
 
123
  tokenized_query = self.lemmatizer.tokenize_text(query)
124
- return self.bm25.get_scores(tokenized_query)
125
 
126
- def filter_by_year_range(self, indices: list[int], year_range: tuple[int, int]) -> list[int]:
127
- """
128
- Фильтрует индексы чанков по диапазону лет (с пересечением).
 
129
 
130
  Args:
131
- indices: Список индексов чанков для фильтрации
132
- year_range: (start_year, end_year) для поиска
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
- Returns:
135
- list: Отфильтрованный список индексов
136
- """
137
- filtered = []
138
- for idx in indices:
139
- if idx < len(self.chunk_dates):
140
- chunk_range = self.chunk_dates[idx]
141
- if years_overlap(chunk_range, year_range):
142
- filtered.append(idx)
143
- return filtered
 
1
  import re
2
+ import time
3
+ import hashlib
4
+ import pickle
5
  from pathlib import Path
6
 
7
+ import numpy as np
8
  import torch
9
+ import pandas as pd
10
  from rank_bm25 import BM25Okapi
11
  import warnings
12
  warnings.filterwarnings('ignore')
13
 
14
  from _1_get_documents import load_and_process_data
15
+ from _2_splitting import years_overlap, parse_metadata_from_document
16
  from lemmatizer import RussianLemmatizer
17
  # from _3_chunking import RussianEmbedder
18
 
 
23
 
24
 
25
  class Retrieval:
26
+ """
27
+ RAG (Retrieval-Augmented Generation) система на русском языке.
28
+
29
+ Структура хранения данных:
30
+ ============================
31
+
32
+ 1. ДАТАФРЕЙМ ПАРАГРАФОВ (self.paragraphs_df):
33
+ ┌──────────────────────┬─────────────────────────────────┐
34
+ │ Колонка │ Описание │
35
+ ├──────────────────────┼─────────────────────────────────┤
36
+ │ paragraph_id │ Уникальный ID параграфа │
37
+ │ summary │ Название документа/раздела │
38
+ │ start_year │ Год начала периода │
39
+ │ end_year │ Год окончания периода │
40
+ │ document_id │ Ссылка на исходный документ │
41
+ └──────────────────────┴─────────────────────────────────┘
42
+
43
+ 2. ДАТАФРЕЙМ ЧАНКОВ (self.chunks_df):
44
+ ┌──────────────────────┬─────────────────────────────────┐
45
+ │ Колонка │ Описание │
46
+ ├──────────────────────┼─────────────────────────────────┤
47
+ │ chunk_id │ Уникальный ID чанка │
48
+ │ paragraph_id │ Foreign key на параграф │
49
+ │ text │ Исходный текст чанка │
50
+ │ lemmatized_text │ Лемматизированный текст │
51
+ │ (embeddings) │ (будет добавлено в будущем) │
52
+ └──────────────────────┴─────────────────────────────────┘
53
+
54
+ 3. ОБЪЕДИНЁННЫЙ ДАТАФРЕЙМ (get_merged_data()):
55
+ Комбинирует оба датафрейма через JOIN по paragraph_id.
56
+ Содержит все колонки обоих датафреймов.
57
+ Используется для поиска и фильтрации.
58
+
59
+ Ключевые преимущества:
60
+ - Избегаем дублирования метаданных параграфов
61
+ - Легко фильтровать по году, summary, документу
62
+ - Оптимизировано для работы с 5000+ чанками
63
+ - Простой merge для получения полной информации
64
+ """
65
+
66
+ def __init__(self, use_gpu: bool = False, load_json: bool = True, use_cache: bool = True):
67
  print("Инициализация RAG системы...")
68
  self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
69
+ self.use_cache = use_cache
70
+
71
+ # Путь к кэшу
72
+ self.cache_dir = Path('.cache')
73
+ if self.use_cache:
74
+ self.cache_dir.mkdir(exist_ok=True)
75
 
76
  # Инициализация лемматизатора для русского языка
77
  print(" Инициализация лемматизатора...")
 
83
  # self.documents after this phase: list of {'text': str, 'date': str}
84
  print(f" Загружено {len(self.documents)} сообщений")
85
 
86
+ # Парсим даты из документов и создаем датафреймы
87
+ self.paragraphs_df, self.chunks_df = self._process_documents_with_dates()
88
+
89
+ # Добавляем лемматизированный текст в датафрейм чанков с кэшем
90
+ print("2. Лемматизация текстов (с кэшированием)...")
91
+ self.chunks_df['lemmatized_text'] = self._lemmatize_with_cache(self.chunks_df['text'].tolist())
92
+
93
+ # Инициализируем CrossEncoder
94
+ print("3. Загрузка CrossEncoder модели...")
95
+ self.cross_encoder = CrossEncoder('DiTy/cross-encoder-russian-msmarco')
96
+
97
+ print("✅ RAG система готова к использованию")
98
 
 
 
99
  def _process_documents_with_dates(self):
100
  """
101
+ Обрабатывает документы с парсингом дат и создает два датафрейма.
102
 
103
  Returns:
104
+ tuple: (paragraphs_df, chunks_df)
105
+
106
+ paragraphs_df:
107
+ - paragraph_id: уникальный идентификатор абзаца
108
+ - summary: название документа/раздела
109
+ - start_year: год начала периода
110
+ - end_year: год окончания периода
111
+ - document_id: ссылка на исходный документ
112
+
113
+ chunks_df:
114
+ - chunk_id: уникальный идентификатор чанка
115
+ - paragraph_id: ссылка на абзац (foreign key)
116
+ - text: текст чанка
117
+ - lemmatized_text: лемматизированный текст (добавляется позже)
118
  """
119
+ paragraphs_data = []
120
+ chunks_data = []
121
+
 
122
  paragraph_id_counter = 0
123
+ chunk_id_counter = 0
124
 
125
  for doc_id, document in enumerate(self.documents):
126
+ dated_chunks = parse_metadata_from_document(document)
127
+
128
+ for chunk_text, year_range, summary in dated_chunks:
129
+ paragraphs = chunk_text.split('\n')
130
 
131
+ for paragraph in paragraphs:
132
+ paragraph = paragraph.strip()
133
+
134
+ # Добавляем информацию о параграфе в датафрейм параграфов
135
+ paragraphs_data.append({
136
+ 'paragraph_id': paragraph_id_counter,
137
+ 'summary': summary,
138
+ 'start_year': year_range[0],
139
+ 'end_year': year_range[1],
140
+ 'document_id': doc_id
141
+ })
142
 
143
+ # Разбиваем параграф на предложения и создаем чанки
144
+ sentences = re.split(r'(?<=[.!?])\s+', paragraph)
145
+ for sent in sentences:
146
+ chunks_data.append({
147
+ 'chunk_id': chunk_id_counter,
148
+ 'paragraph_id': paragraph_id_counter,
149
+ 'text': sent.strip()
150
+ })
151
+ chunk_id_counter += 1
152
+
153
+ paragraph_id_counter += 1
154
+
155
+ # Создаем датафреймы
156
+ paragraphs_df = pd.DataFrame(paragraphs_data)
157
+ chunks_df = pd.DataFrame(chunks_data)
158
+
159
+ print(f"Создано {len(chunks_df)} чанков")
160
+ print(f"Из {len(paragraphs_df)} абзацев в {len(set(paragraphs_df['document_id']))} документах")
161
+
162
+ return paragraphs_df, chunks_df
163
+
164
+ # ============ Методы кэширования лемматизации ============
165
+
166
+ @staticmethod
167
+ def _compute_text_hash(text: str) -> str:
168
+ """
169
+ Вычисляет SHA256 хэш текста.
170
+
171
+ Args:
172
+ text: Текст для хэширования
173
+
174
+ Returns:
175
+ str: Хэш в hex формате
176
+ """
177
+ return hashlib.sha256(text.encode('utf-8')).hexdigest()
178
+
179
+ def _load_cache(self) -> dict:
180
+ """
181
+ Загружает кэш лемматизации из файловой системы.
182
+
183
+ Returns:
184
+ dict: {text_hash -> lemmatized_tokens}
185
+ """
186
+ cache_file = self.cache_dir / 'lemmatization_cache.pkl'
187
+
188
+ if cache_file.exists():
189
+ try:
190
+ with open(cache_file, 'rb') as f:
191
+ cache = pickle.load(f)
192
+ print(f" ✓ Кэш загружен ({len(cache)} записей)")
193
+ return cache
194
+ except Exception as e:
195
+ print(f" ⚠ Ошибка при загрузке кэша: {e}")
196
+ return {}
197
+ return {}
198
+
199
+ def _save_cache(self, cache: dict) -> None:
200
+ """
201
+ Сохраняет кэш лемматизации в файловую систему.
202
+
203
+ Args:
204
+ cache: {text_hash -> lemmatized_tokens}
205
+ """
206
+ cache_file = self.cache_dir / 'lemmatization_cache.pkl'
207
+
208
+ try:
209
+ with open(cache_file, 'wb') as f:
210
+ pickle.dump(cache, f)
211
+ print(f" ✓ Кэш сохранён ({len(cache)} записей)")
212
+ except Exception as e:
213
+ print(f" ⚠ Ошибка при сохранении кэша: {e}")
214
+
215
+ def _lemmatize_with_cache(self, texts: list[str]) -> list:
216
+ """
217
+ Лемматизирует тексты с использованием кэша.
218
+ Проверяет хэши текстов - если хэш совпадает с кэшированным,
219
+ использует кэшированный результат. Иначе перелемматизирует.
220
+
221
+ Args:
222
+ texts: Список текстов для лемматизации
223
+
224
+ Returns:
225
+ list: Лемматизированные тексты
226
+ """
227
+ if not self.use_cache:
228
+ # Если кэш отключен, просто лемматизировать
229
+ return [self.lemmatizer.tokenize_text(text) for text in texts]
230
+
231
+ # Загружаем существующий кэш
232
+ cache = self._load_cache()
233
+ text_hashes = {}
234
+ results = []
235
+ needs_save = False
236
+
237
+ for text in texts:
238
+ text_hash = self._compute_text_hash(text)
239
+ text_hashes[text] = text_hash
240
+
241
+ if text_hash in cache:
242
+ # Используем кэшированный результат
243
+ results.append(cache[text_hash])
244
+ else:
245
+ # Лемматизируем и добавляем в кэш
246
+ lemmatized = self.lemmatizer.tokenize_text(text)
247
+ results.append(lemmatized)
248
+ cache[text_hash] = lemmatized
249
+ needs_save = True
250
+
251
+ # Сохраняем кэш если были новые записи
252
+ if needs_save:
253
+ self._save_cache(cache)
254
+
255
+ return results
256
+
257
+ def clear_cache(self) -> None:
258
+ """
259
+ Очищает кэш лемматизации.
260
+ """
261
+ cache_file = self.cache_dir / 'lemmatization_cache.pkl'
262
+
263
+ try:
264
+ if cache_file.exists():
265
+ cache_file.unlink()
266
+ print("✓ Кэш очищен")
267
+ else:
268
+ print("⚠ Файл кэша не найден")
269
+ except Exception as e:
270
+ print(f"⚠ Ошибка при очистке кэша: {e}")
271
+
272
+ def get_cache_stats(self) -> dict:
273
+ """
274
+ Возвращает статистику кэша.
275
+
276
+ Returns:
277
+ dict: Информация о кэше
278
+ """
279
+ cache_file = self.cache_dir / 'lemmatization_cache.pkl'
280
+
281
+ if cache_file.exists():
282
+ cache = self._load_cache() if self.use_cache else {}
283
+ file_size_mb = cache_file.stat().st_size / (1024 * 1024)
284
+
285
+ return {
286
+ 'cache_enabled': self.use_cache,
287
+ 'cache_file': str(cache_file),
288
+ 'cached_entries': len(cache),
289
+ 'file_size_mb': round(file_size_mb, 2),
290
+ 'exists': True
291
+ }
292
+ else:
293
+ return {
294
+ 'cache_enabled': self.use_cache,
295
+ 'cache_file': str(cache_file),
296
+ 'cached_entries': 0,
297
+ 'file_size_mb': 0,
298
+ 'exists': False
299
+ }
300
 
301
+ # ============ Вспомогательные методы для работы с датафреймами ============
302
+
303
+ def get_merged_data(self):
304
+ """Возвращает объединённый датафрейм чанков с метаданными параграфов.
305
+
306
+ Returns:
307
+ pd.DataFrame: Датафрейм с полями:
308
+ chunk_id, paragraph_id, text, lemmatized_text,
309
+ summary, start_year, end_year, document_id
310
+ """
311
+ return self.chunks_df.merge(
312
+ self.paragraphs_df,
313
+ on='paragraph_id',
314
+ how='left'
315
+ )
316
+
317
+ def filter_by_year_range(self, year_range: tuple[int, int]) -> pd.DataFrame:
318
+ """Возвращает чанки, которые пересекаются с заданным диапазоном лет.
319
+
320
+ Args:
321
+ year_range: (start_year, end_year)
322
+
323
+ Returns:
324
+ pd.DataFrame: Отфильтрованные чанки с метаданными
325
+ """
326
+ merged = self.get_merged_data()
327
+
328
+ # Проверяем пересечение диапазонов
329
+ return merged[
330
+ (merged['start_year'] <= year_range[1]) &
331
+ (merged['end_year'] >= year_range[0])
332
+ ]
333
+
334
  def rerank_search(self, query: str) -> list[dict]:
335
+ """Ранжирует все чанки используя CrossEncoder модель.
336
+
337
+ Args:
338
+ query: Текст запроса
339
+
340
+ Returns:
341
+ list: Отсортированный список результатов с scores
342
  """
343
+ pairs = [[query, text] for text in self.chunks_df['text'].tolist()]
344
+ scores = self.cross_encoder.predict(pairs)
345
+ breakpoint()
346
+
347
+ # Добавляем scores в датафрейм и сортируем
348
+ results = self.chunks_df.copy()
349
+ results['score'] = scores
350
+ return results.sort_values('score', ascending=False).to_dict('records')
351
 
352
  def semantic_search(self, query: str) -> list:
353
  # 1. Семантический поиск
 
355
  semantic_scores = torch.nn.functional.cosine_similarity(self.embeddings, query_embedding, eps=1e-8).cpu()
356
  return semantic_scores
357
 
358
+ def bm25_search(self, query: str) -> list:
359
+ """BM25 поиск, используя лемматизированные чанки.
360
+
361
+ Args:
362
+ query: Текст запроса
363
+
364
+ Returns:
365
+ list: Скоры для каждого чанка
366
+ """
367
+ bm25 = BM25Okapi(self.chunks_df['lemmatized_text'].tolist())
368
  tokenized_query = self.lemmatizer.tokenize_text(query)
369
+ return bm25.get_scores(tokenized_query)
370
 
371
+ # ============ Для тестирования cross-encoder ============
372
+ def test_query_with_cross_encoder(self, query: str,
373
+ target_summary: str):
374
+ """ Тестирует запрос с cross-encoder и выводит результаты.
375
 
376
  Args:
377
+ query: Текст запроса
378
+ target_summary: Ожидаемый summary
379
+ """
380
+ print(f"{'='*90}")
381
+ print(f" ✓ Target summary: '{target_summary}'\n")
382
+
383
+ # Получаем объединённый датафрейм
384
+ merged_df = self.get_merged_data()
385
+
386
+ # ================================================================
387
+ # 1. BM25 ПОИСК
388
+ # ================================================================
389
+ print(f" 📊 BM25 ЛЕКСИЧЕСКИЙ ПОИСК:")
390
+
391
+ # Инициализируем BM25
392
+ bm25_scores = self.bm25_search(query)
393
+
394
+ # Добавляем scores в помощный датафрейм
395
+ search_df = merged_df.copy()
396
+ search_df['bm25_score'] = bm25_scores
397
+
398
+ # Получаем топ-30 по BM25
399
+ top_bm25 = search_df.nlargest(30, 'bm25_score')
400
+
401
+ print(f" Топ-10 чанков, их summary-ы:")
402
+
403
+ # Собираем уникальные summary из BM25 результатов
404
+ bm25_summaries = top_bm25['summary'].unique()
405
+ summary_scores_bm25 = dict(
406
+ top_bm25.groupby('summary')['bm25_score'].first()
407
+ )
408
+
409
+ for rank, summary in enumerate(bm25_summaries[:10], 1):
410
+ score = summary_scores_bm25[summary]
411
+ print(f" {rank:2}. BM25={score:6.2f} [{summary[:50]:50}]")
412
+
413
+ print(f" → Уникальных summary найдено: {len(bm25_summaries)}")
414
+ print(f" → Целевой summary в результатах: {'✓ ДА' if target_summary in bm25_summaries else '✗ НЕТ'}")
415
+
416
+ # ================================================================
417
+ # 2. КРОСС-ЭНКОДЕР РАНЖИРОВАНИЕ
418
+ # ================================================================
419
+ print(f"\n 🏆 КРОСС-ЭНКОДЕР РАНЖИРОВАНИЕ:")
420
+
421
+ # Собираем ВСЕ уникальные summary
422
+ all_unique_summaries = merged_df['summary'].unique().tolist()
423
+ assert target_summary in all_unique_summaries, breakpoint()
424
+
425
+ cross_encoder_start = time.time()
426
+
427
+ # Подготавливаем пары query-summary
428
+ pairs = [[query, summary] for summary in all_unique_summaries]
429
+
430
+ # Ранжируем через кросс-энкодер
431
+ cross_scores = self.cross_encoder.predict(pairs)
432
+ cross_encoder_time = time.time() - cross_encoder_start
433
+
434
+ # Сортируем результаты
435
+ ranked_indices = sorted(
436
+ range(len(cross_scores)),
437
+ key=lambda i: cross_scores[i],
438
+ reverse=True
439
+ )
440
+
441
+ print(f" (время: {cross_encoder_time:.3f} сек)")
442
+ print(f" Top-5 summary (из {len(all_unique_summaries)} всего):")
443
+
444
+ cross_target_rank = None
445
+
446
+ for rank, idx in enumerate(ranked_indices[:5], 1):
447
+ summary = all_unique_summaries[idx]
448
+ score = cross_scores[idx]
449
 
450
+ is_target = summary == target_summary
451
+ mark = "⭐ TARGET ⭐" if is_target else " " * 13
452
+
453
+ print(f" {mark} {rank}. Cross={score:7.4f} [{summary[:50]:50}]")
454
+
455
+ if is_target:
456
+ cross_target_rank = rank
457
+
458
+ if not cross_target_rank:
459
+ print(f" ❌ Целевой summary НЕ в топ-5")
test_cross_encoder.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from retrieval import Retrieval
3
+
4
+
5
+ @dataclass
6
+ class TestCaseForCrossEncoder:
7
+ query: str
8
+ good_answer: str
9
+
10
+
11
+ test_cases = [
12
+ TestCaseForCrossEncoder(
13
+ 'Какие изменения в транспорте Рязани были бы полезны на текущий момент?',
14
+ 'Актуальные проекты новых троллейбусных линий, которые полезно бы построить',
15
+ ),
16
+ TestCaseForCrossEncoder(
17
+ 'Какие продления троллейбусной сети были бы полезны на текущий момент?',
18
+ 'Актуальные проекты новых троллейбусных линий, которые полезно бы построить',
19
+ ),
20
+ TestCaseForCrossEncoder(
21
+ 'Расскажи о провалившихся экспериментах в Рязани',
22
+ 'Попытки (все из которых неудачные) запустить городскую электричку в истории',
23
+ ),
24
+ TestCaseForCrossEncoder(
25
+ 'Расскажи историю маршрута маршрутки № 92 в Рязани',
26
+ 'история ныне закрытой маршрутки № 92',
27
+ ),
28
+ TestCaseForCrossEncoder(
29
+ 'Какой маршрут в Рязани закрылся из-за плохой трассировки?',
30
+ 'У троллейбусного маршрута №2 была неудачная трасса - в объезд основных узлов города',
31
+ ),
32
+ TestCaseForCrossEncoder(
33
+ 'Когда маршрут троллейбуса №10 продлили до площади Попова?',
34
+ 'история троллейбусного маршрута № 10')
35
+ ]
36
+
37
+
38
+ def test_cross_encoder_vs_bm25():
39
+ """Тестирует кросс-энкодер vs BM25 на всех документах."""
40
+ print("=" * 90)
41
+ print("СРАВНЕНИЕ: КРОСС-ЭНКОДЕР vs BM25 ЛЕММАТИЗИРОВАННЫЙ ПОИСК")
42
+ print("=" * 90)
43
+
44
+ # Создаем объект Retrieval (загружает корпус автоматически)
45
+ retrieval = Retrieval(use_gpu=False)
46
+
47
+ # Тестируем каждый тестовый случай
48
+ print("=" * 90)
49
+ print("ТЕСТИРОВАНИЕ ОТДЕЛЬНЫХ ЗАПРОСОВ")
50
+ print("=" * 90)
51
+
52
+ for test_num, test_case in enumerate(test_cases, 1):
53
+ retrieval.test_query_with_cross_encoder(
54
+ query=test_case.query,
55
+ target_summary=test_case.good_answer,
56
+ test_num=test_num
57
+ )
58
+
59
+ print("\n" + "=" * 90)
60
+ print(f"✅ Тестирование завершено")
61
+ print("=" * 90)
62
+
63
+
64
+ if __name__ == "__main__":
65
+ test_cross_encoder_vs_bm25()
66
+
tests/test_retirieval.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from retrieval import Retrieval
2
+
3
+ retr = Retrieval(use_gpu=False)
4
+ res = retr.bm25_search('канищево', top_k=5)
5
+ print(res)