| # RFC-001: Otimização de Eficiência de Memória (Memory-Aware Ripple Attention) | |
| **Autor:** Victor Tavernari | |
| **Data:** 17/01/2026 | |
| **Status:** ✅ **IMPLEMENTADO** (Fase 1 + Fase 2) | |
| **Alvo:** `src/model.py` (Classe `RippleHead`) | |
| --- | |
| ## 1. O Problema (Contexto) | |
| A implementação original do RippleGPT utilizava atenção "vanilla" com injeção manual de viés posicional (ALiBi-style). Embora eficaz para o aprendizado, ela possuía complexidade de memória **O(T²)** devido à materialização explícita de múltiplas matrizes gigantes durante o forward: | |
| - **Matriz de Distância:** `indices[None, :] - indices[:, None]` (Float32/Float16) | |
| - **Matriz de Atenção (wei):** `q @ k.transpose` (scores crus) | |
| - **Matriz após masked_fill:** Cópia temporária | |
| - **Matriz após Softmax:** Outra alocação | |
| **Evidência:** Em testes de validação ("Needle Test"), um modelo de 17M parâmetros consumia **~3.4 GB de RAM** para processar um contexto de ~1,800 tokens (profundidade 60). | |
| --- | |
| ## 2. Objetivos | |
| - [x] Reduzir o consumo de pico de memória durante a inferência em contextos longos (>2048 tokens) em pelo menos 70% | |
| - [x] Manter a precisão (Perplexidade) idêntica à implementação atual | |
| - [ ] Permitir o aumento do `block_size` para 4k ou 8k (pendente validação) | |
| --- | |
| ## 3. Soluções Propostas | |
| ### ✅ Fase 1: SDPA (Scaled Dot Product Attention) - **IMPLEMENTADO** | |
| Substituímos a implementação manual de atenção pela função nativa otimizada `F.scaled_dot_product_attention` do PyTorch 2.0+. | |
| **Mudanças Principais:** | |
| 1. Uso de `F.scaled_dot_product_attention()` que funde softmax/dropout internamente | |
| 2. Cache do `ripple_bias` para reutilização quando T não muda | |
| 3. Fusão da máscara causal no próprio bias (usando `-inf` para tokens futuros) | |
| **Ganho Obtido:** ~**83% de redução de memória** (muito além dos 30-40% estimados!) | |
| ### ✅ Fase 2: Janela Deslizante (Sliding Window Attention) - **IMPLEMENTADO** | |
| Devido à natureza do "Ripple Field" (decaimento exponencial), a atenção em tokens muito distantes tende a zero. Implementamos uma janela rígida de atenção configurável via `attention_window`. | |
| **Configuração:** | |
| - `attention_window=None` → Full attention O(T²) | |
| - `attention_window=512` → Fast, 2-4x speedup, contextos infinitos | |
| - `attention_window=1024` → Balanced quality/speed | |
| **Complexidade:** O(T²) → **O(T × w)** - LINEAR! | |
| ### 🔜 Fase 3: Kernel Fusion Customizado (Triton) | |
| Escrever um kernel Triton que calcula o viés `(i - j) * decay` on-the-fly durante o cálculo da atenção, sem nunca salvá-lo na RAM. | |
| **Ganho Estimado:** ~**90% de redução de memória** | |
| --- | |
| ## 4. Resultados da Validação | |
| ### Fase 1: SDPA - Needle Test (Depth 60, ~1,800 tokens) | |
| | Implementação | Peak Memory | Tokens/sec | | |
| |---------------|-------------|------------| | |
| | **Vanilla (antes)** | 3,358 MB | 4.1 t/s | | |
| | **SDPA (depois)** | 553.7 MB | 5.6 t/s | | |
| | **Melhoria** | **-83.5%** | **+37%** | | |
| ### Fase 2: Sliding Window - Long Sequence Benchmark | |
| | Tokens | Full Attention | Window=512 | Speedup | | |
| |--------|----------------|------------|---------| | |
| | 2,000 | 153ms | **74ms** | **2.1x** | | |
| | 3,000 | 362ms | **97ms** | **3.7x** | | |
| | 4,000 | 393ms | **141ms** | **2.8x** | | |
| | 5,000 | 648ms | **210ms** | **3.1x** | | |
| | 6,000 | ❌ OOM | **276ms** | ∞ | | |
| | 8,000 | ❌ OOM | **286ms** | ∞ | | |
| | 10,000 | ❌ OOM | **324ms** | ∞ | | |
| **Conclusões Fase 2:** | |
| - 🚀 **Contextos de 10,000+ tokens** agora são possíveis | |
| - ⚡ **2-4x mais rápido** para sequências longas | |
| - 📈 **Crescimento LINEAR** (O(T×w) vs O(T²)) | |
| --- | |
| ## 5. Código Implementado | |
| ```python | |
| # src/model.py - RippleHead (Fase 1 RFC-001) | |
| class RippleHead(nn.Module): | |
| def __init__(self, config: RippleConfig): | |
| super().__init__() | |
| # ... | |
| self.dropout_p = config.dropout | |
| # RFC-001: Cache para bias combinado | |
| self._cached_bias = None | |
| self._cached_bias_size = 0 | |
| self._cached_decay_value = None | |
| def _get_ripple_bias(self, T: int, device, dtype) -> torch.Tensor: | |
| """Cache do ripple bias com máscara causal integrada.""" | |
| current_decay = torch.abs(self.decay_factor).item() | |
| needs_rebuild = ( | |
| self._cached_bias is None or | |
| self._cached_bias_size < T or | |
| self._cached_decay_value != current_decay | |
| ) | |
| if needs_rebuild: | |
| indices = torch.arange(T, device=device, dtype=dtype) | |
| dist = indices.unsqueeze(0) - indices.unsqueeze(1) | |
| ripple_bias = dist.clamp(max=0) * current_decay | |
| ripple_bias = ripple_bias.masked_fill(dist > 0, torch.finfo(dtype).min) | |
| self._cached_bias = ripple_bias | |
| self._cached_bias_size = T | |
| self._cached_decay_value = current_decay | |
| return self._cached_bias[:T, :T] | |
| def forward(self, x): | |
| B, T, C = x.shape | |
| q, k, v = self.query(x), self.key(x), self.value(x) | |
| ripple_bias = self._get_ripple_bias(T, x.device, q.dtype) | |
| # SDPA com shapes [B, 1, T, head_size] | |
| q, k, v = q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1) | |
| y = F.scaled_dot_product_attention( | |
| q, k, v, | |
| attn_mask=ripple_bias, | |
| dropout_p=self.dropout_p if self.training else 0.0, | |
| is_causal=False | |
| ) | |
| return y.squeeze(1) | |
| ``` | |
| --- | |
| ## 6. Próximos Passos | |
| 1. ✅ ~~Validar que a precisão não mudou~~ (outputs são equivalentes) | |
| 2. ✅ ~~Testar contextos de 4k e 8k tokens~~ (testado até 10k!) | |
| 3. ✅ ~~Implementar Fase 2 (Sliding Window)~~ (DONE!) | |
| 4. **Considerar** Fase 3 (Triton) se o projeto escalar para produção | |
| --- | |
| ## Changelog | |
| - **2026-01-17:** Fase 1 implementada e validada. Redução de 83% na memória! | |
| - **2026-01-17:** Fase 2 implementada! Sliding Window permite contextos de 10k+ tokens com 2-4x speedup. | |