Spaces:
Sleeping
Sleeping
| from torchvision import transforms | |
| from torch.utils.data import DataLoader | |
| from lightning.pytorch.loggers.wandb import WandbLogger | |
| from lightning.pytorch.callbacks import ModelCheckpoint | |
| import lightning as pl | |
| import wandb | |
| from src.dataset import ClassifierDataset, CustomDataset | |
| from src.classifier import Classifier | |
| from src.models import CycleGAN | |
| from src.config import CFG | |
| def train_classifier(image_size, | |
| batch_size, | |
| epochs, | |
| resume_ckpt_path, | |
| train_dir, | |
| val_dir, | |
| checkpoint_dir, | |
| project, | |
| job_name): | |
| clf_wandb_logger = WandbLogger(project=project, name=job_name, log_model="all") | |
| transform = transforms.Compose([ | |
| transforms.Resize((image_size, image_size)), # Resize image to 512x512 | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485], std=[0.229]) # Normalize image | |
| ]) | |
| # Define dataset paths | |
| # train_dir = "/kaggle/working/CycleGan-CFE/train-data/train" | |
| # val_dir = "/kaggle/working/CycleGan-CFE/train-data/val" | |
| # Create datasets | |
| train_dataset = ClassifierDataset(root_dir=train_dir, transform=transform) | |
| val_dataset = ClassifierDataset(root_dir=val_dir, transform=transform) | |
| print("Total Training Images: ",len(train_dataset)) | |
| print("Total Validation Images: ",len(val_dataset)) | |
| # Create data loaders | |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4) | |
| val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=4) | |
| # Instantiate the classifier model | |
| clf = Classifier(transfer=True) | |
| checkpoint_callback = ModelCheckpoint( | |
| monitor='val_loss', | |
| dirpath=checkpoint_dir, | |
| filename='efficientnet_b2-epoch{epoch:02d}-val_loss{val_loss:.2f}', | |
| auto_insert_metric_name=False, | |
| save_weights_only=False, | |
| save_top_k=3, | |
| mode='min' | |
| ) | |
| # Set up PyTorch Lightning Trainer with multiple GPUs and tqdm progress bar | |
| trainer = pl.Trainer( | |
| devices="auto", | |
| precision="16-mixed", | |
| accelerator="auto", | |
| max_epochs=epochs, | |
| accumulate_grad_batches=10, | |
| log_every_n_steps=1, | |
| check_val_every_n_epoch=1, | |
| benchmark=True, | |
| logger=clf_wandb_logger, | |
| callbacks=[checkpoint_callback], | |
| ) | |
| # Train the classifier | |
| trainer.fit(clf, train_loader, val_loader, ckpt_path=resume_ckpt_path) | |
| wandb.finish() | |
| def train_cyclegan(image_size, | |
| batch_size, | |
| epochs, | |
| classifier_path, | |
| resume_ckpt_path, | |
| train_dir, | |
| val_dir, | |
| test_dir, | |
| checkpoint_dir, | |
| project, | |
| job_name, | |
| ): | |
| testdata_dir = test_dir | |
| train_N = "0" | |
| train_P = "1" | |
| img_res = (image_size, image_size) | |
| test_dataset = CustomDataset(root_dir=testdata_dir, train_N=train_N, train_P=train_P, img_res=img_res) | |
| test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) | |
| wandb_logger = WandbLogger(project=project, name=job_name, log_model="all") | |
| print(classifier_path) | |
| cyclegan = CycleGAN(train_dir=train_dir, val_dir=val_dir, test_dataloader=test_dataloader, classifier_path=classifier_path, checkpoint_dir=checkpoint_dir, gf=CFG.GAN_FILTERS, df=CFG.DIS_FILTERS) | |
| gan_checkpoint_callback = ModelCheckpoint(dirpath=checkpoint_dir, | |
| filename='cyclegan-epoch_{epoch}-vloss_{val_generator_loss:.2f}', | |
| monitor='val_generator_loss', | |
| save_top_k=3, | |
| save_last=True, | |
| save_weights_only=False, | |
| verbose=True, | |
| mode='min') | |
| # Create the trainer | |
| trainer = pl.Trainer( | |
| accelerator="auto", | |
| precision="16-mixed", | |
| max_epochs=epochs, | |
| log_every_n_steps=1, | |
| benchmark=True, | |
| devices="auto", | |
| logger=wandb_logger, | |
| callbacks= [gan_checkpoint_callback] | |
| ) | |
| # Train the CycleGAN model | |
| trainer.fit(cyclegan, ckpt_path=resume_ckpt_path) | |