File size: 9,413 Bytes
6a7ab41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8109cc7
6a7ab41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
891d9fd
6a7ab41
 
891d9fd
6a7ab41
 
 
 
 
 
 
 
 
 
 
 
 
 
8109cc7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""В данном файле определяются параметры для оптимального объединения BM25 и семантического поиска.
Параметры оптимизируются на тестовых примерах"""
from dataclasses import dataclass
from retrieval import Retrieval

import numpy as np


@dataclass
class TestCaseFindSent:
    query: str
    good_answer: str
    compare_with: str = 'text'


@dataclass
class TestCaseFindSummary:
    query: str
    good_answer: str
    compare_with: str = 'summary'


test_cases = (
    TestCaseFindSummary(
        'Какие изменения в транспорте были бы полезны на текущий момент?',
        'Актуальные проекты новых троллейбусных линий, которые полезно бы построить',
    ),
    TestCaseFindSummary(
        'Какие продления троллейбусной сети были бы полезны на текущий момент?',
        'Актуальные проекты новых троллейбусных линий, которые полезно бы построить',
    ),
    TestCaseFindSummary(
        'Расскажи о провалившихся экспериментах',
        'Попытки (все из которых неудачные) запустить городскую электричку в истории',
    ),
    TestCaseFindSummary(
        'Расскажи историю маршрута маршрутки № 92',
        'история ныне закрытой маршрутки № 92'),
    TestCaseFindSummary(
        'Какой маршрут закрылся из-за плохой трассировки?',
        'У троллейбусного маршрута №2 была неудачная трасса - в объезд основных узлов города',
    ),
    TestCaseFindSummary(
        'Когда маршрут троллейбуса №10 продлили до площади Попова?',
        'история троллейбусного маршрута № 10'),
    TestCaseFindSummary(
        'Какой троллейбусный маршрут закрыли из-за того, что после продлений он стал слишком длинным и нестабильным',
        'История троллейбусного маршрута № 8, этап с увеличением трассы и снижение числа троллейбусов на маршруте № 8, его постепенная деградация и в конце концов закрытие.'
    ),
    TestCaseFindSent(
        'Какие называлось МУП УРТ, пока в её составе был трамвай?',
        'УРТ стало троллейбусно-трамвайным управлением (ТТУ) только в 2007 году и таковым являлась до 2010 (когда закрыли трамвай)'),
    TestCaseFindSent(
        'Когда был закрыт троллейбусный маршрут № 18 и почему?',
        'Соответственно, в феврале 2013 года 18 троллейбус был закрыт как дублирующий 8, а 8 троллейбус пущен через Канищево до телезавода, теперь совершенно объезжая Приокский.',
    ),
    TestCaseFindSent(
        'Почему на bus62 присутствует остановка на улице Кудрявцева? Там когда-то что-то было?',
        'В таком виде он застал начало эпохи Глонасса и поэтому автобусная конечная "Центральный рынок" на улице Кудрявцева до сих пор есть в "Умном транспорте" (он же bus62), но на ней ничего не разворачивается уже 13 лет.'
    ),
    TestCaseFindSent(
        'Сколько рейсов делала городская электричка, запущенная в 2009 году?',
        'Ну во-первых электричка ходила редко, всего пять раз в день.'
    ),
    TestCaseFindSent(
        'Когда закрылась версия третьего автобуса, когда он ходил в Канищево?',
        'И с 17 октября 1987 года третий автобус закрывается, причем делается это на редкость официально, со всеми объявлениями.'
    ),
    TestCaseFindSent(
        'Когда в Рязань поступили автобусы ЗИЛ-158?',
        'Впервые ЗИЛ-158 появились в автоколонне 73 в декабре 1957 года'
    ),
    TestCaseFindSent(
        'Когда пропали двухвагонные трамвайные поезда?',
        'А ещё в 2004 году на таком маршруте ездили двухвагонные поезда из трамваев.',
    ),
    TestCaseFindSent(
        'Когда общественный транспорт впервые пришёл в Ворошиловку?',
        'Поэтому в 1980 году были открыты маршруты № 22 в Карцево и № 23 до улицы Берёзовой, а в 1981-м — 24-й автобус на Ворошиловку.'
    ),
    TestCaseFindSent(
        'Когда общественный транспорт впервые пришёл в Карцево?',
        'Поэтому в 1980 году были открыты маршруты № 22 в Карцево и № 23 до улицы Берёзовой, а в 1981-м — 24-й автобус на Ворошиловку.'
    ),
    TestCaseFindSent(
        'Когда закрыли депо 1?',
        'Год спустя, в 2016 году, закрыли первое троллейбусное депо.'),
    TestCaseFindSent(
        'Когда закрылась 91-я маршрутка?',
        'В ноябре 2021 года 91 маршрутка прекратила свою работу.'
    ),
    TestCaseFindSent(
        'Когда на остановочных табличках Рязани появились конечные остановки маршрутов?',
        'Наконец, в 2015 году в истории остановочных табличек Рязани происходит настоящая революция. Впервые за 60 с лишним лет их истории на них наконец-то появились конечные остановки маршрутов. '
    )
)


def test_retrieval():
    """Тестирует кросс-энкодер vs BM25 на всех документах."""   
    # Создаем объект Retrieval (загружает корпус автоматически)
    retrieval = Retrieval()
    
    good_paragraphs = []
    bm25_scores = []
    semantic_scores = []
    for test_case in test_cases:
        good_paragraphs_for_case = []
        for i, text in enumerate(retrieval.paragraphs_df[test_case.compare_with]):
            if test_case.good_answer in text:
                good_paragraphs_for_case.append(i)
        assert good_paragraphs_for_case, breakpoint()
        good_paragraphs.append(good_paragraphs_for_case)
        # masks.append(np.isin(test_case.query, retrieval.paragraphs_df['text']))
        bm25_scores.append(retrieval.bm25_search(query=test_case.query))
        semantic_scores.append(retrieval.semantic_search(query=test_case.query))

    bm25_scores = np.array(bm25_scores)
    semantic_scores = np.array(semantic_scores)

    def func(params):
        scores = bm25_scores * params[0] + semantic_scores
        ranks = np.argsort(scores, axis=1) # важно: сортировка по возрастанию от нерелевантных к релевантным

        # Получаем ранги релевантных параграфов для каждого testcase
        index_maps = [{v: i for i, v in enumerate(ranks_case)} for ranks_case in ranks]
        relevant_indices = [
            [index_map[x] for x in good_paragraphs_for_case] for index_map, good_paragraphs_for_case in zip(index_maps, good_paragraphs, strict=True)]
        
        # Метрика "среднее+минимум" внутри каждого testcase, а затем "среднее+минимум" по всем testcases
        metrics = np.array([np.mean(relevant_ids_case) + np.min(relevant_ids_case) for relevant_ids_case in relevant_indices])
        return -(np.mean(metrics) + np.min(metrics)) # чем больше метрика, тем скор (и ранг) выше, поэтому берём минус для оптимизации на минимум

    from scipy.optimize import differential_evolution, minimize
    params = differential_evolution(func, bounds=[(0.2, 5)])
    print(params)
    print(f"Оптимальный вес для BM25: {params.x[0]:.4f}")


if __name__ == "__main__":
    test_retrieval()