BEAVER / QueryPlanner.py
JusperLee's picture
Update README.md with new project description, features, and configuration details for BEAVER, including a new emoji and enhanced visualization of the compression process.
035ad02
# QueryPlanner.py
# BEAVER query planning module.
# Scores pages with semantic + lexical signals and selects pages via Anchor / Flow / Flash.
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from Segmenter import HSPPlannerConfig, SegmentPageLayout, QuerySplitResult
class QueryPlanner(nn.Module):
"""Plan kept pages based on semantic and lexical relevance."""
def __init__(self, cfg: HSPPlannerConfig, query_dim: int):
super().__init__()
self.cfg = cfg
self.lambda_sem = float(getattr(cfg, "lambda_semantic", 0.7))
self.lambda_lex = float(getattr(cfg, "lambda_lexical", 0.3))
self.min_q_multi = int(getattr(cfg, "min_query_tokens_for_multi", 4))
self.max_q_multi = int(getattr(cfg, "max_query_tokens_for_multi", 32))
def forward(
self,
block_repr: torch.Tensor, # [B, N, D]
layout: SegmentPageLayout,
query_hidden: torch.Tensor, # [B, D]
query_pos: torch.Tensor, # [B]
input_ids: Optional[torch.Tensor] = None, # [B, L]
token_level_weights: Optional[torch.Tensor] = None, # [B, L]
split_results: Optional[Tuple[QuerySplitResult, ...]] = None,
query_token_hidden_list: Optional[List[Optional[torch.Tensor]]] = None, # len=B, [K_b, D]
query_token_weight_list: Optional[List[Optional[torch.Tensor]]] = None, # len=B, [K_b]
) -> torch.Tensor:
B, N, Dp = block_repr.shape
device = block_repr.device
common_dtype = block_repr.dtype
if query_hidden.dtype != common_dtype:
query_hidden = query_hidden.to(common_dtype)
if query_token_hidden_list is not None:
new_list: List[Optional[torch.Tensor]] = []
for qt in query_token_hidden_list:
new_list.append(None if qt is None else qt.to(common_dtype))
query_token_hidden_list = new_list
segment_ids = layout.segment_ids
page_valid_any = layout.page_valid.any(dim=-1)
token2page = layout.token2page
token_valid = layout.token_valid
query_page = torch.gather(token2page, dim=1, index=query_pos.view(B, 1)).squeeze(1)
query_page = query_page.clamp(min=0)
query_seg = torch.gather(segment_ids, dim=1, index=query_page.view(B, 1)).squeeze(1)
scores_sem = block_repr.new_full((B, N), -1e4)
k_vec = F.normalize(block_repr, dim=-1)
use_multi = (query_token_hidden_list is not None and self.min_q_multi > 0)
for b in range(B):
if use_multi and query_token_hidden_list[b] is not None:
q_tok = query_token_hidden_list[b]
if q_tok.numel() == 0:
continue
q_vec = F.normalize(q_tok, dim=-1)
k_b = k_vec[b]
scores_bn = torch.matmul(q_vec, k_b.t())
if query_token_weight_list is not None and query_token_weight_list[b] is not None:
w_q = query_token_weight_list[b]
if w_q.numel() == scores_bn.size(0):
w_q = w_q / (w_q.sum() + 1e-6)
scores_sem[b] = (scores_bn * w_q.unsqueeze(-1)).sum(dim=0)
else:
scores_sem[b] = scores_bn.mean(dim=0)
else:
scores_sem[b] = scores_bn.mean(dim=0)
else:
q = query_hidden[b:b+1]
q_vec = F.normalize(q, dim=-1)
k_b = k_vec[b:b+1]
scores_bn = torch.einsum("bd,bnd->bn", q_vec, k_b)
scores_sem[b] = scores_bn[0]
scores_lex = block_repr.new_zeros((B, N))
use_lex = (
self.lambda_lex > 0
and input_ids is not None
and token_level_weights is not None
and split_results is not None
)
if use_lex:
for b in range(B):
sr = split_results[b]
qs = int(sr.query_start)
qe = int(sr.query_end)
ids_b = input_ids[b]
valid_b = token_valid[b]
if qe < qs:
continue
L = ids_b.size(0)
qs = max(0, min(qs, L - 1))
qe = max(0, min(qe, L - 1))
span_mask = torch.zeros_like(valid_b, dtype=torch.bool)
span_mask[qs:qe+1] = True
span_mask &= valid_b
if not span_mask.any():
continue
q_ids = ids_b[span_mask]
w_b = token_level_weights[b]
for n in range(N):
idx_n = layout.page_indices[b, n]
valid_n = layout.page_valid[b, n]
if not valid_n.any():
continue
pos_n = idx_n[valid_n]
tok_n = ids_b[pos_n]
mask_in_q = torch.isin(tok_n, q_ids)
if mask_in_q.any():
w_page = w_b[pos_n]
scores_lex[b, n] = (w_page[mask_in_q]).sum()
page_idx = torch.arange(N, device=device).view(1, N).expand(B, N)
causal_mask = (page_idx <= query_page.view(B, 1))
valid_mask = page_valid_any & (segment_ids >= 0) & causal_mask
def _norm_scores(scores: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
out = scores.clone()
for bb in range(scores.size(0)):
m = mask[bb]
if not m.any():
continue
v = out[bb, m]
v_min = v.min()
v_max = v.max()
if float(v_max - v_min) < 1e-6:
out[bb, m] = 0.0
else:
out[bb, m] = (v - v_min) / (v_max - v_min)
out = out.masked_fill(~mask, 0.0)
return out
scores_sem_norm = _norm_scores(scores_sem, valid_mask)
scores_mix = scores_sem_norm
if use_lex:
scores_lex_norm = _norm_scores(scores_lex, valid_mask)
if self.lambda_sem + self.lambda_lex > 0:
lam_s = self.lambda_sem
lam_l = self.lambda_lex
scores_mix = (lam_s * scores_sem_norm + lam_l * scores_lex_norm) / (lam_s + lam_l)
scores_final = scores_mix.masked_fill(~valid_mask, -1e4)
anchor = torch.zeros_like(valid_mask)
is_seg0 = (segment_ids == 0) & valid_mask
for b in range(B):
idx_seg0 = torch.nonzero(is_seg0[b], as_tuple=False).flatten()
if idx_seg0.numel() > 0:
k = min(self.cfg.anchor_pages, idx_seg0.numel())
anchor[b, idx_seg0[:k]] = True
if self.cfg.flow_window >= 0:
lower = (query_page.view(B, 1) - self.cfg.flow_window).clamp(min=0)
upper = query_page.view(B, 1)
flow = (page_idx >= lower) & (page_idx <= upper) & valid_mask
else:
flow = valid_mask.clone()
flow.scatter_(1, query_page.view(B, 1), True)
base_keep = anchor | flow
candidate = valid_mask & (~base_keep)
scores_candidate = scores_final.masked_fill(~candidate, -1e4)
effective_k = min(self.cfg.flash_top_k, N)
if effective_k > 0:
_, topk_idx = torch.topk(scores_candidate, k=effective_k, dim=1)
flash = torch.zeros_like(candidate)
flash.scatter_(1, topk_idx, True)
flash = flash & candidate
else:
flash = torch.zeros_like(candidate)
keep_pages = base_keep | flash
return keep_pages