|
|
|
|
|
"""Quick test to verify the optimized model works correctly.""" |
|
|
|
|
|
import sys |
|
|
import os |
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) |
|
|
|
|
|
from src.model import RippleGPT |
|
|
from src.config import RippleConfig |
|
|
import torch |
|
|
|
|
|
def test_model(): |
|
|
print("🔧 Testando modelo otimizado...") |
|
|
|
|
|
config = RippleConfig(vocab_size=65, block_size=256, n_layer=2, n_head=2, n_embd=64) |
|
|
model = RippleGPT(config) |
|
|
|
|
|
|
|
|
x = torch.randint(0, 65, (1, 100)) |
|
|
with torch.no_grad(): |
|
|
logits, _ = model(x) |
|
|
print(f"✅ Forward pass OK - Shape: {logits.shape}") |
|
|
|
|
|
|
|
|
x = torch.randint(0, 65, (1, 256)) |
|
|
with torch.no_grad(): |
|
|
logits, _ = model(x) |
|
|
print(f"✅ Forward pass (256 tokens) OK - Shape: {logits.shape}") |
|
|
|
|
|
|
|
|
x = torch.randint(0, 65, (1, 512)) |
|
|
with torch.no_grad(): |
|
|
logits, _ = model(x) |
|
|
print(f"🔬 Forward pass (512 tokens - 2x!) OK - Shape: {logits.shape}") |
|
|
|
|
|
print() |
|
|
print("✅ Modelo otimizado funcionando corretamente!") |
|
|
print("✅ Extrapolação para 2x contexto: SUCESSO") |
|
|
return 0 |
|
|
|
|
|
if __name__ == "__main__": |
|
|
exit(test_model()) |
|
|
|