Spaces:
Sleeping
Sleeping
File size: 1,703 Bytes
ac91785 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | import os
import torch
import src.data_setup as data_setup
import src.engine as engine
import src.utils as utils
from src.logger import global_logger as logger
from torchvision import transforms
import src.model as model_module
def main():
NUM_EPOCHS = 20
BATCH_SIZE = 32
LEARNING_RATE = 0.001
train_dir = "data\\retinal_oct\\train"
test_dir = "data\\retinal_oct\\test"
device = "cuda" if torch.cuda.is_available() else "cpu"
# Use the transformations required by ResNet50
data_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(
train_dir=train_dir,
test_dir=test_dir,
transform=data_transform,
batch_size=BATCH_SIZE
)
logger.info("Data transformed successfully.")
# Initialize the ResNet50 model
model, _ = model_module.resnet_model(num_classes=len(class_names))
model = model.to(device)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
engine.train(
model=model,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
loss_fn=loss_fn,
optimizer=optimizer,
epochs=NUM_EPOCHS,
device=device
)
utils.save_model(
model=model,
target_dir="models",
model_name="model.pth"
)
logger.info("Model trained successfully.")
logger.info("Model saved to models folder.")
if __name__ == '__main__':
main()
|