C2C_demo / rosetta /model /ablation_projector.py
fuvty's picture
[init] demo
5ccf219
"""
Ablation Projector: A configurable projector for ablation studies based on C2CProjector.
Allows gradual removal of components to study their individual contributions.
"""
import torch
import torch.nn as nn
from torch import Tensor
from typing import Optional, Tuple, Literal
from rosetta.utils.registry import register_model, capture_init_args
from rosetta.model.projector import Projector
from rosetta.model.projector import RegularMLP
@register_model
@capture_init_args
class AblationProjector(Projector):
"""
Ablation study projector based on C2CProjector with configurable component removal.
Ablation levels:
0. Full C2C (baseline)
1. Remove scalar weights (set to 1.0)
2. Remove gates (set to 1.0)
3. Remove target contribution (only use source)
4. Remove gates only (gates=1.0), keep scalars and target
Each level builds on the previous one, allowing gradual degradation study.
"""
def __init__(
self,
source_dim: int,
target_dim: int,
source_num_heads: int = 1,
target_num_heads: int = 1,
intermediate_dim: int = 1024,
hidden_dim: int = 1024,
num_layers: int = 3,
dropout: float = 0.1,
initial_temperature: float = 1.0,
final_temperature: float = 0.001,
anneal_steps: int = 1929,
dtype: torch.dtype = torch.float32,
# Ablation configuration
ablation_level: int = 0, # 0=full, 1=no_scalar, 2=no_gate+no_scalar, 3=no_target, 4=no_gate_only
use_scalar_weights: bool = True, # Can be overridden by ablation_level
use_gates: bool = True, # Can be overridden by ablation_level
use_target: bool = True, # Can be overridden by ablation_level
):
super().__init__()
assert 0 <= ablation_level <= 4, "ablation_level must be 0, 1, 2, 3, or 4"
# Dimensions
self.source_dim = source_dim
self.target_dim = target_dim
self.source_num_heads = source_num_heads
self.target_num_heads = target_num_heads
self.ablation_level = ablation_level
# Override component usage based on ablation level
if ablation_level == 4:
# Special case: disable gates only, keep scalars and target
use_scalar_weights = True
use_gates = False
use_target = True
else:
if ablation_level >= 1:
use_scalar_weights = False
if ablation_level >= 2:
use_gates = False
if ablation_level >= 3:
use_target = False
self.use_scalar_weights = use_scalar_weights
self.use_gates = use_gates
self.use_target = use_target
# Sizes
in_dim = source_dim * source_num_heads
out_dim = target_dim * target_num_heads
# 1) concat(source_X, target_X) then project to hidden_dim
# If not using target, only use source features
if self.use_target:
self.key_in = nn.Linear(in_dim + out_dim, hidden_dim, bias=True, dtype=dtype)
self.value_in = nn.Linear(in_dim + out_dim, hidden_dim, bias=True, dtype=dtype)
else:
# Only use source features
self.key_in = nn.Linear(in_dim, hidden_dim, bias=True, dtype=dtype)
self.value_in = nn.Linear(in_dim, hidden_dim, bias=True, dtype=dtype)
# 2) one-layer common embedding MLP to get intermediate representation (at hidden_dim)
self.key_mlp1 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=1, dropout=dropout, dtype=dtype)
self.value_mlp1 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=1, dropout=dropout, dtype=dtype)
# 3a) intermediate representation → (L-2)-layer MLP for weights → project to head dim
# Only build if using scalar weights
if self.use_scalar_weights:
self.key_scalar_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=hidden_dim, num_layers=1, dropout=dropout, dtype=dtype)
self.value_scalar_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=hidden_dim, num_layers=1, dropout=dropout, dtype=dtype)
self.key_scalar_head = nn.Linear(hidden_dim, target_num_heads, dtype=dtype)
self.value_scalar_head = nn.Linear(hidden_dim, target_num_heads, dtype=dtype)
# 3b) intermediate representation → (L-2)-layer MLP for projected_X → finally project hidden_dim → out_dim
self.key_proj_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=num_layers-2, dropout=dropout, dtype=dtype)
self.value_proj_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=num_layers-2, dropout=dropout, dtype=dtype)
self.key_proj_out = nn.Linear(hidden_dim, out_dim, bias=True, dtype=dtype)
self.value_proj_out = nn.Linear(hidden_dim, out_dim, bias=True, dtype=dtype)
# Scalar key/value gate parameters and temperature schedule
# Only build if using gates
if self.use_gates:
self.key_gate_logit = nn.Parameter(torch.tensor(0.0, dtype=dtype))
self.value_gate_logit = nn.Parameter(torch.tensor(0.0, dtype=dtype))
self.use_gumbel = True
self.register_buffer("gate_temperature", torch.tensor(initial_temperature, dtype=dtype))
self.initial_temperature = initial_temperature
self.final_temperature = final_temperature
self.anneal_steps = anneal_steps
# Temperature for weight normalization
self.scalar_temperature = 1.0
def update_temperature(self, step: int):
"""Update temperature using exponential annealing schedule for gates."""
if self.use_gates:
ratio = min(step / self.anneal_steps, 1.0)
temp = self.initial_temperature * (self.final_temperature / self.initial_temperature) ** ratio
self.gate_temperature.fill_(temp)
def forward(
self,
source_kv: Tuple[Tensor, Tensor],
target_kv: Tuple[Tensor, Tensor],
position_ids: Optional[Tensor] = None,
max_pos: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
source_key, source_value = source_kv
target_key, target_value = target_kv
B, Hs, N, Ds = source_key.shape
_, Ht, _, Dt = target_key.shape
# Flatten heads
source_key_flat = source_key.transpose(1, 2).contiguous().view(B, N, Hs * Ds)
source_value_flat = source_value.transpose(1, 2).contiguous().view(B, N, Hs * Ds)
target_key_flat = target_key.transpose(1, 2).contiguous().view(B, N, Ht * Dt)
target_value_flat = target_value.transpose(1, 2).contiguous().view(B, N, Ht * Dt)
# 1) Prepare input features based on ablation level
if self.use_target:
# Full C2C: concat source and target features
key_cat = torch.cat([source_key_flat, target_key_flat], dim=-1)
value_cat = torch.cat([source_value_flat, target_value_flat], dim=-1)
else:
# Ablation level 3: only use source features
key_cat = source_key_flat
value_cat = source_value_flat
# 2) project to hidden dim
key_hidden = self.key_in(key_cat)
value_hidden = self.value_in(value_cat)
# 3) one-layer common embedding MLP to get intermediate representation (at hidden_dim)
key_hidden = self.key_mlp1(key_hidden)
value_hidden = self.value_mlp1(value_hidden)
# 4b) intermediate representation -> projected feature path
key_proj_hidden = self.key_proj_out(self.key_proj_mlp2(key_hidden)) # (B, N, Ht * Dt)
value_proj_hidden = self.value_proj_out(self.value_proj_mlp2(value_hidden)) # (B, N, Ht * Dt)
projected_key = key_proj_hidden.view(B, N, Ht, Dt).transpose(1, 2) # (B, Ht, N, Dt)
projected_value = value_proj_hidden.view(B, N, Ht, Dt).transpose(1, 2) # (B, Ht, N, Dt)
# 4a) intermediate representation -> scalar path (if using scalar weights)
if self.use_scalar_weights:
key_scalar = self.key_scalar_head(self.key_scalar_mlp2(key_hidden)) # (B, N, Ht)
value_scalar = self.value_scalar_head(self.value_scalar_mlp2(value_hidden)) # (B, N, Ht)
key_scalar = key_scalar.permute(0, 2, 1).unsqueeze(-1) # (B, Ht, N, 1)
value_scalar = value_scalar.permute(0, 2, 1).unsqueeze(-1) # (B, Ht, N, 1)
# Normalize scalars
norm_key_scalar = torch.sigmoid(key_scalar)
norm_value_scalar = torch.sigmoid(value_scalar)
else:
# Ablation level 1+: set scalar weights to 1.0
norm_key_scalar = torch.ones(B, Ht, N, 1, device=projected_key.device, dtype=projected_key.dtype)
norm_value_scalar = torch.ones(B, Ht, N, 1, device=projected_value.device, dtype=projected_value.dtype)
# Key/value gates (if using gates)
if self.use_gates:
key_gate_logit = self.key_gate_logit.view(1, 1, 1, 1)
value_gate_logit = self.value_gate_logit.view(1, 1, 1, 1)
if self.training and self.use_gumbel:
u1 = torch.rand(B, Ht, N, 1, device=key_gate_logit.device, dtype=key_gate_logit.dtype)
u2 = torch.rand(B, Ht, N, 1, device=value_gate_logit.device, dtype=value_gate_logit.dtype)
g1 = -torch.log(-torch.log(u1 + 1e-20) + 1e-20)
g2 = -torch.log(-torch.log(u2 + 1e-20) + 1e-20)
key_gate = torch.sigmoid((key_gate_logit + g1) / self.gate_temperature)
value_gate = torch.sigmoid((value_gate_logit + g2) / self.gate_temperature)
else:
key_gate = (key_gate_logit > 0).float()
value_gate = (value_gate_logit > 0).float()
else:
# Gates disabled: set gates to 1.0 (always open)
key_gate = torch.ones(B, Ht, N, 1, device=projected_key.device, dtype=projected_key.dtype)
value_gate = torch.ones(B, Ht, N, 1, device=projected_value.device, dtype=projected_value.dtype)
# Compute projected contribution
projected_key_term = key_gate * norm_key_scalar * projected_key
projected_value_term = value_gate * norm_value_scalar * projected_value
# Compute target contribution (if using target)
if self.use_target:
# Full C2C: add target with projected
output_key = target_key + projected_key_term
output_value = target_value + projected_value_term
else:
# Ablation level 3: only use projected (no target)
output_key = projected_key_term
output_value = projected_value_term
return output_key, output_value
def get_ablation_info(self) -> dict:
"""Return information about current ablation configuration."""
return {
'ablation_level': self.ablation_level,
'use_scalar_weights': self.use_scalar_weights,
'use_gates': self.use_gates,
'use_target': self.use_target,
'description': self._get_ablation_description()
}
def _get_ablation_description(self) -> str:
"""Get human-readable description of current ablation level."""
descriptions = {
0: "Full C2C (baseline)",
1: "No scalar weights (scalars=1.0)",
2: "No gates (gates=1.0) + No scalar weights",
3: "No target (source-only) + No gates + No scalar weights",
4: "No gates (gates=1.0), keep scalars and target"
}
return descriptions.get(self.ablation_level, "Unknown ablation level")
# Convenience functions for creating specific ablation levels
def create_ablation_projector(
source_dim: int,
target_dim: int,
source_num_heads: int = 1,
target_num_heads: int = 1,
ablation_level: int = 0,
**kwargs
) -> AblationProjector:
"""Create an AblationProjector with specified ablation level."""
return AblationProjector(
source_dim=source_dim,
target_dim=target_dim,
source_num_heads=source_num_heads,
target_num_heads=target_num_heads,
ablation_level=ablation_level,
**kwargs
)
def create_full_c2c_projector(**kwargs) -> AblationProjector:
"""Create full C2C projector (ablation level 0)."""
return create_ablation_projector(ablation_level=0, **kwargs)
def create_no_scalar_projector(**kwargs) -> AblationProjector:
"""Create projector without scalar weights (ablation level 1)."""
return create_ablation_projector(ablation_level=1, **kwargs)
def create_no_gate_projector(**kwargs) -> AblationProjector:
"""Create projector without gates (ablation level 2)."""
return create_ablation_projector(ablation_level=2, **kwargs)
def create_source_only_projector(**kwargs) -> AblationProjector:
"""Create source-only projector (ablation level 3)."""
return create_ablation_projector(ablation_level=3, **kwargs)
def create_no_gate_only_projector(**kwargs) -> AblationProjector:
"""Create projector without gates but with scalar weights and target (ablation level 4)."""
return create_ablation_projector(ablation_level=4, **kwargs)