| |
| |
| |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from Segmenter import HSPPlannerConfig, SegmentPageLayout |
|
|
|
|
| class PageEncoder(nn.Module): |
| """Encode pages by pooling token embeddings under a segment/page layout.""" |
|
|
| def __init__( |
| self, |
| cfg: HSPPlannerConfig, |
| hidden_dim: int, |
| idf_weights: Optional[torch.Tensor] = None, |
| ): |
| super().__init__() |
| self.cfg = cfg |
| self.hidden_dim = int(hidden_dim) |
|
|
| if idf_weights is not None: |
| assert idf_weights.dim() == 1 |
| self.register_buffer("idf_weights", idf_weights.float()) |
| else: |
| self.idf_weights = None |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| layout: SegmentPageLayout, |
| input_ids: Optional[torch.Tensor] = None, |
| token_level_weights: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| B, L, D = hidden_states.shape |
| device = hidden_states.device |
|
|
| page_indices = layout.page_indices |
| page_valid = layout.page_valid |
| B2, N, P = page_indices.shape |
| assert B2 == B |
|
|
| x_paged = hidden_states.new_zeros(B, N, P, D) |
| for b in range(B): |
| idx_b = page_indices[b] |
| flat_idx = idx_b.view(-1) |
| mask = (flat_idx >= 0) |
| if mask.any(): |
| x_b = hidden_states[b, flat_idx[mask]] |
| x_paged[b].view(-1, D)[mask] = x_b |
|
|
| mask_exp = page_valid.unsqueeze(-1) |
|
|
| x_sum = (x_paged * mask_exp).sum(dim=2) |
| count = mask_exp.sum(dim=2).clamp(min=1) |
| x_mean_uniform = x_sum / count |
|
|
| weights_eff = page_valid.float() |
|
|
| if self.idf_weights is not None and input_ids is not None: |
| idx_clamped = page_indices.clamp(min=0) |
| idx_flat = idx_clamped.view(B, -1) |
| tokens_flat = input_ids.gather(1, idx_flat) |
| tokens_paged = tokens_flat.view(B, N, P) |
| idf = self.idf_weights[tokens_paged] |
| weights_eff = weights_eff * idf |
|
|
| if token_level_weights is not None: |
| idx_clamped = page_indices.clamp(min=0) |
| idx_flat = idx_clamped.view(B, -1) |
| w_flat = token_level_weights.gather(1, idx_flat) |
| w_paged = w_flat.view(B, N, P) |
| weights_eff = weights_eff * w_paged |
|
|
| if (self.idf_weights is not None and input_ids is not None) or (token_level_weights is not None): |
| w_sum = weights_eff.sum(dim=2, keepdim=True) |
| thr = 1e-4 |
| low = (w_sum < thr) |
| w_sum_safe = w_sum.clone() |
| w_sum_safe[w_sum_safe < 1e-6] = 1.0 |
| x_weighted = (x_paged * weights_eff.unsqueeze(-1)).sum(dim=2) |
| x_mean_weighted = x_weighted / w_sum_safe |
| low_expand = low.expand(-1, -1, D) |
| x_mean = torch.where(low_expand, x_mean_uniform, x_mean_weighted) |
| else: |
| x_mean = x_mean_uniform |
|
|
| neg_inf = hidden_states.new_full((), -1e4) |
| x_for_max = x_paged.masked_fill(~mask_exp, neg_inf) |
| x_max = x_for_max.max(dim=2).values |
|
|
| w_mean = float(getattr(self.cfg, "identity_mean_weight", 0.7)) |
| w_max = float(getattr(self.cfg, "identity_max_weight", 0.3)) |
| s = w_mean + w_max |
| if s <= 0: |
| block_repr = x_mean |
| else: |
| w_mean = w_mean / s |
| w_max = w_max / s |
| block_repr = w_mean * x_mean + w_max * x_max |
|
|
| return block_repr.to(hidden_states.dtype) |