ULFBERTO's picture
Upload chat.py with huggingface_hub
fd58b5f verified
import torch
from model import TransformerKiller
from tokenizer import CharacterTokenizer
# Configuración (debe coincidir con train.py)
DIM = 128
STATE_DIM = 16
N_LAYERS = 4
DEVICE = "cpu" # Forzar CPU para no interferir con el entrenamiento
def load_model():
checkpoint_path = "ssm_checkpoint.pth"
print("Cargando modelo en CPU...")
cp = torch.load(checkpoint_path, map_location=DEVICE)
# Reconstruir tokenizer
tokenizer = CharacterTokenizer()
tokenizer.chars = cp['tokenizer_chars']
tokenizer.vocab_size = len(tokenizer.chars)
tokenizer.stoi = {ch: i for i, ch in enumerate(tokenizer.chars)}
tokenizer.itos = {i: ch for i, ch in enumerate(tokenizer.chars)}
# Cargar modelo
model = TransformerKiller(
vocab_size=tokenizer.vocab_size,
dim=DIM,
n_layers=N_LAYERS,
state_dim=STATE_DIM
).to(DEVICE)
model.load_state_dict(cp['model_state_dict'])
model.eval()
n_params = sum(p.numel() for p in model.parameters())
print(f"Modelo: Transformer Killer (SSM)")
print(f"Parámetros: {n_params:,}")
print(f"Checkpoint: iter {cp.get('iter', '?')}")
print(f"Vocabulario: {tokenizer.vocab_size} tokens")
return model, tokenizer
def generate(model, tokenizer, prompt, max_tokens=150, temperature=0.8):
idx = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long).to(DEVICE)
with torch.no_grad():
for _ in range(max_tokens):
logits = model(idx)
logits = logits[:, -1, :] / temperature
probs = torch.nn.functional.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
# Parar si genera token de fin
if tokenizer.itos.get(idx_next.item(), "") == "<|end|>":
break
return tokenizer.decode(idx[0].tolist())
def main():
model, tokenizer = load_model()
print("\n" + "="*50)
print(" Transformer Killer - Chat (CPU)")
print(" Escribe 'salir' para terminar")
print(" Escribe 'reload' para recargar el modelo")
print("="*50 + "\n")
while True:
try:
prompt = input("Tú: ").strip()
if prompt.lower() == "salir":
print("¡Hasta luego!")
break
if prompt.lower() == "reload":
model, tokenizer = load_model()
print("Modelo recargado.\n")
continue
if not prompt:
continue
response = generate(model, tokenizer, prompt)
print(f"SSM: {response}\n")
except KeyboardInterrupt:
print("\n¡Hasta luego!")
break
except Exception as e:
print(f"Error: {e}\n")
if __name__ == "__main__":
main()