| | """В данном файле определяются параметры для оптимального объединения BM25 и семантического поиска. |
| | Параметры оптимизируются на тестовых примерах""" |
| | from dataclasses import dataclass |
| | from retrieval import Retrieval |
| |
|
| | import numpy as np |
| |
|
| |
|
| | @dataclass |
| | class TestCaseFindSent: |
| | query: str |
| | good_answer: str |
| | compare_with: str = 'text' |
| |
|
| |
|
| | @dataclass |
| | class TestCaseFindSummary: |
| | query: str |
| | good_answer: str |
| | compare_with: str = 'summary' |
| |
|
| |
|
| | test_cases = ( |
| | TestCaseFindSummary( |
| | 'Какие изменения в транспорте были бы полезны на текущий момент?', |
| | 'Актуальные проекты новых троллейбусных линий, которые полезно бы построить', |
| | ), |
| | TestCaseFindSummary( |
| | 'Какие продления троллейбусной сети были бы полезны на текущий момент?', |
| | 'Актуальные проекты новых троллейбусных линий, которые полезно бы построить', |
| | ), |
| | TestCaseFindSummary( |
| | 'Расскажи о провалившихся экспериментах', |
| | 'Попытки (все из которых неудачные) запустить городскую электричку в истории', |
| | ), |
| | TestCaseFindSummary( |
| | 'Расскажи историю маршрута маршрутки № 92', |
| | 'история ныне закрытой маршрутки № 92'), |
| | TestCaseFindSummary( |
| | 'Какой маршрут закрылся из-за плохой трассировки?', |
| | 'У троллейбусного маршрута №2 была неудачная трасса - в объезд основных узлов города', |
| | ), |
| | TestCaseFindSummary( |
| | 'Когда маршрут троллейбуса №10 продлили до площади Попова?', |
| | 'история троллейбусного маршрута № 10'), |
| | TestCaseFindSummary( |
| | 'Какой троллейбусный маршрут закрыли из-за того, что после продлений он стал слишком длинным и нестабильным', |
| | 'История троллейбусного маршрута № 8, этап с увеличением трассы и снижение числа троллейбусов на маршруте № 8, его постепенная деградация и в конце концов закрытие.' |
| | ), |
| | TestCaseFindSent( |
| | 'Какие называлось МУП УРТ, пока в её составе был трамвай?', |
| | 'УРТ стало троллейбусно-трамвайным управлением (ТТУ) только в 2007 году и таковым являлась до 2010 (когда закрыли трамвай)'), |
| | TestCaseFindSent( |
| | 'Когда был закрыт троллейбусный маршрут № 18 и почему?', |
| | 'Соответственно, в феврале 2013 года 18 троллейбус был закрыт как дублирующий 8, а 8 троллейбус пущен через Канищево до телезавода, теперь совершенно объезжая Приокский.', |
| | ), |
| | TestCaseFindSent( |
| | 'Почему на bus62 присутствует остановка на улице Кудрявцева? Там когда-то что-то было?', |
| | 'В таком виде он застал начало эпохи Глонасса и поэтому автобусная конечная "Центральный рынок" на улице Кудрявцева до сих пор есть в "Умном транспорте" (он же bus62), но на ней ничего не разворачивается уже 13 лет.' |
| | ), |
| | TestCaseFindSent( |
| | 'Сколько рейсов делала городская электричка, запущенная в 2009 году?', |
| | 'Ну во-первых электричка ходила редко, всего пять раз в день.' |
| | ), |
| | TestCaseFindSent( |
| | 'Когда закрылась версия третьего автобуса, когда он ходил в Канищево?', |
| | 'И с 17 октября 1987 года третий автобус закрывается, причем делается это на редкость официально, со всеми объявлениями.' |
| | ), |
| | TestCaseFindSent( |
| | 'Когда в Рязань поступили автобусы ЗИЛ-158?', |
| | 'Впервые ЗИЛ-158 появились в автоколонне 73 в декабре 1957 года' |
| | ), |
| | TestCaseFindSent( |
| | 'Когда пропали двухвагонные трамвайные поезда?', |
| | 'А ещё в 2004 году на таком маршруте ездили двухвагонные поезда из трамваев.', |
| | ), |
| | TestCaseFindSent( |
| | 'Когда общественный транспорт впервые пришёл в Ворошиловку?', |
| | 'Поэтому в 1980 году были открыты маршруты № 22 в Карцево и № 23 до улицы Берёзовой, а в 1981-м — 24-й автобус на Ворошиловку.' |
| | ), |
| | TestCaseFindSent( |
| | 'Когда общественный транспорт впервые пришёл в Карцево?', |
| | 'Поэтому в 1980 году были открыты маршруты № 22 в Карцево и № 23 до улицы Берёзовой, а в 1981-м — 24-й автобус на Ворошиловку.' |
| | ), |
| | TestCaseFindSent( |
| | 'Когда закрыли депо 1?', |
| | 'Год спустя, в 2016 году, закрыли первое троллейбусное депо.'), |
| | TestCaseFindSent( |
| | 'Когда закрылась 91-я маршрутка?', |
| | 'В ноябре 2021 года 91 маршрутка прекратила свою работу.' |
| | ), |
| | TestCaseFindSent( |
| | 'Когда на остановочных табличках Рязани появились конечные остановки маршрутов?', |
| | 'Наконец, в 2015 году в истории остановочных табличек Рязани происходит настоящая революция. Впервые за 60 с лишним лет их истории на них наконец-то появились конечные остановки маршрутов. ' |
| | ) |
| | ) |
| |
|
| |
|
| | def test_retrieval(): |
| | """Тестирует кросс-энкодер vs BM25 на всех документах.""" |
| | |
| | retrieval = Retrieval() |
| | |
| | good_paragraphs = [] |
| | bm25_scores = [] |
| | semantic_scores = [] |
| | for test_case in test_cases: |
| | good_paragraphs_for_case = [] |
| | for i, text in enumerate(retrieval.paragraphs_df[test_case.compare_with]): |
| | if test_case.good_answer in text: |
| | good_paragraphs_for_case.append(i) |
| | assert good_paragraphs_for_case, breakpoint() |
| | good_paragraphs.append(good_paragraphs_for_case) |
| | |
| | bm25_scores.append(retrieval.bm25_search(query=test_case.query)) |
| | semantic_scores.append(retrieval.semantic_search(query=test_case.query)) |
| |
|
| | bm25_scores = np.array(bm25_scores) |
| | semantic_scores = np.array(semantic_scores) |
| |
|
| | def func(params): |
| | scores = bm25_scores * params[0] + semantic_scores |
| | ranks = np.argsort(scores, axis=1) |
| |
|
| | |
| | index_maps = [{v: i for i, v in enumerate(ranks_case)} for ranks_case in ranks] |
| | relevant_indices = [ |
| | [index_map[x] for x in good_paragraphs_for_case] for index_map, good_paragraphs_for_case in zip(index_maps, good_paragraphs, strict=True)] |
| | |
| | |
| | metrics = np.array([np.mean(relevant_ids_case) + np.min(relevant_ids_case) for relevant_ids_case in relevant_indices]) |
| | return -(np.mean(metrics) + np.min(metrics)) |
| |
|
| | from scipy.optimize import differential_evolution, minimize |
| | params = differential_evolution(func, bounds=[(0.2, 5)]) |
| | print(params) |
| | print(f"Оптимальный вес для BM25: {params.x[0]:.4f}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | test_retrieval() |
| |
|