| import torch
|
| import pickle
|
| import torch.optim as optim
|
| from model import BharatAI
|
| from tokenizer import Tokenizer
|
|
|
|
|
| device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
|
| tokenizer = Tokenizer('botchan.model')
|
|
|
|
|
| def get_batch(split):
|
|
|
| batch_size, block_size = 16, 256
|
| data = torch.randint(0, 1000, (batch_size, block_size), dtype=torch.long).to(device)
|
| target = torch.randint(0, 1000, (batch_size, block_size), dtype=torch.long).to(device)
|
| return data, target
|
|
|
|
|
| def train_model(model, optimizer, epochs=250):
|
| model.train()
|
| for epoch in range(epochs):
|
| xb, yb = get_batch('train')
|
| optimizer.zero_grad()
|
| logits, loss = model(xb, yb)
|
| loss.backward()
|
| optimizer.step()
|
|
|
| if epoch % 50 == 0:
|
| print(f"Epoch {epoch}: Loss {loss.item()}")
|
|
|
|
|
| with open('model-latest.pkl', 'wb') as f:
|
| pickle.dump(model, f)
|
| print("Model saved!")
|
|
|
|
|
| if __name__ == '__main__':
|
| vocab_size = 1000
|
| model = BharatAI(vocab_size).to(device)
|
| optimizer = optim.AdamW(model.parameters(), lr=3e-4)
|
| train_model(model, optimizer) |