```python if os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path) embedding_layer.load_state_dict(checkpoint['embedding_state']) pos_encoding.load_state_dict(checkpoint['pos_encoding_state']) transformer_encoderLayer.load_state_dict(checkpoint['transformer_state']) output_layer.load_state_dict(checkpoint['output_state']) optimizer.load_state_dict(checkpoint['optimizer_state']) start_epoch = checkpoint['epoch'] + 1 print(f" Модель загружена, продолжаем с эпохи {start_epoch}") else: start_epoch = 0 print(" Чекпоинт не найден, начинаем обучение с нуля") epochNum = 10 for epoch in range(epochNum): optimizer.zero_grad() epochmy = start_epoch + epoch embedded = embedding_layer(input_ids) embedded = pos_encoding(embedded) src = embedded.transpose(0, 1) outputTransformer = transformer_encoderLayer(src, src_key_padding_mask=(attention_mask == 0)) outputTransformer = outputTransformer.transpose(0, 1) # обратно [batch, seq_len, embedding_dim] logits = output_layer(outputTransformer) loss = criterion(logits.view(-1, vocab_size), target_ids.view(-1)) before = pos_encoding.pos_embedding.weight.clone() loss.backward() optimizer.step() # обновляем веса after = pos_encoding.pos_embedding.weight print(f"Изменение весов pos_encoding: {(after - before).abs().sum():.6f}") print("Loss:", loss.item()) # После обучения (или внутри цикла, чтобы смотреть динамику) with torch.no_grad(): embedded = embedding_layer(input_ids) embedded = pos_encoding(embedded) src = embedded.transpose(0, 1) outputTransformer = transformer_encoderLayer(src, src_key_padding_mask=(attention_mask == 0)) outputTransformer = outputTransformer.transpose(0, 1) logits = output_layer(outputTransformer) # [batch, seq_len, vocab_size] # Берём самый вероятный токен для каждого положения predicted_token_ids = torch.argmax(logits, dim=-1) # [batch, seq_len] # Переводим индексы обратно в текст predicted_text = tokenizer.batch_decode(predicted_token_ids, skip_special_tokens=False) print("Predicted text:", predicted_text[0]) print("Loss before backward:", loss.item())