import os import torch import src.data_setup as data_setup import src.engine as engine import src.utils as utils from src.logger import global_logger as logger from torchvision import transforms import src.model as model_module def main(): NUM_EPOCHS = 20 BATCH_SIZE = 32 LEARNING_RATE = 0.001 train_dir = "data\\retinal_oct\\train" test_dir = "data\\retinal_oct\\test" device = "cuda" if torch.cuda.is_available() else "cpu" # Use the transformations required by ResNet50 data_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders( train_dir=train_dir, test_dir=test_dir, transform=data_transform, batch_size=BATCH_SIZE ) logger.info("Data transformed successfully.") # Initialize the ResNet50 model model, _ = model_module.resnet_model(num_classes=len(class_names)) model = model.to(device) loss_fn = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) engine.train( model=model, train_dataloader=train_dataloader, test_dataloader=test_dataloader, loss_fn=loss_fn, optimizer=optimizer, epochs=NUM_EPOCHS, device=device ) utils.save_model( model=model, target_dir="models", model_name="model.pth" ) logger.info("Model trained successfully.") logger.info("Model saved to models folder.") if __name__ == '__main__': main()