File size: 5,734 Bytes
0c717d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar, Callback
from lightning.pytorch.loggers import TensorBoardLogger
from pathlib import Path
from torch.optim.lr_scheduler import OneCycleLR
from torch_lr_finder import LRFinder
import torch

from datamodules.imagenet_datamodule import ImageNetDataModule
from models.classifier import ImageNetClassifier

class NewLineProgressBar(Callback):
    def on_train_epoch_start(self, trainer, pl_module):
        print(f"\nEpoch {trainer.current_epoch}")
    
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        metrics = trainer.callback_metrics
        train_loss = metrics.get('train_loss', 0)
        train_acc = metrics.get('train_acc', 0)
        print(f"\rTraining - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}", end="")
    
    def on_validation_epoch_start(self, trainer, pl_module):
        print("\n\nValidation:")
    
    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        metrics = trainer.callback_metrics
        val_loss = metrics.get('val_loss', 0)
        val_acc = metrics.get('val_acc', 0)
        print(f"\rValidation - Loss: {val_loss:.4f}, Acc: {val_acc:.4f}", end="")

def find_optimal_lr(model, data_module):
    # Initialize LRFinder
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-7)
    criterion = torch.nn.CrossEntropyLoss()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")   
    lr_finder = LRFinder(model, optimizer, criterion, device=device)
    
    # Run LR finder with stage parameter
    data_module.setup(stage='fit')
    lr_finder.range_test(data_module.train_dataloader(), end_lr=1, num_iter=200, step_mode="exp")
    
    # Get the learning rate with the steepest gradient
    lrs = lr_finder.history['lr']
    losses = lr_finder.history['loss']
    
    # Find the learning rate with minimum loss
    optimal_lr = lrs[losses.index(min(losses))]
    
    # You might want to pick a learning rate slightly lower than the minimum
    optimal_lr = optimal_lr * 0.1  # Common practice to use 1/10th of the value
    
    print(f"Optimal learning rate: {optimal_lr}")
    
    # Plot the LR finder results
    lr_finder.plot()  # Will save the plot
    lr_finder.reset()  # Reset the model and optimizer
    
    return optimal_lr

def main(chkpoint_path=None):
    if chkpoint_path is not None:
        model = ImageNetClassifier(lr=1e-2)
        data_module = ImageNetDataModule(batch_size=256, num_workers=8)
        checkpoint_callback = ModelCheckpoint(
            dirpath="logs/checkpoints",
            filename="{epoch}-{val_loss:.2f}",
            monitor="val_loss",
            save_top_k=3
        )

        # Initialize Trainer
        trainer = L.Trainer(resume_from_checkpoint=chkpoint_path,
            max_epochs=epochs,
            precision="bf16-mixed",
            callbacks=[
                checkpoint_callback,
                NewLineProgressBar(),
                TQDMProgressBar(refresh_rate=1)
            ],
            accelerator="auto",
            logger=TensorBoardLogger(save_dir="logs", name="image_net_classifications"),
            enable_progress_bar=True,
            enable_model_summary=True,
            log_every_n_steps=1,
            val_check_interval=1.0,
            check_val_every_n_epoch=1
        )
        trainer.fit(model, data_module)
    else:
        # Create directories
        Path("logs").mkdir(exist_ok=True)
        Path("data").mkdir(exist_ok=True)
        # Initialize DataModule and Model
        data_module = ImageNetDataModule(batch_size=256, num_workers=8)
        model = ImageNetClassifier(lr=1e-2)  # Initial lr will be overridden

        # Find optimal learning rate
        optimal_lr = find_optimal_lr(model, data_module)
        #optimal_lr = 6.28E-02
        # Calculate total steps for OneCycleLR
        epochs = 60
        data_module.setup(stage='fit')
        steps_per_epoch = len(data_module.train_dataloader())
        total_steps = epochs * steps_per_epoch

        # # Initialize optimizer
        # optimizer = torch.optim.Adam(model.parameters(), lr=optimal_lr)

        # # Initialize OneCycleLR scheduler
        # scheduler = OneCycleLR(
        #     optimizer,
        #     max_lr=optimal_lr,
        #     total_steps=total_steps,
        #     pct_start=0.3,  # Spend 30% of time increasing LR
        #     div_factor=25,  # Initial LR will be max_lr/25
        #     final_div_factor=1e4,  # Final LR will be max_lr/10000
        #     three_phase=False,  # Use one cycle policy
        #     anneal_strategy='cos'  # Use cosine annealing
        # )
        model = ImageNetClassifier(lr=optimal_lr)  # Initial lr will be overridden
        # Initialize callbacks
        checkpoint_callback = ModelCheckpoint(
            dirpath="logs/checkpoints",
            filename="{epoch}-{val_loss:.2f}",
            monitor="val_loss",
            save_top_k=3
        )

        # Initialize Trainer
        trainer = L.Trainer(
            max_epochs=epochs,
            precision="bf16-mixed",
            callbacks=[
                checkpoint_callback,
                NewLineProgressBar(),
                TQDMProgressBar(refresh_rate=1)
            ],
            accelerator="auto",
            logger=TensorBoardLogger(save_dir="logs", name="image_net_classifications"),
            enable_progress_bar=True,
            enable_model_summary=True,
            log_every_n_steps=1,
            val_check_interval=1.0,
            check_val_every_n_epoch=1
        )

        # Train the model
        trainer.fit(model, data_module)

if __name__ == "__main__":
    main(chkpoint_path=None)