neecat's picture
add modified files
57d41d5
from datetime import datetime
import os
import logging
import torch
import torch.nn as nn
import torch.optim as optim
from src import config
import time
from torch.utils.tensorboard import SummaryWriter
def calculate_accuracy(y_pred, y_true):
preds = torch.argmax(y_pred, dim=1)
correct = (preds == y_true).sum().item()
return correct / len(y_true)
def setup_logging(log_dir):
os.makedirs(log_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = os.path.join(log_dir, f"training_{timestamp}.log")
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
return log_file
def train_one_epoch(model, dataloader, criterion, optimizer, device):
model.train()
running_loss, running_acc = 0.0, 0.0
batch_count = len(dataloader)
logging.info(f"Training on {batch_count} batches")
for batch_idx, (images, labels) in enumerate(dataloader):
if batch_idx % 10 == 0:
logging.info(f" Batch {batch_idx}/{batch_count}")
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
acc = calculate_accuracy(outputs, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=config.GRAD_CLIP_VALUE)
optimizer.step()
running_loss += loss.item()
running_acc += acc
return running_loss / len(dataloader), running_acc / len(dataloader)
def train_model(model, train_loader, val_loader, epochs=config.EPOCHS, lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY, device=config.DEVICE):
log_file = setup_logging(config.LOG_DIR)
logging.info(f"Training logs will be saved to: {log_file}")
logging.info(f"Training configuration:")
logging.info(f" Epochs: {epochs}")
logging.info(f" Learning rate: {lr}")
logging.info(f" Weight decay: {weight_decay}")
logging.info(f" Device: {device}")
logging.info(f" Batch size: {config.BATCH_SIZE}")
logging.info(f" Image size: {config.IMAGE_SIZE}")
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr,
weight_decay=weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='max',
factor=config.LR_SCHEDULER_FACTOR,
patience=config.LR_SCHEDULER_PATIENCE,
verbose=True
)
criterion = nn.CrossEntropyLoss()
best_val_acc = 0.0
run_name = time.strftime("run_%Y%m%d-%H%M")
log_dir = f"{config.LOG_DIR}/{run_name}"
writer = SummaryWriter(log_dir=log_dir)
logging.info(f"Training on: {device.upper()}\n")
for epoch in range(epochs):
epoch_start_time = time.time()
logging.info(f"Epoch {epoch+1}/{epochs} started")
train_loss, train_acc = train_one_epoch(
model, train_loader, criterion, optimizer, device)
logging.info("Validating...")
val_loss, val_acc = validate(model, val_loader, criterion, device)
epoch_time = time.time() - epoch_start_time
scheduler.step(val_acc)
logging.info(
f"Epoch {epoch+1}/{epochs} completed in {epoch_time:.2f}s")
logging.info(
f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
logging.info(
f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc*100:.2f}%")
writer.add_scalar("Loss/train", train_loss, epoch)
writer.add_scalar("Loss/val", val_loss, epoch)
writer.add_scalar("Accuracy/train", train_acc, epoch)
writer.add_scalar("Accuracy/val", val_acc, epoch)
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), config.MODEL_SAVE_PATH)
logging.info("Model saved!")
writer.close()
logging.info("Training complete. Best Val Acc: {:.2f}%".format(
best_val_acc * 100))
return best_val_acc
def validate(model, dataloader, criterion, device):
model.eval()
val_loss, val_acc = 0.0, 0.0
with torch.no_grad():
for images, labels in dataloader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
acc = calculate_accuracy(outputs, labels)
val_loss += loss.item()
val_acc += acc
return val_loss / len(dataloader), val_acc / len(dataloader)