# -*- coding: utf-8 -*- import torch import numpy as np import datetime from tqdm import tqdm from visualize import save_ratemaps import os class Trainer(object): def __init__(self, options, model, trajectory_generator, restore=True): self.options = options self.model = model self.trajectory_generator = trajectory_generator lr = self.options.learning_rate self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) self.loss = [] self.err = [] # Set up checkpoints self.ckpt_dir = os.path.join(options.save_dir, options.run_ID) ckpt_path = os.path.join(self.ckpt_dir, "most_recent_model.pth") pc_path = os.path.join(self.ckpt_dir, "place_cell_centers.npy") if restore and os.path.isdir(self.ckpt_dir) and os.path.isfile(ckpt_path): self.model.load_state_dict(torch.load(ckpt_path)) print("Restored trained model from {}".format(ckpt_path)) # 加载 place cell 位置 if os.path.isfile(pc_path): us = np.load(pc_path) self.model.place_cells.us = torch.tensor(us).to(options.device) print("Restored place cell centers from {}".format(pc_path)) else: print("Warning: place_cell_centers.npy not found! Model may not work correctly.") else: if not os.path.isdir(self.ckpt_dir): os.makedirs(self.ckpt_dir, exist_ok=True) print("Initializing new model from scratch.") print("Saving to: {}".format(self.ckpt_dir)) def train_step(self, inputs, pc_outputs, pos): """ Train on one batch of trajectories. Args: inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2]. pc_outputs: Ground truth place cell activations with shape [batch_size, sequence_length, Np]. pos: Ground truth 2d position with shape [batch_size, sequence_length, 2]. Returns: loss: Avg. loss for this training batch. err: Avg. decoded position error in cm. """ self.model.zero_grad() loss, err = self.model.compute_loss(inputs, pc_outputs, pos) loss.backward() self.optimizer.step() return loss.item(), err.item() def train(self, n_epochs: int = 1, n_steps=10, save=True): """ Train model on simulated trajectories. Args: n_steps: Number of training steps save: If true, save a checkpoint after each epoch. """ # Construct generator gen = self.trajectory_generator.get_generator() for epoch_idx in range(n_epochs): tbar = tqdm(range(n_steps), leave=False) for step_idx in tbar: inputs, pc_outputs, pos = next(gen) loss, err = self.train_step(inputs, pc_outputs, pos) self.loss.append(loss) self.err.append(err) # Log error rate to progress bar tbar.set_description('Error = ' + str(int(100*err)) + 'cm') if save and ((epoch_idx + 1) % 10 == 0 or epoch_idx == 0): # Save checkpoint # ckpt_path = os.path.join( # self.ckpt_dir, "epoch_{}.pth".format(epoch_idx) # ) # torch.save(self.model.state_dict(), ckpt_path) torch.save( self.model.state_dict(), os.path.join(self.ckpt_dir, "most_recent_model.pth"), ) # 保存 place cell 位置 np.save( os.path.join(self.ckpt_dir, "place_cell_centers.npy"), self.model.place_cells.us.cpu().numpy(), ) # Save a picture of rate maps save_ratemaps( self.model, self.trajectory_generator, self.options, step=epoch_idx + 1, ) print( "Epoch: {}/{}. Date: {}. Loss: {}. Err: {}cm".format( epoch_idx + 1, n_epochs, str(datetime.datetime.now())[:-7], np.round(loss, 2), np.round(100 * err, 2), ) )