YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

Resultados de Entrenamiento SAE - Llama 3.2 1B

Este directorio contiene los resultados y artefactos del entrenamiento de Sparse Autoencoder (SAE) sobre el modelo Llama 3.2 1B.

Estructura del Directorio

  • prevalences/ - Datos de prevalencia de activaciones neuronales guardados durante el entrenamiento a intervalos regulares
  • runs/ - Logs de TensorBoard y datos de ejecución del entrenamiento
  • *.pth archivos - Checkpoints del modelo guardados en varios pasos de entrenamiento
  • sae_*.pth - Pesos finales del modelo SAE entrenado

Experimentos de Inicialización

Este repositorio contiene dos experimentos comparando diferentes esquemas de inicialización de pesos:

Rama Principal (main): Inicialización d_sae_std

  • Varianza por entrada: 1/d_sae
  • Norma de columnas del decoder: √(d_model/d_sae) ≈ 0.204
  • Comportamiento: Entrenamiento estable desde el paso 0, sin necesidad de warmup especial
  • Ventajas: Estabilidad inmediata, 0% features muertas, distribución de prevalencia unimodal
  • Resultados: L0 se estabiliza rápidamente (~3300→<500), pérdida de reconstrucción se estabiliza en ~0.19

Rama d-model-std-normalization: Inicialización d_model_std (estilo Gemmascope)

  • Varianza por entrada: 1/d_model
  • Norma de columnas del decoder: 1 (renormalizadas cada paso)
  • Comportamiento: Oscilaciones grandes en L0, ráfagas de features muertas
  • Ventajas: Columnas de norma unitaria para interpretabilidad
  • Resultados: L0 con oscilaciones (~2400→<500), pérdida de reconstrucción sube a 0.21

Conclusión: Aunque ambos experimentos convergen a resultados similares, encontramos que la inicialización d_sae_std ofrece mejor estabilidad de entrenamiento sin trucos adicionales, mientras que d_model_std logra estabilizarse pero con un comportamiento más errático inicialmente.

Configuración del Entrenamiento

  • Modelo: Llama 3.2 1B
  • Capa objetivo: Salida de la MLP intermedia (capa 8)
  • Dimensión de entrada: 2048 (d_in)
  • Factor de expansión: 24 (d_sae = 49,152)
  • Pasos de entrenamiento: 256,000
  • Coeficiente de sparsity: 0.001
  • Learning rate: 7e-5
  • Warmup: Warmup completo de sparsity durante toda la duración del entrenamiento
  • Tipo de datos: Parámetros del SAE en fp32, entrenamiento mixto con autocast bf16

Arquitectura del SAE

Parámetros del Modelo

  • Encoder: nn.Linear(2048, 49152, dtype=torch.float32) con bias
  • Decoder: nn.Linear(49152, 2048, dtype=torch.float32) con bias
  • log_threshold: nn.Parameter de forma (49152,) inicializado con log(0.001)
  • Inicialización: Pesos compartidos entre encoder y decoder (decoder.weight = encoder.weight.T), vectores del diccionario normalizados

Forward Pass Detallado

def forward(self, x):
    # x: (batch_size, 2048) - salidas de MLP de Llama
    d = {}
    original_input = x
    
    # 1. Pre-procesamiento (centrado)
    if self.use_pre_enc_bias:
        x = x - self.dec.bias  # (batch_size, 2048)
    
    # 2. Encoding lineal
    x = self.enc(x)  # (batch_size, 49152)
    
    # 3. Thresholding con función Step personalizada
    threshold = torch.exp(self.log_threshold)  # (49152,)
    s = Step.apply(x, threshold)  # (batch_size, 49152) - máscara binaria
    
    # 4. Aplicar sparsity
    x = x * s  # (batch_size, 49152) - activaciones sparse
    
    # 5. Decoding
    x = self.dec(x)  # (batch_size, 2048) - reconstrucción
    
    # 6. Calcular métricas
    d['mask'] = s  # máscara de activaciones activas
    d['reconstruction'] = ((x - original_input).pow(2)).mean(0).sum()  # MSE loss
    
    return d

Función Step Personalizada

  • Forward: (x > threshold).to(x.dtype) - función escalón binaria
  • Backward: Gradiente aproximado con bandwidth=0.001 para threshold aprendible
  • Propósito: Permitir gradientes a través del thresholding para entrenar los thresholds

Archivos de Datos

Checkpoints del Modelo

  • step_*.pth - Checkpoints guardados cada 5,000 pasos después del paso 40,000
  • sae_exp24_sparse0.001_d_sae_std_fullwarmup_steps256000_lr7e-05.pth - Modelo final entrenado

Datos de Prevalencia

  • prevalences/step_*.npy - Datos de prevalencia de activaciones para análisis de neuronas muertas
    • Forma: (49152,) - prevalencia promedio de cada feature del SAE
    • Intervalos: Guardados cada 5,000 pasos durante 2,000 batches de evaluación
    • Contenido: Fracción de ejemplos donde cada feature se activa (s > 0)
  • prevalences/bin_edges.npy - Bordes de bins del histograma para análisis de prevalencia
    • Forma: (101,) - bordes logarítmicos desde 1e-8 hasta 100

Datos de TensorBoard

Los logs contienen métricas de entrenamiento incluyendo:

  • Pérdidas:
    • Reconstruction loss: MSE entre entrada original y reconstrucción
    • L0 loss: Suma de activaciones activas (sparsity penalty)
    • Pérdida total: reconstruction_loss + sparsity_coefficient * l0
  • Learning rate: Programa de coseno con warmup lineal (10% a 100% en 2000 pasos)
  • Coeficiente de sparsity: Programa de warmup lineal hasta valor máximo durante entrenamiento completo
  • Prevalencias: Histogramas de activación de features cada 5,000 pasos
  • Porcentaje de neuronas muertas: Features con prevalencia < 1e-7
  • Productos punto de features: Visualización de similaridad entre vectores del diccionario (top 256 features)

Instalación para Profiling

Para ver los datos de profiling en TensorBoard:

pip install torch-tb-profiler tensorboard-plugin-profile

Uso

Para cargar un modelo SAE entrenado:

import torch
from torch import nn
from math import sqrt

class Sae(nn.Module):
    # ... implementación completa necesaria ...

# Cargar modelo
model = Sae(d_in=2048, d_sae=49152)
model.load_state_dict(torch.load('sae_exp24_sparse0.001_d_sae_std_fullwarmup_steps256000_lr7e-05.pth'))
model.eval()

# Compilar para rendimiento óptimo
model = torch.compile(model, mode="max-autotune")
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support