Spaces:
Runtime error
Runtime error
| """ | |
| Ablation Projector: A configurable projector for ablation studies based on C2CProjector. | |
| Allows gradual removal of components to study their individual contributions. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from torch import Tensor | |
| from typing import Optional, Tuple, Literal | |
| from rosetta.utils.registry import register_model, capture_init_args | |
| from rosetta.model.projector import Projector | |
| from rosetta.model.projector import RegularMLP | |
| class AblationProjector(Projector): | |
| """ | |
| Ablation study projector based on C2CProjector with configurable component removal. | |
| Ablation levels: | |
| 0. Full C2C (baseline) | |
| 1. Remove scalar weights (set to 1.0) | |
| 2. Remove gates (set to 1.0) | |
| 3. Remove target contribution (only use source) | |
| 4. Remove gates only (gates=1.0), keep scalars and target | |
| Each level builds on the previous one, allowing gradual degradation study. | |
| """ | |
| def __init__( | |
| self, | |
| source_dim: int, | |
| target_dim: int, | |
| source_num_heads: int = 1, | |
| target_num_heads: int = 1, | |
| intermediate_dim: int = 1024, | |
| hidden_dim: int = 1024, | |
| num_layers: int = 3, | |
| dropout: float = 0.1, | |
| initial_temperature: float = 1.0, | |
| final_temperature: float = 0.001, | |
| anneal_steps: int = 1929, | |
| dtype: torch.dtype = torch.float32, | |
| # Ablation configuration | |
| ablation_level: int = 0, # 0=full, 1=no_scalar, 2=no_gate+no_scalar, 3=no_target, 4=no_gate_only | |
| use_scalar_weights: bool = True, # Can be overridden by ablation_level | |
| use_gates: bool = True, # Can be overridden by ablation_level | |
| use_target: bool = True, # Can be overridden by ablation_level | |
| ): | |
| super().__init__() | |
| assert 0 <= ablation_level <= 4, "ablation_level must be 0, 1, 2, 3, or 4" | |
| # Dimensions | |
| self.source_dim = source_dim | |
| self.target_dim = target_dim | |
| self.source_num_heads = source_num_heads | |
| self.target_num_heads = target_num_heads | |
| self.ablation_level = ablation_level | |
| # Override component usage based on ablation level | |
| if ablation_level == 4: | |
| # Special case: disable gates only, keep scalars and target | |
| use_scalar_weights = True | |
| use_gates = False | |
| use_target = True | |
| else: | |
| if ablation_level >= 1: | |
| use_scalar_weights = False | |
| if ablation_level >= 2: | |
| use_gates = False | |
| if ablation_level >= 3: | |
| use_target = False | |
| self.use_scalar_weights = use_scalar_weights | |
| self.use_gates = use_gates | |
| self.use_target = use_target | |
| # Sizes | |
| in_dim = source_dim * source_num_heads | |
| out_dim = target_dim * target_num_heads | |
| # 1) concat(source_X, target_X) then project to hidden_dim | |
| # If not using target, only use source features | |
| if self.use_target: | |
| self.key_in = nn.Linear(in_dim + out_dim, hidden_dim, bias=True, dtype=dtype) | |
| self.value_in = nn.Linear(in_dim + out_dim, hidden_dim, bias=True, dtype=dtype) | |
| else: | |
| # Only use source features | |
| self.key_in = nn.Linear(in_dim, hidden_dim, bias=True, dtype=dtype) | |
| self.value_in = nn.Linear(in_dim, hidden_dim, bias=True, dtype=dtype) | |
| # 2) one-layer common embedding MLP to get intermediate representation (at hidden_dim) | |
| self.key_mlp1 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=1, dropout=dropout, dtype=dtype) | |
| self.value_mlp1 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=1, dropout=dropout, dtype=dtype) | |
| # 3a) intermediate representation → (L-2)-layer MLP for weights → project to head dim | |
| # Only build if using scalar weights | |
| if self.use_scalar_weights: | |
| self.key_scalar_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=hidden_dim, num_layers=1, dropout=dropout, dtype=dtype) | |
| self.value_scalar_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=hidden_dim, num_layers=1, dropout=dropout, dtype=dtype) | |
| self.key_scalar_head = nn.Linear(hidden_dim, target_num_heads, dtype=dtype) | |
| self.value_scalar_head = nn.Linear(hidden_dim, target_num_heads, dtype=dtype) | |
| # 3b) intermediate representation → (L-2)-layer MLP for projected_X → finally project hidden_dim → out_dim | |
| self.key_proj_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=num_layers-2, dropout=dropout, dtype=dtype) | |
| self.value_proj_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=num_layers-2, dropout=dropout, dtype=dtype) | |
| self.key_proj_out = nn.Linear(hidden_dim, out_dim, bias=True, dtype=dtype) | |
| self.value_proj_out = nn.Linear(hidden_dim, out_dim, bias=True, dtype=dtype) | |
| # Scalar key/value gate parameters and temperature schedule | |
| # Only build if using gates | |
| if self.use_gates: | |
| self.key_gate_logit = nn.Parameter(torch.tensor(0.0, dtype=dtype)) | |
| self.value_gate_logit = nn.Parameter(torch.tensor(0.0, dtype=dtype)) | |
| self.use_gumbel = True | |
| self.register_buffer("gate_temperature", torch.tensor(initial_temperature, dtype=dtype)) | |
| self.initial_temperature = initial_temperature | |
| self.final_temperature = final_temperature | |
| self.anneal_steps = anneal_steps | |
| # Temperature for weight normalization | |
| self.scalar_temperature = 1.0 | |
| def update_temperature(self, step: int): | |
| """Update temperature using exponential annealing schedule for gates.""" | |
| if self.use_gates: | |
| ratio = min(step / self.anneal_steps, 1.0) | |
| temp = self.initial_temperature * (self.final_temperature / self.initial_temperature) ** ratio | |
| self.gate_temperature.fill_(temp) | |
| def forward( | |
| self, | |
| source_kv: Tuple[Tensor, Tensor], | |
| target_kv: Tuple[Tensor, Tensor], | |
| position_ids: Optional[Tensor] = None, | |
| max_pos: Optional[Tensor] = None, | |
| ) -> Tuple[Tensor, Tensor]: | |
| source_key, source_value = source_kv | |
| target_key, target_value = target_kv | |
| B, Hs, N, Ds = source_key.shape | |
| _, Ht, _, Dt = target_key.shape | |
| # Flatten heads | |
| source_key_flat = source_key.transpose(1, 2).contiguous().view(B, N, Hs * Ds) | |
| source_value_flat = source_value.transpose(1, 2).contiguous().view(B, N, Hs * Ds) | |
| target_key_flat = target_key.transpose(1, 2).contiguous().view(B, N, Ht * Dt) | |
| target_value_flat = target_value.transpose(1, 2).contiguous().view(B, N, Ht * Dt) | |
| # 1) Prepare input features based on ablation level | |
| if self.use_target: | |
| # Full C2C: concat source and target features | |
| key_cat = torch.cat([source_key_flat, target_key_flat], dim=-1) | |
| value_cat = torch.cat([source_value_flat, target_value_flat], dim=-1) | |
| else: | |
| # Ablation level 3: only use source features | |
| key_cat = source_key_flat | |
| value_cat = source_value_flat | |
| # 2) project to hidden dim | |
| key_hidden = self.key_in(key_cat) | |
| value_hidden = self.value_in(value_cat) | |
| # 3) one-layer common embedding MLP to get intermediate representation (at hidden_dim) | |
| key_hidden = self.key_mlp1(key_hidden) | |
| value_hidden = self.value_mlp1(value_hidden) | |
| # 4b) intermediate representation -> projected feature path | |
| key_proj_hidden = self.key_proj_out(self.key_proj_mlp2(key_hidden)) # (B, N, Ht * Dt) | |
| value_proj_hidden = self.value_proj_out(self.value_proj_mlp2(value_hidden)) # (B, N, Ht * Dt) | |
| projected_key = key_proj_hidden.view(B, N, Ht, Dt).transpose(1, 2) # (B, Ht, N, Dt) | |
| projected_value = value_proj_hidden.view(B, N, Ht, Dt).transpose(1, 2) # (B, Ht, N, Dt) | |
| # 4a) intermediate representation -> scalar path (if using scalar weights) | |
| if self.use_scalar_weights: | |
| key_scalar = self.key_scalar_head(self.key_scalar_mlp2(key_hidden)) # (B, N, Ht) | |
| value_scalar = self.value_scalar_head(self.value_scalar_mlp2(value_hidden)) # (B, N, Ht) | |
| key_scalar = key_scalar.permute(0, 2, 1).unsqueeze(-1) # (B, Ht, N, 1) | |
| value_scalar = value_scalar.permute(0, 2, 1).unsqueeze(-1) # (B, Ht, N, 1) | |
| # Normalize scalars | |
| norm_key_scalar = torch.sigmoid(key_scalar) | |
| norm_value_scalar = torch.sigmoid(value_scalar) | |
| else: | |
| # Ablation level 1+: set scalar weights to 1.0 | |
| norm_key_scalar = torch.ones(B, Ht, N, 1, device=projected_key.device, dtype=projected_key.dtype) | |
| norm_value_scalar = torch.ones(B, Ht, N, 1, device=projected_value.device, dtype=projected_value.dtype) | |
| # Key/value gates (if using gates) | |
| if self.use_gates: | |
| key_gate_logit = self.key_gate_logit.view(1, 1, 1, 1) | |
| value_gate_logit = self.value_gate_logit.view(1, 1, 1, 1) | |
| if self.training and self.use_gumbel: | |
| u1 = torch.rand(B, Ht, N, 1, device=key_gate_logit.device, dtype=key_gate_logit.dtype) | |
| u2 = torch.rand(B, Ht, N, 1, device=value_gate_logit.device, dtype=value_gate_logit.dtype) | |
| g1 = -torch.log(-torch.log(u1 + 1e-20) + 1e-20) | |
| g2 = -torch.log(-torch.log(u2 + 1e-20) + 1e-20) | |
| key_gate = torch.sigmoid((key_gate_logit + g1) / self.gate_temperature) | |
| value_gate = torch.sigmoid((value_gate_logit + g2) / self.gate_temperature) | |
| else: | |
| key_gate = (key_gate_logit > 0).float() | |
| value_gate = (value_gate_logit > 0).float() | |
| else: | |
| # Gates disabled: set gates to 1.0 (always open) | |
| key_gate = torch.ones(B, Ht, N, 1, device=projected_key.device, dtype=projected_key.dtype) | |
| value_gate = torch.ones(B, Ht, N, 1, device=projected_value.device, dtype=projected_value.dtype) | |
| # Compute projected contribution | |
| projected_key_term = key_gate * norm_key_scalar * projected_key | |
| projected_value_term = value_gate * norm_value_scalar * projected_value | |
| # Compute target contribution (if using target) | |
| if self.use_target: | |
| # Full C2C: add target with projected | |
| output_key = target_key + projected_key_term | |
| output_value = target_value + projected_value_term | |
| else: | |
| # Ablation level 3: only use projected (no target) | |
| output_key = projected_key_term | |
| output_value = projected_value_term | |
| return output_key, output_value | |
| def get_ablation_info(self) -> dict: | |
| """Return information about current ablation configuration.""" | |
| return { | |
| 'ablation_level': self.ablation_level, | |
| 'use_scalar_weights': self.use_scalar_weights, | |
| 'use_gates': self.use_gates, | |
| 'use_target': self.use_target, | |
| 'description': self._get_ablation_description() | |
| } | |
| def _get_ablation_description(self) -> str: | |
| """Get human-readable description of current ablation level.""" | |
| descriptions = { | |
| 0: "Full C2C (baseline)", | |
| 1: "No scalar weights (scalars=1.0)", | |
| 2: "No gates (gates=1.0) + No scalar weights", | |
| 3: "No target (source-only) + No gates + No scalar weights", | |
| 4: "No gates (gates=1.0), keep scalars and target" | |
| } | |
| return descriptions.get(self.ablation_level, "Unknown ablation level") | |
| # Convenience functions for creating specific ablation levels | |
| def create_ablation_projector( | |
| source_dim: int, | |
| target_dim: int, | |
| source_num_heads: int = 1, | |
| target_num_heads: int = 1, | |
| ablation_level: int = 0, | |
| **kwargs | |
| ) -> AblationProjector: | |
| """Create an AblationProjector with specified ablation level.""" | |
| return AblationProjector( | |
| source_dim=source_dim, | |
| target_dim=target_dim, | |
| source_num_heads=source_num_heads, | |
| target_num_heads=target_num_heads, | |
| ablation_level=ablation_level, | |
| **kwargs | |
| ) | |
| def create_full_c2c_projector(**kwargs) -> AblationProjector: | |
| """Create full C2C projector (ablation level 0).""" | |
| return create_ablation_projector(ablation_level=0, **kwargs) | |
| def create_no_scalar_projector(**kwargs) -> AblationProjector: | |
| """Create projector without scalar weights (ablation level 1).""" | |
| return create_ablation_projector(ablation_level=1, **kwargs) | |
| def create_no_gate_projector(**kwargs) -> AblationProjector: | |
| """Create projector without gates (ablation level 2).""" | |
| return create_ablation_projector(ablation_level=2, **kwargs) | |
| def create_source_only_projector(**kwargs) -> AblationProjector: | |
| """Create source-only projector (ablation level 3).""" | |
| return create_ablation_projector(ablation_level=3, **kwargs) | |
| def create_no_gate_only_projector(**kwargs) -> AblationProjector: | |
| """Create projector without gates but with scalar weights and target (ablation level 4).""" | |
| return create_ablation_projector(ablation_level=4, **kwargs) | |