File size: 2,720 Bytes
1fab54b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
from torch import optim
from torch.utils.tensorboard import SummaryWriter

import tqdm
import numpy as np

from model import Model
from buffer import Buffer


class Trainer:
    def __init__(self, model: Model, buffer: Buffer, base_lr:float = 0.001,
                 weight_decay=1e-4, device:str='cpu'):
        self.main_model = model
        self.main_buffer = buffer
        self.global_step = 0
        self.device = device

        # optimizer
        self.optimizer = optim.SGD(
            self.main_model.parameters(),
            lr = base_lr,
            weight_decay = weight_decay,
            momentum = 0.9
        )

        # self.scheduler = optim.lr_scheduler.CyclicLR(
        #     self.optimizer,
        #     base_lr = base_lr,
        #     max_lr = 0.1
        # )

        # Tensorboard summary writer
        self.writer = SummaryWriter()

    def transfer_buffer(self, buffer) -> None:
        for state, value, policy in zip(buffer.state, buffer.value, buffer.policy):
            self.main_buffer.store_experience(
                state = state,
                value = value,
                policy = policy
            )

    def reset_buffer(self) -> None:
        self.main_buffer.reset()

    # learn from the buffer
    def learn(self, state: np.ndarray, value: np.ndarray, policy: np.ndarray) -> float:

        state = torch.tensor(state, dtype=torch.float32, device=self.device)
        value = torch.tensor(value, dtype=torch.float32, device=self.device).unsqueeze(-1)
        policy = torch.tensor(policy, dtype=torch.float32, device=self.device)

        pred_val, pred_policy = self.main_model(state)

        self.optimizer.zero_grad()
        loss = self.main_model.get_loss(pred_val, pred_policy, value, policy)
        loss.backward()
        self.optimizer.step()

        return loss.detach().cpu().numpy()

    # Training loop for the model
    def train_model(self, epochs: int, batch_size: int):

        train_steps = np.ceil(len(self.main_buffer) / batch_size).astype(np.int32)

        # perform the training
        for epoch in range(epochs):
            for state, value, policy in tqdm(self.main_buffer.sample(batch_size), total=train_steps, desc=f'Epoch:{epoch+1}'):
                loss = self.learn(state, value, policy)
                self.writer.add_scalar("loss", loss, self.global_step)
                self.global_step += 1

        self.writer.flush()

    # close the writer
    def close_writer(self):
        self.writer.close()

    # Save the model
    def save_model(self, step: int):
        torch.save(self.main_model.state_dict(), f'TargetModel_{step}.pt')
        torch.save(self.optimizer.state_dict(), f'Optimizer_{step}.pt')