Spaces:
Build error
Build error
| def train_model(model, data, epochs=5): | |
| optimizer = torch.optim.Adam(model.parameters(), lr=0.001) | |
| criterion = nn.CrossEntropyLoss() | |
| model.train() | |
| for epoch in range(epochs): | |
| total_loss = 0 | |
| for text in data: | |
| tokens = tokenizer(text) | |
| indices = [vocab[token] for token in tokens][:50] # Ограничение длины | |
| if len(indices) < 2: | |
| continue | |
| src = torch.tensor(indices[:-1], dtype=torch.long).unsqueeze(0) | |
| tgt = torch.tensor(indices[1:], dtype=torch.long).unsqueeze(0) | |
| optimizer.zero_grad() | |
| output = model(src) | |
| loss = criterion(output.view(-1, VOCAB_SIZE), tgt.view(-1)) | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| print(f"Epoch {epoch+1}, Loss: {total_loss / len(data)}") | |
| torch.save(model.state_dict(), "model.pt") |