| import os | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import transforms | |
| from PIL import Image | |
| from tqdm import tqdm | |
| import numpy as np | |
| data_dir = "dataset" | |
| batch_size = 4 | |
| num_epochs = 25 | |
| lr = 1e-4 | |
| img_size = 256 | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| class NailDataset(Dataset): | |
| def __init__(self, img_dir, mask_dir, transform=None): | |
| self.img_dir = img_dir | |
| self.mask_dir = mask_dir | |
| self.transform = transform | |
| self.images = [f for f in os.listdir(img_dir) if f.endswith(".jpg") or f.endswith(".png")] | |
| def __len__(self): | |
| return len(self.images) | |
| def __getitem__(self, idx): | |
| img_name = self.images[idx] | |
| img_path = os.path.join(self.img_dir, img_name) | |
| base_name = os.path.splitext(img_name)[0] | |
| mask_name = base_name + ".png" | |
| mask_path = os.path.join(self.mask_dir, mask_name) | |
| if not os.path.exists(mask_path): | |
| raise FileNotFoundError(f"Mask not found for {img_name} β {mask_path}") | |
| image = Image.open(img_path).convert("RGB") | |
| mask = Image.open(mask_path).convert("L") | |
| if self.transform: | |
| image = self.transform(image) | |
| mask = transforms.Resize((img_size, img_size))(mask) | |
| mask = transforms.ToTensor()(mask) | |
| mask = (mask > 0.5).float() | |
| return image, mask | |
| transform = transforms.Compose([ | |
| transforms.Resize((img_size, img_size)), | |
| transforms.ColorJitter(0.2, 0.2, 0.2), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.ToTensor(), | |
| ]) | |
| dataset = NailDataset(os.path.join(data_dir, "images"), os.path.join(data_dir, "masks"), transform) | |
| train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) | |
| class UNet(nn.Module): | |
| def __init__(self): | |
| super(UNet, self).__init__() | |
| def CBR(in_channels, out_channels): | |
| return nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, 3, padding=1), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True) | |
| ) | |
| self.enc1 = CBR(3, 64) | |
| self.enc2 = CBR(64, 128) | |
| self.enc3 = CBR(128, 256) | |
| self.enc4 = CBR(256, 512) | |
| self.pool = nn.MaxPool2d(2) | |
| self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) | |
| self.dec3 = CBR(512 + 256, 256) | |
| self.dec2 = CBR(256 + 128, 128) | |
| self.dec1 = CBR(128 + 64, 64) | |
| self.final = nn.Conv2d(64, 1, 1) | |
| def forward(self, x): | |
| e1 = self.enc1(x) | |
| e2 = self.enc2(self.pool(e1)) | |
| e3 = self.enc3(self.pool(e2)) | |
| e4 = self.enc4(self.pool(e3)) | |
| d3 = self.up(e4) | |
| d3 = self.dec3(torch.cat([d3, e3], dim=1)) | |
| d2 = self.up(d3) | |
| d2 = self.dec2(torch.cat([d2, e2], dim=1)) | |
| d1 = self.up(d2) | |
| d1 = self.dec1(torch.cat([d1, e1], dim=1)) | |
| out = torch.sigmoid(self.final(d1)) | |
| return out | |
| model = UNet().to(device) | |
| criterion = nn.BCELoss() | |
| optimizer = torch.optim.Adam(model.parameters(), lr=lr) | |
| for epoch in range(num_epochs): | |
| model.train() | |
| epoch_loss = 0 | |
| loop = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}]") | |
| for imgs, masks in loop: | |
| imgs, masks = imgs.to(device), masks.to(device) | |
| preds = model(imgs) | |
| loss = criterion(preds, masks) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| epoch_loss += loss.item() | |
| loop.set_postfix(loss=loss.item()) | |
| print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(train_loader):.4f}") | |
| os.makedirs("model", exist_ok=True) | |
| torch.save(model.state_dict(), "model/nail_segmentation_unet.pt") | |
| print("β Model saved as model/nail_segmentation_unet.pt") | |