File size: 8,286 Bytes
e94400c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist


class CrossAttentionBlock(nn.Module):
    def __init__(self, hidden_dim, num_heads, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.cross_attn = nn.MultiheadAttention(
            embed_dim=hidden_dim, num_heads=num_heads, batch_first=True, dropout=dropout
        )

        self.norm2 = nn.LayerNorm(hidden_dim)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, int(hidden_dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(hidden_dim * mlp_ratio), hidden_dim),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, encoder_hidden_state, encoder_attention_mask=None):
        """
        Cross-attention block forward.
        Args:
            query (Tensor): Shape [B, Q, D]. Learnable query tokens propagated across layers.
            encoder_hidden_state (Tensor): Shape [B, L, D]. Features from one encoder layer.
            encoder_attention_mask (Tensor | None): Shape [B, L]. 1/True=keep (visible), 0/False=mask. None disables masking.
        Returns:
            Tensor: Updated query tokens of shape [B, Q, D].
        Details:
            1. LayerNorm + MultiheadAttention (Q = query, K/V = encoder_hidden_state).
            2. Residual path: query = query + attn_output, then add MLP residual.
            3. Dropout is applied only on the MLP output.
        """
        q = self.norm1(query)
        kv = encoder_hidden_state

        if encoder_attention_mask is not None:
            attn_mask = encoder_attention_mask.unsqueeze(1).to(dtype=torch.bool)  # [B, 1, L]
        else:
            attn_mask = None

        attn_output, _ = self.cross_attn(q, kv, kv, key_padding_mask=attn_mask)
        query = query + attn_output
        query = query + self.dropout(self.mlp(self.norm2(query)))
        return query


class LayerwiseQFormer(nn.Module):
    def __init__(
        self, input_hidden_dim=2048, output_hidden_dim=768, num_query_tokens=64, num_layers=37, num_heads=8, config=None
    ):
        super().__init__()
        self.input_hidden_dim = input_hidden_dim
        self.output_hidden_dim = output_hidden_dim
        self.num_query_tokens = num_query_tokens
        self.num_layers = num_layers
        self.config = config
        # Project input to output dimension
        self.proj = nn.Linear(input_hidden_dim, output_hidden_dim)
        # Learnable query tokens
        self.query_tokens = nn.Parameter(torch.randn(num_query_tokens, output_hidden_dim))

        # Independent cross-attention blocks (one per encoder layer)
        self.layers = nn.ModuleList([CrossAttentionBlock(output_hidden_dim, num_heads) for _ in range(num_layers)])

    def forward(self, hidden_states_list, encoder_attention_mask=None):
        """
        Layer-wise Q-Former forward pass.
        Args:
            hidden_states_list (List[Tensor]): Length == num_layers. Each tensor is [B, L, Din], raw encoder layer outputs (before projection).
            encoder_attention_mask (Tensor | None): Shape [B, L]. Same semantics as in CrossAttentionBlock.
        Returns:
            Tensor: Aggregated query tokens of shape [B, Q, Dout].
        Pipeline:
            1. Stack per-layer features to [B, N, L, Din] and linearly project to Dout.
            2. Expand global learnable query tokens to batch: [B, Q, Dout].
            3. Apply cross-attention layer-by-layer: each query attends only to the corresponding encoder layer features.
        Notes:
            - Asserts len(hidden_states_list) == num_layers.
            - Does not modify gradient flow of hidden_states_list.
        """
        # hidden_states_list = self.scale_hook(hidden_states_list)

        assert (
            len(hidden_states_list) == self.num_layers
        ), f"Expected {self.num_layers} layers, got {len(hidden_states_list)}"

        B = hidden_states_list[0].size(0)
        # Project input hidden states to output dimension
        #    Result shape [B, N, L, Din]
        hs = torch.stack(hidden_states_list, dim=1)
        #    proj_hs shape [B, N, L, Dout]
        proj_hs = self.proj(hs)
        # 3) Unbind back to list, each element restored to [B, L, Dout]
        hidden_states_list = list(proj_hs.unbind(dim=1))

        # Expand query tokens for each batch
        query = self.query_tokens.unsqueeze(0).expand(B, -1, -1)  # [B, Q, D]

        # Iterate through each layer and apply cross-attention
        for i, layer in enumerate(self.layers):
            query = layer(query, hidden_states_list[i], encoder_attention_mask)

        return query

    def scale_hook(self, hidden_states_list, scale_factor=0.1):
        """
        (Experimental / optional) Register gradient scaling hooks on each layer's hidden states.
        Args:
            hidden_states_list (List[Tensor]): Per-layer feature tensors.
            scale_factor (float): Gradient scaling factor (effective only if enabled via config and != 1).
        Returns:
            List[Tensor]: Original list (no data copy); hooks may be attached in-place.
        Design:
            - Currently returns immediately (guard condition hard-coded False) as a placeholder.
            - Uses attribute _scaled_hook to avoid duplicate hook registration in distributed settings.
            - Can be enabled later for gradient dampening or perturbation experiments.
        Performance:
            - Excessive hook registrations can hurt speed; kept lazy by default.
        """
        # --- 1. Register gradient scaling hooks on input hidden_states_list ---
        if (
            self.config
            and hasattr(self.config.vla, "layer_qformer")
            and hasattr(self.config.vla.layer_qformer, "grad_scale")
            and self.config.vla.layer_qformer.grad_scale != 1
        ):
            scale_factor = self.config.vla.layer_qformer.grad_scale
        else:
            return hidden_states_list  # If grad_scale is not configured, return the original list

        scaled_hidden_states_list = []
        for hidden_states in hidden_states_list:
            if hidden_states.requires_grad:
                # Ensure gradient scaling is executed only once in distributed settings
                if not hasattr(hidden_states, "_scaled_hook"):  # Prevent duplicate registration --> Seems to accelerate
                    hook = lambda grad: grad * scale_factor
                    hidden_states.register_hook(hook)
                    hidden_states._scaled_hook = True  # Mark as processed
            scaled_hidden_states_list.append(hidden_states)

        return hidden_states_list


import torch
import torch.nn as nn


def get_layerwise_qformer(num_heads=8, config=None, **kwargs):
    """
    Build a LayerwiseQFormer instance.
    Args:
        num_heads (int): Number of attention heads for CrossAttentionBlock.
        config: Configuration object; must contain config.framework.layer_qformer with:
            - qformer_start_layer / qformer_end_layer: range of layers (start inclusive, end exclusive).
            - num_query_tokens: Number of learnable query tokens.
            - input_dim: Input feature dimension (Din).
            - ouptput_dim: Output feature dimension (Dout).
        **kwargs: Reserved for future extensions (unused).
    Returns:
        LayerwiseQFormer: Instantiated model.
    Notes:
        - num_layers = end_layer - start_layer (half-open interval).
        - Does not perform weight loading or device moves here.
    """
    # dist.barrier()
    qformer_cfg = config.framework.layer_qformer
    num_layers = qformer_cfg.qformer_end_layer - qformer_cfg.qformer_start_layer if config else num_layers
    num_query_tokens = qformer_cfg.num_query_tokens
    input_hidden_dim = config.framework.layer_qformer.input_dim
    output_hidden_dim = config.framework.layer_qformer.ouptput_dim
    num_query_tokens = qformer_cfg.num_query_tokens

    qformer = LayerwiseQFormer(
        input_hidden_dim=input_hidden_dim,
        output_hidden_dim=output_hidden_dim,
        num_query_tokens=num_query_tokens,
        num_layers=num_layers,
        num_heads=num_heads,
        config=config,
    )
    return qformer