PriyePrabhakar's picture
Add application file 1
0c717d3
raw
history blame
5.73 kB
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar, Callback
from lightning.pytorch.loggers import TensorBoardLogger
from pathlib import Path
from torch.optim.lr_scheduler import OneCycleLR
from torch_lr_finder import LRFinder
import torch
from datamodules.imagenet_datamodule import ImageNetDataModule
from models.classifier import ImageNetClassifier
class NewLineProgressBar(Callback):
def on_train_epoch_start(self, trainer, pl_module):
print(f"\nEpoch {trainer.current_epoch}")
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
metrics = trainer.callback_metrics
train_loss = metrics.get('train_loss', 0)
train_acc = metrics.get('train_acc', 0)
print(f"\rTraining - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}", end="")
def on_validation_epoch_start(self, trainer, pl_module):
print("\n\nValidation:")
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
metrics = trainer.callback_metrics
val_loss = metrics.get('val_loss', 0)
val_acc = metrics.get('val_acc', 0)
print(f"\rValidation - Loss: {val_loss:.4f}, Acc: {val_acc:.4f}", end="")
def find_optimal_lr(model, data_module):
# Initialize LRFinder
optimizer = torch.optim.Adam(model.parameters(), lr=1e-7)
criterion = torch.nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lr_finder = LRFinder(model, optimizer, criterion, device=device)
# Run LR finder with stage parameter
data_module.setup(stage='fit')
lr_finder.range_test(data_module.train_dataloader(), end_lr=1, num_iter=200, step_mode="exp")
# Get the learning rate with the steepest gradient
lrs = lr_finder.history['lr']
losses = lr_finder.history['loss']
# Find the learning rate with minimum loss
optimal_lr = lrs[losses.index(min(losses))]
# You might want to pick a learning rate slightly lower than the minimum
optimal_lr = optimal_lr * 0.1 # Common practice to use 1/10th of the value
print(f"Optimal learning rate: {optimal_lr}")
# Plot the LR finder results
lr_finder.plot() # Will save the plot
lr_finder.reset() # Reset the model and optimizer
return optimal_lr
def main(chkpoint_path=None):
if chkpoint_path is not None:
model = ImageNetClassifier(lr=1e-2)
data_module = ImageNetDataModule(batch_size=256, num_workers=8)
checkpoint_callback = ModelCheckpoint(
dirpath="logs/checkpoints",
filename="{epoch}-{val_loss:.2f}",
monitor="val_loss",
save_top_k=3
)
# Initialize Trainer
trainer = L.Trainer(resume_from_checkpoint=chkpoint_path,
max_epochs=epochs,
precision="bf16-mixed",
callbacks=[
checkpoint_callback,
NewLineProgressBar(),
TQDMProgressBar(refresh_rate=1)
],
accelerator="auto",
logger=TensorBoardLogger(save_dir="logs", name="image_net_classifications"),
enable_progress_bar=True,
enable_model_summary=True,
log_every_n_steps=1,
val_check_interval=1.0,
check_val_every_n_epoch=1
)
trainer.fit(model, data_module)
else:
# Create directories
Path("logs").mkdir(exist_ok=True)
Path("data").mkdir(exist_ok=True)
# Initialize DataModule and Model
data_module = ImageNetDataModule(batch_size=256, num_workers=8)
model = ImageNetClassifier(lr=1e-2) # Initial lr will be overridden
# Find optimal learning rate
optimal_lr = find_optimal_lr(model, data_module)
#optimal_lr = 6.28E-02
# Calculate total steps for OneCycleLR
epochs = 60
data_module.setup(stage='fit')
steps_per_epoch = len(data_module.train_dataloader())
total_steps = epochs * steps_per_epoch
# # Initialize optimizer
# optimizer = torch.optim.Adam(model.parameters(), lr=optimal_lr)
# # Initialize OneCycleLR scheduler
# scheduler = OneCycleLR(
# optimizer,
# max_lr=optimal_lr,
# total_steps=total_steps,
# pct_start=0.3, # Spend 30% of time increasing LR
# div_factor=25, # Initial LR will be max_lr/25
# final_div_factor=1e4, # Final LR will be max_lr/10000
# three_phase=False, # Use one cycle policy
# anneal_strategy='cos' # Use cosine annealing
# )
model = ImageNetClassifier(lr=optimal_lr) # Initial lr will be overridden
# Initialize callbacks
checkpoint_callback = ModelCheckpoint(
dirpath="logs/checkpoints",
filename="{epoch}-{val_loss:.2f}",
monitor="val_loss",
save_top_k=3
)
# Initialize Trainer
trainer = L.Trainer(
max_epochs=epochs,
precision="bf16-mixed",
callbacks=[
checkpoint_callback,
NewLineProgressBar(),
TQDMProgressBar(refresh_rate=1)
],
accelerator="auto",
logger=TensorBoardLogger(save_dir="logs", name="image_net_classifications"),
enable_progress_bar=True,
enable_model_summary=True,
log_every_n_steps=1,
val_check_interval=1.0,
check_val_every_n_epoch=1
)
# Train the model
trainer.fit(model, data_module)
if __name__ == "__main__":
main(chkpoint_path=None)