RippleGPT-Nano / docs /RFC-001_Memory_Optimization.md
Tavernari's picture
Upload folder using huggingface_hub
148b631 verified

RFC-001: Otimização de Eficiência de Memória (Memory-Aware Ripple Attention)

Autor: Victor Tavernari
Data: 17/01/2026
Status:IMPLEMENTADO (Fase 1 + Fase 2)
Alvo: src/model.py (Classe RippleHead)


1. O Problema (Contexto)

A implementação original do RippleGPT utilizava atenção "vanilla" com injeção manual de viés posicional (ALiBi-style). Embora eficaz para o aprendizado, ela possuía complexidade de memória O(T²) devido à materialização explícita de múltiplas matrizes gigantes durante o forward:

  • Matriz de Distância: indices[None, :] - indices[:, None] (Float32/Float16)
  • Matriz de Atenção (wei): q @ k.transpose (scores crus)
  • Matriz após masked_fill: Cópia temporária
  • Matriz após Softmax: Outra alocação

Evidência: Em testes de validação ("Needle Test"), um modelo de 17M parâmetros consumia ~3.4 GB de RAM para processar um contexto de ~1,800 tokens (profundidade 60).


2. Objetivos

  • Reduzir o consumo de pico de memória durante a inferência em contextos longos (>2048 tokens) em pelo menos 70%
  • Manter a precisão (Perplexidade) idêntica à implementação atual
  • Permitir o aumento do block_size para 4k ou 8k (pendente validação)

3. Soluções Propostas

✅ Fase 1: SDPA (Scaled Dot Product Attention) - IMPLEMENTADO

Substituímos a implementação manual de atenção pela função nativa otimizada F.scaled_dot_product_attention do PyTorch 2.0+.

Mudanças Principais:

  1. Uso de F.scaled_dot_product_attention() que funde softmax/dropout internamente
  2. Cache do ripple_bias para reutilização quando T não muda
  3. Fusão da máscara causal no próprio bias (usando -inf para tokens futuros)

Ganho Obtido: ~83% de redução de memória (muito além dos 30-40% estimados!)

✅ Fase 2: Janela Deslizante (Sliding Window Attention) - IMPLEMENTADO

Devido à natureza do "Ripple Field" (decaimento exponencial), a atenção em tokens muito distantes tende a zero. Implementamos uma janela rígida de atenção configurável via attention_window.

Configuração:

  • attention_window=None → Full attention O(T²)
  • attention_window=512 → Fast, 2-4x speedup, contextos infinitos
  • attention_window=1024 → Balanced quality/speed

Complexidade: O(T²) → O(T × w) - LINEAR!

🔜 Fase 3: Kernel Fusion Customizado (Triton)

Escrever um kernel Triton que calcula o viés (i - j) * decay on-the-fly durante o cálculo da atenção, sem nunca salvá-lo na RAM.

Ganho Estimado: ~90% de redução de memória


4. Resultados da Validação

Fase 1: SDPA - Needle Test (Depth 60, ~1,800 tokens)

Implementação Peak Memory Tokens/sec
Vanilla (antes) 3,358 MB 4.1 t/s
SDPA (depois) 553.7 MB 5.6 t/s
Melhoria -83.5% +37%

Fase 2: Sliding Window - Long Sequence Benchmark

Tokens Full Attention Window=512 Speedup
2,000 153ms 74ms 2.1x
3,000 362ms 97ms 3.7x
4,000 393ms 141ms 2.8x
5,000 648ms 210ms 3.1x
6,000 ❌ OOM 276ms
8,000 ❌ OOM 286ms
10,000 ❌ OOM 324ms

Conclusões Fase 2:

  • 🚀 Contextos de 10,000+ tokens agora são possíveis
  • 2-4x mais rápido para sequências longas
  • 📈 Crescimento LINEAR (O(T×w) vs O(T²))

5. Código Implementado

# src/model.py - RippleHead (Fase 1 RFC-001)

class RippleHead(nn.Module):
    def __init__(self, config: RippleConfig):
        super().__init__()
        # ...
        self.dropout_p = config.dropout
        
        # RFC-001: Cache para bias combinado
        self._cached_bias = None
        self._cached_bias_size = 0
        self._cached_decay_value = None

    def _get_ripple_bias(self, T: int, device, dtype) -> torch.Tensor:
        """Cache do ripple bias com máscara causal integrada."""
        current_decay = torch.abs(self.decay_factor).item()
        
        needs_rebuild = (
            self._cached_bias is None or 
            self._cached_bias_size < T or
            self._cached_decay_value != current_decay
        )
        
        if needs_rebuild:
            indices = torch.arange(T, device=device, dtype=dtype)
            dist = indices.unsqueeze(0) - indices.unsqueeze(1)
            ripple_bias = dist.clamp(max=0) * current_decay
            ripple_bias = ripple_bias.masked_fill(dist > 0, torch.finfo(dtype).min)
            
            self._cached_bias = ripple_bias
            self._cached_bias_size = T
            self._cached_decay_value = current_decay
        
        return self._cached_bias[:T, :T]

    def forward(self, x):
        B, T, C = x.shape
        q, k, v = self.query(x), self.key(x), self.value(x)
        
        ripple_bias = self._get_ripple_bias(T, x.device, q.dtype)
        
        # SDPA com shapes [B, 1, T, head_size]
        q, k, v = q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1)
        
        y = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=ripple_bias,
            dropout_p=self.dropout_p if self.training else 0.0,
            is_causal=False
        )
        
        return y.squeeze(1)

6. Próximos Passos

  1. Validar que a precisão não mudou (outputs são equivalentes)
  2. Testar contextos de 4k e 8k tokens (testado até 10k!)
  3. Implementar Fase 2 (Sliding Window) (DONE!)
  4. Considerar Fase 3 (Triton) se o projeto escalar para produção

Changelog

  • 2026-01-17: Fase 1 implementada e validada. Redução de 83% na memória!
  • 2026-01-17: Fase 2 implementada! Sliding Window permite contextos de 10k+ tokens com 2-4x speedup.