Testing_model / README.md
MarkProMaster229's picture
Update README.md
7dacb1f verified
|
raw
history blame
2.14 kB
checkpoint_path = "model_checkpoint.pt"

if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    embedding_layer.load_state_dict(checkpoint['embedding_state'])
    transformer_encoderLayer.load_state_dict(checkpoint['transformer_state'])
    output_layer.load_state_dict(checkpoint['output_state'])
    optimizer.load_state_dict(checkpoint['optimizer_state'])
    start_epoch = checkpoint['epoch'] + 1
    print(f" Модель загружена, продолжаем с эпохи {start_epoch}")
else:
    start_epoch = 0
    print(" Чекпоинт не найден, начинаем обучение с нуля")


epochNum = 20
for epoch in range(epochNum):
    optimizer.zero_grad()
    epochmy = start_epoch + epoch
    embedded = embedding_layer(input_ids)
    src = embedded.transpose(0, 1)

    outputTransformer = transformer_encoderLayer(src, src_key_padding_mask=(attention_mask == 0))
    outputTransformer = outputTransformer.transpose(0, 1)  # обратно [batch, seq_len, embedding_dim]

    logits = output_layer(outputTransformer)
    loss = criterion(logits.view(-1, vocab_size), target_ids.view(-1))
    loss.backward()
    optimizer.step() 

    with torch.no_grad():
        embedded = embedding_layer(input_ids)
        src = embedded.transpose(0, 1)
        outputTransformer = transformer_encoderLayer(src, src_key_padding_mask=(attention_mask == 0))
        outputTransformer = outputTransformer.transpose(0, 1)
        logits = output_layer(outputTransformer)  # [batch, seq_len, vocab_size]


        predicted_token_ids = torch.argmax(logits, dim=-1)  # [batch, seq_len]

        predicted_text = tokenizer.batch_decode(predicted_token_ids, skip_special_tokens=True)
        print("Predicted text:", predicted_text[0])

print(f"Epoch [{epoch + 1}/{epochNum}] — Loss: {loss.item():.4f}")
torch.save({
    'embedding_state': embedding_layer.state_dict(),
    'transformer_state': transformer_encoderLayer.state_dict(),
    'output_state': output_layer.state_dict(),
    'optimizer_state': optimizer.state_dict(),
    'epoch': epochmy
}, "model_checkpoint.pt")