File size: 3,811 Bytes
7e56915 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import numpy as np
data_dir = "dataset"
batch_size = 4
num_epochs = 25
lr = 1e-4
img_size = 256
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class NailDataset(Dataset):
def __init__(self, img_dir, mask_dir, transform=None):
self.img_dir = img_dir
self.mask_dir = mask_dir
self.transform = transform
self.images = [f for f in os.listdir(img_dir) if f.endswith(".jpg") or f.endswith(".png")]
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_name = self.images[idx]
img_path = os.path.join(self.img_dir, img_name)
base_name = os.path.splitext(img_name)[0]
mask_name = base_name + ".png"
mask_path = os.path.join(self.mask_dir, mask_name)
if not os.path.exists(mask_path):
raise FileNotFoundError(f"Mask not found for {img_name} → {mask_path}")
image = Image.open(img_path).convert("RGB")
mask = Image.open(mask_path).convert("L")
if self.transform:
image = self.transform(image)
mask = transforms.Resize((img_size, img_size))(mask)
mask = transforms.ToTensor()(mask)
mask = (mask > 0.5).float()
return image, mask
transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ColorJitter(0.2, 0.2, 0.2),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
dataset = NailDataset(os.path.join(data_dir, "images"), os.path.join(data_dir, "masks"), transform)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
def CBR(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.enc1 = CBR(3, 64)
self.enc2 = CBR(64, 128)
self.enc3 = CBR(128, 256)
self.enc4 = CBR(256, 512)
self.pool = nn.MaxPool2d(2)
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
self.dec3 = CBR(512 + 256, 256)
self.dec2 = CBR(256 + 128, 128)
self.dec1 = CBR(128 + 64, 64)
self.final = nn.Conv2d(64, 1, 1)
def forward(self, x):
e1 = self.enc1(x)
e2 = self.enc2(self.pool(e1))
e3 = self.enc3(self.pool(e2))
e4 = self.enc4(self.pool(e3))
d3 = self.up(e4)
d3 = self.dec3(torch.cat([d3, e3], dim=1))
d2 = self.up(d3)
d2 = self.dec2(torch.cat([d2, e2], dim=1))
d1 = self.up(d2)
d1 = self.dec1(torch.cat([d1, e1], dim=1))
out = torch.sigmoid(self.final(d1))
return out
model = UNet().to(device)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
for epoch in range(num_epochs):
model.train()
epoch_loss = 0
loop = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}]")
for imgs, masks in loop:
imgs, masks = imgs.to(device), masks.to(device)
preds = model(imgs)
loss = criterion(preds, masks)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
loop.set_postfix(loss=loss.item())
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(train_loader):.4f}")
os.makedirs("model", exist_ok=True)
torch.save(model.state_dict(), "model/nail_segmentation_unet.pt")
print("✅ Model saved as model/nail_segmentation_unet.pt")
|