RippleGPT-Nano / docs /RFC-001_Memory_Optimization.md
Tavernari's picture
Upload folder using huggingface_hub
148b631 verified
# RFC-001: Otimização de Eficiência de Memória (Memory-Aware Ripple Attention)
**Autor:** Victor Tavernari
**Data:** 17/01/2026
**Status:****IMPLEMENTADO** (Fase 1 + Fase 2)
**Alvo:** `src/model.py` (Classe `RippleHead`)
---
## 1. O Problema (Contexto)
A implementação original do RippleGPT utilizava atenção "vanilla" com injeção manual de viés posicional (ALiBi-style). Embora eficaz para o aprendizado, ela possuía complexidade de memória **O(T²)** devido à materialização explícita de múltiplas matrizes gigantes durante o forward:
- **Matriz de Distância:** `indices[None, :] - indices[:, None]` (Float32/Float16)
- **Matriz de Atenção (wei):** `q @ k.transpose` (scores crus)
- **Matriz após masked_fill:** Cópia temporária
- **Matriz após Softmax:** Outra alocação
**Evidência:** Em testes de validação ("Needle Test"), um modelo de 17M parâmetros consumia **~3.4 GB de RAM** para processar um contexto de ~1,800 tokens (profundidade 60).
---
## 2. Objetivos
- [x] Reduzir o consumo de pico de memória durante a inferência em contextos longos (>2048 tokens) em pelo menos 70%
- [x] Manter a precisão (Perplexidade) idêntica à implementação atual
- [ ] Permitir o aumento do `block_size` para 4k ou 8k (pendente validação)
---
## 3. Soluções Propostas
### ✅ Fase 1: SDPA (Scaled Dot Product Attention) - **IMPLEMENTADO**
Substituímos a implementação manual de atenção pela função nativa otimizada `F.scaled_dot_product_attention` do PyTorch 2.0+.
**Mudanças Principais:**
1. Uso de `F.scaled_dot_product_attention()` que funde softmax/dropout internamente
2. Cache do `ripple_bias` para reutilização quando T não muda
3. Fusão da máscara causal no próprio bias (usando `-inf` para tokens futuros)
**Ganho Obtido:** ~**83% de redução de memória** (muito além dos 30-40% estimados!)
### ✅ Fase 2: Janela Deslizante (Sliding Window Attention) - **IMPLEMENTADO**
Devido à natureza do "Ripple Field" (decaimento exponencial), a atenção em tokens muito distantes tende a zero. Implementamos uma janela rígida de atenção configurável via `attention_window`.
**Configuração:**
- `attention_window=None` → Full attention O(T²)
- `attention_window=512` → Fast, 2-4x speedup, contextos infinitos
- `attention_window=1024` → Balanced quality/speed
**Complexidade:** O(T²) → **O(T × w)** - LINEAR!
### 🔜 Fase 3: Kernel Fusion Customizado (Triton)
Escrever um kernel Triton que calcula o viés `(i - j) * decay` on-the-fly durante o cálculo da atenção, sem nunca salvá-lo na RAM.
**Ganho Estimado:** ~**90% de redução de memória**
---
## 4. Resultados da Validação
### Fase 1: SDPA - Needle Test (Depth 60, ~1,800 tokens)
| Implementação | Peak Memory | Tokens/sec |
|---------------|-------------|------------|
| **Vanilla (antes)** | 3,358 MB | 4.1 t/s |
| **SDPA (depois)** | 553.7 MB | 5.6 t/s |
| **Melhoria** | **-83.5%** | **+37%** |
### Fase 2: Sliding Window - Long Sequence Benchmark
| Tokens | Full Attention | Window=512 | Speedup |
|--------|----------------|------------|---------|
| 2,000 | 153ms | **74ms** | **2.1x** |
| 3,000 | 362ms | **97ms** | **3.7x** |
| 4,000 | 393ms | **141ms** | **2.8x** |
| 5,000 | 648ms | **210ms** | **3.1x** |
| 6,000 | ❌ OOM | **276ms** | ∞ |
| 8,000 | ❌ OOM | **286ms** | ∞ |
| 10,000 | ❌ OOM | **324ms** | ∞ |
**Conclusões Fase 2:**
- 🚀 **Contextos de 10,000+ tokens** agora são possíveis
-**2-4x mais rápido** para sequências longas
- 📈 **Crescimento LINEAR** (O(T×w) vs O(T²))
---
## 5. Código Implementado
```python
# src/model.py - RippleHead (Fase 1 RFC-001)
class RippleHead(nn.Module):
def __init__(self, config: RippleConfig):
super().__init__()
# ...
self.dropout_p = config.dropout
# RFC-001: Cache para bias combinado
self._cached_bias = None
self._cached_bias_size = 0
self._cached_decay_value = None
def _get_ripple_bias(self, T: int, device, dtype) -> torch.Tensor:
"""Cache do ripple bias com máscara causal integrada."""
current_decay = torch.abs(self.decay_factor).item()
needs_rebuild = (
self._cached_bias is None or
self._cached_bias_size < T or
self._cached_decay_value != current_decay
)
if needs_rebuild:
indices = torch.arange(T, device=device, dtype=dtype)
dist = indices.unsqueeze(0) - indices.unsqueeze(1)
ripple_bias = dist.clamp(max=0) * current_decay
ripple_bias = ripple_bias.masked_fill(dist > 0, torch.finfo(dtype).min)
self._cached_bias = ripple_bias
self._cached_bias_size = T
self._cached_decay_value = current_decay
return self._cached_bias[:T, :T]
def forward(self, x):
B, T, C = x.shape
q, k, v = self.query(x), self.key(x), self.value(x)
ripple_bias = self._get_ripple_bias(T, x.device, q.dtype)
# SDPA com shapes [B, 1, T, head_size]
q, k, v = q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1)
y = F.scaled_dot_product_attention(
q, k, v,
attn_mask=ripple_bias,
dropout_p=self.dropout_p if self.training else 0.0,
is_causal=False
)
return y.squeeze(1)
```
---
## 6. Próximos Passos
1. ✅ ~~Validar que a precisão não mudou~~ (outputs são equivalentes)
2. ✅ ~~Testar contextos de 4k e 8k tokens~~ (testado até 10k!)
3. ✅ ~~Implementar Fase 2 (Sliding Window)~~ (DONE!)
4. **Considerar** Fase 3 (Triton) se o projeto escalar para produção
---
## Changelog
- **2026-01-17:** Fase 1 implementada e validada. Redução de 83% na memória!
- **2026-01-17:** Fase 2 implementada! Sliding Window permite contextos de 10k+ tokens com 2-4x speedup.