| |
| |
| |
|
|
|
|
|
|
| from dataclasses import dataclass, field |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
| from torch.distributions import Beta |
| from transformers import PretrainedConfig |
| from transformers.feature_extraction_utils import BatchFeature |
|
|
| from starVLA.model.modules.action_model.flow_matching_head.action_encoder import ( |
| SinusoidalPositionalEncoding, |
| swish, |
| ) |
|
|
| from starVLA.model.modules.action_model.flow_matching_head.cross_attention_dit import DiT, SelfAttentionTransformer |
|
|
| |
|
|
| class CategorySpecificLinear(nn.Module): |
| def __init__(self, num_categories, input_dim, hidden_dim): |
| super().__init__() |
| self.num_categories = num_categories |
| |
| self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim)) |
| self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim)) |
|
|
| def forward(self, x, cat_ids): |
| selected_W = self.W[cat_ids] |
| selected_b = self.b[cat_ids] |
| |
| return torch.bmm(x, selected_W) + selected_b.unsqueeze(1) |
|
|
|
|
| class CategorySpecificMLP(nn.Module): |
| def __init__(self, num_categories, input_dim, hidden_dim, output_dim): |
| super().__init__() |
| self.num_categories = num_categories |
| self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim) |
| self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim) |
|
|
| def forward(self, x, cat_ids): |
| hidden = F.relu(self.layer1(x, cat_ids)) |
| return self.layer2(hidden, cat_ids) |
|
|
|
|
|
|
| class MLP(nn.Module): |
| def __init__(self, input_dim, hidden_dim=1024, output_dim=2048): |
| super().__init__() |
| self.layer1 = nn.Linear(input_dim, hidden_dim) |
| self.layer2 = nn.Linear(hidden_dim, output_dim) |
|
|
| def forward(self, x): |
| return self.layer2(F.relu(self.layer1(x))) |
|
|
|
|
| class ActionEncoder(nn.Module): |
| def __init__(self, action_dim, hidden_size=1024): |
| super().__init__() |
| self.hidden_size = hidden_size |
| self.action_dim = action_dim |
| self.layer1 = nn.Linear(action_dim, hidden_size) |
| self.layer2 = nn.Linear(2 * hidden_size, hidden_size) |
| self.layer3 = nn.Linear(hidden_size, hidden_size) |
| self.pos_encoding = SinusoidalPositionalEncoding(hidden_size) |
|
|
| def forward(self, actions, timesteps): |
| """ |
| actions: shape (B, T, action_dim) |
| timesteps: shape (B,) -- a single scalar per batch item |
| returns: shape (B, T, hidden_size) |
| """ |
| B, T, _ = actions.shape |
|
|
| |
| |
| |
| if timesteps.dim() == 1 and timesteps.shape[0] == B: |
| |
| timesteps = timesteps.unsqueeze(1).expand(-1, T) |
| else: |
| raise ValueError( |
| "Expected `timesteps` to have shape (B,) so we can replicate across T." |
| ) |
|
|
| |
| a_emb = self.layer1(actions) |
|
|
| |
| tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype) |
|
|
| |
| x = torch.cat([a_emb, tau_emb], dim=-1) |
| x = swish(self.layer2(x)) |
|
|
| |
| x = self.layer3(x) |
| return x |
|
|
|
|
|
|
| class MultiEmbodimentActionEncoder(nn.Module): |
| def __init__(self, action_dim, hidden_size=1024, num_embodiments=8): |
| super().__init__() |
| self.hidden_size = hidden_size |
| self.num_embodiments = num_embodiments |
|
|
| |
| self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size) |
| self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size) |
| self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size) |
| self.pos_encoding = SinusoidalPositionalEncoding(hidden_size) |
|
|
| def forward(self, actions, timesteps, cat_ids): |
| """ |
| actions: shape (B, T, action_dim) |
| timesteps: shape (B,) -- a single scalar per batch item |
| cat_ids: shape (B,) |
| returns: shape (B, T, hidden_size) |
| """ |
| B, T, _ = actions.shape |
|
|
| |
| |
| |
| if timesteps.dim() == 1 and timesteps.shape[0] == B: |
| |
| timesteps = timesteps.unsqueeze(1).expand(-1, T) |
| else: |
| raise ValueError( |
| "Expected `timesteps` to have shape (B,) so we can replicate across T." |
| ) |
|
|
| |
| a_emb = self.W1(actions, cat_ids) |
|
|
| |
| tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype) |
|
|
| |
| x = torch.cat([a_emb, tau_emb], dim=-1) |
| x = swish(self.W2(x, cat_ids)) |
|
|
| |
| x = self.W3(x, cat_ids) |
| return x |
|
|
|
|
| @dataclass |
| class FlowmatchingActionHeadConfig(PretrainedConfig): |
| """NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head""" |
|
|
| add_pos_embed: bool = field( |
| default=True, metadata={"help": "Whether to add positional embedding"} |
| ) |
| diffusion_model_cfg: dict = field( |
| default=None, metadata={"help": "Diffusion model configuration."} |
| ) |
| input_embedding_dim: int = field( |
| default=1536, metadata={"help": "Input embedding channel dimension."} |
| ) |
|
|
| hidden_size: int = field(default=1024, metadata={"help": "Input embedding dimension."}) |
| max_seq_len: int = field(default=1024, metadata={"help": "Maxium Sequence Length"}) |
| action_dim: int = field(default=None, metadata={"help": "Action dimension."}) |
| action_horizon: int = field(default=None, metadata={"help": "Action horizon."}) |
| noise_beta_alpha: float = field(default=1.5, metadata={"help": ""}) |
| noise_beta_beta: float = field(default=1.0, metadata={"help": ""}) |
| noise_s: float = field( |
| default=0.999, metadata={"help": "Flow matching noise Beta distribution s."} |
| ) |
| num_timestep_buckets: int = field( |
| default=1000, metadata={"help": "Number of timestep discretization buckets."} |
| ) |
| num_inference_timesteps: int = field( |
| default=None, |
| metadata={"help": "Number of inference steps for noise diffusion."}, |
| ) |
| max_num_embodiments: int = field(default=32, metadata={"help": "Number of embodiments."}) |
| tune_projector: bool = field(default=True, metadata={"help": "Whether to tune the projector."}) |
| tune_diffusion_model: bool = field( |
| default=True, metadata={"help": "Whether to tune the diffusion model."} |
| ) |
| load_pretrained_det_decode_layer_path: str = field( |
| default=None, metadata={"help": "Path to pretrained detection model."} |
| ) |
| detection_coeff: float = field(default=1.0, metadata={"help": "Detection coefficient."}) |
|
|
| freeze_decode_layer: bool = field(default=False) |
| expand_batch: int = field(default=None) |
| use_vlln: bool = field(default=True) |
|
|
| vl_self_attention_cfg: dict = field(default=None) |
| num_target_vision_tokens: int = field( |
| default=32, metadata={"help": "Number of target vision tokens."} |
| ) |
|
|
| def __init__(self, **kwargs): |
| super().__init__(**kwargs) |
| for key, value in kwargs.items(): |
| setattr(self, key, value) |
|
|
|
|
|
|
|
|
| DiTConfig = {"num_layers": 36, "input_embedding_dim": 2048, "attention_head_dim": 64, "num_attention_heads": 32} |
|
|
|
|
| class LayerwiseFlowmatchingActionHead(nn.Module): |
| def __init__( |
| self, |
| global_config, |
| **kwargs, |
| ): |
| super().__init__() |
| action_config = global_config.framework.action_model |
| diffusion_model_cfg = action_config.diffusion_model_cfg |
|
|
| |
| DiTConfig["num_layers"] = global_config.framework.qwenvl.num_vl_layers |
| DiTConfig["input_embedding_dim"] = global_config.framework.qwenvl.vl_hidden_dim |
| DiTConfig["num_attention_heads"] = DiTConfig["input_embedding_dim"] // DiTConfig["attention_head_dim"] |
| diffusion_model_cfg.update(DiTConfig) |
| |
| diffusion_model_cfg.cross_attention_dim = DiTConfig["input_embedding_dim"] |
| self.input_embedding_dim = global_config.framework.qwenvl.vl_hidden_dim |
| self.model = DiT(**diffusion_model_cfg) |
| self.dit_out_hidden_size = self.input_embedding_dim |
| self.action_dim = action_config.action_dim |
| self.action_horizon = action_config.future_action_window_size + 1 |
| self.num_inference_timesteps = action_config.num_inference_timesteps |
|
|
| self.state_encoder = MLP( |
| input_dim=action_config.state_dim, |
| output_dim=self.input_embedding_dim, |
| ) if action_config.state_dim else None |
|
|
| self.action_encoder = ActionEncoder( |
| action_dim=action_config.action_dim, |
| hidden_size=self.input_embedding_dim, |
| ) |
| self.action_decoder = MLP( |
| input_dim=self.input_embedding_dim, |
| hidden_dim=1024, |
| output_dim=self.action_dim, |
| ) |
| self.future_tokens = nn.Embedding(action_config.num_target_vision_tokens, self.input_embedding_dim) |
| nn.init.normal_(self.future_tokens.weight, mean=0.0, std=0.02) |
|
|
| if action_config.add_pos_embed: |
| self.position_embedding = nn.Embedding(action_config.max_seq_len, self.input_embedding_dim) |
| nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02) |
|
|
| self.beta_dist = Beta(action_config.noise_beta_alpha, action_config.noise_beta_beta) |
| self.num_timestep_buckets = action_config.num_timestep_buckets |
| self.config = action_config |
|
|
| def sample_time(self, batch_size, device, dtype): |
| sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype) |
| return (self.config.noise_s - sample) / self.config.noise_s |
|
|
| def prepare_input(self, batch: dict) -> BatchFeature: |
| return BatchFeature(data=batch) |
|
|
|
|
| def forward(self, vl_embs_list: list, actions: torch.Tensor, state: torch.Tensor = None, encoder_attention_mask: torch.Tensor = None): |
| """ |
| vl_embs: list of torch.Tensor, each shape (B, seq_length, feature_dim) |
| actions: shape (B, future_action_window_size, D_action) |
| encoder_attention_mask: optional (B, seq_length) mask for VLM padding tokens |
| """ |
| device = actions.device |
| num_layers = len(vl_embs_list) |
| B, L, D = vl_embs_list[0].shape |
| |
| noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype) |
| t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype) |
| t = t[:, None, None] |
|
|
| noisy_trajectory = (1 - t) * noise + t * actions |
| velocity = actions - noise |
|
|
| |
| t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long() |
| action_features = self.action_encoder(noisy_trajectory, t_discretized) |
|
|
| |
| state_features = self.state_encoder(state) if state is not None else None |
|
|
| |
| if self.config.add_pos_embed: |
| pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device) |
| pos_embs = self.position_embedding(pos_ids).unsqueeze(0) |
| action_features = action_features + pos_embs |
|
|
| |
| future_tokens = self.future_tokens.weight.unsqueeze(0).expand(B, -1, -1) |
| sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1) \ |
| if state_features is not None else torch.cat((future_tokens, action_features), dim=1) |
| |
| |
| temb = self.model.timestep_encoder(t_discretized) |
|
|
| |
| |
| |
| if encoder_attention_mask is not None: |
| encoder_attention_mask = encoder_attention_mask.bool() |
|
|
| |
| model_output = sa_embs |
| for layer_idx, layer in enumerate(self.model.transformer_blocks): |
| model_output = layer( |
| hidden_states=model_output, |
| encoder_hidden_states=vl_embs_list[layer_idx], |
| temb=temb, |
| encoder_attention_mask=encoder_attention_mask, |
| ) |
| |
| |
| pred = self.action_decoder(model_output) |
| pred_actions = pred[:, -actions.shape[1] :] |
|
|
| |
| loss = ((pred_actions - velocity) ** 2).mean() |
| return loss |
|
|
| @torch.no_grad() |
| def predict_action(self, vl_embs_list: list, state: torch.Tensor = None, encoder_attention_mask: torch.Tensor = None) -> torch.Tensor: |
| |
| batch_size = vl_embs_list[0].shape[0] |
| device = vl_embs_list[0].device |
| actions = torch.randn( |
| size=(batch_size, self.action_horizon, self.action_dim), |
| dtype=vl_embs_list[0].dtype, |
| device=device, |
| ) |
|
|
| num_steps = self.num_inference_timesteps |
| dt = 1.0 / num_steps |
|
|
| state_features = self.state_encoder(state) if state is not None else None |
|
|
| |
| if encoder_attention_mask is not None: |
| encoder_attention_mask = encoder_attention_mask.bool() |
|
|
| |
| for t in range(num_steps): |
| t_cont = t / float(num_steps) |
| t_discretized_int = int(t_cont * self.num_timestep_buckets) |
| timesteps_tensor = torch.full( |
| size=(batch_size,), fill_value=t_discretized_int, device=device, dtype=torch.long |
| ) |
|
|
| |
| action_features = self.action_encoder(actions, timesteps_tensor) |
|
|
| |
| if self.config.add_pos_embed: |
| pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device) |
| pos_embs = self.position_embedding(pos_ids).unsqueeze(0) |
| action_features = action_features + pos_embs |
|
|
| future_tokens = self.future_tokens.weight.unsqueeze(0).expand(batch_size, -1, -1) |
| sa_embs = ( |
| torch.cat((state_features, future_tokens, action_features), dim=1) |
| if state_features is not None |
| else torch.cat((future_tokens, action_features), dim=1) |
| ) |
|
|
| |
| temb = self.model.timestep_encoder(timesteps_tensor) |
|
|
| |
| model_output = sa_embs |
| for layer_idx, layer in enumerate(self.model.transformer_blocks): |
| model_output = layer( |
| hidden_states=model_output, |
| encoder_hidden_states=vl_embs_list[layer_idx], |
| temb=temb, |
| encoder_attention_mask=encoder_attention_mask, |
| ) |
| |
| pred = self.action_decoder(model_output) |
| pred_velocity = pred[:, -self.action_horizon :] |
|
|
| |
| actions = actions + dt * pred_velocity |
| return actions |
|
|
| @property |
| def device(self): |
| return next(iter(self.parameters())).device |
|
|
| @property |
| def dtype(self): |
| return next(iter(self.parameters())).dtype |
|
|
|
|
|
|
| def get_action_model(config=None): |
| """ |
| Factory: build FlowmatchingActionHead from global framework config. |
| |
| Args: |
| config: Global config (expects config.framework.action_model namespace). |
| |
| Returns: |
| FlowmatchingActionHead: Initialized FlowMatchingActionHead. |
| """ |
| return LayerwiseFlowmatchingActionHead( |
| global_config=config |
| ) |
|
|
|
|
|
|
| if __name__ == "__main__": |
| |
|
|
| pass |