MarkProMaster229 commited on
Commit
da73587
·
verified ·
1 Parent(s): 2b27003

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +52 -0
README.md ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ checkpoint_path = "model_checkpoint.pt"
3
+
4
+ if os.path.exists(checkpoint_path):
5
+ checkpoint = torch.load(checkpoint_path)
6
+ embedding_layer.load_state_dict(checkpoint['embedding_state'])
7
+ transformer_encoderLayer.load_state_dict(checkpoint['transformer_state'])
8
+ output_layer.load_state_dict(checkpoint['output_state'])
9
+ optimizer.load_state_dict(checkpoint['optimizer_state'])
10
+ start_epoch = checkpoint['epoch'] + 1
11
+ print(f" Модель загружена, продолжаем с эпохи {start_epoch}")
12
+ else:
13
+ start_epoch = 0
14
+ print(" Чекпоинт не найден, начинаем обучение с нуля")
15
+
16
+
17
+ epochNum = 20
18
+ for epoch in range(epochNum):
19
+ optimizer.zero_grad()
20
+ epochmy = start_epoch + epoch
21
+ embedded = embedding_layer(input_ids)
22
+ src = embedded.transpose(0, 1)
23
+
24
+ outputTransformer = transformer_encoderLayer(src, src_key_padding_mask=(attention_mask == 0))
25
+ outputTransformer = outputTransformer.transpose(0, 1) # обратно [batch, seq_len, embedding_dim]
26
+
27
+ logits = output_layer(outputTransformer)
28
+ loss = criterion(logits.view(-1, vocab_size), target_ids.view(-1))
29
+ loss.backward()
30
+ optimizer.step()
31
+
32
+ with torch.no_grad():
33
+ embedded = embedding_layer(input_ids)
34
+ src = embedded.transpose(0, 1)
35
+ outputTransformer = transformer_encoderLayer(src, src_key_padding_mask=(attention_mask == 0))
36
+ outputTransformer = outputTransformer.transpose(0, 1)
37
+ logits = output_layer(outputTransformer) # [batch, seq_len, vocab_size]
38
+
39
+
40
+ predicted_token_ids = torch.argmax(logits, dim=-1) # [batch, seq_len]
41
+
42
+ predicted_text = tokenizer.batch_decode(predicted_token_ids, skip_special_tokens=True)
43
+ print("Predicted text:", predicted_text[0])
44
+
45
+ print(f"Epoch [{epoch + 1}/{epochNum}] — Loss: {loss.item():.4f}")
46
+ torch.save({
47
+ 'embedding_state': embedding_layer.state_dict(),
48
+ 'transformer_state': transformer_encoderLayer.state_dict(),
49
+ 'output_state': output_layer.state_dict(),
50
+ 'optimizer_state': optimizer.state_dict(),
51
+ 'epoch': epochmy
52
+ }, "model_checkpoint.pt")