File size: 8,286 Bytes
e94400c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 | import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
class CrossAttentionBlock(nn.Module):
def __init__(self, hidden_dim, num_heads, mlp_ratio=4.0, dropout=0.1):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_dim)
self.cross_attn = nn.MultiheadAttention(
embed_dim=hidden_dim, num_heads=num_heads, batch_first=True, dropout=dropout
)
self.norm2 = nn.LayerNorm(hidden_dim)
self.mlp = nn.Sequential(
nn.Linear(hidden_dim, int(hidden_dim * mlp_ratio)),
nn.GELU(),
nn.Linear(int(hidden_dim * mlp_ratio), hidden_dim),
)
self.dropout = nn.Dropout(dropout)
def forward(self, query, encoder_hidden_state, encoder_attention_mask=None):
"""
Cross-attention block forward.
Args:
query (Tensor): Shape [B, Q, D]. Learnable query tokens propagated across layers.
encoder_hidden_state (Tensor): Shape [B, L, D]. Features from one encoder layer.
encoder_attention_mask (Tensor | None): Shape [B, L]. 1/True=keep (visible), 0/False=mask. None disables masking.
Returns:
Tensor: Updated query tokens of shape [B, Q, D].
Details:
1. LayerNorm + MultiheadAttention (Q = query, K/V = encoder_hidden_state).
2. Residual path: query = query + attn_output, then add MLP residual.
3. Dropout is applied only on the MLP output.
"""
q = self.norm1(query)
kv = encoder_hidden_state
if encoder_attention_mask is not None:
attn_mask = encoder_attention_mask.unsqueeze(1).to(dtype=torch.bool) # [B, 1, L]
else:
attn_mask = None
attn_output, _ = self.cross_attn(q, kv, kv, key_padding_mask=attn_mask)
query = query + attn_output
query = query + self.dropout(self.mlp(self.norm2(query)))
return query
class LayerwiseQFormer(nn.Module):
def __init__(
self, input_hidden_dim=2048, output_hidden_dim=768, num_query_tokens=64, num_layers=37, num_heads=8, config=None
):
super().__init__()
self.input_hidden_dim = input_hidden_dim
self.output_hidden_dim = output_hidden_dim
self.num_query_tokens = num_query_tokens
self.num_layers = num_layers
self.config = config
# Project input to output dimension
self.proj = nn.Linear(input_hidden_dim, output_hidden_dim)
# Learnable query tokens
self.query_tokens = nn.Parameter(torch.randn(num_query_tokens, output_hidden_dim))
# Independent cross-attention blocks (one per encoder layer)
self.layers = nn.ModuleList([CrossAttentionBlock(output_hidden_dim, num_heads) for _ in range(num_layers)])
def forward(self, hidden_states_list, encoder_attention_mask=None):
"""
Layer-wise Q-Former forward pass.
Args:
hidden_states_list (List[Tensor]): Length == num_layers. Each tensor is [B, L, Din], raw encoder layer outputs (before projection).
encoder_attention_mask (Tensor | None): Shape [B, L]. Same semantics as in CrossAttentionBlock.
Returns:
Tensor: Aggregated query tokens of shape [B, Q, Dout].
Pipeline:
1. Stack per-layer features to [B, N, L, Din] and linearly project to Dout.
2. Expand global learnable query tokens to batch: [B, Q, Dout].
3. Apply cross-attention layer-by-layer: each query attends only to the corresponding encoder layer features.
Notes:
- Asserts len(hidden_states_list) == num_layers.
- Does not modify gradient flow of hidden_states_list.
"""
# hidden_states_list = self.scale_hook(hidden_states_list)
assert (
len(hidden_states_list) == self.num_layers
), f"Expected {self.num_layers} layers, got {len(hidden_states_list)}"
B = hidden_states_list[0].size(0)
# Project input hidden states to output dimension
# Result shape [B, N, L, Din]
hs = torch.stack(hidden_states_list, dim=1)
# proj_hs shape [B, N, L, Dout]
proj_hs = self.proj(hs)
# 3) Unbind back to list, each element restored to [B, L, Dout]
hidden_states_list = list(proj_hs.unbind(dim=1))
# Expand query tokens for each batch
query = self.query_tokens.unsqueeze(0).expand(B, -1, -1) # [B, Q, D]
# Iterate through each layer and apply cross-attention
for i, layer in enumerate(self.layers):
query = layer(query, hidden_states_list[i], encoder_attention_mask)
return query
def scale_hook(self, hidden_states_list, scale_factor=0.1):
"""
(Experimental / optional) Register gradient scaling hooks on each layer's hidden states.
Args:
hidden_states_list (List[Tensor]): Per-layer feature tensors.
scale_factor (float): Gradient scaling factor (effective only if enabled via config and != 1).
Returns:
List[Tensor]: Original list (no data copy); hooks may be attached in-place.
Design:
- Currently returns immediately (guard condition hard-coded False) as a placeholder.
- Uses attribute _scaled_hook to avoid duplicate hook registration in distributed settings.
- Can be enabled later for gradient dampening or perturbation experiments.
Performance:
- Excessive hook registrations can hurt speed; kept lazy by default.
"""
# --- 1. Register gradient scaling hooks on input hidden_states_list ---
if (
self.config
and hasattr(self.config.vla, "layer_qformer")
and hasattr(self.config.vla.layer_qformer, "grad_scale")
and self.config.vla.layer_qformer.grad_scale != 1
):
scale_factor = self.config.vla.layer_qformer.grad_scale
else:
return hidden_states_list # If grad_scale is not configured, return the original list
scaled_hidden_states_list = []
for hidden_states in hidden_states_list:
if hidden_states.requires_grad:
# Ensure gradient scaling is executed only once in distributed settings
if not hasattr(hidden_states, "_scaled_hook"): # Prevent duplicate registration --> Seems to accelerate
hook = lambda grad: grad * scale_factor
hidden_states.register_hook(hook)
hidden_states._scaled_hook = True # Mark as processed
scaled_hidden_states_list.append(hidden_states)
return hidden_states_list
import torch
import torch.nn as nn
def get_layerwise_qformer(num_heads=8, config=None, **kwargs):
"""
Build a LayerwiseQFormer instance.
Args:
num_heads (int): Number of attention heads for CrossAttentionBlock.
config: Configuration object; must contain config.framework.layer_qformer with:
- qformer_start_layer / qformer_end_layer: range of layers (start inclusive, end exclusive).
- num_query_tokens: Number of learnable query tokens.
- input_dim: Input feature dimension (Din).
- ouptput_dim: Output feature dimension (Dout).
**kwargs: Reserved for future extensions (unused).
Returns:
LayerwiseQFormer: Instantiated model.
Notes:
- num_layers = end_layer - start_layer (half-open interval).
- Does not perform weight loading or device moves here.
"""
# dist.barrier()
qformer_cfg = config.framework.layer_qformer
num_layers = qformer_cfg.qformer_end_layer - qformer_cfg.qformer_start_layer if config else num_layers
num_query_tokens = qformer_cfg.num_query_tokens
input_hidden_dim = config.framework.layer_qformer.input_dim
output_hidden_dim = config.framework.layer_qformer.ouptput_dim
num_query_tokens = qformer_cfg.num_query_tokens
qformer = LayerwiseQFormer(
input_hidden_dim=input_hidden_dim,
output_hidden_dim=output_hidden_dim,
num_query_tokens=num_query_tokens,
num_layers=num_layers,
num_heads=num_heads,
config=config,
)
return qformer
|