# RFC-001: Otimização de Eficiência de Memória (Memory-Aware Ripple Attention) **Autor:** Victor Tavernari **Data:** 17/01/2026 **Status:** ✅ **IMPLEMENTADO** (Fase 1 + Fase 2) **Alvo:** `src/model.py` (Classe `RippleHead`) --- ## 1. O Problema (Contexto) A implementação original do RippleGPT utilizava atenção "vanilla" com injeção manual de viés posicional (ALiBi-style). Embora eficaz para o aprendizado, ela possuía complexidade de memória **O(T²)** devido à materialização explícita de múltiplas matrizes gigantes durante o forward: - **Matriz de Distância:** `indices[None, :] - indices[:, None]` (Float32/Float16) - **Matriz de Atenção (wei):** `q @ k.transpose` (scores crus) - **Matriz após masked_fill:** Cópia temporária - **Matriz após Softmax:** Outra alocação **Evidência:** Em testes de validação ("Needle Test"), um modelo de 17M parâmetros consumia **~3.4 GB de RAM** para processar um contexto de ~1,800 tokens (profundidade 60). --- ## 2. Objetivos - [x] Reduzir o consumo de pico de memória durante a inferência em contextos longos (>2048 tokens) em pelo menos 70% - [x] Manter a precisão (Perplexidade) idêntica à implementação atual - [ ] Permitir o aumento do `block_size` para 4k ou 8k (pendente validação) --- ## 3. Soluções Propostas ### ✅ Fase 1: SDPA (Scaled Dot Product Attention) - **IMPLEMENTADO** Substituímos a implementação manual de atenção pela função nativa otimizada `F.scaled_dot_product_attention` do PyTorch 2.0+. **Mudanças Principais:** 1. Uso de `F.scaled_dot_product_attention()` que funde softmax/dropout internamente 2. Cache do `ripple_bias` para reutilização quando T não muda 3. Fusão da máscara causal no próprio bias (usando `-inf` para tokens futuros) **Ganho Obtido:** ~**83% de redução de memória** (muito além dos 30-40% estimados!) ### ✅ Fase 2: Janela Deslizante (Sliding Window Attention) - **IMPLEMENTADO** Devido à natureza do "Ripple Field" (decaimento exponencial), a atenção em tokens muito distantes tende a zero. Implementamos uma janela rígida de atenção configurável via `attention_window`. **Configuração:** - `attention_window=None` → Full attention O(T²) - `attention_window=512` → Fast, 2-4x speedup, contextos infinitos - `attention_window=1024` → Balanced quality/speed **Complexidade:** O(T²) → **O(T × w)** - LINEAR! ### 🔜 Fase 3: Kernel Fusion Customizado (Triton) Escrever um kernel Triton que calcula o viés `(i - j) * decay` on-the-fly durante o cálculo da atenção, sem nunca salvá-lo na RAM. **Ganho Estimado:** ~**90% de redução de memória** --- ## 4. Resultados da Validação ### Fase 1: SDPA - Needle Test (Depth 60, ~1,800 tokens) | Implementação | Peak Memory | Tokens/sec | |---------------|-------------|------------| | **Vanilla (antes)** | 3,358 MB | 4.1 t/s | | **SDPA (depois)** | 553.7 MB | 5.6 t/s | | **Melhoria** | **-83.5%** | **+37%** | ### Fase 2: Sliding Window - Long Sequence Benchmark | Tokens | Full Attention | Window=512 | Speedup | |--------|----------------|------------|---------| | 2,000 | 153ms | **74ms** | **2.1x** | | 3,000 | 362ms | **97ms** | **3.7x** | | 4,000 | 393ms | **141ms** | **2.8x** | | 5,000 | 648ms | **210ms** | **3.1x** | | 6,000 | ❌ OOM | **276ms** | ∞ | | 8,000 | ❌ OOM | **286ms** | ∞ | | 10,000 | ❌ OOM | **324ms** | ∞ | **Conclusões Fase 2:** - 🚀 **Contextos de 10,000+ tokens** agora são possíveis - ⚡ **2-4x mais rápido** para sequências longas - 📈 **Crescimento LINEAR** (O(T×w) vs O(T²)) --- ## 5. Código Implementado ```python # src/model.py - RippleHead (Fase 1 RFC-001) class RippleHead(nn.Module): def __init__(self, config: RippleConfig): super().__init__() # ... self.dropout_p = config.dropout # RFC-001: Cache para bias combinado self._cached_bias = None self._cached_bias_size = 0 self._cached_decay_value = None def _get_ripple_bias(self, T: int, device, dtype) -> torch.Tensor: """Cache do ripple bias com máscara causal integrada.""" current_decay = torch.abs(self.decay_factor).item() needs_rebuild = ( self._cached_bias is None or self._cached_bias_size < T or self._cached_decay_value != current_decay ) if needs_rebuild: indices = torch.arange(T, device=device, dtype=dtype) dist = indices.unsqueeze(0) - indices.unsqueeze(1) ripple_bias = dist.clamp(max=0) * current_decay ripple_bias = ripple_bias.masked_fill(dist > 0, torch.finfo(dtype).min) self._cached_bias = ripple_bias self._cached_bias_size = T self._cached_decay_value = current_decay return self._cached_bias[:T, :T] def forward(self, x): B, T, C = x.shape q, k, v = self.query(x), self.key(x), self.value(x) ripple_bias = self._get_ripple_bias(T, x.device, q.dtype) # SDPA com shapes [B, 1, T, head_size] q, k, v = q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1) y = F.scaled_dot_product_attention( q, k, v, attn_mask=ripple_bias, dropout_p=self.dropout_p if self.training else 0.0, is_causal=False ) return y.squeeze(1) ``` --- ## 6. Próximos Passos 1. ✅ ~~Validar que a precisão não mudou~~ (outputs são equivalentes) 2. ✅ ~~Testar contextos de 4k e 8k tokens~~ (testado até 10k!) 3. ✅ ~~Implementar Fase 2 (Sliding Window)~~ (DONE!) 4. **Considerar** Fase 3 (Triton) se o projeto escalar para produção --- ## Changelog - **2026-01-17:** Fase 1 implementada e validada. Redução de 83% na memória! - **2026-01-17:** Fase 2 implementada! Sliding Window permite contextos de 10k+ tokens com 2-4x speedup.