File size: 1,703 Bytes
ac91785
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import torch
import src.data_setup as data_setup
import src.engine as engine
import src.utils as utils
from src.logger import global_logger as logger
from torchvision import transforms
import src.model as model_module

def main():

    NUM_EPOCHS = 20
    BATCH_SIZE = 32
    LEARNING_RATE = 0.001

    train_dir = "data\\retinal_oct\\train"
    test_dir = "data\\retinal_oct\\test"

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Use the transformations required by ResNet50
    data_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(
        train_dir=train_dir,
        test_dir=test_dir,
        transform=data_transform,
        batch_size=BATCH_SIZE
    )
    logger.info("Data transformed successfully.")

    # Initialize the ResNet50 model
    model, _ = model_module.resnet_model(num_classes=len(class_names))
    model = model.to(device)

    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

    engine.train(
        model=model,
        train_dataloader=train_dataloader,
        test_dataloader=test_dataloader,
        loss_fn=loss_fn,
        optimizer=optimizer,
        epochs=NUM_EPOCHS,
        device=device
    )
    
    utils.save_model(
        model=model,
        target_dir="models",
        model_name="model.pth"
    )
    logger.info("Model trained successfully.")
    logger.info("Model saved to models folder.")

if __name__ == '__main__':
    main()