| ```python | |
| checkpoint_path = "model_checkpoint.pt" | |
| if os.path.exists(checkpoint_path): | |
| checkpoint = torch.load(checkpoint_path) | |
| embedding_layer.load_state_dict(checkpoint['embedding_state']) | |
| transformer_encoderLayer.load_state_dict(checkpoint['transformer_state']) | |
| output_layer.load_state_dict(checkpoint['output_state']) | |
| optimizer.load_state_dict(checkpoint['optimizer_state']) | |
| start_epoch = checkpoint['epoch'] + 1 | |
| print(f" Модель загружена, продолжаем с эпохи {start_epoch}") | |
| else: | |
| start_epoch = 0 | |
| print(" Чекпоинт не найден, начинаем обучение с нуля") | |
| epochNum = 20 | |
| for epoch in range(epochNum): | |
| optimizer.zero_grad() | |
| epochmy = start_epoch + epoch | |
| embedded = embedding_layer(input_ids) | |
| src = embedded.transpose(0, 1) | |
| outputTransformer = transformer_encoderLayer(src, src_key_padding_mask=(attention_mask == 0)) | |
| outputTransformer = outputTransformer.transpose(0, 1) # обратно [batch, seq_len, embedding_dim] | |
| logits = output_layer(outputTransformer) | |
| loss = criterion(logits.view(-1, vocab_size), target_ids.view(-1)) | |
| loss.backward() | |
| optimizer.step() | |
| with torch.no_grad(): | |
| embedded = embedding_layer(input_ids) | |
| src = embedded.transpose(0, 1) | |
| outputTransformer = transformer_encoderLayer(src, src_key_padding_mask=(attention_mask == 0)) | |
| outputTransformer = outputTransformer.transpose(0, 1) | |
| logits = output_layer(outputTransformer) # [batch, seq_len, vocab_size] | |
| predicted_token_ids = torch.argmax(logits, dim=-1) # [batch, seq_len] | |
| predicted_text = tokenizer.batch_decode(predicted_token_ids, skip_special_tokens=True) | |
| print("Predicted text:", predicted_text[0]) | |
| print(f"Epoch [{epoch + 1}/{epochNum}] — Loss: {loss.item():.4f}") | |
| torch.save({ | |
| 'embedding_state': embedding_layer.state_dict(), | |
| 'transformer_state': transformer_encoderLayer.state_dict(), | |
| 'output_state': output_layer.state_dict(), | |
| 'optimizer_state': optimizer.state_dict(), | |
| 'epoch': epochmy | |
| }, "model_checkpoint.pt") | |