Spaces:
Sleeping
Sleeping
File size: 1,766 Bytes
eef8873 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | import logging
import torch.nn as nn
from torch.optim import AdamW
from src.config import DEVICE, EPOCHS, NUM_CLASSES
from src.models.resnet_model import CarClassifierResNet
from src.data.dataset import create_resnet_dataloaders
from src.training.trainer import train_single_input_model
logger = logging.getLogger(__name__)
def run_resnet_training():
logger.info("Initializing ResNet training pipeline...")
train_loader, eval_loader = create_resnet_dataloaders()
model = CarClassifierResNet(
num_classes=NUM_CLASSES
).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = AdamW([
{
"params": model.model.layer3.parameters(),
"lr": 1e-5
},
{
"params": model.model.layer4.parameters(),
"lr": 1e-5
},
{
"params": model.model.fc.parameters(),
"lr": 1e-4
}
])
logger.info("Starting ResNet training...")
all_preds, all_labels = train_single_input_model(
model=model,
train_loader=train_loader,
eval_loader=eval_loader,
optimizer=optimizer,
criterion=criterion,
device=DEVICE,
epochs=EPOCHS,
checkpoint_model_name="best_resnet_model",
patience=7
)
logger.info("ResNet training completed.")
return all_preds, all_labels
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
preds, labels = run_resnet_training()
print("\nTraining completed successfully.")
print("Prediction samples:", preds[:10])
print("Label samples:", labels[:10]) |