File size: 3,713 Bytes
035ad02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# PageEncoder.py
# BEAVER page encoding module.
# Pools token embeddings into per-page representations with optional weighting.
from typing import Optional

import torch
import torch.nn as nn

from Segmenter import HSPPlannerConfig, SegmentPageLayout


class PageEncoder(nn.Module):
    """Encode pages by pooling token embeddings under a segment/page layout."""

    def __init__(
        self,
        cfg: HSPPlannerConfig,
        hidden_dim: int,
        idf_weights: Optional[torch.Tensor] = None,
    ):
        super().__init__()
        self.cfg = cfg
        self.hidden_dim = int(hidden_dim)

        if idf_weights is not None:
            assert idf_weights.dim() == 1
            self.register_buffer("idf_weights", idf_weights.float())
        else:
            self.idf_weights = None

    def forward(
        self,
        hidden_states: torch.Tensor,
        layout: SegmentPageLayout,
        input_ids: Optional[torch.Tensor] = None,
        token_level_weights: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        B, L, D = hidden_states.shape
        device = hidden_states.device

        page_indices = layout.page_indices
        page_valid = layout.page_valid
        B2, N, P = page_indices.shape
        assert B2 == B

        x_paged = hidden_states.new_zeros(B, N, P, D)
        for b in range(B):
            idx_b = page_indices[b]
            flat_idx = idx_b.view(-1)
            mask = (flat_idx >= 0)
            if mask.any():
                x_b = hidden_states[b, flat_idx[mask]]
                x_paged[b].view(-1, D)[mask] = x_b

        mask_exp = page_valid.unsqueeze(-1)  # [B, N, P, 1]

        x_sum = (x_paged * mask_exp).sum(dim=2)      # [B, N, D]
        count = mask_exp.sum(dim=2).clamp(min=1)     # [B, N, 1]
        x_mean_uniform = x_sum / count

        weights_eff = page_valid.float()

        if self.idf_weights is not None and input_ids is not None:
            idx_clamped = page_indices.clamp(min=0)
            idx_flat = idx_clamped.view(B, -1)
            tokens_flat = input_ids.gather(1, idx_flat)
            tokens_paged = tokens_flat.view(B, N, P)
            idf = self.idf_weights[tokens_paged]
            weights_eff = weights_eff * idf

        if token_level_weights is not None:
            idx_clamped = page_indices.clamp(min=0)
            idx_flat = idx_clamped.view(B, -1)
            w_flat = token_level_weights.gather(1, idx_flat)
            w_paged = w_flat.view(B, N, P)
            weights_eff = weights_eff * w_paged

        if (self.idf_weights is not None and input_ids is not None) or (token_level_weights is not None):
            w_sum = weights_eff.sum(dim=2, keepdim=True)  # [B, N, 1]
            thr = 1e-4
            low = (w_sum < thr)
            w_sum_safe = w_sum.clone()
            w_sum_safe[w_sum_safe < 1e-6] = 1.0
            x_weighted = (x_paged * weights_eff.unsqueeze(-1)).sum(dim=2)  # [B, N, D]
            x_mean_weighted = x_weighted / w_sum_safe
            low_expand = low.expand(-1, -1, D)
            x_mean = torch.where(low_expand, x_mean_uniform, x_mean_weighted)
        else:
            x_mean = x_mean_uniform

        neg_inf = hidden_states.new_full((), -1e4)
        x_for_max = x_paged.masked_fill(~mask_exp, neg_inf)
        x_max = x_for_max.max(dim=2).values

        w_mean = float(getattr(self.cfg, "identity_mean_weight", 0.7))
        w_max = float(getattr(self.cfg, "identity_max_weight", 0.3))
        s = w_mean + w_max
        if s <= 0:
            block_repr = x_mean
        else:
            w_mean = w_mean / s
            w_max = w_max / s
            block_repr = w_mean * x_mean + w_max * x_max

        return block_repr.to(hidden_states.dtype)