File size: 7,735 Bytes
035ad02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
# QueryPlanner.py
# BEAVER query planning module.
# Scores pages with semantic + lexical signals and selects pages via Anchor / Flow / Flash.

from typing import Optional, Tuple, List

import torch
import torch.nn as nn
import torch.nn.functional as F

from Segmenter import HSPPlannerConfig, SegmentPageLayout, QuerySplitResult


class QueryPlanner(nn.Module):
    """Plan kept pages based on semantic and lexical relevance."""

    def __init__(self, cfg: HSPPlannerConfig, query_dim: int):
        super().__init__()
        self.cfg = cfg

        self.lambda_sem = float(getattr(cfg, "lambda_semantic", 0.7))
        self.lambda_lex = float(getattr(cfg, "lambda_lexical", 0.3))
        self.min_q_multi = int(getattr(cfg, "min_query_tokens_for_multi", 4))
        self.max_q_multi = int(getattr(cfg, "max_query_tokens_for_multi", 32))

    def forward(
        self,
        block_repr: torch.Tensor,       # [B, N, D]
        layout: SegmentPageLayout,
        query_hidden: torch.Tensor,     # [B, D]
        query_pos: torch.Tensor,        # [B]
        input_ids: Optional[torch.Tensor] = None,                    # [B, L]
        token_level_weights: Optional[torch.Tensor] = None,          # [B, L]
        split_results: Optional[Tuple[QuerySplitResult, ...]] = None,
        query_token_hidden_list: Optional[List[Optional[torch.Tensor]]] = None,  # len=B, [K_b, D]
        query_token_weight_list: Optional[List[Optional[torch.Tensor]]] = None,  # len=B, [K_b]
    ) -> torch.Tensor:
        B, N, Dp = block_repr.shape
        device = block_repr.device

        common_dtype = block_repr.dtype
        if query_hidden.dtype != common_dtype:
            query_hidden = query_hidden.to(common_dtype)

        if query_token_hidden_list is not None:
            new_list: List[Optional[torch.Tensor]] = []
            for qt in query_token_hidden_list:
                new_list.append(None if qt is None else qt.to(common_dtype))
            query_token_hidden_list = new_list

        segment_ids = layout.segment_ids
        page_valid_any = layout.page_valid.any(dim=-1)
        token2page = layout.token2page
        token_valid = layout.token_valid

        query_page = torch.gather(token2page, dim=1, index=query_pos.view(B, 1)).squeeze(1)
        query_page = query_page.clamp(min=0)

        query_seg = torch.gather(segment_ids, dim=1, index=query_page.view(B, 1)).squeeze(1)

        scores_sem = block_repr.new_full((B, N), -1e4)
        k_vec = F.normalize(block_repr, dim=-1)

        use_multi = (query_token_hidden_list is not None and self.min_q_multi > 0)

        for b in range(B):
            if use_multi and query_token_hidden_list[b] is not None:
                q_tok = query_token_hidden_list[b]
                if q_tok.numel() == 0:
                    continue
                q_vec = F.normalize(q_tok, dim=-1)
                k_b = k_vec[b]
                scores_bn = torch.matmul(q_vec, k_b.t())

                if query_token_weight_list is not None and query_token_weight_list[b] is not None:
                    w_q = query_token_weight_list[b]
                    if w_q.numel() == scores_bn.size(0):
                        w_q = w_q / (w_q.sum() + 1e-6)
                        scores_sem[b] = (scores_bn * w_q.unsqueeze(-1)).sum(dim=0)
                    else:
                        scores_sem[b] = scores_bn.mean(dim=0)
                else:
                    scores_sem[b] = scores_bn.mean(dim=0)
            else:
                q = query_hidden[b:b+1]
                q_vec = F.normalize(q, dim=-1)
                k_b = k_vec[b:b+1]
                scores_bn = torch.einsum("bd,bnd->bn", q_vec, k_b)
                scores_sem[b] = scores_bn[0]

        scores_lex = block_repr.new_zeros((B, N))
        use_lex = (
            self.lambda_lex > 0
            and input_ids is not None
            and token_level_weights is not None
            and split_results is not None
        )

        if use_lex:
            for b in range(B):
                sr = split_results[b]
                qs = int(sr.query_start)
                qe = int(sr.query_end)
                ids_b = input_ids[b]
                valid_b = token_valid[b]
                if qe < qs:
                    continue
                L = ids_b.size(0)
                qs = max(0, min(qs, L - 1))
                qe = max(0, min(qe, L - 1))

                span_mask = torch.zeros_like(valid_b, dtype=torch.bool)
                span_mask[qs:qe+1] = True
                span_mask &= valid_b
                if not span_mask.any():
                    continue

                q_ids = ids_b[span_mask]
                w_b = token_level_weights[b]

                for n in range(N):
                    idx_n = layout.page_indices[b, n]
                    valid_n = layout.page_valid[b, n]
                    if not valid_n.any():
                        continue
                    pos_n = idx_n[valid_n]
                    tok_n = ids_b[pos_n]
                    mask_in_q = torch.isin(tok_n, q_ids)
                    if mask_in_q.any():
                        w_page = w_b[pos_n]
                        scores_lex[b, n] = (w_page[mask_in_q]).sum()

        page_idx = torch.arange(N, device=device).view(1, N).expand(B, N)
        causal_mask = (page_idx <= query_page.view(B, 1))
        valid_mask = page_valid_any & (segment_ids >= 0) & causal_mask

        def _norm_scores(scores: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
            out = scores.clone()
            for bb in range(scores.size(0)):
                m = mask[bb]
                if not m.any():
                    continue
                v = out[bb, m]
                v_min = v.min()
                v_max = v.max()
                if float(v_max - v_min) < 1e-6:
                    out[bb, m] = 0.0
                else:
                    out[bb, m] = (v - v_min) / (v_max - v_min)
            out = out.masked_fill(~mask, 0.0)
            return out

        scores_sem_norm = _norm_scores(scores_sem, valid_mask)
        scores_mix = scores_sem_norm

        if use_lex:
            scores_lex_norm = _norm_scores(scores_lex, valid_mask)
            if self.lambda_sem + self.lambda_lex > 0:
                lam_s = self.lambda_sem
                lam_l = self.lambda_lex
                scores_mix = (lam_s * scores_sem_norm + lam_l * scores_lex_norm) / (lam_s + lam_l)

        scores_final = scores_mix.masked_fill(~valid_mask, -1e4)

        anchor = torch.zeros_like(valid_mask)
        is_seg0 = (segment_ids == 0) & valid_mask
        for b in range(B):
            idx_seg0 = torch.nonzero(is_seg0[b], as_tuple=False).flatten()
            if idx_seg0.numel() > 0:
                k = min(self.cfg.anchor_pages, idx_seg0.numel())
                anchor[b, idx_seg0[:k]] = True

        if self.cfg.flow_window >= 0:
            lower = (query_page.view(B, 1) - self.cfg.flow_window).clamp(min=0)
            upper = query_page.view(B, 1)
            flow = (page_idx >= lower) & (page_idx <= upper) & valid_mask
        else:
            flow = valid_mask.clone()

        flow.scatter_(1, query_page.view(B, 1), True)

        base_keep = anchor | flow
        candidate = valid_mask & (~base_keep)

        scores_candidate = scores_final.masked_fill(~candidate, -1e4)
        effective_k = min(self.cfg.flash_top_k, N)
        if effective_k > 0:
            _, topk_idx = torch.topk(scores_candidate, k=effective_k, dim=1)
            flash = torch.zeros_like(candidate)
            flash.scatter_(1, topk_idx, True)
            flash = flash & candidate
        else:
            flash = torch.zeros_like(candidate)

        keep_pages = base_keep | flash
        return keep_pages