| |
| import torch |
| import numpy as np |
| import datetime |
| from tqdm import tqdm |
|
|
| from visualize import save_ratemaps |
| import os |
|
|
|
|
| class Trainer(object): |
| def __init__(self, options, model, trajectory_generator, restore=True): |
| self.options = options |
| self.model = model |
| self.trajectory_generator = trajectory_generator |
| lr = self.options.learning_rate |
| self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) |
|
|
| self.loss = [] |
| self.err = [] |
|
|
| |
| self.ckpt_dir = os.path.join(options.save_dir, options.run_ID) |
| ckpt_path = os.path.join(self.ckpt_dir, "most_recent_model.pth") |
| pc_path = os.path.join(self.ckpt_dir, "place_cell_centers.npy") |
| if restore and os.path.isdir(self.ckpt_dir) and os.path.isfile(ckpt_path): |
| self.model.load_state_dict(torch.load(ckpt_path)) |
| print("Restored trained model from {}".format(ckpt_path)) |
| |
| if os.path.isfile(pc_path): |
| us = np.load(pc_path) |
| self.model.place_cells.us = torch.tensor(us).to(options.device) |
| print("Restored place cell centers from {}".format(pc_path)) |
| else: |
| print("Warning: place_cell_centers.npy not found! Model may not work correctly.") |
| else: |
| if not os.path.isdir(self.ckpt_dir): |
| os.makedirs(self.ckpt_dir, exist_ok=True) |
| print("Initializing new model from scratch.") |
| print("Saving to: {}".format(self.ckpt_dir)) |
|
|
| def train_step(self, inputs, pc_outputs, pos): |
| """ |
| Train on one batch of trajectories. |
| |
| Args: |
| inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2]. |
| pc_outputs: Ground truth place cell activations with shape |
| [batch_size, sequence_length, Np]. |
| pos: Ground truth 2d position with shape [batch_size, sequence_length, 2]. |
| |
| Returns: |
| loss: Avg. loss for this training batch. |
| err: Avg. decoded position error in cm. |
| """ |
| self.model.zero_grad() |
|
|
| loss, err = self.model.compute_loss(inputs, pc_outputs, pos) |
|
|
| loss.backward() |
| self.optimizer.step() |
|
|
| return loss.item(), err.item() |
|
|
| def train(self, n_epochs: int = 1, n_steps=10, save=True): |
| """ |
| Train model on simulated trajectories. |
| |
| Args: |
| n_steps: Number of training steps |
| save: If true, save a checkpoint after each epoch. |
| """ |
|
|
| |
| gen = self.trajectory_generator.get_generator() |
|
|
| for epoch_idx in range(n_epochs): |
| tbar = tqdm(range(n_steps), leave=False) |
| for step_idx in tbar: |
| inputs, pc_outputs, pos = next(gen) |
| loss, err = self.train_step(inputs, pc_outputs, pos) |
| self.loss.append(loss) |
| self.err.append(err) |
|
|
| |
| tbar.set_description('Error = ' + str(int(100*err)) + 'cm') |
|
|
| if save and ((epoch_idx + 1) % 10 == 0 or epoch_idx == 0): |
| |
| |
| |
| |
| |
| torch.save( |
| self.model.state_dict(), |
| os.path.join(self.ckpt_dir, "most_recent_model.pth"), |
| ) |
| |
| np.save( |
| os.path.join(self.ckpt_dir, "place_cell_centers.npy"), |
| self.model.place_cells.us.cpu().numpy(), |
| ) |
|
|
| |
| save_ratemaps( |
| self.model, |
| self.trajectory_generator, |
| self.options, |
| step=epoch_idx + 1, |
| ) |
|
|
| print( |
| "Epoch: {}/{}. Date: {}. Loss: {}. Err: {}cm".format( |
| epoch_idx + 1, |
| n_epochs, |
| str(datetime.datetime.now())[:-7], |
| np.round(loss, 2), |
| np.round(100 * err, 2), |
| ) |
| ) |
|
|