ULFBERTO commited on
Commit
fd58b5f
·
verified ·
1 Parent(s): 677f594

Upload chat.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. chat.py +95 -0
chat.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from model import TransformerKiller
3
+ from tokenizer import CharacterTokenizer
4
+
5
+ # Configuración (debe coincidir con train.py)
6
+ DIM = 128
7
+ STATE_DIM = 16
8
+ N_LAYERS = 4
9
+ DEVICE = "cpu" # Forzar CPU para no interferir con el entrenamiento
10
+
11
+ def load_model():
12
+ checkpoint_path = "ssm_checkpoint.pth"
13
+
14
+ print("Cargando modelo en CPU...")
15
+ cp = torch.load(checkpoint_path, map_location=DEVICE)
16
+
17
+ # Reconstruir tokenizer
18
+ tokenizer = CharacterTokenizer()
19
+ tokenizer.chars = cp['tokenizer_chars']
20
+ tokenizer.vocab_size = len(tokenizer.chars)
21
+ tokenizer.stoi = {ch: i for i, ch in enumerate(tokenizer.chars)}
22
+ tokenizer.itos = {i: ch for i, ch in enumerate(tokenizer.chars)}
23
+
24
+ # Cargar modelo
25
+ model = TransformerKiller(
26
+ vocab_size=tokenizer.vocab_size,
27
+ dim=DIM,
28
+ n_layers=N_LAYERS,
29
+ state_dim=STATE_DIM
30
+ ).to(DEVICE)
31
+
32
+ model.load_state_dict(cp['model_state_dict'])
33
+ model.eval()
34
+
35
+ n_params = sum(p.numel() for p in model.parameters())
36
+ print(f"Modelo: Transformer Killer (SSM)")
37
+ print(f"Parámetros: {n_params:,}")
38
+ print(f"Checkpoint: iter {cp.get('iter', '?')}")
39
+ print(f"Vocabulario: {tokenizer.vocab_size} tokens")
40
+
41
+ return model, tokenizer
42
+
43
+ def generate(model, tokenizer, prompt, max_tokens=150, temperature=0.8):
44
+ idx = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long).to(DEVICE)
45
+
46
+ with torch.no_grad():
47
+ for _ in range(max_tokens):
48
+ logits = model(idx)
49
+ logits = logits[:, -1, :] / temperature
50
+ probs = torch.nn.functional.softmax(logits, dim=-1)
51
+ idx_next = torch.multinomial(probs, num_samples=1)
52
+ idx = torch.cat((idx, idx_next), dim=1)
53
+
54
+ # Parar si genera token de fin
55
+ if tokenizer.itos.get(idx_next.item(), "") == "<|end|>":
56
+ break
57
+
58
+ return tokenizer.decode(idx[0].tolist())
59
+
60
+ def main():
61
+ model, tokenizer = load_model()
62
+
63
+ print("\n" + "="*50)
64
+ print(" Transformer Killer - Chat (CPU)")
65
+ print(" Escribe 'salir' para terminar")
66
+ print(" Escribe 'reload' para recargar el modelo")
67
+ print("="*50 + "\n")
68
+
69
+ while True:
70
+ try:
71
+ prompt = input("Tú: ").strip()
72
+
73
+ if prompt.lower() == "salir":
74
+ print("¡Hasta luego!")
75
+ break
76
+
77
+ if prompt.lower() == "reload":
78
+ model, tokenizer = load_model()
79
+ print("Modelo recargado.\n")
80
+ continue
81
+
82
+ if not prompt:
83
+ continue
84
+
85
+ response = generate(model, tokenizer, prompt)
86
+ print(f"SSM: {response}\n")
87
+
88
+ except KeyboardInterrupt:
89
+ print("\n¡Hasta luego!")
90
+ break
91
+ except Exception as e:
92
+ print(f"Error: {e}\n")
93
+
94
+ if __name__ == "__main__":
95
+ main()