File size: 4,267 Bytes
b47957e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import pickle
import torch
from src.model import RippleGPT
from src.config import RippleConfig

# -----------------------------------------------------------------------------
out_dir = 'out'
num_samples = 1 # Quantas variações de cada prompt
max_new_tokens = 200 # Curto para testar várias coisas rápido
temperature = 0.8 # Criatividade equilibrada
top_k = 200 
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
# -----------------------------------------------------------------------------

def main():
    torch.manual_seed(1337)
    
    # 1. Carrega o Melhor Modelo (Checkpoint com menor Loss)
    ckpt_path = os.path.join(out_dir, 'ckpt_best.pt')
    if not os.path.exists(ckpt_path):
        ckpt_path = os.path.join(out_dir, 'ckpt.pt')
        print("⚠️ Aviso: 'ckpt_best.pt' não encontrado, usando o último 'ckpt.pt'")
    
    print(f"Loading checkpoint from {ckpt_path}...")
    checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
    
    # Configuração e Modelo
    gptconf = RippleConfig(**checkpoint['model_args'])
    model = RippleGPT(gptconf)
    
    # Limpeza de chaves do state_dict (caso venha de compile)
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
    
    model.eval()
    model.to(device)
    
    # 2. Carrega o Vocabulário (Meta)
    meta_path = os.path.join('data', 'meta.pkl')
    if os.path.exists(meta_path):
        print(f"Loading meta from {meta_path}...")
        with open(meta_path, 'rb') as f:
            meta = pickle.load(f)
        stoi, itos = meta['stoi'], meta['itos']
        # Safe encode: uses '?' (if available) or ignores unknown chars.
        # Fallback to 0 if '?' not in vocab (unlikely for english text but possible)
        unknown_token = stoi.get('?', 0) 
        encode = lambda s: [stoi.get(c, unknown_token) for c in s]
        decode = lambda l: ''.join([itos[i] for i in l])
    else:
        print("❌ ERRO: meta.pkl não encontrado! Rode prepare_data.py primeiro.")
        return

    # 3. OS TESTES DE INTELIGÊNCIA (Gatilhos Fortes)
    test_cases = [
        # A. Teste de Código (Python)
        # TRUQUE: Adicionar um comentário ou docstring ajuda a firmar o contexto
        {
            "domain": "🐍 PYTHON CODING",
            "prompt": "# Function to calculate factorial\ndef factorial(n):\n    if n == 0:\n        return 1\n    else:\n        return"
        },
        
        # B. Teste de Matemática (Algebra)
        # TRUQUE: Dar um exemplo antes (Few-shot prompting)
        {
            "domain": "🧮 MATH LOGIC",
            "prompt": "Q: Solve 2x = 10\nA: x = 5\n\nQ: Solve -5k + 5 = -10\nA:"
        },
        
        # C. Teste de TinyStories
        {
            "domain": "📖 TINY STORY",
            "prompt": "Once upon a time, there was a little frog. The frog liked to jump. One day,"
        },
        
        # D. Teste de Literatura
        {
            "domain": "⚔️ LITERATURE",
            "prompt": "The General looked at the map and shouted,"
        }
    ]

    # 4. Loop de Geração
    print("\n" + "="*40)
    print(f"🤖 RIPPLE GPT: MULTI-DOMAIN TEST")
    print("="*40)

    with torch.no_grad():
        for case in test_cases:
            prompt = case["prompt"]
            domain = case["domain"]
            
            print(f"\n[{domain}] Prompt: {prompt.strip()}")
            print("-" * 20)
            
            # Encode
            start_ids = encode(prompt)
            x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])

            # Generate
            y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
            
            # Decode e Print
            generated_text = decode(y[0].tolist())
            
            # Destaca o que foi gerado vs o que era prompt
            new_content = generated_text[len(prompt):]
            print(f"{prompt}\033[94m{new_content}\033[0m") # Azul para o gerado (no terminal)
            print("-" * 40)

if __name__ == '__main__':
    main()