Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import lightning as pl | |
| from torchinfo import summary | |
| from lightning.pytorch import loggers as pl_loggers | |
| from functorch.compile import compiled_function,draw_graph | |
| from lightning.pytorch.profilers import PyTorchProfiler | |
| from lightning.pytorch.callbacks import ( | |
| DeviceStatsMonitor, | |
| EarlyStopping, | |
| LearningRateMonitor, | |
| ModelCheckpoint, | |
| ModelPruning | |
| ) | |
| from lightning.pytorch.callbacks.progress import TQDMProgressBar | |
| from data import LitMNISTDataModule | |
| from config import CONFIG | |
| from model import LitMNISTModel | |
| from utils import TRAIN_TRANSFORMS, TEST_TRANSFORMS | |
| # Auxilary utils | |
| torch.backends.cuda.matmul.allow_tf32=True | |
| torch.set_float32_matmul_precision('high') | |
| torch.cuda.amp.autocast(enabled=True,dtype=torch.float32) | |
| device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
| torch.set_default_device( device= device ) | |
| torch.cuda.empty_cache() | |
| # pl.seed_everything(123, workers=True) | |
| ## Loggers | |
| logger:pl_loggers.TensorBoardLogger = pl_loggers.TensorBoardLogger(save_dir='logs/',name= "lightning_logs",log_graph=True) | |
| ## CallBacks | |
| call_backs = [ | |
| TQDMProgressBar(refresh_rate=10), | |
| ModelCheckpoint( | |
| monitor="val/loss", dirpath=os.path.join('logs','chkpoints'), filename="{epoch:02d}",save_top_k=1, | |
| ), | |
| DeviceStatsMonitor(cpu_stats=True), | |
| # EarlyStopping(monitor="val/loss",mode='min'), | |
| LearningRateMonitor(logging_interval='step') | |
| ] | |
| ## Profilers | |
| perf_dir = os.path.join(os.path.dirname(__file__),'logs','profiler') | |
| perf_profiler =PyTorchProfiler( | |
| dirpath=perf_dir, | |
| filename="perf_logs_pytorch", | |
| group_by_input_shapes=True, | |
| emit_nvtx=torch.cuda.is_available(), | |
| activities=( | |
| [ | |
| torch.profiler.ProfilerActivity.CPU, | |
| torch.profiler.ProfilerActivity.CUDA, | |
| ] | |
| if torch.cuda.is_available() | |
| else [ | |
| torch.profiler.ProfilerActivity.CPU, | |
| ] | |
| ), | |
| schedule=torch.profiler.schedule( | |
| wait=1, warmup=1, active=5, repeat=3, skip_first=True | |
| ), | |
| profile_memory=True, | |
| with_stack=True, | |
| with_flops=True, | |
| with_modules=True, | |
| on_trace_ready=torch.profiler.tensorboard_trace_handler(str( os.path.join(perf_dir,'trace')) ), | |
| ) | |
| ## MNISTDataModule | |
| dm = LitMNISTDataModule( | |
| data_dir=CONFIG['data'].get('dir_path','.'), | |
| batch_size= CONFIG.get('batch_size'), | |
| num_workers=CONFIG.get('num_workers'), | |
| test_transform=TEST_TRANSFORMS, | |
| train_transform=TRAIN_TRANSFORMS | |
| ) | |
| dm.prepare_data() | |
| dm.setup() | |
| ## MNISTModel | |
| model = LitMNISTModel() | |
| # model = LitMNISTModel.load_from_checkpoint(r'C:\Users\muthu\GitHub\Spaces ๐\UnSolvedMNIST\logs\chkpoints\epoch=04.ckpt') | |
| # Single BATCH | |
| batch = next(iter(dm.train_dataloader())) | |
| # Computational graph | |
| model.example_input_array = batch[0] | |
| # CPU Stats | |
| with torch.autograd.profiler.profile() as prof: | |
| output = model.to(device)(batch[0].to(device)) | |
| os.makedirs(name=os.path.join(os.path.dirname(__file__),'logs','profiler'),exist_ok=True) | |
| with open(os.path.join(os.path.dirname(__file__),'logs','profiler',"cpu_throttle.txt"), "w") as text_file: | |
| text_file.write(f"{prof.key_averages().table(sort_by='self_cpu_time_total',top_level_events_only=False)}") | |
| # Model Summary | |
| summary( | |
| model, | |
| input_size=batch[0].shape, | |
| depth=5, | |
| verbose=2, | |
| col_width=16, | |
| col_names=[ | |
| "input_size", | |
| "output_size", | |
| "num_params", | |
| "kernel_size", | |
| "mult_adds", | |
| ], | |
| row_settings=["var_names"], | |
| ) | |
| ## Trainer | |
| trainer = pl.Trainer( | |
| max_epochs=CONFIG['training'].get('num_epochs',15), | |
| logger=logger, | |
| profiler='pytorch',#perf_profiler,#'advanced', | |
| callbacks=call_backs, | |
| precision=32, | |
| enable_model_summary=False, | |
| enable_progress_bar=True, | |
| ) | |
| ## Training | |
| trainer.fit(model=model,datamodule=dm) | |
| ## Validation | |
| trainer.validate(model,datamodule=dm) |