Spaces:
Sleeping
Sleeping
| import logging | |
| import torch.nn as nn | |
| from torch.optim import AdamW | |
| from src.config import DEVICE, EPOCHS, NUM_CLASSES | |
| from src.models.resnet_model import CarClassifierResNet | |
| from src.data.dataset import create_resnet_dataloaders | |
| from src.training.trainer import train_single_input_model | |
| logger = logging.getLogger(__name__) | |
| def run_resnet_training(): | |
| logger.info("Initializing ResNet training pipeline...") | |
| train_loader, eval_loader = create_resnet_dataloaders() | |
| model = CarClassifierResNet( | |
| num_classes=NUM_CLASSES | |
| ).to(DEVICE) | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = AdamW([ | |
| { | |
| "params": model.model.layer3.parameters(), | |
| "lr": 1e-5 | |
| }, | |
| { | |
| "params": model.model.layer4.parameters(), | |
| "lr": 1e-5 | |
| }, | |
| { | |
| "params": model.model.fc.parameters(), | |
| "lr": 1e-4 | |
| } | |
| ]) | |
| logger.info("Starting ResNet training...") | |
| all_preds, all_labels = train_single_input_model( | |
| model=model, | |
| train_loader=train_loader, | |
| eval_loader=eval_loader, | |
| optimizer=optimizer, | |
| criterion=criterion, | |
| device=DEVICE, | |
| epochs=EPOCHS, | |
| checkpoint_model_name="best_resnet_model", | |
| patience=7 | |
| ) | |
| logger.info("ResNet training completed.") | |
| return all_preds, all_labels | |
| if __name__ == "__main__": | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(message)s" | |
| ) | |
| preds, labels = run_resnet_training() | |
| print("\nTraining completed successfully.") | |
| print("Prediction samples:", preds[:10]) | |
| print("Label samples:", labels[:10]) |