Nova / train.py
HeavensHackDev's picture
Create train.py
6d5a01a verified
raw
history blame contribute delete
930 Bytes
def train_model(model, data, epochs=5):
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
model.train()
for epoch in range(epochs):
total_loss = 0
for text in data:
tokens = tokenizer(text)
indices = [vocab[token] for token in tokens][:50] # Ограничение длины
if len(indices) < 2:
continue
src = torch.tensor(indices[:-1], dtype=torch.long).unsqueeze(0)
tgt = torch.tensor(indices[1:], dtype=torch.long).unsqueeze(0)
optimizer.zero_grad()
output = model(src)
loss = criterion(output.view(-1, VOCAB_SIZE), tgt.view(-1))
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss / len(data)}")
torch.save(model.state_dict(), "model.pt")