File size: 4,461 Bytes
00c2650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc91ac1
00c2650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# -*- coding: utf-8 -*-
import torch
import numpy as np
import datetime
from tqdm import tqdm

from visualize import save_ratemaps
import os


class Trainer(object):
    def __init__(self, options, model, trajectory_generator, restore=True):
        self.options = options
        self.model = model
        self.trajectory_generator = trajectory_generator
        lr = self.options.learning_rate
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)

        self.loss = []
        self.err = []

        # Set up checkpoints
        self.ckpt_dir = os.path.join(options.save_dir, options.run_ID)
        ckpt_path = os.path.join(self.ckpt_dir, "most_recent_model.pth")
        pc_path = os.path.join(self.ckpt_dir, "place_cell_centers.npy")
        if restore and os.path.isdir(self.ckpt_dir) and os.path.isfile(ckpt_path):
            self.model.load_state_dict(torch.load(ckpt_path))
            print("Restored trained model from {}".format(ckpt_path))
            # 加载 place cell 位置
            if os.path.isfile(pc_path):
                us = np.load(pc_path)
                self.model.place_cells.us = torch.tensor(us).to(options.device)
                print("Restored place cell centers from {}".format(pc_path))
            else:
                print("Warning: place_cell_centers.npy not found! Model may not work correctly.")
        else:
            if not os.path.isdir(self.ckpt_dir):
                os.makedirs(self.ckpt_dir, exist_ok=True)
            print("Initializing new model from scratch.")
            print("Saving to: {}".format(self.ckpt_dir))

    def train_step(self, inputs, pc_outputs, pos):
        """
        Train on one batch of trajectories.

        Args:
            inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2].
            pc_outputs: Ground truth place cell activations with shape
                [batch_size, sequence_length, Np].
            pos: Ground truth 2d position with shape [batch_size, sequence_length, 2].

        Returns:
            loss: Avg. loss for this training batch.
            err: Avg. decoded position error in cm.
        """
        self.model.zero_grad()

        loss, err = self.model.compute_loss(inputs, pc_outputs, pos)

        loss.backward()
        self.optimizer.step()

        return loss.item(), err.item()

    def train(self, n_epochs: int = 1, n_steps=10, save=True):
        """
        Train model on simulated trajectories.

        Args:
            n_steps: Number of training steps
            save: If true, save a checkpoint after each epoch.
        """

        # Construct generator
        gen = self.trajectory_generator.get_generator()

        for epoch_idx in range(n_epochs):
            tbar = tqdm(range(n_steps), leave=False)
            for step_idx in tbar:
                inputs, pc_outputs, pos = next(gen)
                loss, err = self.train_step(inputs, pc_outputs, pos)
                self.loss.append(loss)
                self.err.append(err)

                # Log error rate to progress bar
                tbar.set_description('Error = ' + str(int(100*err)) + 'cm')

            if save and ((epoch_idx + 1) % 10 == 0 or epoch_idx == 0):
                # Save checkpoint
                # ckpt_path = os.path.join(
                #     self.ckpt_dir, "epoch_{}.pth".format(epoch_idx)
                # )
                # torch.save(self.model.state_dict(), ckpt_path)
                torch.save(
                    self.model.state_dict(),
                    os.path.join(self.ckpt_dir, "most_recent_model.pth"),
                )
                # 保存 place cell 位置
                np.save(
                    os.path.join(self.ckpt_dir, "place_cell_centers.npy"),
                    self.model.place_cells.us.cpu().numpy(),
                )

                # Save a picture of rate maps
                save_ratemaps(
                    self.model,
                    self.trajectory_generator,
                    self.options,
                    step=epoch_idx + 1,
                )

                print(
                    "Epoch: {}/{}. Date: {}. Loss: {}. Err: {}cm".format(
                        epoch_idx + 1,
                        n_epochs,
                        str(datetime.datetime.now())[:-7],
                        np.round(loss, 2),
                        np.round(100 * err, 2),
                    )
                )