Testing_model / README.md
MarkProMaster229's picture
Update README.md
333e7f3 verified
|
raw
history blame
2.5 kB
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    embedding_layer.load_state_dict(checkpoint['embedding_state'])
    pos_encoding.load_state_dict(checkpoint['pos_encoding_state']) 
    transformer_encoderLayer.load_state_dict(checkpoint['transformer_state'])
    output_layer.load_state_dict(checkpoint['output_state'])
    optimizer.load_state_dict(checkpoint['optimizer_state'])
    start_epoch = checkpoint['epoch'] + 1
    print(f" Модель загружена, продолжаем с эпохи {start_epoch}")
else:
    start_epoch = 0
    print(" Чекпоинт не найден, начинаем обучение с нуля")

epochNum = 10
for epoch in range(epochNum):
    optimizer.zero_grad()
    epochmy = start_epoch + epoch
    embedded = embedding_layer(input_ids)
    embedded = pos_encoding(embedded)
    src = embedded.transpose(0, 1)

    outputTransformer = transformer_encoderLayer(src, src_key_padding_mask=(attention_mask == 0))
    outputTransformer = outputTransformer.transpose(0, 1)  # обратно [batch, seq_len, embedding_dim]

    logits = output_layer(outputTransformer)
    loss = criterion(logits.view(-1, vocab_size), target_ids.view(-1))
    before = pos_encoding.pos_embedding.weight.clone()
    loss.backward()
    optimizer.step()  # обновляем веса
    after = pos_encoding.pos_embedding.weight
    print(f"Изменение весов pos_encoding: {(after - before).abs().sum():.6f}")
    print("Loss:", loss.item())

    # После обучения (или внутри цикла, чтобы смотреть динамику)
    with torch.no_grad():
        embedded = embedding_layer(input_ids)
        embedded = pos_encoding(embedded)
        src = embedded.transpose(0, 1)
        outputTransformer = transformer_encoderLayer(src, src_key_padding_mask=(attention_mask == 0))
        outputTransformer = outputTransformer.transpose(0, 1)
        logits = output_layer(outputTransformer)  # [batch, seq_len, vocab_size]

        # Берём самый вероятный токен для каждого положения
        predicted_token_ids = torch.argmax(logits, dim=-1)  # [batch, seq_len]

        # Переводим индексы обратно в текст
        predicted_text = tokenizer.batch_decode(predicted_token_ids, skip_special_tokens=False)
        print("Predicted text:", predicted_text[0])

        print("Loss before backward:", loss.item())