| ```python | |
| if os.path.exists(checkpoint_path): | |
| checkpoint = torch.load(checkpoint_path) | |
| embedding_layer.load_state_dict(checkpoint['embedding_state']) | |
| pos_encoding.load_state_dict(checkpoint['pos_encoding_state']) | |
| transformer_encoderLayer.load_state_dict(checkpoint['transformer_state']) | |
| output_layer.load_state_dict(checkpoint['output_state']) | |
| optimizer.load_state_dict(checkpoint['optimizer_state']) | |
| start_epoch = checkpoint['epoch'] + 1 | |
| print(f" Модель загружена, продолжаем с эпохи {start_epoch}") | |
| else: | |
| start_epoch = 0 | |
| print(" Чекпоинт не найден, начинаем обучение с нуля") | |
| epochNum = 10 | |
| for epoch in range(epochNum): | |
| optimizer.zero_grad() | |
| epochmy = start_epoch + epoch | |
| embedded = embedding_layer(input_ids) | |
| embedded = pos_encoding(embedded) | |
| src = embedded.transpose(0, 1) | |
| outputTransformer = transformer_encoderLayer(src, src_key_padding_mask=(attention_mask == 0)) | |
| outputTransformer = outputTransformer.transpose(0, 1) # обратно [batch, seq_len, embedding_dim] | |
| logits = output_layer(outputTransformer) | |
| loss = criterion(logits.view(-1, vocab_size), target_ids.view(-1)) | |
| before = pos_encoding.pos_embedding.weight.clone() | |
| loss.backward() | |
| optimizer.step() # обновляем веса | |
| after = pos_encoding.pos_embedding.weight | |
| print(f"Изменение весов pos_encoding: {(after - before).abs().sum():.6f}") | |
| print("Loss:", loss.item()) | |
| # После обучения (или внутри цикла, чтобы смотреть динамику) | |
| with torch.no_grad(): | |
| embedded = embedding_layer(input_ids) | |
| embedded = pos_encoding(embedded) | |
| src = embedded.transpose(0, 1) | |
| outputTransformer = transformer_encoderLayer(src, src_key_padding_mask=(attention_mask == 0)) | |
| outputTransformer = outputTransformer.transpose(0, 1) | |
| logits = output_layer(outputTransformer) # [batch, seq_len, vocab_size] | |
| # Берём самый вероятный токен для каждого положения | |
| predicted_token_ids = torch.argmax(logits, dim=-1) # [batch, seq_len] | |
| # Переводим индексы обратно в текст | |
| predicted_text = tokenizer.batch_decode(predicted_token_ids, skip_special_tokens=False) | |
| print("Predicted text:", predicted_text[0]) | |
| print("Loss before backward:", loss.item()) | |