File size: 4,461 Bytes
00c2650 bc91ac1 00c2650 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | # -*- coding: utf-8 -*-
import torch
import numpy as np
import datetime
from tqdm import tqdm
from visualize import save_ratemaps
import os
class Trainer(object):
def __init__(self, options, model, trajectory_generator, restore=True):
self.options = options
self.model = model
self.trajectory_generator = trajectory_generator
lr = self.options.learning_rate
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
self.loss = []
self.err = []
# Set up checkpoints
self.ckpt_dir = os.path.join(options.save_dir, options.run_ID)
ckpt_path = os.path.join(self.ckpt_dir, "most_recent_model.pth")
pc_path = os.path.join(self.ckpt_dir, "place_cell_centers.npy")
if restore and os.path.isdir(self.ckpt_dir) and os.path.isfile(ckpt_path):
self.model.load_state_dict(torch.load(ckpt_path))
print("Restored trained model from {}".format(ckpt_path))
# 加载 place cell 位置
if os.path.isfile(pc_path):
us = np.load(pc_path)
self.model.place_cells.us = torch.tensor(us).to(options.device)
print("Restored place cell centers from {}".format(pc_path))
else:
print("Warning: place_cell_centers.npy not found! Model may not work correctly.")
else:
if not os.path.isdir(self.ckpt_dir):
os.makedirs(self.ckpt_dir, exist_ok=True)
print("Initializing new model from scratch.")
print("Saving to: {}".format(self.ckpt_dir))
def train_step(self, inputs, pc_outputs, pos):
"""
Train on one batch of trajectories.
Args:
inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2].
pc_outputs: Ground truth place cell activations with shape
[batch_size, sequence_length, Np].
pos: Ground truth 2d position with shape [batch_size, sequence_length, 2].
Returns:
loss: Avg. loss for this training batch.
err: Avg. decoded position error in cm.
"""
self.model.zero_grad()
loss, err = self.model.compute_loss(inputs, pc_outputs, pos)
loss.backward()
self.optimizer.step()
return loss.item(), err.item()
def train(self, n_epochs: int = 1, n_steps=10, save=True):
"""
Train model on simulated trajectories.
Args:
n_steps: Number of training steps
save: If true, save a checkpoint after each epoch.
"""
# Construct generator
gen = self.trajectory_generator.get_generator()
for epoch_idx in range(n_epochs):
tbar = tqdm(range(n_steps), leave=False)
for step_idx in tbar:
inputs, pc_outputs, pos = next(gen)
loss, err = self.train_step(inputs, pc_outputs, pos)
self.loss.append(loss)
self.err.append(err)
# Log error rate to progress bar
tbar.set_description('Error = ' + str(int(100*err)) + 'cm')
if save and ((epoch_idx + 1) % 10 == 0 or epoch_idx == 0):
# Save checkpoint
# ckpt_path = os.path.join(
# self.ckpt_dir, "epoch_{}.pth".format(epoch_idx)
# )
# torch.save(self.model.state_dict(), ckpt_path)
torch.save(
self.model.state_dict(),
os.path.join(self.ckpt_dir, "most_recent_model.pth"),
)
# 保存 place cell 位置
np.save(
os.path.join(self.ckpt_dir, "place_cell_centers.npy"),
self.model.place_cells.us.cpu().numpy(),
)
# Save a picture of rate maps
save_ratemaps(
self.model,
self.trajectory_generator,
self.options,
step=epoch_idx + 1,
)
print(
"Epoch: {}/{}. Date: {}. Loss: {}. Err: {}cm".format(
epoch_idx + 1,
n_epochs,
str(datetime.datetime.now())[:-7],
np.round(loss, 2),
np.round(100 * err, 2),
)
)
|