File size: 930 Bytes
6d5a01a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def train_model(model, data, epochs=5):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for text in data:
            tokens = tokenizer(text)
            indices = [vocab[token] for token in tokens][:50]  # Ограничение длины
            if len(indices) < 2:
                continue
            src = torch.tensor(indices[:-1], dtype=torch.long).unsqueeze(0)
            tgt = torch.tensor(indices[1:], dtype=torch.long).unsqueeze(0)
            optimizer.zero_grad()
            output = model(src)
            loss = criterion(output.view(-1, VOCAB_SIZE), tgt.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {total_loss / len(data)}")
    torch.save(model.state_dict(), "model.pt")