File size: 7,735 Bytes
035ad02 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 | # QueryPlanner.py
# BEAVER query planning module.
# Scores pages with semantic + lexical signals and selects pages via Anchor / Flow / Flash.
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from Segmenter import HSPPlannerConfig, SegmentPageLayout, QuerySplitResult
class QueryPlanner(nn.Module):
"""Plan kept pages based on semantic and lexical relevance."""
def __init__(self, cfg: HSPPlannerConfig, query_dim: int):
super().__init__()
self.cfg = cfg
self.lambda_sem = float(getattr(cfg, "lambda_semantic", 0.7))
self.lambda_lex = float(getattr(cfg, "lambda_lexical", 0.3))
self.min_q_multi = int(getattr(cfg, "min_query_tokens_for_multi", 4))
self.max_q_multi = int(getattr(cfg, "max_query_tokens_for_multi", 32))
def forward(
self,
block_repr: torch.Tensor, # [B, N, D]
layout: SegmentPageLayout,
query_hidden: torch.Tensor, # [B, D]
query_pos: torch.Tensor, # [B]
input_ids: Optional[torch.Tensor] = None, # [B, L]
token_level_weights: Optional[torch.Tensor] = None, # [B, L]
split_results: Optional[Tuple[QuerySplitResult, ...]] = None,
query_token_hidden_list: Optional[List[Optional[torch.Tensor]]] = None, # len=B, [K_b, D]
query_token_weight_list: Optional[List[Optional[torch.Tensor]]] = None, # len=B, [K_b]
) -> torch.Tensor:
B, N, Dp = block_repr.shape
device = block_repr.device
common_dtype = block_repr.dtype
if query_hidden.dtype != common_dtype:
query_hidden = query_hidden.to(common_dtype)
if query_token_hidden_list is not None:
new_list: List[Optional[torch.Tensor]] = []
for qt in query_token_hidden_list:
new_list.append(None if qt is None else qt.to(common_dtype))
query_token_hidden_list = new_list
segment_ids = layout.segment_ids
page_valid_any = layout.page_valid.any(dim=-1)
token2page = layout.token2page
token_valid = layout.token_valid
query_page = torch.gather(token2page, dim=1, index=query_pos.view(B, 1)).squeeze(1)
query_page = query_page.clamp(min=0)
query_seg = torch.gather(segment_ids, dim=1, index=query_page.view(B, 1)).squeeze(1)
scores_sem = block_repr.new_full((B, N), -1e4)
k_vec = F.normalize(block_repr, dim=-1)
use_multi = (query_token_hidden_list is not None and self.min_q_multi > 0)
for b in range(B):
if use_multi and query_token_hidden_list[b] is not None:
q_tok = query_token_hidden_list[b]
if q_tok.numel() == 0:
continue
q_vec = F.normalize(q_tok, dim=-1)
k_b = k_vec[b]
scores_bn = torch.matmul(q_vec, k_b.t())
if query_token_weight_list is not None and query_token_weight_list[b] is not None:
w_q = query_token_weight_list[b]
if w_q.numel() == scores_bn.size(0):
w_q = w_q / (w_q.sum() + 1e-6)
scores_sem[b] = (scores_bn * w_q.unsqueeze(-1)).sum(dim=0)
else:
scores_sem[b] = scores_bn.mean(dim=0)
else:
scores_sem[b] = scores_bn.mean(dim=0)
else:
q = query_hidden[b:b+1]
q_vec = F.normalize(q, dim=-1)
k_b = k_vec[b:b+1]
scores_bn = torch.einsum("bd,bnd->bn", q_vec, k_b)
scores_sem[b] = scores_bn[0]
scores_lex = block_repr.new_zeros((B, N))
use_lex = (
self.lambda_lex > 0
and input_ids is not None
and token_level_weights is not None
and split_results is not None
)
if use_lex:
for b in range(B):
sr = split_results[b]
qs = int(sr.query_start)
qe = int(sr.query_end)
ids_b = input_ids[b]
valid_b = token_valid[b]
if qe < qs:
continue
L = ids_b.size(0)
qs = max(0, min(qs, L - 1))
qe = max(0, min(qe, L - 1))
span_mask = torch.zeros_like(valid_b, dtype=torch.bool)
span_mask[qs:qe+1] = True
span_mask &= valid_b
if not span_mask.any():
continue
q_ids = ids_b[span_mask]
w_b = token_level_weights[b]
for n in range(N):
idx_n = layout.page_indices[b, n]
valid_n = layout.page_valid[b, n]
if not valid_n.any():
continue
pos_n = idx_n[valid_n]
tok_n = ids_b[pos_n]
mask_in_q = torch.isin(tok_n, q_ids)
if mask_in_q.any():
w_page = w_b[pos_n]
scores_lex[b, n] = (w_page[mask_in_q]).sum()
page_idx = torch.arange(N, device=device).view(1, N).expand(B, N)
causal_mask = (page_idx <= query_page.view(B, 1))
valid_mask = page_valid_any & (segment_ids >= 0) & causal_mask
def _norm_scores(scores: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
out = scores.clone()
for bb in range(scores.size(0)):
m = mask[bb]
if not m.any():
continue
v = out[bb, m]
v_min = v.min()
v_max = v.max()
if float(v_max - v_min) < 1e-6:
out[bb, m] = 0.0
else:
out[bb, m] = (v - v_min) / (v_max - v_min)
out = out.masked_fill(~mask, 0.0)
return out
scores_sem_norm = _norm_scores(scores_sem, valid_mask)
scores_mix = scores_sem_norm
if use_lex:
scores_lex_norm = _norm_scores(scores_lex, valid_mask)
if self.lambda_sem + self.lambda_lex > 0:
lam_s = self.lambda_sem
lam_l = self.lambda_lex
scores_mix = (lam_s * scores_sem_norm + lam_l * scores_lex_norm) / (lam_s + lam_l)
scores_final = scores_mix.masked_fill(~valid_mask, -1e4)
anchor = torch.zeros_like(valid_mask)
is_seg0 = (segment_ids == 0) & valid_mask
for b in range(B):
idx_seg0 = torch.nonzero(is_seg0[b], as_tuple=False).flatten()
if idx_seg0.numel() > 0:
k = min(self.cfg.anchor_pages, idx_seg0.numel())
anchor[b, idx_seg0[:k]] = True
if self.cfg.flow_window >= 0:
lower = (query_page.view(B, 1) - self.cfg.flow_window).clamp(min=0)
upper = query_page.view(B, 1)
flow = (page_idx >= lower) & (page_idx <= upper) & valid_mask
else:
flow = valid_mask.clone()
flow.scatter_(1, query_page.view(B, 1), True)
base_keep = anchor | flow
candidate = valid_mask & (~base_keep)
scores_candidate = scores_final.masked_fill(~candidate, -1e4)
effective_k = min(self.cfg.flash_top_k, N)
if effective_k > 0:
_, topk_idx = torch.topk(scores_candidate, k=effective_k, dim=1)
flash = torch.zeros_like(candidate)
flash.scatter_(1, topk_idx, True)
flash = flash & candidate
else:
flash = torch.zeros_like(candidate)
keep_pages = base_keep | flash
return keep_pages |