Spaces:
Sleeping
Sleeping
File size: 3,008 Bytes
d581b00 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 | import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from src.models.model import build_efficientnet
from src.data.loader import get_dataloaders
def train(epochs=20, batch_size=32, lr=1e-4):
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
model = build_efficientnet().to(device)
train_loader, val_loader, _ = get_dataloaders(batch_size=batch_size)
# Unfreeze last conv block + classifier
for name, param in model.named_parameters():
if "features.8" in name or "classifier" in name:
param.requires_grad = True
criterion = nn.BCEWithLogitsLoss()
optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
scheduler = ReduceLROnPlateau(optimizer, patience=2)
best_val_loss = float("inf")
early_stop_patience = 3
no_improve_count = 0
for epoch in range(epochs):
model.train()
train_loss, correct, total = 0, 0, 0
for images, labels in train_loader:
images = images.to(device)
labels = labels.float().unsqueeze(1).to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
preds = (torch.sigmoid(outputs) >= 0.5).float()
correct += (preds == labels).sum().item()
total += labels.size(0)
train_acc = correct / total
avg_train_loss = train_loss / len(train_loader)
model.eval()
val_loss, val_correct, val_total = 0, 0, 0
with torch.no_grad():
for images, labels in val_loader:
images = images.to(device)
labels = labels.float().unsqueeze(1).to(device)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item()
preds = (torch.sigmoid(outputs) >= 0.5).float()
val_correct += (preds == labels).sum().item()
val_total += labels.size(0)
val_acc = val_correct / val_total
avg_val_loss = val_loss / len(val_loader)
scheduler.step(avg_val_loss)
print(f"Epoch {epoch+1}/{epochs} | "
f"Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.4f} | "
f"Val Loss: {avg_val_loss:.4f} | Val Acc: {val_acc:.4f}")
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
no_improve_count = 0
torch.save(model.state_dict(), "saved_models/efficientnet_best.pth")
print(f" -> Best model saved")
else:
no_improve_count += 1
if no_improve_count >= early_stop_patience:
print(f"Early stopping at epoch {epoch+1}")
break
if __name__ == "__main__":
train() |