Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torch.optim import Adam | |
| from torch.optim.lr_scheduler import ReduceLROnPlateau | |
| from src.models.model import build_model | |
| from src.data.loader import get_dataloaders | |
| def train(epochs=10, batch_size=32, lr=1e-3): | |
| device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| model = build_model().to(device) | |
| # Unfreeze layer4 and fc for better learning | |
| for name, param in model.named_parameters(): | |
| if "layer4" in name or "fc" in name: | |
| param.requires_grad = True | |
| train_loader, val_loader, _ = get_dataloaders(batch_size=batch_size) | |
| criterion = nn.BCEWithLogitsLoss() | |
| optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4) | |
| scheduler = ReduceLROnPlateau(optimizer, patience=2) | |
| best_val_loss = float("inf") | |
| early_stop_patience = 3 | |
| no_improve_count = 0 | |
| for epoch in range(epochs): | |
| # Training | |
| model.train() | |
| train_loss, correct, total = 0, 0, 0 | |
| for images, labels in train_loader: | |
| images = images.to(device) | |
| labels = labels.float().unsqueeze(1).to(device) | |
| optimizer.zero_grad() | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| train_loss += loss.item() | |
| preds = (torch.sigmoid(outputs) >= 0.5).float() | |
| correct += (preds == labels).sum().item() | |
| total += labels.size(0) | |
| train_acc = correct / total | |
| avg_train_loss = train_loss / len(train_loader) | |
| # Validation | |
| model.eval() | |
| val_loss, val_correct, val_total = 0, 0, 0 | |
| with torch.no_grad(): | |
| for images, labels in val_loader: | |
| images = images.to(device) | |
| labels = labels.float().unsqueeze(1).to(device) | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| val_loss += loss.item() | |
| preds = (torch.sigmoid(outputs) >= 0.5).float() | |
| val_correct += (preds == labels).sum().item() | |
| val_total += labels.size(0) | |
| val_acc = val_correct / val_total | |
| avg_val_loss = val_loss / len(val_loader) | |
| scheduler.step(avg_val_loss) | |
| print(f"Epoch {epoch+1}/{epochs} | " | |
| f"Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.4f} | " | |
| f"Val Loss: {avg_val_loss:.4f} | Val Acc: {val_acc:.4f}") | |
| # Save best model | |
| if avg_val_loss < best_val_loss: | |
| best_val_loss = avg_val_loss | |
| no_improve_count = 0 | |
| torch.save(model.state_dict(), "saved_models/best_model.pth") | |
| print(f" -> Best model saved") | |
| else: | |
| no_improve_count += 1 | |
| if no_improve_count >= early_stop_patience: | |
| print(f"Early stopping at epoch {epoch+1}") | |
| break | |
| if __name__ == "__main__": | |
| train() |