import os import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image from tqdm import tqdm import numpy as np data_dir = "dataset" batch_size = 4 num_epochs = 25 lr = 1e-4 img_size = 256 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class NailDataset(Dataset): def __init__(self, img_dir, mask_dir, transform=None): self.img_dir = img_dir self.mask_dir = mask_dir self.transform = transform self.images = [f for f in os.listdir(img_dir) if f.endswith(".jpg") or f.endswith(".png")] def __len__(self): return len(self.images) def __getitem__(self, idx): img_name = self.images[idx] img_path = os.path.join(self.img_dir, img_name) base_name = os.path.splitext(img_name)[0] mask_name = base_name + ".png" mask_path = os.path.join(self.mask_dir, mask_name) if not os.path.exists(mask_path): raise FileNotFoundError(f"Mask not found for {img_name} → {mask_path}") image = Image.open(img_path).convert("RGB") mask = Image.open(mask_path).convert("L") if self.transform: image = self.transform(image) mask = transforms.Resize((img_size, img_size))(mask) mask = transforms.ToTensor()(mask) mask = (mask > 0.5).float() return image, mask transform = transforms.Compose([ transforms.Resize((img_size, img_size)), transforms.ColorJitter(0.2, 0.2, 0.2), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) dataset = NailDataset(os.path.join(data_dir, "images"), os.path.join(data_dir, "masks"), transform) train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) class UNet(nn.Module): def __init__(self): super(UNet, self).__init__() def CBR(in_channels, out_channels): return nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) self.enc1 = CBR(3, 64) self.enc2 = CBR(64, 128) self.enc3 = CBR(128, 256) self.enc4 = CBR(256, 512) self.pool = nn.MaxPool2d(2) self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) self.dec3 = CBR(512 + 256, 256) self.dec2 = CBR(256 + 128, 128) self.dec1 = CBR(128 + 64, 64) self.final = nn.Conv2d(64, 1, 1) def forward(self, x): e1 = self.enc1(x) e2 = self.enc2(self.pool(e1)) e3 = self.enc3(self.pool(e2)) e4 = self.enc4(self.pool(e3)) d3 = self.up(e4) d3 = self.dec3(torch.cat([d3, e3], dim=1)) d2 = self.up(d3) d2 = self.dec2(torch.cat([d2, e2], dim=1)) d1 = self.up(d2) d1 = self.dec1(torch.cat([d1, e1], dim=1)) out = torch.sigmoid(self.final(d1)) return out model = UNet().to(device) criterion = nn.BCELoss() optimizer = torch.optim.Adam(model.parameters(), lr=lr) for epoch in range(num_epochs): model.train() epoch_loss = 0 loop = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}]") for imgs, masks in loop: imgs, masks = imgs.to(device), masks.to(device) preds = model(imgs) loss = criterion(preds, masks) optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss += loss.item() loop.set_postfix(loss=loss.item()) print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(train_loader):.4f}") os.makedirs("model", exist_ok=True) torch.save(model.state_dict(), "model/nail_segmentation_unet.pt") print("✅ Model saved as model/nail_segmentation_unet.pt")