ImageTrust-AI / src /models /train.py
SiemonCha's picture
initial deployment
d581b00
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from src.models.model import build_model
from src.data.loader import get_dataloaders
def train(epochs=10, batch_size=32, lr=1e-3):
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
model = build_model().to(device)
# Unfreeze layer4 and fc for better learning
for name, param in model.named_parameters():
if "layer4" in name or "fc" in name:
param.requires_grad = True
train_loader, val_loader, _ = get_dataloaders(batch_size=batch_size)
criterion = nn.BCEWithLogitsLoss()
optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
scheduler = ReduceLROnPlateau(optimizer, patience=2)
best_val_loss = float("inf")
early_stop_patience = 3
no_improve_count = 0
for epoch in range(epochs):
# Training
model.train()
train_loss, correct, total = 0, 0, 0
for images, labels in train_loader:
images = images.to(device)
labels = labels.float().unsqueeze(1).to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
preds = (torch.sigmoid(outputs) >= 0.5).float()
correct += (preds == labels).sum().item()
total += labels.size(0)
train_acc = correct / total
avg_train_loss = train_loss / len(train_loader)
# Validation
model.eval()
val_loss, val_correct, val_total = 0, 0, 0
with torch.no_grad():
for images, labels in val_loader:
images = images.to(device)
labels = labels.float().unsqueeze(1).to(device)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item()
preds = (torch.sigmoid(outputs) >= 0.5).float()
val_correct += (preds == labels).sum().item()
val_total += labels.size(0)
val_acc = val_correct / val_total
avg_val_loss = val_loss / len(val_loader)
scheduler.step(avg_val_loss)
print(f"Epoch {epoch+1}/{epochs} | "
f"Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.4f} | "
f"Val Loss: {avg_val_loss:.4f} | Val Acc: {val_acc:.4f}")
# Save best model
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
no_improve_count = 0
torch.save(model.state_dict(), "saved_models/best_model.pth")
print(f" -> Best model saved")
else:
no_improve_count += 1
if no_improve_count >= early_stop_patience:
print(f"Early stopping at epoch {epoch+1}")
break
if __name__ == "__main__":
train()