RAG2 / calculate_params.py
antimoda1
more refactor
8109cc7
"""В данном файле определяются параметры для оптимального объединения BM25 и семантического поиска.
Параметры оптимизируются на тестовых примерах"""
from dataclasses import dataclass
from retrieval import Retrieval
import numpy as np
@dataclass
class TestCaseFindSent:
query: str
good_answer: str
compare_with: str = 'text'
@dataclass
class TestCaseFindSummary:
query: str
good_answer: str
compare_with: str = 'summary'
test_cases = (
TestCaseFindSummary(
'Какие изменения в транспорте были бы полезны на текущий момент?',
'Актуальные проекты новых троллейбусных линий, которые полезно бы построить',
),
TestCaseFindSummary(
'Какие продления троллейбусной сети были бы полезны на текущий момент?',
'Актуальные проекты новых троллейбусных линий, которые полезно бы построить',
),
TestCaseFindSummary(
'Расскажи о провалившихся экспериментах',
'Попытки (все из которых неудачные) запустить городскую электричку в истории',
),
TestCaseFindSummary(
'Расскажи историю маршрута маршрутки № 92',
'история ныне закрытой маршрутки № 92'),
TestCaseFindSummary(
'Какой маршрут закрылся из-за плохой трассировки?',
'У троллейбусного маршрута №2 была неудачная трасса - в объезд основных узлов города',
),
TestCaseFindSummary(
'Когда маршрут троллейбуса №10 продлили до площади Попова?',
'история троллейбусного маршрута № 10'),
TestCaseFindSummary(
'Какой троллейбусный маршрут закрыли из-за того, что после продлений он стал слишком длинным и нестабильным',
'История троллейбусного маршрута № 8, этап с увеличением трассы и снижение числа троллейбусов на маршруте № 8, его постепенная деградация и в конце концов закрытие.'
),
TestCaseFindSent(
'Какие называлось МУП УРТ, пока в её составе был трамвай?',
'УРТ стало троллейбусно-трамвайным управлением (ТТУ) только в 2007 году и таковым являлась до 2010 (когда закрыли трамвай)'),
TestCaseFindSent(
'Когда был закрыт троллейбусный маршрут № 18 и почему?',
'Соответственно, в феврале 2013 года 18 троллейбус был закрыт как дублирующий 8, а 8 троллейбус пущен через Канищево до телезавода, теперь совершенно объезжая Приокский.',
),
TestCaseFindSent(
'Почему на bus62 присутствует остановка на улице Кудрявцева? Там когда-то что-то было?',
'В таком виде он застал начало эпохи Глонасса и поэтому автобусная конечная "Центральный рынок" на улице Кудрявцева до сих пор есть в "Умном транспорте" (он же bus62), но на ней ничего не разворачивается уже 13 лет.'
),
TestCaseFindSent(
'Сколько рейсов делала городская электричка, запущенная в 2009 году?',
'Ну во-первых электричка ходила редко, всего пять раз в день.'
),
TestCaseFindSent(
'Когда закрылась версия третьего автобуса, когда он ходил в Канищево?',
'И с 17 октября 1987 года третий автобус закрывается, причем делается это на редкость официально, со всеми объявлениями.'
),
TestCaseFindSent(
'Когда в Рязань поступили автобусы ЗИЛ-158?',
'Впервые ЗИЛ-158 появились в автоколонне 73 в декабре 1957 года'
),
TestCaseFindSent(
'Когда пропали двухвагонные трамвайные поезда?',
'А ещё в 2004 году на таком маршруте ездили двухвагонные поезда из трамваев.',
),
TestCaseFindSent(
'Когда общественный транспорт впервые пришёл в Ворошиловку?',
'Поэтому в 1980 году были открыты маршруты № 22 в Карцево и № 23 до улицы Берёзовой, а в 1981-м — 24-й автобус на Ворошиловку.'
),
TestCaseFindSent(
'Когда общественный транспорт впервые пришёл в Карцево?',
'Поэтому в 1980 году были открыты маршруты № 22 в Карцево и № 23 до улицы Берёзовой, а в 1981-м — 24-й автобус на Ворошиловку.'
),
TestCaseFindSent(
'Когда закрыли депо 1?',
'Год спустя, в 2016 году, закрыли первое троллейбусное депо.'),
TestCaseFindSent(
'Когда закрылась 91-я маршрутка?',
'В ноябре 2021 года 91 маршрутка прекратила свою работу.'
),
TestCaseFindSent(
'Когда на остановочных табличках Рязани появились конечные остановки маршрутов?',
'Наконец, в 2015 году в истории остановочных табличек Рязани происходит настоящая революция. Впервые за 60 с лишним лет их истории на них наконец-то появились конечные остановки маршрутов. '
)
)
def test_retrieval():
"""Тестирует кросс-энкодер vs BM25 на всех документах."""
# Создаем объект Retrieval (загружает корпус автоматически)
retrieval = Retrieval()
good_paragraphs = []
bm25_scores = []
semantic_scores = []
for test_case in test_cases:
good_paragraphs_for_case = []
for i, text in enumerate(retrieval.paragraphs_df[test_case.compare_with]):
if test_case.good_answer in text:
good_paragraphs_for_case.append(i)
assert good_paragraphs_for_case, breakpoint()
good_paragraphs.append(good_paragraphs_for_case)
# masks.append(np.isin(test_case.query, retrieval.paragraphs_df['text']))
bm25_scores.append(retrieval.bm25_search(query=test_case.query))
semantic_scores.append(retrieval.semantic_search(query=test_case.query))
bm25_scores = np.array(bm25_scores)
semantic_scores = np.array(semantic_scores)
def func(params):
scores = bm25_scores * params[0] + semantic_scores
ranks = np.argsort(scores, axis=1) # важно: сортировка по возрастанию от нерелевантных к релевантным
# Получаем ранги релевантных параграфов для каждого testcase
index_maps = [{v: i for i, v in enumerate(ranks_case)} for ranks_case in ranks]
relevant_indices = [
[index_map[x] for x in good_paragraphs_for_case] for index_map, good_paragraphs_for_case in zip(index_maps, good_paragraphs, strict=True)]
# Метрика "среднее+минимум" внутри каждого testcase, а затем "среднее+минимум" по всем testcases
metrics = np.array([np.mean(relevant_ids_case) + np.min(relevant_ids_case) for relevant_ids_case in relevant_indices])
return -(np.mean(metrics) + np.min(metrics)) # чем больше метрика, тем скор (и ранг) выше, поэтому берём минус для оптимизации на минимум
from scipy.optimize import differential_evolution, minimize
params = differential_evolution(func, bounds=[(0.2, 5)])
print(params)
print(f"Оптимальный вес для BM25: {params.x[0]:.4f}")
if __name__ == "__main__":
test_retrieval()