mihretgold commited on
Commit
5a5235d
·
verified ·
1 Parent(s): d39892f

Upload 4 files

Browse files
Files changed (4) hide show
  1. msae/__init__.py +2 -0
  2. msae/sae.py +979 -0
  3. msae/utils.py +545 -0
  4. vocab/clip_disect_20k.txt +0 -0
msae/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # MSAE (Multiscale Sparse Autoencoder) model package
2
+ # Bundled from heirarchical/MSAE/ for deployment
msae/sae.py ADDED
@@ -0,0 +1,979 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Any
2
+ from functools import partial
3
+ import warnings
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from msae.utils import normalize_data, JumpReLUFunction, StepFunction
10
+
11
+ """
12
+ Sparse Autoencoder (SAE) Implementation
13
+
14
+ This module implements various sparse autoencoder architectures and activation functions
15
+ designed to learn interpretable features in high-dimensional data.
16
+ """
17
+
18
+
19
+ class SoftCapping(nn.Module):
20
+ """
21
+ Soft capping layer to prevent latent activations from growing excessively large.
22
+
23
+ This layer applies a scaled tanh transformation that smoothly saturates values
24
+ without hard truncation, helping stabilize training.
25
+
26
+ Args:
27
+ soft_cap (float): The scale factor for the tanh transformation
28
+ """
29
+
30
+ def __init__(self, soft_cap):
31
+ super(SoftCapping, self).__init__()
32
+ self.soft_cap = soft_cap
33
+
34
+ def forward(self, logits):
35
+ """
36
+ Apply soft capping to input values.
37
+
38
+ Args:
39
+ logits (torch.Tensor): Input tensor
40
+
41
+ Returns:
42
+ torch.Tensor: Soft-capped values with range approximately [-soft_cap, soft_cap]
43
+ """
44
+ return self.soft_cap * torch.tanh(logits / self.soft_cap)
45
+
46
+
47
+ class TopK(nn.Module):
48
+ """
49
+ Top-K activation function that only keeps the K largest activations per sample.
50
+
51
+ This activation enforces sparsity by zeroing out all but the k highest values in each
52
+ input vector. Can optionally use absolute values for selection and apply a subsequent
53
+ activation function.
54
+
55
+ Args:
56
+ k (int): Number of activations to keep
57
+ act_fn (Callable, optional): Secondary activation function to apply to the kept values.
58
+ Defaults to nn.ReLU().
59
+ use_abs (bool, optional): If True, selection is based on absolute values. Defaults to False.
60
+ """
61
+
62
+ def __init__(self, k: int, act_fn: Callable = nn.ReLU(), use_abs: bool = False) -> None:
63
+ super().__init__()
64
+ self.k = k
65
+ self.act_fn = act_fn
66
+ self.use_abs = use_abs
67
+ # print(f"Top_K used: {self.k}")
68
+
69
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
70
+ """
71
+ Forward pass that keeps only the top-k activations for each sample.
72
+
73
+ Args:
74
+ x (torch.Tensor): Input tensor of shape [batch_size, features]
75
+
76
+ Returns:
77
+ torch.Tensor: Sparse output tensor with same shape as input, where all but
78
+ the top k values (per sample) are zero
79
+ """
80
+ if self.use_abs:
81
+ x = torch.abs(x)
82
+
83
+ # Get indices of top-k values along feature dimension
84
+ _, indices = torch.topk(x, k=self.k, dim=-1)
85
+ # Gather the corresponding values from the original input
86
+ values = torch.gather(x, -1, indices)
87
+
88
+ # Apply the activation function to the selected values
89
+ activated_values = self.act_fn(values)
90
+ # Create a tensor of zeros and place the activated values at the correct positions
91
+ result = torch.zeros_like(x)
92
+ result.scatter_(-1, indices, activated_values)
93
+
94
+ # Verify sparsity constraint is met
95
+ assert (result != 0.0).sum(dim=-1).max() <= self.k
96
+ return result
97
+
98
+ def forward_eval(self, x: torch.Tensor) -> torch.Tensor:
99
+ """
100
+ Evaluation mode forward pass that doesn't enforce sparsity.
101
+
102
+ Used for computing full activations during evaluation or visualization.
103
+
104
+ Args:
105
+ x (torch.Tensor): Input tensor
106
+
107
+ Returns:
108
+ torch.Tensor: Output after applying activation function (without top-k filtering)
109
+ """
110
+ if self.use_abs:
111
+ x = torch.abs(x)
112
+
113
+ x = self.act_fn(x)
114
+ return x
115
+
116
+
117
+ class BatchTopK(TopK):
118
+ """
119
+ Batch-wide Top-K activation function that selects K largest activations across the entire batch.
120
+
121
+ Unlike standard TopK which operates per sample, this selects the k*batch_size highest
122
+ activations across all samples in the batch, potentially allowing some samples to have
123
+ more activations than others based on relative magnitudes.
124
+
125
+ Args:
126
+ k (int): Target number of activations to keep per sample (actual number may vary)
127
+ act_fn (Callable, optional): Secondary activation function. Defaults to nn.Identity().
128
+ use_abs (bool, optional): If True, selection is based on absolute values. Defaults to False.
129
+ """
130
+
131
+ def __init__(self, k: int, act_fn: Callable = nn.Identity(), use_abs: bool = False) -> None:
132
+ # Call the parent class constructor
133
+ super().__init__(k, act_fn, use_abs)
134
+
135
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
136
+ """
137
+ Forward pass that keeps the top-k activations across the entire batch.
138
+
139
+ Args:
140
+ x (torch.Tensor): Input tensor of shape [batch_size, features]
141
+
142
+ Returns:
143
+ torch.Tensor: Sparse output tensor with the same shape as input, where only
144
+ approximately k*batch_size values are non-zero across the entire batch
145
+ """
146
+ # Get batch size
147
+ batch_size = x.shape[0]
148
+
149
+ # Calculate total number of values to keep
150
+ total_k = min(self.k * batch_size, x.numel())
151
+
152
+ # Use absolute values if requested for selection
153
+ values = torch.abs(x) if self.use_abs else x
154
+
155
+ # Store original shape and flatten
156
+ flat_values = values.flatten()
157
+ flat_x = x.flatten()
158
+
159
+ # Get indices of top-k elements across the entire batch
160
+ _, indices = torch.topk(flat_values, k=total_k, dim=-1)
161
+
162
+ # Create output tensor of zeros and place original values at correct positions
163
+ flat_result = torch.zeros_like(flat_x)
164
+
165
+ # Apply activation function to selected values and place them in the result
166
+ activated_values = self.act_fn(flat_x[indices])
167
+ flat_result.scatter_(-1, indices, activated_values)
168
+
169
+ # Reshape back to original shape
170
+ result = flat_result.reshape(values.shape)
171
+
172
+ return result
173
+
174
+ def forward_eval(self, x: torch.Tensor) -> torch.Tensor:
175
+ """
176
+ Evaluation mode forward pass that doesn't enforce sparsity.
177
+
178
+ Used for computing full activations during evaluation or visualization.
179
+
180
+ Args:
181
+ x (torch.Tensor): Input tensor
182
+
183
+ Returns:
184
+ torch.Tensor: Output after applying activation function (without top-k filtering)
185
+ """
186
+ x = torch.abs(x) if self.use_abs else x
187
+ x = self.act_fn(x)
188
+ return x
189
+
190
+
191
+ class JumpReLU(nn.Module):
192
+ """
193
+ JumpReLU activation with learned thresholds.
194
+
195
+ This activation implements a soft version of a threshold-based activation function,
196
+ where values below a learned threshold are suppressed. The bandwidth parameter
197
+ controls the sharpness of the transition at the threshold.
198
+
199
+ Args:
200
+ hidden_dim (int): Dimension of the input tensor
201
+ init_threshold (float, optional): Initial threshold value. Defaults to 0.001.
202
+ bandwidth (float, optional): Controls the transition sharpness. Defaults to 0.001.
203
+ """
204
+
205
+ def __init__(self, hidden_dim: int, init_threshold: float = 0.001, bandwidth: float = 0.001) -> None:
206
+ """
207
+ Initialize JumpReLU activation with specified parameters.
208
+
209
+ Args:
210
+ hidden_dim (int): Dimension of the input tensor
211
+ init_threshold (float, optional): Initial threshold for the JUMP mechanism. Defaults to 0.001.
212
+ bandwidth (float, optional): Controls transition sharpness. Defaults to 0.001.
213
+ """
214
+ super().__init__()
215
+ self.log_threshold = nn.Parameter(torch.full((hidden_dim,), np.log(init_threshold)))
216
+ self.bandwidth = bandwidth
217
+
218
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
219
+ """
220
+ Forward pass for inference.
221
+
222
+ Applies ReLU followed by the JUMP mechanism.
223
+
224
+ Args:
225
+ x (torch.Tensor): Input tensor
226
+
227
+ Returns:
228
+ torch.Tensor: Activated tensor
229
+ """
230
+ x_relu = torch.relu(x)
231
+ out = JumpReLUFunction.apply(x_relu, self.log_threshold, self.bandwidth)
232
+ return out
233
+
234
+ def forward_train(self, x: torch.Tensor) -> torch.Tensor:
235
+ """
236
+ Forward pass used during training.
237
+
238
+ Uses a step function approximation for computing gradients.
239
+
240
+ Args:
241
+ x (torch.Tensor): Input tensor
242
+
243
+ Returns:
244
+ torch.Tensor: Activated tensor with gradient-friendly step function
245
+ """
246
+ return StepFunction.apply(x, self.log_threshold, self.bandwidth)
247
+
248
+
249
+ # Mapping of activation function names to their corresponding classes
250
+ ACTIVATIONS_CLASSES = {
251
+ "ReLU": nn.ReLU,
252
+ "JumpReLU": JumpReLU,
253
+ "Identity": nn.Identity,
254
+ "TopK": partial(TopK, act_fn=nn.Identity()),
255
+ "TopKReLU": partial(TopK, act_fn=nn.ReLU()),
256
+ "TopKabs": partial(TopK, use_abs=True, act_fn=nn.Identity()),
257
+ "TopKabsReLU": partial(TopK, use_abs=True, act_fn=nn.ReLU()),
258
+ "BatchTopK": partial(BatchTopK, act_fn=nn.Identity()),
259
+ "BatchTopKReLU": partial(BatchTopK, act_fn=nn.ReLU()),
260
+ "BatchTopKabs": partial(BatchTopK, use_abs=True, act_fn=nn.Identity()),
261
+ "BatchTopKabsReLU": partial(BatchTopK, use_abs=True, act_fn=nn.ReLU()),
262
+ }
263
+
264
+
265
+ def get_activation(activation: str) -> nn.Module:
266
+ """
267
+ Factory function to create activation function instances by name.
268
+
269
+ Handles special cases like parameterized activations (e.g., TopK_64).
270
+
271
+ Args:
272
+ activation (str): Name of the activation function, with optional parameter
273
+ (e.g., "TopKReLU_64" for TopKReLU with k=64)
274
+
275
+ Returns:
276
+ nn.Module: Instantiated activation function
277
+ """
278
+ if "_" in activation:
279
+ activation, arg = activation.split("_")
280
+ if "TopK" in activation:
281
+ return ACTIVATIONS_CLASSES[activation](k=int(arg))
282
+ elif "JumpReLU" in activation:
283
+ return ACTIVATIONS_CLASSES[activation](hidden_dim=int(arg))
284
+ return ACTIVATIONS_CLASSES[activation]()
285
+
286
+
287
+ class Autoencoder(nn.Module):
288
+ """
289
+ Sparse autoencoder base class.
290
+
291
+ Implements the standard sparse autoencoder architecture:
292
+ latents = activation(encoder(x - pre_bias) + latent_bias)
293
+ recons = decoder(latents) + pre_bias
294
+
295
+ Includes various options for controlling activation functions, weight initialization,
296
+ and feature normalization.
297
+
298
+ Attributes:
299
+ n_latents (int): Number of latent features (neurons)
300
+ n_inputs (int): Dimensionality of the input data
301
+ tied (bool): Whether decoder weights are tied to encoder weights
302
+ normalize (bool): Whether to normalize input data
303
+ encoder (nn.Parameter): Encoder weight matrix [n_inputs, n_latents]
304
+ decoder (nn.Parameter): Decoder weight matrix [n_latents, n_inputs] (if not tied)
305
+ pre_bias (nn.Parameter): Input bias/offset [n_inputs]
306
+ latent_bias (nn.Parameter): Latent bias [n_latents]
307
+ activation (nn.Module): Activation function for the latent layer
308
+ latents_activation_frequency (torch.Tensor): Tracks how often neurons activate
309
+ """
310
+
311
+ def __init__(
312
+ self,
313
+ n_latents: int,
314
+ n_inputs: int,
315
+ activation: Callable = nn.ReLU(),
316
+ tied: bool = False,
317
+ normalize: bool = False,
318
+ bias_init: torch.Tensor | float = 0.0,
319
+ init_method: str = "kaiming",
320
+ latent_soft_cap: float = 30.0,
321
+ threshold: torch.Tensor | None = None,
322
+ *args,
323
+ **kwargs,
324
+ ) -> None:
325
+ """
326
+ Initialize the sparse autoencoder.
327
+
328
+ Args:
329
+ n_latents (int): Dimension of the autoencoder latent space
330
+ n_inputs (int): Dimensionality of the original data
331
+ activation (Callable or str): Activation function or name
332
+ tied (bool, optional): Whether to tie encoder and decoder weights. Defaults to False.
333
+ normalize (bool, optional): Whether to normalize input data. Defaults to False.
334
+ bias_init (torch.Tensor | float, optional): Initial bias value. Defaults to 0.0.
335
+ init_method (str, optional): Weight initialization method. Defaults to "kaiming".
336
+ latent_soft_cap (float, optional): Soft cap value for latent activations. Defaults to 30.0.
337
+ threshold (torch.Tensor, optional): Threshold for JumpReLU. Defaults to None.
338
+ """
339
+ super().__init__()
340
+ if isinstance(activation, str):
341
+ activation = get_activation(activation)
342
+
343
+ # Store configuration
344
+ self.tied = tied
345
+ self.n_latents = n_latents
346
+ self.n_inputs = n_inputs
347
+ self.init_method = init_method
348
+ self.bias_init = bias_init
349
+ self.normalize = normalize
350
+
351
+ # Initialize parameters
352
+ self.pre_bias = nn.Parameter(torch.full((n_inputs,), bias_init) if isinstance(bias_init, float) else bias_init)
353
+ self.encoder = nn.Parameter(torch.zeros((n_inputs, n_latents)))
354
+ self.latent_bias = nn.Parameter(
355
+ torch.zeros(
356
+ n_latents,
357
+ )
358
+ )
359
+
360
+ # For tied weights, decoder is derived from encoder
361
+ if tied:
362
+ self.register_parameter("decoder", None)
363
+ else:
364
+ self.decoder = nn.Parameter(torch.zeros((n_latents, n_inputs)))
365
+
366
+ # Set up activation functions
367
+ self.latent_soft_cap = SoftCapping(latent_soft_cap) if latent_soft_cap > 0 else nn.Identity()
368
+ self.activation = activation
369
+
370
+ # Set threshold for JumpReLU if needed
371
+ if isinstance(self.activation, JumpReLU) and threshold is not None:
372
+ self.activation.log_threshold = threshold
373
+
374
+ self.dead_activations = activation
375
+
376
+ # Initialize weights
377
+ self._init_weights()
378
+
379
+ # Set up activation tracking
380
+ self.latents_activation_frequency: torch.Tensor
381
+ self.register_buffer("latents_activation_frequency", torch.zeros(n_latents, dtype=torch.int64, requires_grad=False))
382
+ self.num_updates = 0
383
+
384
+ self.dead_latents = []
385
+
386
+ def get_and_reset_stats(self) -> torch.Tensor:
387
+ """
388
+ Get activation statistics and reset the counters.
389
+
390
+ Returns:
391
+ torch.Tensor: Proportion of samples that activated each neuron
392
+ """
393
+ activations = self.latents_activation_frequency.detach().cpu().float() / self.num_updates
394
+ self.latents_activation_frequency.zero_()
395
+ self.num_updates = 0
396
+ return activations
397
+
398
+ @torch.no_grad()
399
+ def _init_weights(self, norm=0.1, neuron_indices: list[int] | None = None) -> None:
400
+ """
401
+ Initialize network weights.
402
+
403
+ Args:
404
+ norm (float, optional): Target norm for the weights. Defaults to 0.1.
405
+ neuron_indices (list[int] | None, optional): Indices of neurons to initialize.
406
+ If None, initialize all neurons.
407
+
408
+ Raises:
409
+ ValueError: If invalid initialization method is specified
410
+ """
411
+ if self.init_method not in ["kaiming", "xavier", "uniform", "normal"]:
412
+ raise ValueError(f"Invalid init_method: {self.init_method}")
413
+
414
+ # Use transposed encoder if weights are tied
415
+ if self.tied:
416
+ decoder_weight = self.encoder.t()
417
+ else:
418
+ decoder_weight = self.decoder
419
+
420
+ # Initialize with specified method
421
+ if self.init_method == "kaiming":
422
+ new_W_dec = nn.init.kaiming_uniform_(torch.zeros_like(decoder_weight), nonlinearity="relu")
423
+ elif self.init_method == "xavier":
424
+ new_W_dec = nn.init.xavier_uniform_(torch.zeros_like(decoder_weight), gain=nn.init.calculate_gain("relu"))
425
+ elif self.init_method == "uniform":
426
+ new_W_dec = nn.init.uniform_(torch.zeros_like(decoder_weight), a=-1, b=1)
427
+ elif self.init_method == "normal":
428
+ new_W_dec = nn.init.normal_(torch.zeros_like(decoder_weight))
429
+ else:
430
+ raise ValueError(f"Invalid init_method: {self.init_method}")
431
+
432
+ # Normalize to target norm
433
+ new_W_dec *= norm / new_W_dec.norm(p=2, dim=-1, keepdim=True)
434
+
435
+ # Initialize bias to zero
436
+ new_l_bias = torch.zeros_like(self.latent_bias)
437
+
438
+ # Transpose for encoder
439
+ new_W_enc = new_W_dec.t().clone()
440
+
441
+ # Apply initialization to all or specific neurons
442
+ if neuron_indices is None:
443
+ if not self.tied:
444
+ self.decoder.data = new_W_dec
445
+ self.encoder.data = new_W_enc
446
+ self.latent_bias.data = new_l_bias
447
+ else:
448
+ if not self.tied:
449
+ self.decoder.data[neuron_indices] = new_W_dec[neuron_indices]
450
+ self.encoder.data[:, neuron_indices] = new_W_enc[:, neuron_indices]
451
+ self.latent_bias.data[neuron_indices] = new_l_bias[neuron_indices]
452
+
453
+ @torch.no_grad()
454
+ def project_grads_decode(self):
455
+ """
456
+ Project out components of decoder gradient that would change its norm.
457
+
458
+ This helps maintain normalized decoder norms during training.
459
+ """
460
+ if self.tied:
461
+ weights = self.encoder.data.T
462
+ grad = self.encoder.grad.T
463
+ else:
464
+ weights = self.decoder.data
465
+ grad = self.decoder.grad
466
+
467
+ # Project out the component parallel to weights
468
+ grad_proj = (grad * weights).sum(dim=-1, keepdim=True) * weights
469
+
470
+ # Update gradients
471
+ if self.tied:
472
+ self.encoder.grad -= grad_proj.T
473
+ else:
474
+ self.decoder.grad -= grad_proj
475
+
476
+ @torch.no_grad()
477
+ def scale_to_unit_norm(self) -> None:
478
+ """
479
+ Scale decoder rows to unit norm, and adjust other parameters accordingly.
480
+
481
+ This normalization helps with feature interpretability and training stability.
482
+ """
483
+ eps = torch.finfo(self.decoder.dtype).eps
484
+
485
+ # Normalize tied or untied weights
486
+ if self.tied:
487
+ norm = self.encoder.data.T.norm(p=2, dim=-1, keepdim=True) + eps
488
+ self.encoder.data.T /= norm
489
+ else:
490
+ norm = self.decoder.data.norm(p=2, dim=-1, keepdim=True) + eps
491
+ self.decoder.data /= norm
492
+ self.encoder.data *= norm.t()
493
+
494
+ # Scale biases accordingly
495
+ self.latent_bias.data *= norm.squeeze()
496
+
497
+ # Adjust JumpReLU thresholds if present
498
+ if isinstance(self.activation, JumpReLU):
499
+ cur_threshold = torch.exp(self.activation.log_threshold.data)
500
+ self.activation.log_threshold.data = torch.log(cur_threshold * norm.squeeze())
501
+
502
+ def encode_pre_act(self, x: torch.Tensor) -> torch.Tensor:
503
+ """
504
+ Compute pre-activation latent values.
505
+
506
+ Args:
507
+ x (torch.Tensor): Input data [batch, n_inputs]
508
+
509
+ Returns:
510
+ torch.Tensor: Pre-activation latent values [batch, n_latents]
511
+ """
512
+ x = x - self.pre_bias
513
+ latents_pre_act_full = x @ self.encoder + self.latent_bias
514
+ return latents_pre_act_full
515
+
516
+ def preprocess(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]:
517
+ """
518
+ Preprocess input data, optionally normalizing it.
519
+
520
+ Args:
521
+ x (torch.Tensor): Input data [batch, n_inputs]
522
+
523
+ Returns:
524
+ tuple: (preprocessed_data, normalization_info)
525
+ - preprocessed_data: Processed input data
526
+ - normalization_info: Dict with normalization parameters (if normalize=True)
527
+ """
528
+ if not self.normalize:
529
+ return x, dict()
530
+ x_processed, mu, std = normalize_data(x)
531
+ return x_processed, dict(mu=mu, std=std)
532
+
533
+ def encode(self, x: torch.Tensor, topk_number: int | None = None) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]:
534
+ """
535
+ Encode input data to latent representations.
536
+
537
+ Args:
538
+ x (torch.Tensor): Input data [batch, n_inputs]
539
+ topk_number (int | None, optional): Number of top-k activations to keep (for inference).
540
+ Defaults to None.
541
+
542
+ Returns:
543
+ tuple: (encoded, full_encoded, info)
544
+ - encoded: Latent activations with sparsity constraints [batch, n_latents]
545
+ - full_encoded: Latent activations without sparsity (for analysis) [batch, n_latents]
546
+ - info: Normalization information dictionary
547
+ """
548
+ x, info = self.preprocess(x)
549
+ pre_encoded = self.encode_pre_act(x)
550
+ encoded = self.activation(pre_encoded)
551
+
552
+ # Get full activations (for analysis) depending on activation type
553
+ if isinstance(self.activation, TopK):
554
+ full_encoded = self.activation.forward_eval(pre_encoded)
555
+ else:
556
+ full_encoded = torch.clone(encoded)
557
+
558
+ # Apply topk filtering for inference if requested
559
+ if topk_number is not None:
560
+ _, indices = torch.topk(full_encoded, k=topk_number, dim=-1)
561
+ values = torch.gather(full_encoded, -1, indices)
562
+ full_encoded = torch.zeros_like(full_encoded)
563
+ full_encoded.scatter_(-1, indices, values)
564
+
565
+ # Apply soft capping to both outputs
566
+ caped_encoded = self.latent_soft_cap(encoded)
567
+ capped_full_encoded = self.latent_soft_cap(full_encoded)
568
+
569
+ return caped_encoded, capped_full_encoded, info
570
+
571
+ def decode(self, latents: torch.Tensor, info: dict[str, Any] | None = None) -> torch.Tensor:
572
+ """
573
+ Decode latent representations to reconstructed inputs.
574
+
575
+ Args:
576
+ latents (torch.Tensor): Latent activations [batch, n_latents]
577
+ info (dict[str, Any] | None, optional): Normalization information. Defaults to None.
578
+
579
+ Returns:
580
+ torch.Tensor: Reconstructed input data [batch, n_inputs]
581
+ """
582
+ # Decode using tied or untied weights
583
+ if self.tied:
584
+ ret = latents @ self.encoder.t() + self.pre_bias
585
+ else:
586
+ ret = latents @ self.decoder + self.pre_bias
587
+
588
+ # Denormalize if needed
589
+ if self.normalize:
590
+ assert info is not None
591
+ ret = ret * info["std"] + info["mu"]
592
+
593
+ return ret
594
+
595
+ @torch.no_grad()
596
+ def update_latent_statistics(self, latents: torch.Tensor) -> None:
597
+ """
598
+ Update statistics on latent activations.
599
+
600
+ Args:
601
+ latents (torch.Tensor): Latent activations [batch, n_latents]
602
+ """
603
+ self.num_updates += latents.shape[0]
604
+ current_activation_frequency = (latents != 0).to(torch.int64).sum(dim=0)
605
+ self.latents_activation_frequency += current_activation_frequency
606
+
607
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
608
+ """
609
+ Forward pass through the autoencoder.
610
+
611
+ Args:
612
+ x (torch.Tensor): Input data [batch, n_inputs]
613
+
614
+ Returns:
615
+ tuple: (recons, latents, all_recons, all_latents)
616
+ - recons: Reconstructed data [batch, n_inputs]
617
+ - latents: Latent activations [batch, n_latents]
618
+ - all_recons: Reconstructed data without sparsity constraints (for analysis)
619
+ - all_latents: Latent activations without sparsity constraints (for analysis)
620
+ """
621
+ # Preprocess data
622
+ x_processed, info = self.preprocess(x)
623
+
624
+ # Compute pre-activations
625
+ latents_pre_act = self.encode_pre_act(x_processed)
626
+
627
+ # Apply activation function
628
+ latents = self.activation(latents_pre_act)
629
+ latents_caped = self.latent_soft_cap(latents)
630
+
631
+ # Decode to reconstruction
632
+ recons = self.decode(latents_caped, info)
633
+
634
+ # Update activation statistics
635
+ self.update_latent_statistics(latents_caped)
636
+
637
+ # Handle different activation function types for analysis outputs
638
+ if isinstance(self.activation, TopK):
639
+ # For TopK, return both sparse and full activations
640
+ all_latents = self.activation.forward_eval(latents_pre_act)
641
+ all_latents_caped = self.latent_soft_cap(all_latents)
642
+ all_recons = self.decode(all_latents_caped, info)
643
+ return recons, latents_caped, all_recons, all_latents_caped
644
+ elif isinstance(self.activation, JumpReLU) and self.training:
645
+ # For JumpReLU in training mode, use special training activations
646
+ loss_latents = self.activation.forward_train(latents)
647
+ return recons, loss_latents, recons, latents_caped
648
+ else:
649
+ # For other activations, return the same for both
650
+ return recons, latents_caped, recons, latents_caped
651
+
652
+
653
+ class MatryoshkaAutoencoder(Autoencoder):
654
+ """
655
+ Matryoshka Sparse Autoencoder.
656
+
657
+ This extends the base Autoencoder with a nested structure of latent representations,
658
+ where different numbers of features can be used depending on computational budget
659
+ or desired level of detail.
660
+
661
+ The model uses multiple TopK activations with different k values and maintains
662
+ relative importance weights for each level of the hierarchy.
663
+ """
664
+
665
+ def __init__(
666
+ self,
667
+ n_latents: int,
668
+ n_inputs: int,
669
+ activation: str = "TopKReLU",
670
+ tied: bool = False,
671
+ normalize: bool = False,
672
+ bias_init: torch.Tensor | float = 0.0,
673
+ init_method: str = "kaiming",
674
+ latent_soft_cap: float = 30.0,
675
+ nesting_list: list[int] = [16, 32],
676
+ relative_importance: list[float] | None = None,
677
+ *args,
678
+ **kwargs,
679
+ ) -> None:
680
+ """
681
+ Initialize the Matryoshka Sparse Autoencoder.
682
+
683
+ Args:
684
+ n_latents (int): Dimension of the autoencoder latent space
685
+ n_inputs (int): Dimensionality of the original data
686
+ activation (str, optional): Base activation function name. Defaults to "TopKReLU".
687
+ tied (bool, optional): Whether to tie encoder and decoder weights. Defaults to False.
688
+ normalize (bool, optional): Whether to normalize input data. Defaults to False.
689
+ bias_init (torch.Tensor | float, optional): Initial bias value. Defaults to 0.0.
690
+ init_method (str, optional): Weight initialization method. Defaults to "kaiming".
691
+ latent_soft_cap (float, optional): Soft cap value for latent activations. Defaults to 30.0.
692
+ nesting_list (list[int], optional): List of k values for nested representations. Defaults to [16, 32].
693
+ relative_importance (list[float] | None, optional): Importance weights for each nesting level.
694
+ Defaults to equal weights.
695
+ """
696
+ # Initialize nesting hierarchy
697
+ self.nesting_list = sorted(nesting_list)
698
+ self.relative_importance = relative_importance if relative_importance is not None else [1.0] * len(nesting_list)
699
+ assert len(self.relative_importance) == len(self.nesting_list)
700
+
701
+ # Ensure activation is TopK-based
702
+ if "TopK" not in activation:
703
+ warnings.warn(f"MatryoshkaAutoencoder: activation {activation} is not a TopK activation. We are changing it to TopKReLU")
704
+ activation = "TopKReLU"
705
+
706
+ # Initialize with base activation
707
+ base_activation = activation + f"_{self.nesting_list[0]}"
708
+ super().__init__(n_latents, n_inputs, base_activation, tied, normalize, bias_init, init_method, latent_soft_cap)
709
+
710
+ # Create multiple activations with different k values
711
+ self.activation = nn.ModuleList([get_activation(activation + f"_{nesting}") for nesting in self.nesting_list])
712
+
713
+ def encode(self, x: torch.Tensor, topk_number: int | None = None) -> tuple[list[torch.Tensor], torch.Tensor, dict[str, Any]]:
714
+ """
715
+ Encode input data to multiple latent representations with different sparsity levels.
716
+
717
+ Args:
718
+ x (torch.Tensor): Input data [batch, n_inputs]
719
+ topk_number (int | None, optional): Number of top-k activations to keep (for inference).
720
+ Defaults to None.
721
+
722
+ Returns:
723
+ tuple: (encoded_list, last_encoded, info)
724
+ - encoded_list: List of latent activations with different sparsity levels
725
+ - last_encoded: The least sparse latent activations (from largest k value)
726
+ - info: Normalization information dictionary
727
+ """
728
+ x, info = self.preprocess(x)
729
+ pre_encoded = self.encode_pre_act(x)
730
+
731
+ # Apply each activation function in the hierarchy
732
+ encoded = [activation(pre_encoded) for activation in self.activation]
733
+ caped_encoded = [self.latent_soft_cap(enc) for enc in encoded]
734
+
735
+ # Apply additional top-k filtering for inference if requested
736
+ if topk_number is not None:
737
+ last_encoded = caped_encoded[-1]
738
+ _, indices = torch.topk(last_encoded, k=topk_number, dim=-1)
739
+ values = torch.gather(last_encoded, -1, indices)
740
+ last_encoded = torch.zeros_like(last_encoded)
741
+ last_encoded.scatter_(-1, indices, values)
742
+ else:
743
+ last_encoded = caped_encoded[-1]
744
+
745
+ return caped_encoded, last_encoded, info
746
+
747
+ def decode(self, latents: list[torch.Tensor], info: dict[str, Any] | None = None) -> list[torch.Tensor]:
748
+ """
749
+ Decode multiple latent representations to reconstructions.
750
+
751
+ Args:
752
+ latents (list[torch.Tensor]): List of latent activations at different sparsity levels
753
+ info (dict[str, Any] | None, optional): Normalization information. Defaults to None.
754
+
755
+ Returns:
756
+ list[torch.Tensor]: List of reconstructed inputs at different sparsity levels
757
+ """
758
+ # Decode each latent representation
759
+ if self.tied:
760
+ ret = [latent @ self.encoder.t() + self.pre_bias for latent in latents]
761
+ else:
762
+ ret = [latent @ self.decoder + self.pre_bias for latent in latents]
763
+
764
+ # Denormalize if needed
765
+ if self.normalize:
766
+ assert info is not None
767
+ ret = [re * info["std"] + info["mu"] for re in ret]
768
+
769
+ return ret
770
+
771
+ def forward(self, x: torch.Tensor) -> tuple[list[torch.Tensor], list[torch.Tensor], torch.Tensor, torch.Tensor]:
772
+ """
773
+ Forward pass through the Matryoshka autoencoder.
774
+
775
+ Args:
776
+ x (torch.Tensor): Input data [batch, n_inputs]
777
+
778
+ Returns:
779
+ tuple: (recons_list, latents_list, final_recon, final_latent)
780
+ - recons_list: List of reconstructions at different sparsity levels
781
+ - latents_list: List of latent activations at different sparsity levels
782
+ - final_recon: Reconstruction from the largest k value
783
+ - final_latent: Latent activations from the largest k value
784
+ """
785
+ # Preprocess data
786
+ x_processed, info = self.preprocess(x)
787
+ latents_pre_act = self.encode_pre_act(x_processed)
788
+
789
+ # Apply each activation in the hierarchy
790
+ latents = [activation(latents_pre_act) for activation in self.activation]
791
+ assert len(latents) == len(self.activation)
792
+ latents_caped = [self.latent_soft_cap(latent) for latent in latents]
793
+
794
+ # Decode each level
795
+ recons = self.decode(latents_caped, info)
796
+ assert len(recons) == len(latents)
797
+
798
+ # Update activation statistics using the largest k
799
+ self.update_latent_statistics(latents_caped[-1])
800
+
801
+ # Get full activations for analysis
802
+ all_latents = self.activation[0].forward_eval(latents_pre_act)
803
+ all_latents_caped = self.latent_soft_cap(all_latents)
804
+ all_recons = self.decode([all_latents_caped], info)[0]
805
+
806
+ # Return all reconstructions and the final ones
807
+ return recons, latents_caped, all_recons, all_latents_caped
808
+
809
+
810
+ def load_model(path):
811
+ """
812
+ Load a saved sparse autoencoder model from a file.
813
+
814
+ This function parses the filename to extract model configuration parameters
815
+ and then loads the saved model weights.
816
+
817
+ Args:
818
+ path (str): Path to the saved model file (.pt)
819
+
820
+ Returns:
821
+ tuple: (model, data_mean_center, data_normalized, scaling_factor)
822
+ - model: The loaded Autoencoder model
823
+ - mean_center: Boolean indicating if data was mean-centered
824
+ - target_norm: Target normalization factor for the data
825
+ """
826
+ # Extract configuration from filename
827
+ path_head = path.split("/")[-1]
828
+ path_name = path_head[: path_head.find(".pt")]
829
+ path_name_spited = path_name.split("_")
830
+
831
+ n_latents = int(path_name_spited.pop(0))
832
+ n_inputs = int(path_name_spited.pop(0))
833
+ activation = path_name_spited.pop(0)
834
+ if "TopK" in activation:
835
+ activation += "_" + path_name_spited.pop(0)
836
+ elif "ReLU" == activation:
837
+ path_name_spited.pop(0)
838
+ if "UW" in path_name_spited[0] or "RW" in path_name_spited[0]:
839
+ path_name_spited.pop(0)
840
+ tied = False if path_name_spited.pop(0) == "False" else True
841
+ normalize = False if path_name_spited.pop(0) == "False" else True
842
+ latent_soft_cap = float(path_name_spited.pop(0))
843
+
844
+ # Create and load the model
845
+ model = Autoencoder(n_latents, n_inputs, activation, tied=tied, normalize=normalize, latent_soft_cap=latent_soft_cap)
846
+ map_location = "cuda" if torch.cuda.is_available() else "cpu"
847
+ # PyTorch 2.6+ defaults `weights_only=True` in torch.load, which can fail
848
+ # on checkpoints containing numpy scalars (common in older training code).
849
+ # This repo's checkpoints are expected to be trusted model artifacts.
850
+ try:
851
+ model_state_dict = torch.load(path, map_location=map_location, weights_only=False)
852
+ except TypeError:
853
+ # Older PyTorch versions don't support `weights_only`.
854
+ model_state_dict = torch.load(path, map_location=map_location)
855
+ model.load_state_dict(model_state_dict["model"])
856
+ mean_center = model_state_dict["mean_center"]
857
+ scaling_factor = model_state_dict["scaling_factor"]
858
+ target_norm = model_state_dict["target_norm"]
859
+ return model, mean_center, scaling_factor, target_norm
860
+
861
+
862
+ class SAE(nn.Module):
863
+ def __init__(self, path: str) -> None:
864
+ """
865
+ Initialize the Sparse Autoencoder (SAE) model.
866
+
867
+ Args:
868
+ path (str): Path to the saved model file (.pt)
869
+ """
870
+ super().__init__()
871
+ self.model, mean, scaling_factor, _ = load_model(path)
872
+ self.register_buffer("mean", mean.clone().detach() if isinstance(mean, torch.Tensor) else torch.tensor(mean))
873
+ self.register_buffer("scaling_factor", torch.tensor(scaling_factor))
874
+
875
+ @property
876
+ def input_dim(self) -> int:
877
+ """Return input dimension of the model."""
878
+ return self.model.n_inputs
879
+
880
+ @property
881
+ def latent_dim(self) -> int:
882
+ """Return latent dimension of the model."""
883
+ return self.model.n_latents
884
+
885
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
886
+ """
887
+ Preprocess input data (mean-centering and scaling).
888
+
889
+ Args:
890
+ x: Input tensor
891
+
892
+ Returns:
893
+ Preprocessed tensor
894
+ """
895
+ # Mean-center and scale the input
896
+ x = (x - self.mean) * self.scaling_factor
897
+ return x
898
+
899
+ def postprocess(self, x: torch.Tensor) -> torch.Tensor:
900
+ """
901
+ Post-process output data (denormalization).
902
+
903
+ Args:
904
+ x: Output tensor
905
+
906
+ Returns:
907
+ Denormalized tensor
908
+ """
909
+ # Rescale and mean-center the output
910
+ x = x / self.scaling_factor + self.mean
911
+ return x
912
+
913
+ def encode(self, x: torch.Tensor, topk: int = -1) -> tuple[torch.Tensor, torch.Tensor]:
914
+ """
915
+ Encode input data to latent representation.
916
+
917
+ Args:
918
+ x: Input tensor
919
+ topk (int, optional): Number of top-k activations to keep. Defaults to -1 (no sparsity).
920
+
921
+ Returns:
922
+ Encoded latents and full latents
923
+ """
924
+ # Preprocess input
925
+ x = self.preprocess(x)
926
+
927
+ # Validate topk constrain
928
+ if topk > 0 and topk < self.model.n_latents:
929
+ topk_number = topk
930
+ else:
931
+ topk_number = None
932
+
933
+ # Encode using the model
934
+ latents, full_latents, _ = self.model.encode(x, topk_number=topk_number)
935
+
936
+ return latents, full_latents
937
+
938
+ def decode(self, latents: torch.Tensor) -> torch.Tensor:
939
+ """
940
+ Decode latent representation to input space.
941
+
942
+ Args:
943
+ latents: Latent tensor
944
+
945
+ Returns:
946
+ Reconstructed input tensor
947
+ """
948
+ # Decode using the model
949
+ reconstructed = self.model.decode(latents)
950
+
951
+ # Postprocess output
952
+ reconstructed = self.postprocess(reconstructed)
953
+
954
+ return reconstructed
955
+
956
+ def forward(self, x: torch.Tensor, topk: int = -1) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
957
+ """
958
+ Forward pass through the SAE.
959
+
960
+ Args:
961
+ x: Input tensor
962
+ topk (int, optional): Number of top-k activations to keep. Defaults to -1 (no sparsity).
963
+
964
+ Returns:
965
+ - Post-processed reconstructed tensor
966
+ - Reconstructed tensor
967
+ - Full latent activations
968
+ """
969
+ # Encode to latent space
970
+ _, full_latents = self.encode(x, topk=topk)
971
+
972
+ # Decode back to input space
973
+ reconstructed = self.model.decode(full_latents)
974
+
975
+ # Postprocess output
976
+ post_reconstructed = self.postprocess(reconstructed)
977
+
978
+ # Return reconstructed, post_reconstructed, full_latents
979
+ return post_reconstructed, reconstructed, full_latents
msae/utils.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from tqdm import tqdm
4
+
5
+ import torch
6
+ import random
7
+ import numpy as np
8
+
9
+ """
10
+ Sparse Autoencoder (SAE) Utilities
11
+
12
+ This module provides utility functions and classes for training and using
13
+ Sparse Autoencoders, including dataset handling, learning rate schedulers,
14
+ custom activation functions, and various mathematical operations.
15
+ """
16
+
17
+ class SAEDataset(torch.utils.data.Dataset):
18
+ """
19
+ Memory-efficient dataset implementation for Sparse Autoencoders.
20
+
21
+ This class loads data from memory-mapped numpy arrays to efficiently handle
22
+ large datasets without loading everything into memory at once. It also
23
+ handles preprocessing like mean centering and normalization.
24
+
25
+ The class automatically parses dataset dimensions from the filename,
26
+ which is expected to contain the data shape as the last two underscored
27
+ components (e.g., "dataset_name_10000_768.npy" for 10000 vectors of size 768).
28
+
29
+ Args:
30
+ data_path (str): Path to the memory-mapped numpy array file
31
+ dtype (torch.dtype, optional): Data type for tensors. Defaults to torch.float32.
32
+ mean_center (bool, optional): Whether to center the data by subtracting the mean.
33
+ Defaults to False.
34
+ target_norm (float, optional): Target norm for normalization. If None, uses sqrt(vector_size).
35
+ If 0.0, no normalization is applied. Defaults to None.
36
+ """
37
+ def __init__(self, data_path: str, dtype: torch.dtype = torch.float32, mean_center: bool = False, target_norm: float = None):
38
+ # Parse vector dimensions from filename
39
+ parts = data_path.split("/")[-1].split(".")[0].split("_")
40
+ self.len, self.vector_size = map(int, parts[-2:])
41
+
42
+ # Set core attributes
43
+ self.dtype = dtype
44
+ self.data = np.memmap(data_path, dtype="float32", mode="r",
45
+ shape=(self.len, self.vector_size))
46
+
47
+ # Special case for representation files (already preprocessed)
48
+ if "repr" in data_path:
49
+ self.mean = torch.zeros(self.vector_size, dtype=dtype)
50
+ self.mean_center = False
51
+ self.scaling_factor = 1.0
52
+ return
53
+
54
+ # Set preprocessing configuration
55
+ self.mean_center = mean_center
56
+ self.target_norm = np.sqrt(self.vector_size) if target_norm is None else target_norm
57
+
58
+ # Compute statistics if needed
59
+ if self.mean_center or self.target_norm != 0.0:
60
+ self._compute_statistics()
61
+ else:
62
+ self.mean = torch.zeros(self.vector_size, dtype=dtype)
63
+ self.scaling_factor = 1.0
64
+
65
+ def _compute_statistics(self, batch_size: int = 10000):
66
+ """
67
+ Compute dataset statistics (mean and scaling factor) in memory-efficient batches.
68
+
69
+ Args:
70
+ batch_size (int, optional): Number of samples to process at once. Defaults to 10000.
71
+ """
72
+ # Compute mean if mean centering is enabled
73
+ if self.mean_center:
74
+ mean_acc = np.zeros(self.vector_size, dtype=np.float32)
75
+ total = 0
76
+
77
+ for start in range(0, self.len, batch_size):
78
+ end = min(start + batch_size, self.len)
79
+ batch = self.data[start:end].copy()
80
+ mean_acc += np.sum(batch, axis=0)
81
+ total += (end - start)
82
+
83
+ self.mean = torch.from_numpy(mean_acc / total).to(self.dtype)
84
+ else:
85
+ self.mean = torch.zeros(self.vector_size, dtype=self.dtype)
86
+
87
+ # Compute scaling factor if normalization is enabled
88
+ if self.target_norm != 0.0:
89
+ squared_norm_sum = 0.0
90
+ total = 0
91
+
92
+ for start in range(0, self.len, batch_size):
93
+ end = min(start + batch_size, self.len)
94
+ batch = self.data[start:end].copy()
95
+ # Center the batch if needed
96
+ batch = batch - self.mean.numpy()
97
+ squared_norm_sum += np.sum(np.square(batch))
98
+ total += (end - start)
99
+
100
+ avg_squared_norm = squared_norm_sum / total
101
+ self.scaling_factor = float(self.target_norm / np.sqrt(avg_squared_norm))
102
+ else:
103
+ self.scaling_factor = 1.0
104
+
105
+ def __len__(self):
106
+ """Return the number of samples in the dataset."""
107
+ return self.len
108
+
109
+ def process_data(self, data: torch.Tensor) -> torch.Tensor:
110
+ """
111
+ Process data for the autoencoder (subtract mean and apply scaling).
112
+
113
+ Args:
114
+ data (torch.Tensor): Input data tensor
115
+
116
+ Returns:
117
+ torch.Tensor: Processed data tensor
118
+ """
119
+ data.sub_(self.mean)
120
+ data.mul_(self.scaling_factor)
121
+
122
+ return data
123
+
124
+ def unprocess_data(self, data: torch.Tensor) -> torch.Tensor:
125
+ """
126
+ Reverse the processing of data (apply inverse scaling and add mean).
127
+
128
+ Args:
129
+ data (torch.Tensor): Input data tensor
130
+
131
+ Returns:
132
+ torch.Tensor: Unprocessed data tensor
133
+ """
134
+ data.div_(self.scaling_factor)
135
+ data.add_(self.mean)
136
+
137
+ return data
138
+
139
+ @torch.no_grad()
140
+ def __getitem__(self, idx):
141
+ """
142
+ Get a preprocessed data sample at the specified index.
143
+
144
+ Args:
145
+ idx (int): Index of the sample to retrieve
146
+
147
+ Returns:
148
+ torch.Tensor: Preprocessed data sample
149
+ """
150
+ torch_data = torch.tensor(self.data[idx])
151
+ output = self.process_data(torch_data.clone())
152
+ return output.to(self.dtype)
153
+
154
+
155
+ class LinearDecayLR(torch.optim.lr_scheduler.LambdaLR):
156
+ """
157
+ Learning rate scheduler with a constant phase followed by linear decay.
158
+
159
+ The learning rate remains constant for a specified fraction of total epochs,
160
+ then decays linearly to zero for the remaining epochs.
161
+
162
+ Args:
163
+ optimizer (torch.optim.Optimizer): The optimizer to adjust
164
+ total_epochs (int): Total number of training epochs
165
+ decay_time (float, optional): Fraction of total epochs before decay starts.
166
+ Defaults to 0.8 (80% of training).
167
+ last_epoch (int, optional): The index of the last epoch. Defaults to -1.
168
+ """
169
+ def __init__(self, optimizer, total_epochs, decay_time = 0.8, last_epoch=-1):
170
+ def lr_lambda(epoch):
171
+ if epoch < int(decay_time * total_epochs):
172
+ return 1.0
173
+ return max(0.0, (total_epochs - epoch) / ((1-decay_time) * total_epochs))
174
+
175
+ super().__init__(optimizer, lr_lambda, last_epoch)
176
+
177
+
178
+ class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):
179
+ """
180
+ Learning rate scheduler with warmup and cosine annealing.
181
+
182
+ This scheduler implements:
183
+ 1. Linear warmup from initial_lr (max_lr * final_lr_factor) to max_lr during the warmup epoch
184
+ 2. Cosine annealing from max_lr to final_lr (max_lr * final_lr_factor) for the remaining epochs
185
+
186
+ Args:
187
+ optimizer (torch.optim.Optimizer): The optimizer to adjust
188
+ max_lr (float): Maximum learning rate after warmup
189
+ total_epochs (int): Total number of training epochs
190
+ warmup_epoch (int, optional): Number of warmup epochs. Defaults to 1.
191
+ final_lr_factor (float, optional): Ratio of final LR to max LR. Defaults to 0.1.
192
+ last_epoch (int, optional): The index of the last epoch. Defaults to -1.
193
+ """
194
+ def __init__(self, optimizer, max_lr, total_epochs, warmup_epoch=1,
195
+ final_lr_factor=0.1, last_epoch=-1):
196
+ self.max_lr = max_lr
197
+ self.total_epochs = total_epochs
198
+ self.warmup_epoch = warmup_epoch
199
+ self.initial_lr = max_lr * final_lr_factor
200
+ self.final_lr = max_lr * final_lr_factor
201
+ super().__init__(optimizer, last_epoch)
202
+
203
+ def get_lr(self):
204
+ """
205
+ Calculate the learning rate for the current epoch.
206
+
207
+ Returns:
208
+ list: Learning rates for each parameter group
209
+ """
210
+ if not self._get_lr_called_within_step:
211
+ warnings.warn("To get the last learning rate computed by the scheduler, "
212
+ "please use `get_last_lr()`.")
213
+
214
+ # During warmup (first epoch)
215
+ if self.last_epoch < self.warmup_epoch:
216
+ # Linear interpolation from initial_lr to max_lr
217
+ alpha = self.last_epoch / self.warmup_epoch
218
+ return [self.initial_lr + (self.max_lr - self.initial_lr) * alpha
219
+ for _ in self.base_lrs]
220
+
221
+ # After warmup - Cosine annealing
222
+ else:
223
+ # Adjust epoch count to start cosine annealing after warmup
224
+ current = self.last_epoch - self.warmup_epoch
225
+ total = self.total_epochs - self.warmup_epoch
226
+
227
+ # Implement cosine annealing
228
+ cosine_factor = (1 + math.cos(math.pi * current / total)) / 2
229
+ return [self.final_lr + (self.max_lr - self.final_lr) * cosine_factor
230
+ for _ in self.base_lrs]
231
+
232
+
233
+ def set_seed(seed: int) -> None:
234
+ """
235
+ Set random seeds for reproducibility across all random number generators.
236
+
237
+ Args:
238
+ seed (int): The seed value to use
239
+ """
240
+ random.seed(seed)
241
+ np.random.seed(seed)
242
+ torch.manual_seed(seed)
243
+ torch.cuda.manual_seed_all(seed)
244
+
245
+
246
+ def get_device() -> torch.device:
247
+ """
248
+ Determine the best available device for PyTorch computation.
249
+
250
+ Returns:
251
+ torch.device: The selected device (CUDA if available, MPS on Apple Silicon, CPU otherwise)
252
+ """
253
+ device = "cpu"
254
+ if torch.cuda.is_available():
255
+ device = "cuda"
256
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
257
+ device = "mps"
258
+ return torch.device(device)
259
+
260
+
261
+ def normalize_data(x: torch.Tensor, eps: float = 1e-5) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
262
+ """
263
+ Normalize input data to zero mean and unit variance.
264
+
265
+ Args:
266
+ x (torch.Tensor): Input tensor to normalize
267
+ eps (float, optional): Small constant for numerical stability. Defaults to 1e-5.
268
+
269
+ Returns:
270
+ tuple: (normalized_data, mean, std)
271
+ - normalized_data: Data normalized to zero mean and unit variance
272
+ - mean: Mean of the original data (for denormalization)
273
+ - std: Standard deviation of the original data (for denormalization)
274
+ """
275
+ mu = x.mean(dim=-1, keepdim=True)
276
+ x = x - mu
277
+ std = x.std(dim=-1, keepdim=True)
278
+ x = x / (std + eps)
279
+ return x, mu, std
280
+
281
+
282
+ @torch.no_grad()
283
+ def geometric_median(dataset: torch.utils.data.Dataset, eps: float = 1e-5,
284
+ device: torch.device = torch.device("cpu"),
285
+ max_number: int = 925117, max_iter: int = 1000) -> torch.Tensor:
286
+ """
287
+ Compute the geometric median of a dataset using Weiszfeld's algorithm.
288
+
289
+ The geometric median is a generalization of the median to multiple dimensions
290
+ and is robust to outliers. This implementation uses iterative approximation
291
+ with early stopping based on convergence.
292
+
293
+ Args:
294
+ dataset (torch.utils.data.Dataset): The dataset to compute median for
295
+ eps (float, optional): Convergence threshold. Defaults to 1e-5.
296
+ device (torch.device, optional): Computation device. Defaults to CPU.
297
+ max_number (int, optional): Maximum number of samples to use. Defaults to 925117.
298
+ max_iter (int, optional): Maximum number of iterations. Defaults to 1000.
299
+
300
+ Returns:
301
+ torch.Tensor: The geometric median vector
302
+ """
303
+ # Sample a subset of the dataset if it's large
304
+ indices = torch.randperm(len(dataset))[:min(len(dataset), max_number)]
305
+ X = dataset[indices]
306
+
307
+ # Move data to device
308
+ try:
309
+ X = X.to(device)
310
+ except Exception as e:
311
+ warnings.warn(f"Error moving dataset to device: {device}, using default device {X.device}")
312
+
313
+ # Initialize with arithmetic mean
314
+ y = torch.mean(X, dim=0)
315
+ progress_bar = tqdm(range(max_iter), desc="Geometric Median Iteration", leave=False)
316
+
317
+ # Weiszfeld's algorithm
318
+ for _ in progress_bar:
319
+ # Compute distances to current estimate
320
+ D = torch.norm(X - y, dim=1)
321
+ nonzeros = (D != 0) # Avoid division by zero
322
+
323
+ # Compute weights for non-zero distances
324
+ Dinv = 1 / D[nonzeros]
325
+ Dinv_sum = torch.sum(Dinv)
326
+ W = Dinv / Dinv_sum
327
+
328
+ # Weighted average of points
329
+ T = torch.sum(W.view(-1, 1) * X[nonzeros], dim=0)
330
+
331
+ # Handle special case when some points equal the current estimate
332
+ num_zeros = len(X) - torch.sum(nonzeros)
333
+ if num_zeros == 0:
334
+ # No points equal the current estimate
335
+ y1 = T
336
+ else:
337
+ # Some points equal the current estimate
338
+ R = T * Dinv_sum / (Dinv_sum - num_zeros)
339
+ r = torch.norm(R - y)
340
+ progress_bar.set_postfix({"r": r.item()})
341
+ if r < eps:
342
+ return y
343
+ y1 = R
344
+
345
+ # Check convergence
346
+ if torch.norm(y - y1) < eps:
347
+ return y1
348
+
349
+ y = y1
350
+
351
+ # Return best estimate after max iterations
352
+ return y
353
+
354
+
355
+ def calculate_vector_mean(dataset: torch.utils.data.Dataset,
356
+ batch_size: int = 10000,
357
+ num_workers: int = 4) -> torch.Tensor:
358
+ """
359
+ Efficiently calculate the mean of vectors in a dataset.
360
+
361
+ This function processes the dataset in batches to handle large datasets
362
+ that might not fit in memory all at once.
363
+
364
+ Args:
365
+ dataset (torch.utils.data.Dataset): Dataset containing vectors
366
+ batch_size (int, optional): Batch size for processing. Defaults to 10000.
367
+ num_workers (int, optional): Number of worker processes for data loading. Defaults to 4.
368
+
369
+ Returns:
370
+ torch.Tensor: Mean vector of the dataset
371
+ """
372
+ # Use DataLoader to efficiently iterate through the dataset
373
+ dataloader = torch.utils.data.DataLoader(
374
+ dataset,
375
+ batch_size=batch_size,
376
+ num_workers=num_workers,
377
+ shuffle=False # No need to shuffle for calculating mean
378
+ )
379
+
380
+ # Initialize sum and count
381
+ vector_sum = torch.zeros_like(dataset[0])
382
+ count = 0
383
+
384
+ # Iterate through batches
385
+ for batch in tqdm(dataloader, desc="Calculating Mean Vector", leave=False):
386
+ batch_count = batch.size(0)
387
+ vector_sum += batch.sum(dim=0)
388
+ count += batch_count
389
+
390
+ # Calculate mean
391
+ mean_vector = vector_sum / count
392
+
393
+ return mean_vector
394
+
395
+
396
+ class RectangleFunction(torch.autograd.Function):
397
+ """
398
+ Custom autograd function that implements a rectangle function.
399
+
400
+ This function outputs 1.0 for inputs between -0.5 and 0.5, and 0.0 elsewhere.
401
+ The gradient is non-zero only within this interval.
402
+
403
+ Used as a building block for other activation functions with custom gradients.
404
+ """
405
+ @staticmethod
406
+ def forward(ctx, x):
407
+ """
408
+ Forward pass of the rectangle function.
409
+
410
+ Args:
411
+ ctx: Context for saving variables for backward
412
+ x (torch.Tensor): Input tensor
413
+
414
+ Returns:
415
+ torch.Tensor: Output tensor with values in {0.0, 1.0}
416
+ """
417
+ ctx.save_for_backward(x)
418
+ return ((x > -0.5) & (x < 0.5)).float()
419
+
420
+ @staticmethod
421
+ def backward(ctx, grad_output):
422
+ """
423
+ Backward pass of the rectangle function.
424
+
425
+ Args:
426
+ ctx: Context with saved variables
427
+ grad_output (torch.Tensor): Gradient from subsequent layers
428
+
429
+ Returns:
430
+ torch.Tensor: Gradient with respect to input
431
+ """
432
+ (x,) = ctx.saved_tensors
433
+ grad_input = grad_output.clone()
434
+ grad_input[(x <= -0.5) | (x >= 0.5)] = 0
435
+ return grad_input
436
+
437
+
438
+ class JumpReLUFunction(torch.autograd.Function):
439
+ """
440
+ Custom autograd function implementing a thresholded ReLU with learnable threshold.
441
+
442
+ This activation function passes values through only if they exceed a learned threshold.
443
+ It has custom gradients for both the input and the threshold parameter.
444
+ """
445
+ @staticmethod
446
+ def forward(ctx, x, log_threshold, bandwidth):
447
+ """
448
+ Forward pass of the JumpReLU function.
449
+
450
+ Args:
451
+ ctx: Context for saving variables for backward
452
+ x (torch.Tensor): Input tensor
453
+ log_threshold (torch.Tensor): Log of the threshold value (learned parameter)
454
+ bandwidth (float): Bandwidth parameter for gradient approximation
455
+
456
+ Returns:
457
+ torch.Tensor: Output tensor
458
+ """
459
+ ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth))
460
+ threshold = torch.exp(log_threshold)
461
+ return x * (x > threshold).float()
462
+
463
+ @staticmethod
464
+ def backward(ctx, grad_output):
465
+ """
466
+ Backward pass of the JumpReLU function.
467
+
468
+ Args:
469
+ ctx: Context with saved variables
470
+ grad_output (torch.Tensor): Gradient from subsequent layers
471
+
472
+ Returns:
473
+ tuple: (input_gradient, threshold_gradient, None)
474
+ """
475
+ x, log_threshold, bandwidth_tensor = ctx.saved_tensors
476
+ bandwidth = bandwidth_tensor.item()
477
+ threshold = torch.exp(log_threshold)
478
+
479
+ # Gradient with respect to x
480
+ x_grad = (x > threshold).float() * grad_output
481
+
482
+ # Gradient with respect to threshold
483
+ # Uses rectangle function to approximate the dirac delta
484
+ threshold_grad = (
485
+ -(threshold / bandwidth)
486
+ * RectangleFunction.apply((x - threshold) / bandwidth)
487
+ * grad_output
488
+ )
489
+
490
+ return x_grad, threshold_grad, None # None for bandwidth
491
+
492
+
493
+ class StepFunction(torch.autograd.Function):
494
+ """
495
+ Custom autograd function implementing a step function with learnable threshold.
496
+
497
+ This activation function outputs 1 for values above a threshold and 0 otherwise.
498
+ It has custom gradients for both the input and the threshold parameter.
499
+ """
500
+ @staticmethod
501
+ def forward(ctx, x, log_threshold, bandwidth):
502
+ """
503
+ Forward pass of the step function.
504
+
505
+ Args:
506
+ ctx: Context for saving variables for backward
507
+ x (torch.Tensor): Input tensor
508
+ log_threshold (torch.Tensor): Log of the threshold value (learned parameter)
509
+ bandwidth (float): Bandwidth parameter for gradient approximation
510
+
511
+ Returns:
512
+ torch.Tensor: Binary output tensor with values in {0.0, 1.0}
513
+ """
514
+ ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth))
515
+ threshold = torch.exp(log_threshold)
516
+ return (x > threshold).float()
517
+
518
+ @staticmethod
519
+ def backward(ctx, grad_output):
520
+ """
521
+ Backward pass of the step function.
522
+
523
+ Args:
524
+ ctx: Context with saved variables
525
+ grad_output (torch.Tensor): Gradient from subsequent layers
526
+
527
+ Returns:
528
+ tuple: (input_gradient, threshold_gradient, None)
529
+ """
530
+ x, log_threshold, bandwidth_tensor = ctx.saved_tensors
531
+ bandwidth = bandwidth_tensor.item()
532
+ threshold = torch.exp(log_threshold)
533
+
534
+ # No gradient with respect to x (step function)
535
+ x_grad = torch.zeros_like(x)
536
+
537
+ # Gradient with respect to threshold
538
+ # Uses rectangle function to approximate the dirac delta
539
+ threshold_grad = (
540
+ -(1.0 / bandwidth)
541
+ * RectangleFunction.apply((x - threshold) / bandwidth)
542
+ * grad_output
543
+ )
544
+
545
+ return x_grad, threshold_grad, None # None for bandwidth
vocab/clip_disect_20k.txt ADDED
The diff for this file is too large to render. See raw diff