BEAVER / PageEncoder.py
JusperLee's picture
Update README.md with new project description, features, and configuration details for BEAVER, including a new emoji and enhanced visualization of the compression process.
035ad02
# PageEncoder.py
# BEAVER page encoding module.
# Pools token embeddings into per-page representations with optional weighting.
from typing import Optional
import torch
import torch.nn as nn
from Segmenter import HSPPlannerConfig, SegmentPageLayout
class PageEncoder(nn.Module):
"""Encode pages by pooling token embeddings under a segment/page layout."""
def __init__(
self,
cfg: HSPPlannerConfig,
hidden_dim: int,
idf_weights: Optional[torch.Tensor] = None,
):
super().__init__()
self.cfg = cfg
self.hidden_dim = int(hidden_dim)
if idf_weights is not None:
assert idf_weights.dim() == 1
self.register_buffer("idf_weights", idf_weights.float())
else:
self.idf_weights = None
def forward(
self,
hidden_states: torch.Tensor,
layout: SegmentPageLayout,
input_ids: Optional[torch.Tensor] = None,
token_level_weights: Optional[torch.Tensor] = None,
) -> torch.Tensor:
B, L, D = hidden_states.shape
device = hidden_states.device
page_indices = layout.page_indices
page_valid = layout.page_valid
B2, N, P = page_indices.shape
assert B2 == B
x_paged = hidden_states.new_zeros(B, N, P, D)
for b in range(B):
idx_b = page_indices[b]
flat_idx = idx_b.view(-1)
mask = (flat_idx >= 0)
if mask.any():
x_b = hidden_states[b, flat_idx[mask]]
x_paged[b].view(-1, D)[mask] = x_b
mask_exp = page_valid.unsqueeze(-1) # [B, N, P, 1]
x_sum = (x_paged * mask_exp).sum(dim=2) # [B, N, D]
count = mask_exp.sum(dim=2).clamp(min=1) # [B, N, 1]
x_mean_uniform = x_sum / count
weights_eff = page_valid.float()
if self.idf_weights is not None and input_ids is not None:
idx_clamped = page_indices.clamp(min=0)
idx_flat = idx_clamped.view(B, -1)
tokens_flat = input_ids.gather(1, idx_flat)
tokens_paged = tokens_flat.view(B, N, P)
idf = self.idf_weights[tokens_paged]
weights_eff = weights_eff * idf
if token_level_weights is not None:
idx_clamped = page_indices.clamp(min=0)
idx_flat = idx_clamped.view(B, -1)
w_flat = token_level_weights.gather(1, idx_flat)
w_paged = w_flat.view(B, N, P)
weights_eff = weights_eff * w_paged
if (self.idf_weights is not None and input_ids is not None) or (token_level_weights is not None):
w_sum = weights_eff.sum(dim=2, keepdim=True) # [B, N, 1]
thr = 1e-4
low = (w_sum < thr)
w_sum_safe = w_sum.clone()
w_sum_safe[w_sum_safe < 1e-6] = 1.0
x_weighted = (x_paged * weights_eff.unsqueeze(-1)).sum(dim=2) # [B, N, D]
x_mean_weighted = x_weighted / w_sum_safe
low_expand = low.expand(-1, -1, D)
x_mean = torch.where(low_expand, x_mean_uniform, x_mean_weighted)
else:
x_mean = x_mean_uniform
neg_inf = hidden_states.new_full((), -1e4)
x_for_max = x_paged.masked_fill(~mask_exp, neg_inf)
x_max = x_for_max.max(dim=2).values
w_mean = float(getattr(self.cfg, "identity_mean_weight", 0.7))
w_max = float(getattr(self.cfg, "identity_max_weight", 0.3))
s = w_mean + w_max
if s <= 0:
block_repr = x_mean
else:
w_mean = w_mean / s
w_max = w_max / s
block_repr = w_mean * x_mean + w_max * x_max
return block_repr.to(hidden_states.dtype)