File size: 2,587 Bytes
f37be5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from __future__ import annotations

import torch
import torch.nn as nn

from src.model.config import ModelConfig


class ProposalRolloutBranch(nn.Module):
    def __init__(self, cfg: ModelConfig):
        super().__init__()
        self.cfg = cfg
        hidden_dim = max(32, int(cfg.anchor_proposal_rollout_hidden))
        fusion_dim = cfg.d_model * 5
        self.seed_proj = nn.Sequential(
            nn.Linear(fusion_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, cfg.d_model),
        )
        self.cond_proj = nn.Sequential(
            nn.Linear(fusion_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, cfg.d_model),
        )
        self.input_proj = nn.Sequential(
            nn.Linear(cfg.d_model * 3, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, cfg.d_model),
        )
        self.step_emb = nn.Embedding(max(1, int(cfg.anchor_proposal_rollout_steps)), cfg.d_model)
        self.cell = nn.GRUCell(cfg.d_model, cfg.d_model)
        self.state_norm = nn.LayerNorm(cfg.d_model)
        self.summary_gate = nn.Linear(cfg.d_model * 2, 1)

    def forward(
        self,
        anchor_repr: torch.Tensor,
        proposal_repr: torch.Tensor,
        context_repr: torch.Tensor,
    ) -> dict[str, torch.Tensor]:
        fusion = torch.cat(
            [
                anchor_repr,
                proposal_repr,
                context_repr,
                proposal_repr - anchor_repr,
                proposal_repr * anchor_repr,
            ],
            dim=-1,
        )
        condition = self.cond_proj(fusion)
        state = proposal_repr + float(self.cfg.anchor_proposal_rollout_residual_scale) * self.seed_proj(fusion)

        states: list[torch.Tensor] = []
        for step_idx in range(max(1, int(self.cfg.anchor_proposal_rollout_steps))):
            step_vec = self.step_emb.weight[step_idx]
            step_input = self.input_proj(torch.cat([condition, state, step_vec], dim=-1))
            state = self.cell(step_input.unsqueeze(0), state.unsqueeze(0)).squeeze(0)
            states.append(self.state_norm(state))

        rollout_states = torch.stack(states, dim=0)
        gate_in = torch.cat(
            [rollout_states, condition.unsqueeze(0).expand_as(rollout_states)],
            dim=-1,
        )
        summary_gate = torch.sigmoid(self.summary_gate(gate_in))
        summary = (summary_gate * rollout_states).sum(dim=0) / summary_gate.sum(dim=0).clamp_min(1e-6)
        return {
            "states": rollout_states,
            "summary": summary,
        }