#%% import torch import torch.nn as nn import torch.optim as optim from model import * from dataset import NosePointDataset image_size = (64, 64) batch_size = 32 num_epochs = 1000 lr = 1e-3 val_split = 0.2 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dataset = NosePointDataset(image_size=image_size) train, val = torch.utils.data.random_split(dataset, [int(len(dataset) * (1 - val_split)), len(dataset) - int(len(dataset) * (1 - val_split))]) train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True) val_loader = torch.utils.data.DataLoader(val, batch_size=batch_size, shuffle=False) # model = NosePointRegressor(input_channels=3).to(device) model = ResNetNoseRegressor(pretrained=True).to(device) # criterion = nn.MSELoss() criterion = nn.SmoothL1Loss() optimizer = optim.Adam(model.parameters(), lr=lr) # %% import matplotlib.pyplot as plt from tqdm import tqdm save_path = "best_model.pth" plot_path = "loss_plot.png" train_losses = [] val_losses = [] best_val_loss = float('inf') # ===== Training Loop ===== for epoch in range(num_epochs): model.train() train_loss = 0.0 for images, targets in tqdm(train_loader): images, targets = images.to(device), targets.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, targets) loss.backward() optimizer.step() train_loss += loss.item() * images.size(0) train_loss /= len(train_loader.dataset) model.eval() val_loss = 0.0 with torch.no_grad(): for images, targets in val_loader: images, targets = images.to(device), targets.to(device) outputs = model(images) loss = criterion(outputs, targets) val_loss += loss.item() * images.size(0) val_loss /= len(val_loader.dataset) # Logging train_losses.append(train_loss) val_losses.append(val_loss) print(f"[Epoch {epoch+1}/{num_epochs}] Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}") # Save best model if val_loss < best_val_loss: best_val_loss = val_loss torch.save(model.state_dict(), save_path) print("✅ Saved best model.") # Save plot plt.figure(figsize=(6, 4)) plt.plot(range(1, len(train_losses)+1), train_losses, label="Train Loss") plt.plot(range(1, len(val_losses)+1), val_losses, label="Val Loss") plt.xlabel("Epoch") plt.ylabel("Loss") plt.title("Training vs Validation Loss") plt.legend() plt.grid(True) plt.tight_layout() plt.savefig(plot_path) plt.close() print("✅ Training complete.") # %%