File size: 13,376 Bytes
5ccf219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
"""
Ablation Projector: A configurable projector for ablation studies based on C2CProjector.
Allows gradual removal of components to study their individual contributions.
"""

import torch
import torch.nn as nn
from torch import Tensor
from typing import Optional, Tuple, Literal

from rosetta.utils.registry import register_model, capture_init_args
from rosetta.model.projector import Projector
from rosetta.model.projector import RegularMLP


@register_model
@capture_init_args
class AblationProjector(Projector):
    """
    Ablation study projector based on C2CProjector with configurable component removal.
    
    Ablation levels:
    0. Full C2C (baseline)
    1. Remove scalar weights (set to 1.0)
    2. Remove gates (set to 1.0) 
    3. Remove target contribution (only use source)
    4. Remove gates only (gates=1.0), keep scalars and target
    
    Each level builds on the previous one, allowing gradual degradation study.
    """

    def __init__(
        self,
        source_dim: int,
        target_dim: int,
        source_num_heads: int = 1,
        target_num_heads: int = 1,
        intermediate_dim: int = 1024,
        hidden_dim: int = 1024,
        num_layers: int = 3,
        dropout: float = 0.1,
        initial_temperature: float = 1.0,
        final_temperature: float = 0.001,
        anneal_steps: int = 1929,
        dtype: torch.dtype = torch.float32,
        
        # Ablation configuration
        ablation_level: int = 0,  # 0=full, 1=no_scalar, 2=no_gate+no_scalar, 3=no_target, 4=no_gate_only
        use_scalar_weights: bool = True,  # Can be overridden by ablation_level
        use_gates: bool = True,          # Can be overridden by ablation_level  
        use_target: bool = True,         # Can be overridden by ablation_level
    ):
        super().__init__()

        assert 0 <= ablation_level <= 4, "ablation_level must be 0, 1, 2, 3, or 4"

        # Dimensions
        self.source_dim = source_dim
        self.target_dim = target_dim
        self.source_num_heads = source_num_heads
        self.target_num_heads = target_num_heads
        self.ablation_level = ablation_level

        # Override component usage based on ablation level
        if ablation_level == 4:
            # Special case: disable gates only, keep scalars and target
            use_scalar_weights = True
            use_gates = False
            use_target = True
        else:
            if ablation_level >= 1:
                use_scalar_weights = False
            if ablation_level >= 2: 
                use_gates = False
            if ablation_level >= 3:
                use_target = False
            
        self.use_scalar_weights = use_scalar_weights
        self.use_gates = use_gates
        self.use_target = use_target

        # Sizes
        in_dim = source_dim * source_num_heads
        out_dim = target_dim * target_num_heads

        # 1) concat(source_X, target_X) then project to hidden_dim
        # If not using target, only use source features
        if self.use_target:
            self.key_in = nn.Linear(in_dim + out_dim, hidden_dim, bias=True, dtype=dtype)
            self.value_in = nn.Linear(in_dim + out_dim, hidden_dim, bias=True, dtype=dtype)
        else:
            # Only use source features
            self.key_in = nn.Linear(in_dim, hidden_dim, bias=True, dtype=dtype)
            self.value_in = nn.Linear(in_dim, hidden_dim, bias=True, dtype=dtype)

        # 2) one-layer common embedding MLP to get intermediate representation (at hidden_dim)
        self.key_mlp1 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=1, dropout=dropout, dtype=dtype)
        self.value_mlp1 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=1, dropout=dropout, dtype=dtype)

        # 3a) intermediate representation β†’ (L-2)-layer MLP for weights β†’ project to head dim
        # Only build if using scalar weights
        if self.use_scalar_weights:
            self.key_scalar_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=hidden_dim, num_layers=1, dropout=dropout, dtype=dtype)
            self.value_scalar_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=hidden_dim, num_layers=1, dropout=dropout, dtype=dtype)
            self.key_scalar_head = nn.Linear(hidden_dim, target_num_heads, dtype=dtype)
            self.value_scalar_head = nn.Linear(hidden_dim, target_num_heads, dtype=dtype)

        # 3b) intermediate representation β†’ (L-2)-layer MLP for projected_X β†’ finally project hidden_dim β†’ out_dim
        self.key_proj_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=num_layers-2, dropout=dropout, dtype=dtype)
        self.value_proj_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=num_layers-2, dropout=dropout, dtype=dtype)
        self.key_proj_out = nn.Linear(hidden_dim, out_dim, bias=True, dtype=dtype)
        self.value_proj_out = nn.Linear(hidden_dim, out_dim, bias=True, dtype=dtype)

        # Scalar key/value gate parameters and temperature schedule
        # Only build if using gates
        if self.use_gates:
            self.key_gate_logit = nn.Parameter(torch.tensor(0.0, dtype=dtype))
            self.value_gate_logit = nn.Parameter(torch.tensor(0.0, dtype=dtype))
            self.use_gumbel = True
            self.register_buffer("gate_temperature", torch.tensor(initial_temperature, dtype=dtype))
            self.initial_temperature = initial_temperature
            self.final_temperature = final_temperature
            self.anneal_steps = anneal_steps

        # Temperature for weight normalization
        self.scalar_temperature = 1.0

    def update_temperature(self, step: int):
        """Update temperature using exponential annealing schedule for gates."""
        if self.use_gates:
            ratio = min(step / self.anneal_steps, 1.0)
            temp = self.initial_temperature * (self.final_temperature / self.initial_temperature) ** ratio
            self.gate_temperature.fill_(temp)

    def forward(
        self,
        source_kv: Tuple[Tensor, Tensor],
        target_kv: Tuple[Tensor, Tensor],
        position_ids: Optional[Tensor] = None,
        max_pos: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Tensor]:
        source_key, source_value = source_kv
        target_key, target_value = target_kv

        B, Hs, N, Ds = source_key.shape
        _, Ht, _, Dt = target_key.shape

        # Flatten heads
        source_key_flat = source_key.transpose(1, 2).contiguous().view(B, N, Hs * Ds)
        source_value_flat = source_value.transpose(1, 2).contiguous().view(B, N, Hs * Ds)
        target_key_flat = target_key.transpose(1, 2).contiguous().view(B, N, Ht * Dt)
        target_value_flat = target_value.transpose(1, 2).contiguous().view(B, N, Ht * Dt)

        # 1) Prepare input features based on ablation level
        if self.use_target:
            # Full C2C: concat source and target features
            key_cat = torch.cat([source_key_flat, target_key_flat], dim=-1)
            value_cat = torch.cat([source_value_flat, target_value_flat], dim=-1)
        else:
            # Ablation level 3: only use source features
            key_cat = source_key_flat
            value_cat = source_value_flat

        # 2) project to hidden dim
        key_hidden = self.key_in(key_cat)
        value_hidden = self.value_in(value_cat)

        # 3) one-layer common embedding MLP to get intermediate representation (at hidden_dim)
        key_hidden = self.key_mlp1(key_hidden)
        value_hidden = self.value_mlp1(value_hidden)

        # 4b) intermediate representation -> projected feature path
        key_proj_hidden = self.key_proj_out(self.key_proj_mlp2(key_hidden)) # (B, N, Ht * Dt)
        value_proj_hidden = self.value_proj_out(self.value_proj_mlp2(value_hidden)) # (B, N, Ht * Dt)
        projected_key = key_proj_hidden.view(B, N, Ht, Dt).transpose(1, 2) # (B, Ht, N, Dt)
        projected_value = value_proj_hidden.view(B, N, Ht, Dt).transpose(1, 2) # (B, Ht, N, Dt)

        # 4a) intermediate representation -> scalar path (if using scalar weights)
        if self.use_scalar_weights:
            key_scalar = self.key_scalar_head(self.key_scalar_mlp2(key_hidden))       # (B, N, Ht)
            value_scalar = self.value_scalar_head(self.value_scalar_mlp2(value_hidden)) # (B, N, Ht)
            key_scalar = key_scalar.permute(0, 2, 1).unsqueeze(-1)   # (B, Ht, N, 1)
            value_scalar = value_scalar.permute(0, 2, 1).unsqueeze(-1)  # (B, Ht, N, 1)
            # Normalize scalars
            norm_key_scalar = torch.sigmoid(key_scalar)
            norm_value_scalar = torch.sigmoid(value_scalar)
        else:
            # Ablation level 1+: set scalar weights to 1.0
            norm_key_scalar = torch.ones(B, Ht, N, 1, device=projected_key.device, dtype=projected_key.dtype)
            norm_value_scalar = torch.ones(B, Ht, N, 1, device=projected_value.device, dtype=projected_value.dtype)

        # Key/value gates (if using gates)
        if self.use_gates:
            key_gate_logit = self.key_gate_logit.view(1, 1, 1, 1)
            value_gate_logit = self.value_gate_logit.view(1, 1, 1, 1)
            if self.training and self.use_gumbel:
                u1 = torch.rand(B, Ht, N, 1, device=key_gate_logit.device, dtype=key_gate_logit.dtype)
                u2 = torch.rand(B, Ht, N, 1, device=value_gate_logit.device, dtype=value_gate_logit.dtype)
                g1 = -torch.log(-torch.log(u1 + 1e-20) + 1e-20)
                g2 = -torch.log(-torch.log(u2 + 1e-20) + 1e-20)
                key_gate = torch.sigmoid((key_gate_logit + g1) / self.gate_temperature)
                value_gate = torch.sigmoid((value_gate_logit + g2) / self.gate_temperature)
            else:
                key_gate = (key_gate_logit > 0).float()
                value_gate = (value_gate_logit > 0).float()
        else:
            # Gates disabled: set gates to 1.0 (always open)
            key_gate = torch.ones(B, Ht, N, 1, device=projected_key.device, dtype=projected_key.dtype)
            value_gate = torch.ones(B, Ht, N, 1, device=projected_value.device, dtype=projected_value.dtype)

        # Compute projected contribution
        projected_key_term = key_gate * norm_key_scalar * projected_key
        projected_value_term = value_gate * norm_value_scalar * projected_value

        # Compute target contribution (if using target)
        if self.use_target:
            # Full C2C: add target with projected
            output_key = target_key + projected_key_term
            output_value = target_value + projected_value_term
        else:
            # Ablation level 3: only use projected (no target)
            output_key = projected_key_term
            output_value = projected_value_term

        return output_key, output_value

    def get_ablation_info(self) -> dict:
        """Return information about current ablation configuration."""
        return {
            'ablation_level': self.ablation_level,
            'use_scalar_weights': self.use_scalar_weights,
            'use_gates': self.use_gates,
            'use_target': self.use_target,
            'description': self._get_ablation_description()
        }
    
    def _get_ablation_description(self) -> str:
        """Get human-readable description of current ablation level."""
        descriptions = {
            0: "Full C2C (baseline)",
            1: "No scalar weights (scalars=1.0)",
            2: "No gates (gates=1.0) + No scalar weights",
            3: "No target (source-only) + No gates + No scalar weights",
            4: "No gates (gates=1.0), keep scalars and target"
        }
        return descriptions.get(self.ablation_level, "Unknown ablation level")


# Convenience functions for creating specific ablation levels
def create_ablation_projector(
    source_dim: int,
    target_dim: int,
    source_num_heads: int = 1,
    target_num_heads: int = 1,
    ablation_level: int = 0,
    **kwargs
) -> AblationProjector:
    """Create an AblationProjector with specified ablation level."""
    return AblationProjector(
        source_dim=source_dim,
        target_dim=target_dim,
        source_num_heads=source_num_heads,
        target_num_heads=target_num_heads,
        ablation_level=ablation_level,
        **kwargs
    )


def create_full_c2c_projector(**kwargs) -> AblationProjector:
    """Create full C2C projector (ablation level 0)."""
    return create_ablation_projector(ablation_level=0, **kwargs)


def create_no_scalar_projector(**kwargs) -> AblationProjector:
    """Create projector without scalar weights (ablation level 1)."""
    return create_ablation_projector(ablation_level=1, **kwargs)


def create_no_gate_projector(**kwargs) -> AblationProjector:
    """Create projector without gates (ablation level 2)."""
    return create_ablation_projector(ablation_level=2, **kwargs)


def create_source_only_projector(**kwargs) -> AblationProjector:
    """Create source-only projector (ablation level 3)."""
    return create_ablation_projector(ablation_level=3, **kwargs)


def create_no_gate_only_projector(**kwargs) -> AblationProjector:
    """Create projector without gates but with scalar weights and target (ablation level 4)."""
    return create_ablation_projector(ablation_level=4, **kwargs)