antimoda1 commited on
Commit
6a7ab41
·
1 Parent(s): 4dd2f1d

add file to test

Browse files
Files changed (2) hide show
  1. calculate_params.py +152 -0
  2. retrieval.py +3 -4
calculate_params.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """В данном файле определяются параметры для оптимального объединения BM25 и семантического поиска.
2
+ Параметры оптимизируются на тестовых примерах"""
3
+ from dataclasses import dataclass
4
+ from retrieval import Retrieval
5
+
6
+ import numpy as np
7
+
8
+
9
+ @dataclass
10
+ class TestCaseFindSent:
11
+ query: str
12
+ good_answer: str
13
+ compare_with: str = 'text'
14
+
15
+
16
+ @dataclass
17
+ class TestCaseFindSummary:
18
+ query: str
19
+ good_answer: str
20
+ compare_with: str = 'summary'
21
+
22
+
23
+ test_cases = (
24
+ TestCaseFindSummary(
25
+ 'Какие изменения в транспорте были бы полезны на текущий момент?',
26
+ 'Актуальные проекты новых троллейбусных линий, которые полезно бы построить',
27
+ ),
28
+ TestCaseFindSummary(
29
+ 'Какие продления троллейбусной сети были бы полезны на текущий момент?',
30
+ 'Актуальные проекты новых троллейбусных линий, которые полезно бы построить',
31
+ ),
32
+ TestCaseFindSummary(
33
+ 'Расскажи о провалившихся экспериментах',
34
+ 'Попытки (все из которых неудачные) запустить городскую электричку в истории',
35
+ ),
36
+ TestCaseFindSummary(
37
+ 'Расскажи историю маршрута маршрутки № 92',
38
+ 'история ныне закрытой маршрутки № 92'),
39
+ TestCaseFindSummary(
40
+ 'Какой маршрут закрылся из-за плохой трассировки?',
41
+ 'У троллейбусного маршрута №2 была неудачная трасса - в объезд основных узлов города',
42
+ ),
43
+ TestCaseFindSummary(
44
+ 'Когда маршрут троллейбуса №10 продлили до площади Попова?',
45
+ 'история троллейбусного маршрута № 10'),
46
+ TestCaseFindSummary(
47
+ 'Какой троллейбусный маршрут закрыли из-за того, что после продлений он стал слишком длинным и нестабильным',
48
+ 'История троллейбусного маршрута № 8, этап с увеличением трассы и снижение числа троллейбусов на маршруте № 8, его постепенная деградация и в конце концов закрытие.'
49
+ ),
50
+ TestCaseFindSent(
51
+ 'Какие называлось МУП УРТ, пока в её составе был трамвай?',
52
+ 'УРТ стало троллейбусно-трамвайным управлением (ТТУ) только в 2007 году и таковым являлась до 2010 (когда закрыли трамвай)'),
53
+ TestCaseFindSent(
54
+ 'Когда был закрыт троллейбусный маршрут № 18 и почему?',
55
+ 'Соответственно, в феврале 2013 года 18 троллейбус был закрыт как дублирующий 8, а 8 троллейбус пущен через Канищево до телезавода, теперь совершенно объезжая Приокский.',
56
+ ),
57
+ TestCaseFindSent(
58
+ 'Почему на bus62 присутствует остановка на улице Кудрявцева? Там когда-то что-то было?',
59
+ 'В таком виде он застал начало эпохи Глонасса и поэтому автобусная конечная "Центральный рынок" на улице Кудрявцева до сих пор есть в "Умном транспорте" (он же bus62), но на ней ничего не разворачивается уже 13 лет.'
60
+ ),
61
+ TestCaseFindSent(
62
+ 'Сколько рейсов делала городская электричка, запущенная в 2009 году?',
63
+ 'Ну во-первых электричка ходила редко, всего пять раз в день.'
64
+ ),
65
+ TestCaseFindSent(
66
+ 'Когда закрылась версия третьего автобуса, когда он ходил в Канищево?',
67
+ 'И с 17 октября 1987 года третий автобус закрывается, причем делается это на редкость официально, со всеми объявлениями.'
68
+ ),
69
+ TestCaseFindSent(
70
+ 'Когда в Рязань поступили автобусы ЗИЛ-158?',
71
+ 'Впервые ЗИЛ-158 появились в автоколонне 73 в декабре 1957 года'
72
+ ),
73
+ TestCaseFindSent(
74
+ 'Когда пропали двухвагонные трамвайные поезда?',
75
+ 'А ещё в 2004 году на таком маршруте ездили двухвагонные поезда из трамваев.',
76
+ ),
77
+ TestCaseFindSent(
78
+ 'Когда общественный транспорт впервые пришёл в Ворошиловку?',
79
+ 'Поэтому в 1980 году были открыты маршруты № 22 в Карцево и № 23 до улицы Берёзовой, а в 1981-м — 24-й автобус на Ворошиловку.'
80
+ ),
81
+ TestCaseFindSent(
82
+ 'Когда общественный транспорт впервые пришёл в Карцево?',
83
+ 'Поэтому в 1980 году были открыты маршруты № 22 в Карцево и № 23 до улицы Берёзовой, а в 1981-м — 24-й автобус на Ворошиловку.'
84
+ ),
85
+ TestCaseFindSent(
86
+ 'Когда закрыли депо 1?',
87
+ 'Год спустя, в 2016 году, закрыли первое троллейбусное депо.'),
88
+ TestCaseFindSent(
89
+ 'Когда закрылась 91-я маршрутка?',
90
+ 'В ноябре 2021 года 91 маршрутка прекратила свою работу.'
91
+ ),
92
+ TestCaseFindSent(
93
+ 'Когда на остановочных табличках Рязани появились конечные остановки маршрутов?',
94
+ 'Наконец, в 2015 году в истории остановочных табличек Рязани происходит настоящая революция. Впервые за 60 с лишним лет их истории на них наконец-то появились конечные остановки маршрутов. '
95
+ )
96
+ )
97
+
98
+
99
+ def get_ranks(scores, good_paragraphs):
100
+ scores = np.array(scores)
101
+ ranks = np.argsort(scores, axis=1) # важно: сортировка по возрастанию от нерелевантных к релевантным
102
+ mask = np.array([np.isin(rank_for_case, good_paragraphs_for_case)
103
+ for rank_for_case, good_paragraphs_for_case in zip(ranks, good_paragraphs, strict=True)])
104
+ relevant_ranks = [ranks_case[mask_case] for mask_case, ranks_case in zip(mask, ranks, strict=True)]
105
+ breakpoint()
106
+ return relevant_ranks
107
+
108
+ def test_cross_encoder_vs_bm25():
109
+ """Тестирует кросс-энкодер vs BM25 на всех документах."""
110
+ # Создаем объект Retrieval (загружает корпус автоматически)
111
+ retrieval = Retrieval()
112
+
113
+ good_paragraphs = []
114
+ bm25_scores = []
115
+ semantic_scores = []
116
+ for test_case in test_cases:
117
+ good_paragraphs_for_case = []
118
+ for i, text in enumerate(retrieval.paragraphs_df[test_case.compare_with]):
119
+ if test_case.good_answer in text:
120
+ good_paragraphs_for_case.append(i)
121
+ assert good_paragraphs_for_case, breakpoint()
122
+ good_paragraphs.append(good_paragraphs_for_case)
123
+ # masks.append(np.isin(test_case.query, retrieval.paragraphs_df['text']))
124
+ bm25_scores.append(retrieval.bm25_search(query=test_case.query))
125
+ semantic_scores.append(retrieval.semantic_search(query=test_case.query))
126
+
127
+ bm25_scores = np.array(bm25_scores)
128
+ semantic_scores = np.array(semantic_scores)
129
+
130
+ def func(params):
131
+ scores = bm25_scores * params[0] + semantic_scores
132
+ # ranks = np.argsort(scores, axis=1) # важно: сортировка по возрастанию от нерелевантных к релевантным
133
+
134
+ # Получаем ранги релевантных параграфов для каждого testcase
135
+ index_maps = [{v: i for i, v in enumerate(scores_case)} for scores_case in scores]
136
+ relevant_indices = [
137
+ [index_map[x] for x in good_paragraphs_for_case] for index_map, good_paragraphs_for_case in zip(index_maps, good_paragraphs, strict=True)]
138
+
139
+ # Метрика "среднее+минимум" внутри каждого testcase, а затем "среднее+минимум" по всем testcases
140
+ metrics = np.array([np.mean(relevant_ids_case) + np.min(relevant_ids_case) for relevant_ids_case in relevant_indices])
141
+ return -(np.mean(metrics) + np.min(metrics)) # чем больше метрика, тем скор (и ранг) выше, поэтому берём минус для оптимизации на минимум
142
+
143
+ from scipy.optimize import differential_evolution, minimize
144
+ params = differential_evolution(func, bounds=[(0.2, 5)])
145
+ print(params)
146
+ print(f"Оптимальный вес для BM25: {params.x[0]:.4f}")
147
+
148
+
149
+
150
+ if __name__ == "__main__":
151
+ test_cross_encoder_vs_bm25()
152
+
retrieval.py CHANGED
@@ -149,7 +149,7 @@ class Retrieval:
149
 
150
  # Сохраняем кэш если были новые записи
151
  if needs_save:
152
- with open(cache, 'wb') as f:
153
  pickle.dump(cache, f)
154
  print(f" ✓ Кэш сохранён ({len(cache)} записей)")
155
 
@@ -170,7 +170,7 @@ class Retrieval:
170
  Returns:
171
  np.ndarray: Скоры для каждого абзаца (не предложения!)
172
  """
173
- bm25 = BM25Okapi(self.chunks_df['lemmatized_text'].tolist())
174
  tokenized_query = self.lemmatizer.tokenize_text(query)
175
  sentences_scores = bm25.get_scores(tokenized_query)
176
  df = self.chunks_df['paragraph_id'].to_frame().copy()
@@ -190,5 +190,4 @@ class Retrieval:
190
  bm25_scores = self.bm25_search(query)
191
  semantic_scores = self.semantic_search(query).numpy()
192
  bm25_scores = normalize_array(bm25_scores)
193
- semantic_scores = normalize_array(semantic_scores)
194
- return weight_semantic * semantic_scores + weight_bm25 * bm25_scores
 
149
 
150
  # Сохраняем кэш если были новые записи
151
  if needs_save:
152
+ with open(self.cache_dir / 'lemmatization_cache.pkl', 'wb') as f:
153
  pickle.dump(cache, f)
154
  print(f" ✓ Кэш сохранён ({len(cache)} записей)")
155
 
 
170
  Returns:
171
  np.ndarray: Скоры для каждого абзаца (не предложения!)
172
  """
173
+ bm25 = BM25Okapi(self.chunks_df['lemmatized_text'])
174
  tokenized_query = self.lemmatizer.tokenize_text(query)
175
  sentences_scores = bm25.get_scores(tokenized_query)
176
  df = self.chunks_df['paragraph_id'].to_frame().copy()
 
190
  bm25_scores = self.bm25_search(query)
191
  semantic_scores = self.semantic_search(query).numpy()
192
  bm25_scores = normalize_array(bm25_scores)
193
+ return semantic_scores + 1.0 * bm25_scores