Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import optim | |
| from torch.utils.tensorboard import SummaryWriter | |
| import tqdm | |
| import numpy as np | |
| from model import Model | |
| from buffer import Buffer | |
| class Trainer: | |
| def __init__(self, model: Model, buffer: Buffer, base_lr:float = 0.001, | |
| weight_decay=1e-4, device:str='cpu'): | |
| self.main_model = model | |
| self.main_buffer = buffer | |
| self.global_step = 0 | |
| self.device = device | |
| # optimizer | |
| self.optimizer = optim.SGD( | |
| self.main_model.parameters(), | |
| lr = base_lr, | |
| weight_decay = weight_decay, | |
| momentum = 0.9 | |
| ) | |
| # self.scheduler = optim.lr_scheduler.CyclicLR( | |
| # self.optimizer, | |
| # base_lr = base_lr, | |
| # max_lr = 0.1 | |
| # ) | |
| # Tensorboard summary writer | |
| self.writer = SummaryWriter() | |
| def transfer_buffer(self, buffer) -> None: | |
| for state, value, policy in zip(buffer.state, buffer.value, buffer.policy): | |
| self.main_buffer.store_experience( | |
| state = state, | |
| value = value, | |
| policy = policy | |
| ) | |
| def reset_buffer(self) -> None: | |
| self.main_buffer.reset() | |
| # learn from the buffer | |
| def learn(self, state: np.ndarray, value: np.ndarray, policy: np.ndarray) -> float: | |
| state = torch.tensor(state, dtype=torch.float32, device=self.device) | |
| value = torch.tensor(value, dtype=torch.float32, device=self.device).unsqueeze(-1) | |
| policy = torch.tensor(policy, dtype=torch.float32, device=self.device) | |
| pred_val, pred_policy = self.main_model(state) | |
| self.optimizer.zero_grad() | |
| loss = self.main_model.get_loss(pred_val, pred_policy, value, policy) | |
| loss.backward() | |
| self.optimizer.step() | |
| return loss.detach().cpu().numpy() | |
| # Training loop for the model | |
| def train_model(self, epochs: int, batch_size: int): | |
| train_steps = np.ceil(len(self.main_buffer) / batch_size).astype(np.int32) | |
| # perform the training | |
| for epoch in range(epochs): | |
| for state, value, policy in tqdm(self.main_buffer.sample(batch_size), total=train_steps, desc=f'Epoch:{epoch+1}'): | |
| loss = self.learn(state, value, policy) | |
| self.writer.add_scalar("loss", loss, self.global_step) | |
| self.global_step += 1 | |
| self.writer.flush() | |
| # close the writer | |
| def close_writer(self): | |
| self.writer.close() | |
| # Save the model | |
| def save_model(self, step: int): | |
| torch.save(self.main_model.state_dict(), f'TargetModel_{step}.pt') | |
| torch.save(self.optimizer.state_dict(), f'Optimizer_{step}.pt') |