BEAVER / Wrapper.py
JusperLee's picture
Update README.md with new project description, features, and configuration details for BEAVER, including a new emoji and enhanced visualization of the compression process.
035ad02
from dataclasses import dataclass
from typing import List, Optional, Dict, Any, Tuple
import time
import torch
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from Segmenter import HSPPlannerConfig, Segmenter
from PageEncoder import PageEncoder
from QueryPlanner import QueryPlanner
@dataclass
class HSPWrapperConfig:
page_size: int = 64
flow_window: int = 4
flash_top_k: int = 22
anchor_pages: int = 4
pad_token_id: Optional[int] = None
newline_token_ids: Optional[Tuple[int, ...]] = None
newline_token_id: int = 198
sentence_boundary_ids: Optional[Tuple[int, ...]] = None
allow_implicit_query: bool = False
query_block_size: int = 64
use_dynamic_token_weights: bool = True
min_length_for_dynamic_weights: int = 256
use_query_multitoken_semantic: bool = True
min_query_tokens_for_multi: int = 4
max_query_tokens_for_multi: int = 32
identity_mean_weight: float = 0.7
identity_max_weight: float = 0.3
use_lexical_overlap_score: bool = True
semantic_score_weight: float = 0.7
lexical_score_weight: float = 0.3
class HSPBlackBoxWrapper:
def __init__(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
cfg: HSPWrapperConfig,
idf_weights: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
):
self.model = model
self.tokenizer = tokenizer
self.wrap_cfg = cfg
if device is None:
device = getattr(model, "device", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
self.device = device
self.model.to(self.device)
if self.wrap_cfg.pad_token_id is None:
if self.tokenizer.pad_token_id is not None:
self.wrap_cfg.pad_token_id = int(self.tokenizer.pad_token_id)
else:
self.wrap_cfg.pad_token_id = 0
else:
if (
self.tokenizer.pad_token_id is not None
and int(self.wrap_cfg.pad_token_id) != int(self.tokenizer.pad_token_id)
):
raise ValueError(
f"pad_token_id mismatch: wrapper cfg={self.wrap_cfg.pad_token_id}, tokenizer={self.tokenizer.pad_token_id}."
)
if self.wrap_cfg.newline_token_ids is None:
auto_newline_ids: List[int] = []
for s in ["\n", "\r\n", "\n\n"]:
ids = self.tokenizer.encode(s, add_special_tokens=False)
if len(ids) == 1:
tid = int(ids[0])
if tid not in auto_newline_ids:
auto_newline_ids.append(tid)
if auto_newline_ids:
self.wrap_cfg.newline_token_ids = tuple(auto_newline_ids)
self.wrap_cfg.newline_token_id = auto_newline_ids[0]
if self.wrap_cfg.sentence_boundary_ids is None:
boundary_ids: List[int] = []
if self.wrap_cfg.newline_token_ids is not None:
boundary_ids.extend(int(x) for x in self.wrap_cfg.newline_token_ids)
punct_candidates = ["。", "!", "?", ".", "!", "?"]
for s in punct_candidates:
ids = self.tokenizer.encode(s, add_special_tokens=False)
if len(ids) == 1:
tid = int(ids[0])
if tid not in boundary_ids:
boundary_ids.append(tid)
if boundary_ids:
self.wrap_cfg.sentence_boundary_ids = tuple(boundary_ids)
hsp_cfg = HSPPlannerConfig(
page_size=cfg.page_size,
flow_window=cfg.flow_window,
flash_top_k=cfg.flash_top_k,
anchor_pages=cfg.anchor_pages,
pad_token_id=cfg.pad_token_id,
newline_token_id=cfg.newline_token_id,
newline_token_ids=cfg.newline_token_ids,
sentence_boundary_ids=cfg.sentence_boundary_ids,
identity_mean_weight=cfg.identity_mean_weight,
identity_max_weight=cfg.identity_max_weight,
lambda_semantic=cfg.semantic_score_weight,
lambda_lexical=cfg.lexical_score_weight,
min_query_tokens_for_multi=cfg.min_query_tokens_for_multi,
max_query_tokens_for_multi=cfg.max_query_tokens_for_multi,
)
self.hsp_cfg = hsp_cfg
boundary_ids = (
hsp_cfg.sentence_boundary_ids
if hsp_cfg.sentence_boundary_ids is not None
else (
hsp_cfg.newline_token_ids
if (hsp_cfg.newline_token_ids is not None and len(hsp_cfg.newline_token_ids) > 0)
else ((hsp_cfg.newline_token_id,) if hsp_cfg.newline_token_id is not None else tuple())
)
)
self.segmenter = Segmenter(
cfg=hsp_cfg,
query_block_size=cfg.query_block_size,
boundary_ids=boundary_ids,
align_explicit_query_pos=False,
)
hidden_dim = model.config.hidden_size
self.page_encoder = PageEncoder(hsp_cfg, hidden_dim, idf_weights=idf_weights)
self.query_planner = QueryPlanner(hsp_cfg, query_dim=hidden_dim)
self.to(self.device)
def to(self, device: torch.device):
self.device = device
self.model.to(device)
self.segmenter.to(device)
self.page_encoder.to(device)
self.query_planner.to(device)
return self
# Token weights (tokenizer-dependent)
def _compute_local_token_weights(
self,
input_ids: torch.Tensor, # [B, L]
attention_mask: torch.Tensor, # [B, L]
) -> torch.Tensor:
if not self.wrap_cfg.use_dynamic_token_weights:
return attention_mask.float()
B, L = input_ids.shape
pad_id = self.wrap_cfg.pad_token_id
device = input_ids.device
weights = torch.zeros_like(input_ids, dtype=torch.float, device=device)
for b in range(B):
valid = attention_mask[b].bool() if attention_mask is not None else (input_ids[b] != pad_id)
ids_valid = input_ids[b, valid]
L_b = ids_valid.numel()
if L_b == 0 or L_b < self.wrap_cfg.min_length_for_dynamic_weights:
weights[b, valid] = 1.0
continue
uniq, counts = ids_valid.unique(return_counts=True)
counts = counts.float()
w = torch.log1p(L_b / (1.0 + counts))
w_pos = torch.zeros_like(ids_valid, dtype=torch.float, device=device)
for uid, wv in zip(uniq.tolist(), w.tolist()):
mask = (ids_valid == uid)
if mask.any():
w_pos[mask] = wv
maxv = float(w_pos.max().item())
if maxv > 0:
w_pos = w_pos / maxv
weights[b, valid] = w_pos
weights = weights + 1e-6
return weights
# Build tokenized inputs (tokenizer-dependent)
def _build_inputs_from_texts(
self,
contexts: List[str],
questions: Optional[List[str]],
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
B = len(contexts)
if questions is None:
questions = [""] * B
else:
assert len(contexts) == len(questions)
ctx_enc = [self.tokenizer.encode(c, add_special_tokens=False) for c in contexts]
implicit_mode = False
if self.wrap_cfg.allow_implicit_query:
no_query_flags = [((q is None) or (str(q).strip() == "")) for q in questions]
if all(no_query_flags):
implicit_mode = True
if implicit_mode:
q_enc = [[] for _ in range(B)]
else:
q_enc = [self.tokenizer.encode((q if q is not None else ""), add_special_tokens=False) for q in questions]
input_ids_list = []
explicit_qp_list = []
for i in range(B):
ids_ctx = ctx_enc[i]
ids_q = q_enc[i]
all_ids = ids_ctx if implicit_mode else (ids_ctx + ids_q)
if len(all_ids) == 0:
all_ids = [self.wrap_cfg.pad_token_id]
qp = 0
else:
qp = len(ids_ctx)
input_ids_list.append(all_ids)
if not implicit_mode:
explicit_qp_list.append(qp)
max_len = max(len(x) for x in input_ids_list)
pad_id = self.wrap_cfg.pad_token_id
input_ids = torch.full((B, max_len), pad_id, dtype=torch.long, device=self.device)
attention_mask = torch.zeros((B, max_len), dtype=torch.long, device=self.device)
for i in range(B):
L_i = len(input_ids_list[i])
input_ids[i, :L_i] = torch.tensor(input_ids_list[i], dtype=torch.long, device=self.device)
attention_mask[i, :L_i] = 1
explicit_qp = None if implicit_mode else torch.tensor(explicit_qp_list, dtype=torch.long, device=self.device)
return input_ids, attention_mask, explicit_qp
# Compress inputs for prefill
@torch.no_grad()
def compress_inputs_for_prefill(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
explicit_query_pos: Optional[torch.Tensor],
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
B, L = input_ids.shape
aligned_qp, split_info, layout = self.segmenter.build_layout(
input_ids=input_ids,
attention_mask=attention_mask,
explicit_query_pos=explicit_query_pos,
)
if hasattr(self.model, "get_input_embeddings") and self.model.get_input_embeddings() is not None:
embed = self.model.get_input_embeddings()
hidden = embed(input_ids)
elif hasattr(self.model, "model") and hasattr(self.model.model, "embed_tokens"):
hidden = self.model.model.embed_tokens(input_ids)
elif hasattr(self.model, "transformer") and hasattr(self.model.transformer, "wte"):
hidden = self.model.transformer.wte(input_ids)
else:
raise RuntimeError("Cannot locate input embedding layer for wrapped model.")
token_level_weights = self._compute_local_token_weights(input_ids, attention_mask)
block_repr = self.page_encoder(
hidden_states=hidden,
layout=layout,
input_ids=input_ids,
token_level_weights=token_level_weights,
)
token_valid = layout.token_valid
query_hidden_list = []
pos_idx = torch.arange(L, device=self.device)
w_mean = getattr(self.hsp_cfg, "identity_mean_weight", 0.7)
w_max = getattr(self.hsp_cfg, "identity_max_weight", 0.3)
w_sum = w_mean + w_max
if w_sum <= 0:
w_mean_eff, w_max_eff = 1.0, 0.0
else:
w_mean_eff, w_max_eff = w_mean / w_sum, w_max / w_sum
has_idf = getattr(self.page_encoder, "idf_weights", None) is not None
for b in range(B):
split_b = split_info[b]
qs, qe = int(split_b.query_start), int(split_b.query_end)
if qe < qs:
qp = int(aligned_qp[b].item())
query_hidden_list.append(hidden[b, qp])
continue
span_mask = (pos_idx >= qs) & (pos_idx <= qe) & token_valid[b]
if not span_mask.any():
qp = int(aligned_qp[b].item())
query_hidden_list.append(hidden[b, qp])
continue
h_b = hidden[b]
ids_b = input_ids[b]
if has_idf and input_ids is not None:
idf_vec = self.page_encoder.idf_weights
idf_b = idf_vec[ids_b]
weights = idf_b * span_mask.float()
w_sum_b = float(weights.sum().item())
if w_sum_b > 1e-6:
mean_b = (h_b * weights.unsqueeze(-1)).sum(dim=0) / w_sum_b
else:
mean_b = h_b[span_mask].mean(dim=0)
else:
mean_b = h_b[span_mask].mean(dim=0)
max_b, _ = h_b[span_mask].max(dim=0)
pooled_b = w_mean_eff * mean_b + w_max_eff * max_b
query_hidden_list.append(pooled_b)
query_hidden = torch.stack(query_hidden_list, dim=0)
query_token_hidden_list = []
query_token_weight_list = []
min_q = getattr(self.hsp_cfg, "min_query_tokens_for_multi", 4)
max_q = getattr(self.hsp_cfg, "max_query_tokens_for_multi", 32)
use_multi = self.wrap_cfg.use_query_multitoken_semantic and min_q > 0
for b in range(B):
split_b = split_info[b]
qs, qe = int(split_b.query_start), int(split_b.query_end)
if qe < qs or not use_multi:
query_token_hidden_list.append(None)
query_token_weight_list.append(None)
continue
span_mask = (pos_idx >= qs) & (pos_idx <= qe) & token_valid[b]
idx_span = pos_idx[span_mask]
Tq = idx_span.numel()
if Tq < min_q:
query_token_hidden_list.append(None)
query_token_weight_list.append(None)
continue
h_b = hidden[b]
w_b = token_level_weights[b]
w_span = w_b[idx_span]
top_k = min(Tq, max_q)
vals, indices = torch.topk(w_span, k=top_k, largest=True, sorted=True)
idx_top = idx_span[indices]
h_top = h_b[idx_top]
w_top = vals / (vals.sum() + 1e-6)
query_token_hidden_list.append(h_top)
query_token_weight_list.append(w_top)
keep_pages = self.query_planner(
block_repr=block_repr,
layout=layout,
query_hidden=query_hidden,
query_pos=aligned_qp,
input_ids=input_ids,
token_level_weights=token_level_weights,
split_results=split_info,
query_token_hidden_list=query_token_hidden_list,
query_token_weight_list=query_token_weight_list,
)
token2page = layout.token2page # [B, L]
token_valid = layout.token_valid # [B, L] bool
keep_pages_bool = keep_pages.bool() # [B, P] bool
token2page_clamped = token2page.clamp(min=0) # avoid -1 gather; token_valid will mask anyway
keep_token = token_valid & keep_pages_bool.gather(1, token2page_clamped)
boundary_ids = (
self.hsp_cfg.sentence_boundary_ids
if self.hsp_cfg.sentence_boundary_ids is not None
else (
self.hsp_cfg.newline_token_ids
if (self.hsp_cfg.newline_token_ids is not None and len(self.hsp_cfg.newline_token_ids) > 0)
else ((self.hsp_cfg.newline_token_id,) if self.hsp_cfg.newline_token_id is not None else tuple())
)
)
kept_context_token_indices: List[List[int]] = []
context_lens: List[int] = []
for b in range(B):
qp = int(aligned_qp[b].item())
qp = max(0, min(qp, L))
context_lens.append(qp)
if qp <= 0:
kept_context_token_indices.append([])
continue
keep_ctx = keep_token[b, :qp].clone()
valid_ctx = token_valid[b, :qp]
if boundary_ids is not None and len(boundary_ids) > 0:
ids_slice = input_ids[b, :qp]
is_boundary = torch.zeros(qp, dtype=torch.bool, device=input_ids.device)
for bid in boundary_ids:
is_boundary |= (ids_slice == int(bid))
boundary_pos = torch.nonzero(is_boundary, as_tuple=False).flatten().tolist()
start = 0
for p in boundary_pos:
end = int(p) + 1
if end <= start:
continue
if keep_ctx[start:end].any():
keep_ctx[start:end] = valid_ctx[start:end]
start = end
if start < qp and keep_ctx[start:qp].any():
keep_ctx[start:qp] = valid_ctx[start:qp]
kept_idx = torch.nonzero(keep_ctx & valid_ctx, as_tuple=False).flatten().detach().cpu().tolist()
kept_context_token_indices.append([int(x) for x in kept_idx])
compressed = self.segmenter.compress(
input_ids=input_ids,
attention_mask=attention_mask,
layout=layout,
keep_pages=keep_pages,
query_pos=aligned_qp,
)
L_comp = compressed["input_ids"].size(1)
stats = {
"aligned_query_pos": [int(x) for x in aligned_qp.detach().cpu().tolist()],
"context_len": context_lens,
"kept_context_token_indices": kept_context_token_indices,
"original_len": int(L),
"compressed_len": int(L_comp),
"compression_ratio": float(L_comp / max(L, 1)),
}
return compressed, stats
# Batch generation (dense or BEAVER)
@torch.no_grad()
def generate_batch(
self,
contexts: List[str],
questions: List[str],
max_new_tokens: int = 128,
use_hsp: bool = True,
**gen_kwargs,
) -> Dict[str, Any]:
assert len(contexts) == len(questions)
self.model.eval()
input_ids, attention_mask, explicit_qp = self._build_inputs_from_texts(contexts, questions)
prompt_lens = attention_mask.sum(dim=1).tolist()
if not use_hsp:
t0 = time.time()
outputs = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
use_cache=True,
**gen_kwargs,
)
t1 = time.time()
texts: List[str] = []
for i in range(outputs.size(0)):
pl = int(prompt_lens[i]) if i < len(prompt_lens) else 0
answer_tokens = outputs[i, pl:]
texts.append(self.tokenizer.decode(answer_tokens, skip_special_tokens=True))
meta = {
"mode": "dense",
"time_total": float(t1 - t0),
"original_len": int(input_ids.size(1)),
"compressed_len": int(input_ids.size(1)),
"compression_ratio": 1.0,
}
return {"outputs": texts, "meta": meta}
t0 = time.time()
compressed_inputs, stats = self.compress_inputs_for_prefill(
input_ids=input_ids,
attention_mask=attention_mask,
explicit_query_pos=explicit_qp,
)
t1 = time.time()
comp_attn = compressed_inputs["attention_mask"]
comp_prompt_lens = comp_attn.sum(dim=1).tolist()
gen_outputs = self.model.generate(
input_ids=compressed_inputs["input_ids"],
attention_mask=compressed_inputs["attention_mask"],
max_new_tokens=max_new_tokens,
use_cache=True,
**gen_kwargs,
)
t2 = time.time()
texts: List[str] = []
for i in range(gen_outputs.size(0)):
pl = int(comp_prompt_lens[i]) if i < len(comp_prompt_lens) else 0
answer_tokens = gen_outputs[i, pl:]
texts.append(self.tokenizer.decode(answer_tokens, skip_special_tokens=True))
meta = {
"mode": "hsp",
"time_prefill": float(t1 - t0),
"time_total": float(t2 - t0),
"original_len": stats["original_len"],
"compressed_len": stats["compressed_len"],
"compression_ratio": stats["compression_ratio"],
}
return {"outputs": texts, "meta": meta}