File size: 3,041 Bytes
fd58b5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
from model import TransformerKiller
from tokenizer import CharacterTokenizer

# Configuración (debe coincidir con train.py)
DIM = 128
STATE_DIM = 16
N_LAYERS = 4
DEVICE = "cpu"  # Forzar CPU para no interferir con el entrenamiento

def load_model():
    checkpoint_path = "ssm_checkpoint.pth"
    
    print("Cargando modelo en CPU...")
    cp = torch.load(checkpoint_path, map_location=DEVICE)
    
    # Reconstruir tokenizer
    tokenizer = CharacterTokenizer()
    tokenizer.chars = cp['tokenizer_chars']
    tokenizer.vocab_size = len(tokenizer.chars)
    tokenizer.stoi = {ch: i for i, ch in enumerate(tokenizer.chars)}
    tokenizer.itos = {i: ch for i, ch in enumerate(tokenizer.chars)}
    
    # Cargar modelo
    model = TransformerKiller(
        vocab_size=tokenizer.vocab_size,
        dim=DIM,
        n_layers=N_LAYERS,
        state_dim=STATE_DIM
    ).to(DEVICE)
    
    model.load_state_dict(cp['model_state_dict'])
    model.eval()
    
    n_params = sum(p.numel() for p in model.parameters())
    print(f"Modelo: Transformer Killer (SSM)")
    print(f"Parámetros: {n_params:,}")
    print(f"Checkpoint: iter {cp.get('iter', '?')}")
    print(f"Vocabulario: {tokenizer.vocab_size} tokens")
    
    return model, tokenizer

def generate(model, tokenizer, prompt, max_tokens=150, temperature=0.8):
    idx = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long).to(DEVICE)
    
    with torch.no_grad():
        for _ in range(max_tokens):
            logits = model(idx)
            logits = logits[:, -1, :] / temperature
            probs = torch.nn.functional.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
            
            # Parar si genera token de fin
            if tokenizer.itos.get(idx_next.item(), "") == "<|end|>":
                break
    
    return tokenizer.decode(idx[0].tolist())

def main():
    model, tokenizer = load_model()
    
    print("\n" + "="*50)
    print("  Transformer Killer - Chat (CPU)")
    print("  Escribe 'salir' para terminar")
    print("  Escribe 'reload' para recargar el modelo")
    print("="*50 + "\n")
    
    while True:
        try:
            prompt = input("Tú: ").strip()
            
            if prompt.lower() == "salir":
                print("¡Hasta luego!")
                break
            
            if prompt.lower() == "reload":
                model, tokenizer = load_model()
                print("Modelo recargado.\n")
                continue
            
            if not prompt:
                continue
            
            response = generate(model, tokenizer, prompt)
            print(f"SSM: {response}\n")
            
        except KeyboardInterrupt:
            print("\n¡Hasta luego!")
            break
        except Exception as e:
            print(f"Error: {e}\n")

if __name__ == "__main__":
    main()