Spaces:
Sleeping
Sleeping
File size: 4,691 Bytes
57d41d5 84d0c9e 57d41d5 84d0c9e 57d41d5 84d0c9e 57d41d5 84d0c9e 57d41d5 84d0c9e 57d41d5 84d0c9e 57d41d5 84d0c9e 57d41d5 84d0c9e 57d41d5 84d0c9e 57d41d5 84d0c9e 57d41d5 84d0c9e 57d41d5 84d0c9e 57d41d5 84d0c9e 57d41d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
from datetime import datetime
import os
import logging
import torch
import torch.nn as nn
import torch.optim as optim
from src import config
import time
from torch.utils.tensorboard import SummaryWriter
def calculate_accuracy(y_pred, y_true):
preds = torch.argmax(y_pred, dim=1)
correct = (preds == y_true).sum().item()
return correct / len(y_true)
def setup_logging(log_dir):
os.makedirs(log_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = os.path.join(log_dir, f"training_{timestamp}.log")
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
return log_file
def train_one_epoch(model, dataloader, criterion, optimizer, device):
model.train()
running_loss, running_acc = 0.0, 0.0
batch_count = len(dataloader)
logging.info(f"Training on {batch_count} batches")
for batch_idx, (images, labels) in enumerate(dataloader):
if batch_idx % 10 == 0:
logging.info(f" Batch {batch_idx}/{batch_count}")
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
acc = calculate_accuracy(outputs, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=config.GRAD_CLIP_VALUE)
optimizer.step()
running_loss += loss.item()
running_acc += acc
return running_loss / len(dataloader), running_acc / len(dataloader)
def train_model(model, train_loader, val_loader, epochs=config.EPOCHS, lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY, device=config.DEVICE):
log_file = setup_logging(config.LOG_DIR)
logging.info(f"Training logs will be saved to: {log_file}")
logging.info(f"Training configuration:")
logging.info(f" Epochs: {epochs}")
logging.info(f" Learning rate: {lr}")
logging.info(f" Weight decay: {weight_decay}")
logging.info(f" Device: {device}")
logging.info(f" Batch size: {config.BATCH_SIZE}")
logging.info(f" Image size: {config.IMAGE_SIZE}")
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr,
weight_decay=weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='max',
factor=config.LR_SCHEDULER_FACTOR,
patience=config.LR_SCHEDULER_PATIENCE,
verbose=True
)
criterion = nn.CrossEntropyLoss()
best_val_acc = 0.0
run_name = time.strftime("run_%Y%m%d-%H%M")
log_dir = f"{config.LOG_DIR}/{run_name}"
writer = SummaryWriter(log_dir=log_dir)
logging.info(f"Training on: {device.upper()}\n")
for epoch in range(epochs):
epoch_start_time = time.time()
logging.info(f"Epoch {epoch+1}/{epochs} started")
train_loss, train_acc = train_one_epoch(
model, train_loader, criterion, optimizer, device)
logging.info("Validating...")
val_loss, val_acc = validate(model, val_loader, criterion, device)
epoch_time = time.time() - epoch_start_time
scheduler.step(val_acc)
logging.info(
f"Epoch {epoch+1}/{epochs} completed in {epoch_time:.2f}s")
logging.info(
f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
logging.info(
f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc*100:.2f}%")
writer.add_scalar("Loss/train", train_loss, epoch)
writer.add_scalar("Loss/val", val_loss, epoch)
writer.add_scalar("Accuracy/train", train_acc, epoch)
writer.add_scalar("Accuracy/val", val_acc, epoch)
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), config.MODEL_SAVE_PATH)
logging.info("Model saved!")
writer.close()
logging.info("Training complete. Best Val Acc: {:.2f}%".format(
best_val_acc * 100))
return best_val_acc
def validate(model, dataloader, criterion, device):
model.eval()
val_loss, val_acc = 0.0, 0.0
with torch.no_grad():
for images, labels in dataloader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
acc = calculate_accuracy(outputs, labels)
val_loss += loss.item()
val_acc += acc
return val_loss / len(dataloader), val_acc / len(dataloader)
|