File size: 1,766 Bytes
eef8873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import logging
import torch.nn as nn
from torch.optim import AdamW

from src.config import DEVICE, EPOCHS, NUM_CLASSES
from src.models.resnet_model import CarClassifierResNet
from src.data.dataset import create_resnet_dataloaders
from src.training.trainer import train_single_input_model

logger = logging.getLogger(__name__)


def run_resnet_training():
    logger.info("Initializing ResNet training pipeline...")

    train_loader, eval_loader = create_resnet_dataloaders()

    model = CarClassifierResNet(
        num_classes=NUM_CLASSES
    ).to(DEVICE)

    criterion = nn.CrossEntropyLoss()

    optimizer = AdamW([
        {
            "params": model.model.layer3.parameters(),
            "lr": 1e-5
        },
        {
            "params": model.model.layer4.parameters(),
            "lr": 1e-5
        },
        {
            "params": model.model.fc.parameters(),
            "lr": 1e-4
        }
    ])

    logger.info("Starting ResNet training...")

    all_preds, all_labels = train_single_input_model(
        model=model,
        train_loader=train_loader,
        eval_loader=eval_loader,
        optimizer=optimizer,
        criterion=criterion,
        device=DEVICE,
        epochs=EPOCHS,
        checkpoint_model_name="best_resnet_model",
        patience=7
    )

    logger.info("ResNet training completed.")

    return all_preds, all_labels


if __name__ == "__main__":
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s"
    )

    preds, labels = run_resnet_training()

    print("\nTraining completed successfully.")
    print("Prediction samples:", preds[:10])
    print("Label samples:", labels[:10])