YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
Resultados de Entrenamiento SAE - Llama 3.2 1B
Este directorio contiene los resultados y artefactos del entrenamiento de Sparse Autoencoder (SAE) sobre el modelo Llama 3.2 1B.
Estructura del Directorio
prevalences/- Datos de prevalencia de activaciones neuronales guardados durante el entrenamiento a intervalos regularesruns/- Logs de TensorBoard y datos de ejecución del entrenamiento*.ptharchivos - Checkpoints del modelo guardados en varios pasos de entrenamientosae_*.pth- Pesos finales del modelo SAE entrenado
Experimentos de Inicialización
Este repositorio contiene dos experimentos comparando diferentes esquemas de inicialización de pesos:
Rama Principal (main): Inicialización d_sae_std
- Varianza por entrada:
1/d_sae - Norma de columnas del decoder:
√(d_model/d_sae) ≈ 0.204 - Comportamiento: Entrenamiento estable desde el paso 0, sin necesidad de warmup especial
- Ventajas: Estabilidad inmediata, 0% features muertas, distribución de prevalencia unimodal
- Resultados: L0 se estabiliza rápidamente (~3300→<500), pérdida de reconstrucción se estabiliza en ~0.19
Rama d-model-std-normalization: Inicialización d_model_std (estilo Gemmascope)
- Varianza por entrada:
1/d_model - Norma de columnas del decoder:
1(renormalizadas cada paso) - Comportamiento: Oscilaciones grandes en L0, ráfagas de features muertas
- Ventajas: Columnas de norma unitaria para interpretabilidad
- Resultados: L0 con oscilaciones (~2400→<500), pérdida de reconstrucción sube a 0.21
Conclusión: Aunque ambos experimentos convergen a resultados similares, encontramos que la inicialización d_sae_std ofrece mejor estabilidad de entrenamiento sin trucos adicionales, mientras que d_model_std logra estabilizarse pero con un comportamiento más errático inicialmente.
Configuración del Entrenamiento
- Modelo: Llama 3.2 1B
- Capa objetivo: Salida de la MLP intermedia (capa 8)
- Dimensión de entrada: 2048 (d_in)
- Factor de expansión: 24 (d_sae = 49,152)
- Pasos de entrenamiento: 256,000
- Coeficiente de sparsity: 0.001
- Learning rate: 7e-5
- Warmup: Warmup completo de sparsity durante toda la duración del entrenamiento
- Tipo de datos: Parámetros del SAE en fp32, entrenamiento mixto con autocast bf16
Arquitectura del SAE
Parámetros del Modelo
- Encoder:
nn.Linear(2048, 49152, dtype=torch.float32)con bias - Decoder:
nn.Linear(49152, 2048, dtype=torch.float32)con bias - log_threshold:
nn.Parameterde forma (49152,) inicializado con log(0.001) - Inicialización: Pesos compartidos entre encoder y decoder (decoder.weight = encoder.weight.T), vectores del diccionario normalizados
Forward Pass Detallado
def forward(self, x):
# x: (batch_size, 2048) - salidas de MLP de Llama
d = {}
original_input = x
# 1. Pre-procesamiento (centrado)
if self.use_pre_enc_bias:
x = x - self.dec.bias # (batch_size, 2048)
# 2. Encoding lineal
x = self.enc(x) # (batch_size, 49152)
# 3. Thresholding con función Step personalizada
threshold = torch.exp(self.log_threshold) # (49152,)
s = Step.apply(x, threshold) # (batch_size, 49152) - máscara binaria
# 4. Aplicar sparsity
x = x * s # (batch_size, 49152) - activaciones sparse
# 5. Decoding
x = self.dec(x) # (batch_size, 2048) - reconstrucción
# 6. Calcular métricas
d['mask'] = s # máscara de activaciones activas
d['reconstruction'] = ((x - original_input).pow(2)).mean(0).sum() # MSE loss
return d
Función Step Personalizada
- Forward:
(x > threshold).to(x.dtype)- función escalón binaria - Backward: Gradiente aproximado con bandwidth=0.001 para threshold aprendible
- Propósito: Permitir gradientes a través del thresholding para entrenar los thresholds
Archivos de Datos
Checkpoints del Modelo
step_*.pth- Checkpoints guardados cada 5,000 pasos después del paso 40,000sae_exp24_sparse0.001_d_sae_std_fullwarmup_steps256000_lr7e-05.pth- Modelo final entrenado
Datos de Prevalencia
prevalences/step_*.npy- Datos de prevalencia de activaciones para análisis de neuronas muertas- Forma: (49152,) - prevalencia promedio de cada feature del SAE
- Intervalos: Guardados cada 5,000 pasos durante 2,000 batches de evaluación
- Contenido: Fracción de ejemplos donde cada feature se activa (s > 0)
prevalences/bin_edges.npy- Bordes de bins del histograma para análisis de prevalencia- Forma: (101,) - bordes logarítmicos desde 1e-8 hasta 100
Datos de TensorBoard
Los logs contienen métricas de entrenamiento incluyendo:
- Pérdidas:
- Reconstruction loss: MSE entre entrada original y reconstrucción
- L0 loss: Suma de activaciones activas (sparsity penalty)
- Pérdida total: reconstruction_loss + sparsity_coefficient * l0
- Learning rate: Programa de coseno con warmup lineal (10% a 100% en 2000 pasos)
- Coeficiente de sparsity: Programa de warmup lineal hasta valor máximo durante entrenamiento completo
- Prevalencias: Histogramas de activación de features cada 5,000 pasos
- Porcentaje de neuronas muertas: Features con prevalencia < 1e-7
- Productos punto de features: Visualización de similaridad entre vectores del diccionario (top 256 features)
Instalación para Profiling
Para ver los datos de profiling en TensorBoard:
pip install torch-tb-profiler tensorboard-plugin-profile
Uso
Para cargar un modelo SAE entrenado:
import torch
from torch import nn
from math import sqrt
class Sae(nn.Module):
# ... implementación completa necesaria ...
# Cargar modelo
model = Sae(d_in=2048, d_sae=49152)
model.load_state_dict(torch.load('sae_exp24_sparse0.001_d_sae_std_fullwarmup_steps256000_lr7e-05.pth'))
model.eval()
# Compilar para rendimiento óptimo
model = torch.compile(model, mode="max-autotune")