Spaces:
Build error
Build error
File size: 930 Bytes
6d5a01a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
def train_model(model, data, epochs=5):
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
model.train()
for epoch in range(epochs):
total_loss = 0
for text in data:
tokens = tokenizer(text)
indices = [vocab[token] for token in tokens][:50] # Ограничение длины
if len(indices) < 2:
continue
src = torch.tensor(indices[:-1], dtype=torch.long).unsqueeze(0)
tgt = torch.tensor(indices[1:], dtype=torch.long).unsqueeze(0)
optimizer.zero_grad()
output = model(src)
loss = criterion(output.view(-1, VOCAB_SIZE), tgt.view(-1))
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss / len(data)}")
torch.save(model.state_dict(), "model.pt") |