| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.distributed as dist |
|
|
|
|
| class CrossAttentionBlock(nn.Module): |
| def __init__(self, hidden_dim, num_heads, mlp_ratio=4.0, dropout=0.1): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(hidden_dim) |
| self.cross_attn = nn.MultiheadAttention( |
| embed_dim=hidden_dim, num_heads=num_heads, batch_first=True, dropout=dropout |
| ) |
|
|
| self.norm2 = nn.LayerNorm(hidden_dim) |
| self.mlp = nn.Sequential( |
| nn.Linear(hidden_dim, int(hidden_dim * mlp_ratio)), |
| nn.GELU(), |
| nn.Linear(int(hidden_dim * mlp_ratio), hidden_dim), |
| ) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, query, encoder_hidden_state, encoder_attention_mask=None): |
| """ |
| Cross-attention block forward. |
| Args: |
| query (Tensor): Shape [B, Q, D]. Learnable query tokens propagated across layers. |
| encoder_hidden_state (Tensor): Shape [B, L, D]. Features from one encoder layer. |
| encoder_attention_mask (Tensor | None): Shape [B, L]. 1/True=keep (visible), 0/False=mask. None disables masking. |
| Returns: |
| Tensor: Updated query tokens of shape [B, Q, D]. |
| Details: |
| 1. LayerNorm + MultiheadAttention (Q = query, K/V = encoder_hidden_state). |
| 2. Residual path: query = query + attn_output, then add MLP residual. |
| 3. Dropout is applied only on the MLP output. |
| """ |
| q = self.norm1(query) |
| kv = encoder_hidden_state |
|
|
| if encoder_attention_mask is not None: |
| attn_mask = encoder_attention_mask.unsqueeze(1).to(dtype=torch.bool) |
| else: |
| attn_mask = None |
|
|
| attn_output, _ = self.cross_attn(q, kv, kv, key_padding_mask=attn_mask) |
| query = query + attn_output |
| query = query + self.dropout(self.mlp(self.norm2(query))) |
| return query |
|
|
|
|
| class LayerwiseQFormer(nn.Module): |
| def __init__( |
| self, input_hidden_dim=2048, output_hidden_dim=768, num_query_tokens=64, num_layers=37, num_heads=8, config=None |
| ): |
| super().__init__() |
| self.input_hidden_dim = input_hidden_dim |
| self.output_hidden_dim = output_hidden_dim |
| self.num_query_tokens = num_query_tokens |
| self.num_layers = num_layers |
| self.config = config |
| |
| self.proj = nn.Linear(input_hidden_dim, output_hidden_dim) |
| |
| self.query_tokens = nn.Parameter(torch.randn(num_query_tokens, output_hidden_dim)) |
|
|
| |
| self.layers = nn.ModuleList([CrossAttentionBlock(output_hidden_dim, num_heads) for _ in range(num_layers)]) |
|
|
| def forward(self, hidden_states_list, encoder_attention_mask=None): |
| """ |
| Layer-wise Q-Former forward pass. |
| Args: |
| hidden_states_list (List[Tensor]): Length == num_layers. Each tensor is [B, L, Din], raw encoder layer outputs (before projection). |
| encoder_attention_mask (Tensor | None): Shape [B, L]. Same semantics as in CrossAttentionBlock. |
| Returns: |
| Tensor: Aggregated query tokens of shape [B, Q, Dout]. |
| Pipeline: |
| 1. Stack per-layer features to [B, N, L, Din] and linearly project to Dout. |
| 2. Expand global learnable query tokens to batch: [B, Q, Dout]. |
| 3. Apply cross-attention layer-by-layer: each query attends only to the corresponding encoder layer features. |
| Notes: |
| - Asserts len(hidden_states_list) == num_layers. |
| - Does not modify gradient flow of hidden_states_list. |
| """ |
| |
|
|
| assert ( |
| len(hidden_states_list) == self.num_layers |
| ), f"Expected {self.num_layers} layers, got {len(hidden_states_list)}" |
|
|
| B = hidden_states_list[0].size(0) |
| |
| |
| hs = torch.stack(hidden_states_list, dim=1) |
| |
| proj_hs = self.proj(hs) |
| |
| hidden_states_list = list(proj_hs.unbind(dim=1)) |
|
|
| |
| query = self.query_tokens.unsqueeze(0).expand(B, -1, -1) |
|
|
| |
| for i, layer in enumerate(self.layers): |
| query = layer(query, hidden_states_list[i], encoder_attention_mask) |
|
|
| return query |
|
|
| def scale_hook(self, hidden_states_list, scale_factor=0.1): |
| """ |
| (Experimental / optional) Register gradient scaling hooks on each layer's hidden states. |
| Args: |
| hidden_states_list (List[Tensor]): Per-layer feature tensors. |
| scale_factor (float): Gradient scaling factor (effective only if enabled via config and != 1). |
| Returns: |
| List[Tensor]: Original list (no data copy); hooks may be attached in-place. |
| Design: |
| - Currently returns immediately (guard condition hard-coded False) as a placeholder. |
| - Uses attribute _scaled_hook to avoid duplicate hook registration in distributed settings. |
| - Can be enabled later for gradient dampening or perturbation experiments. |
| Performance: |
| - Excessive hook registrations can hurt speed; kept lazy by default. |
| """ |
| |
| if ( |
| self.config |
| and hasattr(self.config.vla, "layer_qformer") |
| and hasattr(self.config.vla.layer_qformer, "grad_scale") |
| and self.config.vla.layer_qformer.grad_scale != 1 |
| ): |
| scale_factor = self.config.vla.layer_qformer.grad_scale |
| else: |
| return hidden_states_list |
|
|
| scaled_hidden_states_list = [] |
| for hidden_states in hidden_states_list: |
| if hidden_states.requires_grad: |
| |
| if not hasattr(hidden_states, "_scaled_hook"): |
| hook = lambda grad: grad * scale_factor |
| hidden_states.register_hook(hook) |
| hidden_states._scaled_hook = True |
| scaled_hidden_states_list.append(hidden_states) |
|
|
| return hidden_states_list |
|
|
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| def get_layerwise_qformer(num_heads=8, config=None, **kwargs): |
| """ |
| Build a LayerwiseQFormer instance. |
| Args: |
| num_heads (int): Number of attention heads for CrossAttentionBlock. |
| config: Configuration object; must contain config.framework.layer_qformer with: |
| - qformer_start_layer / qformer_end_layer: range of layers (start inclusive, end exclusive). |
| - num_query_tokens: Number of learnable query tokens. |
| - input_dim: Input feature dimension (Din). |
| - ouptput_dim: Output feature dimension (Dout). |
| **kwargs: Reserved for future extensions (unused). |
| Returns: |
| LayerwiseQFormer: Instantiated model. |
| Notes: |
| - num_layers = end_layer - start_layer (half-open interval). |
| - Does not perform weight loading or device moves here. |
| """ |
| |
| qformer_cfg = config.framework.layer_qformer |
| num_layers = qformer_cfg.qformer_end_layer - qformer_cfg.qformer_start_layer if config else num_layers |
| num_query_tokens = qformer_cfg.num_query_tokens |
| input_hidden_dim = config.framework.layer_qformer.input_dim |
| output_hidden_dim = config.framework.layer_qformer.ouptput_dim |
| num_query_tokens = qformer_cfg.num_query_tokens |
|
|
| qformer = LayerwiseQFormer( |
| input_hidden_dim=input_hidden_dim, |
| output_hidden_dim=output_hidden_dim, |
| num_query_tokens=num_query_tokens, |
| num_layers=num_layers, |
| num_heads=num_heads, |
| config=config, |
| ) |
| return qformer |
|
|