RFC-001: Otimização de Eficiência de Memória (Memory-Aware Ripple Attention)
Autor: Victor Tavernari
Data: 17/01/2026
Status: ✅ IMPLEMENTADO (Fase 1 + Fase 2)
Alvo: src/model.py (Classe RippleHead)
1. O Problema (Contexto)
A implementação original do RippleGPT utilizava atenção "vanilla" com injeção manual de viés posicional (ALiBi-style). Embora eficaz para o aprendizado, ela possuía complexidade de memória O(T²) devido à materialização explícita de múltiplas matrizes gigantes durante o forward:
- Matriz de Distância:
indices[None, :] - indices[:, None](Float32/Float16) - Matriz de Atenção (wei):
q @ k.transpose(scores crus) - Matriz após masked_fill: Cópia temporária
- Matriz após Softmax: Outra alocação
Evidência: Em testes de validação ("Needle Test"), um modelo de 17M parâmetros consumia ~3.4 GB de RAM para processar um contexto de ~1,800 tokens (profundidade 60).
2. Objetivos
- Reduzir o consumo de pico de memória durante a inferência em contextos longos (>2048 tokens) em pelo menos 70%
- Manter a precisão (Perplexidade) idêntica à implementação atual
- Permitir o aumento do
block_sizepara 4k ou 8k (pendente validação)
3. Soluções Propostas
✅ Fase 1: SDPA (Scaled Dot Product Attention) - IMPLEMENTADO
Substituímos a implementação manual de atenção pela função nativa otimizada F.scaled_dot_product_attention do PyTorch 2.0+.
Mudanças Principais:
- Uso de
F.scaled_dot_product_attention()que funde softmax/dropout internamente - Cache do
ripple_biaspara reutilização quando T não muda - Fusão da máscara causal no próprio bias (usando
-infpara tokens futuros)
Ganho Obtido: ~83% de redução de memória (muito além dos 30-40% estimados!)
✅ Fase 2: Janela Deslizante (Sliding Window Attention) - IMPLEMENTADO
Devido à natureza do "Ripple Field" (decaimento exponencial), a atenção em tokens muito distantes tende a zero. Implementamos uma janela rígida de atenção configurável via attention_window.
Configuração:
attention_window=None→ Full attention O(T²)attention_window=512→ Fast, 2-4x speedup, contextos infinitosattention_window=1024→ Balanced quality/speed
Complexidade: O(T²) → O(T × w) - LINEAR!
🔜 Fase 3: Kernel Fusion Customizado (Triton)
Escrever um kernel Triton que calcula o viés (i - j) * decay on-the-fly durante o cálculo da atenção, sem nunca salvá-lo na RAM.
Ganho Estimado: ~90% de redução de memória
4. Resultados da Validação
Fase 1: SDPA - Needle Test (Depth 60, ~1,800 tokens)
| Implementação | Peak Memory | Tokens/sec |
|---|---|---|
| Vanilla (antes) | 3,358 MB | 4.1 t/s |
| SDPA (depois) | 553.7 MB | 5.6 t/s |
| Melhoria | -83.5% | +37% |
Fase 2: Sliding Window - Long Sequence Benchmark
| Tokens | Full Attention | Window=512 | Speedup |
|---|---|---|---|
| 2,000 | 153ms | 74ms | 2.1x |
| 3,000 | 362ms | 97ms | 3.7x |
| 4,000 | 393ms | 141ms | 2.8x |
| 5,000 | 648ms | 210ms | 3.1x |
| 6,000 | ❌ OOM | 276ms | ∞ |
| 8,000 | ❌ OOM | 286ms | ∞ |
| 10,000 | ❌ OOM | 324ms | ∞ |
Conclusões Fase 2:
- 🚀 Contextos de 10,000+ tokens agora são possíveis
- ⚡ 2-4x mais rápido para sequências longas
- 📈 Crescimento LINEAR (O(T×w) vs O(T²))
5. Código Implementado
# src/model.py - RippleHead (Fase 1 RFC-001)
class RippleHead(nn.Module):
def __init__(self, config: RippleConfig):
super().__init__()
# ...
self.dropout_p = config.dropout
# RFC-001: Cache para bias combinado
self._cached_bias = None
self._cached_bias_size = 0
self._cached_decay_value = None
def _get_ripple_bias(self, T: int, device, dtype) -> torch.Tensor:
"""Cache do ripple bias com máscara causal integrada."""
current_decay = torch.abs(self.decay_factor).item()
needs_rebuild = (
self._cached_bias is None or
self._cached_bias_size < T or
self._cached_decay_value != current_decay
)
if needs_rebuild:
indices = torch.arange(T, device=device, dtype=dtype)
dist = indices.unsqueeze(0) - indices.unsqueeze(1)
ripple_bias = dist.clamp(max=0) * current_decay
ripple_bias = ripple_bias.masked_fill(dist > 0, torch.finfo(dtype).min)
self._cached_bias = ripple_bias
self._cached_bias_size = T
self._cached_decay_value = current_decay
return self._cached_bias[:T, :T]
def forward(self, x):
B, T, C = x.shape
q, k, v = self.query(x), self.key(x), self.value(x)
ripple_bias = self._get_ripple_bias(T, x.device, q.dtype)
# SDPA com shapes [B, 1, T, head_size]
q, k, v = q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1)
y = F.scaled_dot_product_attention(
q, k, v,
attn_mask=ripple_bias,
dropout_p=self.dropout_p if self.training else 0.0,
is_causal=False
)
return y.squeeze(1)
6. Próximos Passos
- ✅
Validar que a precisão não mudou(outputs são equivalentes) - ✅
Testar contextos de 4k e 8k tokens(testado até 10k!) - ✅
Implementar Fase 2 (Sliding Window)(DONE!) - Considerar Fase 3 (Triton) se o projeto escalar para produção
Changelog
- 2026-01-17: Fase 1 implementada e validada. Redução de 83% na memória!
- 2026-01-17: Fase 2 implementada! Sliding Window permite contextos de 10k+ tokens com 2-4x speedup.