|
|
import torch
|
|
|
from model import TransformerKiller
|
|
|
from tokenizer import CharacterTokenizer
|
|
|
|
|
|
|
|
|
DIM = 128
|
|
|
STATE_DIM = 16
|
|
|
N_LAYERS = 4
|
|
|
DEVICE = "cpu"
|
|
|
|
|
|
def load_model():
|
|
|
checkpoint_path = "ssm_checkpoint.pth"
|
|
|
|
|
|
print("Cargando modelo en CPU...")
|
|
|
cp = torch.load(checkpoint_path, map_location=DEVICE)
|
|
|
|
|
|
|
|
|
tokenizer = CharacterTokenizer()
|
|
|
tokenizer.chars = cp['tokenizer_chars']
|
|
|
tokenizer.vocab_size = len(tokenizer.chars)
|
|
|
tokenizer.stoi = {ch: i for i, ch in enumerate(tokenizer.chars)}
|
|
|
tokenizer.itos = {i: ch for i, ch in enumerate(tokenizer.chars)}
|
|
|
|
|
|
|
|
|
model = TransformerKiller(
|
|
|
vocab_size=tokenizer.vocab_size,
|
|
|
dim=DIM,
|
|
|
n_layers=N_LAYERS,
|
|
|
state_dim=STATE_DIM
|
|
|
).to(DEVICE)
|
|
|
|
|
|
model.load_state_dict(cp['model_state_dict'])
|
|
|
model.eval()
|
|
|
|
|
|
n_params = sum(p.numel() for p in model.parameters())
|
|
|
print(f"Modelo: Transformer Killer (SSM)")
|
|
|
print(f"Parámetros: {n_params:,}")
|
|
|
print(f"Checkpoint: iter {cp.get('iter', '?')}")
|
|
|
print(f"Vocabulario: {tokenizer.vocab_size} tokens")
|
|
|
|
|
|
return model, tokenizer
|
|
|
|
|
|
def generate(model, tokenizer, prompt, max_tokens=150, temperature=0.8):
|
|
|
idx = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long).to(DEVICE)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for _ in range(max_tokens):
|
|
|
logits = model(idx)
|
|
|
logits = logits[:, -1, :] / temperature
|
|
|
probs = torch.nn.functional.softmax(logits, dim=-1)
|
|
|
idx_next = torch.multinomial(probs, num_samples=1)
|
|
|
idx = torch.cat((idx, idx_next), dim=1)
|
|
|
|
|
|
|
|
|
if tokenizer.itos.get(idx_next.item(), "") == "<|end|>":
|
|
|
break
|
|
|
|
|
|
return tokenizer.decode(idx[0].tolist())
|
|
|
|
|
|
def main():
|
|
|
model, tokenizer = load_model()
|
|
|
|
|
|
print("\n" + "="*50)
|
|
|
print(" Transformer Killer - Chat (CPU)")
|
|
|
print(" Escribe 'salir' para terminar")
|
|
|
print(" Escribe 'reload' para recargar el modelo")
|
|
|
print("="*50 + "\n")
|
|
|
|
|
|
while True:
|
|
|
try:
|
|
|
prompt = input("Tú: ").strip()
|
|
|
|
|
|
if prompt.lower() == "salir":
|
|
|
print("¡Hasta luego!")
|
|
|
break
|
|
|
|
|
|
if prompt.lower() == "reload":
|
|
|
model, tokenizer = load_model()
|
|
|
print("Modelo recargado.\n")
|
|
|
continue
|
|
|
|
|
|
if not prompt:
|
|
|
continue
|
|
|
|
|
|
response = generate(model, tokenizer, prompt)
|
|
|
print(f"SSM: {response}\n")
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
print("\n¡Hasta luego!")
|
|
|
break
|
|
|
except Exception as e:
|
|
|
print(f"Error: {e}\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|