checkpoint_path = "model_checkpoint.pt"
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
embedding_layer.load_state_dict(checkpoint['embedding_state'])
transformer_encoderLayer.load_state_dict(checkpoint['transformer_state'])
output_layer.load_state_dict(checkpoint['output_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
start_epoch = checkpoint['epoch'] + 1
print(f" Модель загружена, продолжаем с эпохи {start_epoch}")
else:
start_epoch = 0
print(" Чекпоинт не найден, начинаем обучение с нуля")
epochNum = 20
for epoch in range(epochNum):
optimizer.zero_grad()
epochmy = start_epoch + epoch
embedded = embedding_layer(input_ids)
src = embedded.transpose(0, 1)
outputTransformer = transformer_encoderLayer(src, src_key_padding_mask=(attention_mask == 0))
outputTransformer = outputTransformer.transpose(0, 1)
logits = output_layer(outputTransformer)
loss = criterion(logits.view(-1, vocab_size), target_ids.view(-1))
loss.backward()
optimizer.step()
with torch.no_grad():
embedded = embedding_layer(input_ids)
src = embedded.transpose(0, 1)
outputTransformer = transformer_encoderLayer(src, src_key_padding_mask=(attention_mask == 0))
outputTransformer = outputTransformer.transpose(0, 1)
logits = output_layer(outputTransformer)
predicted_token_ids = torch.argmax(logits, dim=-1)
predicted_text = tokenizer.batch_decode(predicted_token_ids, skip_special_tokens=True)
print("Predicted text:", predicted_text[0])
print(f"Epoch [{epoch + 1}/{epochNum}] — Loss: {loss.item():.4f}")
torch.save({
'embedding_state': embedding_layer.state_dict(),
'transformer_state': transformer_encoderLayer.state_dict(),
'output_state': output_layer.state_dict(),
'optimizer_state': optimizer.state_dict(),
'epoch': epochmy
}, "model_checkpoint.pt")