if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
embedding_layer.load_state_dict(checkpoint['embedding_state'])
pos_encoding.load_state_dict(checkpoint['pos_encoding_state'])
transformer_encoderLayer.load_state_dict(checkpoint['transformer_state'])
output_layer.load_state_dict(checkpoint['output_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
start_epoch = checkpoint['epoch'] + 1
print(f" Модель загружена, продолжаем с эпохи {start_epoch}")
else:
start_epoch = 0
print(" Чекпоинт не найден, начинаем обучение с нуля")
epochNum = 10
for epoch in range(epochNum):
optimizer.zero_grad()
epochmy = start_epoch + epoch
embedded = embedding_layer(input_ids)
embedded = pos_encoding(embedded)
src = embedded.transpose(0, 1)
outputTransformer = transformer_encoderLayer(src, src_key_padding_mask=(attention_mask == 0))
outputTransformer = outputTransformer.transpose(0, 1)
logits = output_layer(outputTransformer)
loss = criterion(logits.view(-1, vocab_size), target_ids.view(-1))
before = pos_encoding.pos_embedding.weight.clone()
loss.backward()
optimizer.step()
after = pos_encoding.pos_embedding.weight
print(f"Изменение весов pos_encoding: {(after - before).abs().sum():.6f}")
print("Loss:", loss.item())
with torch.no_grad():
embedded = embedding_layer(input_ids)
embedded = pos_encoding(embedded)
src = embedded.transpose(0, 1)
outputTransformer = transformer_encoderLayer(src, src_key_padding_mask=(attention_mask == 0))
outputTransformer = outputTransformer.transpose(0, 1)
logits = output_layer(outputTransformer)
predicted_token_ids = torch.argmax(logits, dim=-1)
predicted_text = tokenizer.batch_decode(predicted_token_ids, skip_special_tokens=False)
print("Predicted text:", predicted_text[0])
print("Loss before backward:", loss.item())