| |
| """Generate UNet notebook.""" |
|
|
| import nbformat as nbf |
|
|
| nb = nbf.v4.new_notebook() |
| nb.metadata = { |
| "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, |
| "language_info": {"name": "python", "version": "3.12.0"}, |
| } |
|
|
| cells = [] |
| def md(s): cells.append(nbf.v4.new_markdown_cell(s)) |
| def code(s): cells.append(nbf.v4.new_code_cell(s)) |
|
|
| md("""\ |
| # UNet: Semantic Segmentation |
| |
| Encoder-decoder with skip connections for pixel-wise classification. |
| """) |
|
|
| md("""\ |
| ## 背景 |
| |
| UNet 最初为医学图像分割设计,核心创新是 **跳跃连接(skip connections)**: |
| 编码器逐步下采样提取高层语义,解码器逐步上采样恢复空间分辨率, |
| 跳跃连接将编码器每层的特征直接拼接到对应解码器层,保留细节信息。 |
| |
| 数据集:Oxford-IIIT Pet — 图像 + 逐像素标注(3 类:前景、背景、轮廓)。 |
| """) |
|
|
| md("""\ |
| ## 数学原理 |
| |
| ### 逐像素分类 |
| |
| 每个像素被分类为 $C$ 类之一: |
| |
| $$\\hat{y}_{i,j} = \\arg\\max_c \\, \\text{logits}_{i,j,c}$$ |
| |
| ### 损失函数 |
| |
| $$\\mathcal{L} = -\\frac{1}{N} \\sum_{i,j} \\sum_c y_{i,j,c} \\log(\\hat{y}_{i,j,c})$$ |
| |
| 忽略未标注像素(`ignore_index=0`)。 |
| |
| ### 架构 |
| |
| ``` |
| Input → Conv+ReLU → ← skip ─── UpConv → Conv+ReLU → Conv1×1 → C classes |
| ↓ MaxPool ↑ |
| Conv+ReLU → ← skip ─── UpConv → Conv+ReLU |
| ↓ MaxPool ↑ |
| Conv+ReLU → ← skip ─── UpConv → Conv+ReLU |
| ↓ MaxPool ↑ |
| Conv+ReLU → ← skip ─── UpConv → Conv+ReLU |
| ↓ MaxPool ↑ |
| Conv+ReLU × 2 (bottleneck) |
| ``` |
| """) |
|
|
| code("""\ |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import DataLoader |
| from torchvision import transforms |
| from torchvision.transforms import functional as TF |
| from datasets import load_dataset |
| import random |
| |
| from cv.unet.model import UNet |
| from utils.device import get_device |
| |
| device = get_device() |
| print(f"Device: {device}") |
| """) |
|
|
| code("""\ |
| class PetDataset(torch.utils.data.Dataset): |
| def __init__(self, split="train", image_size=128, augment=False): |
| self.image_size = image_size |
| self.augment = augment and split == "train" |
| ds = load_dataset("tchevrou/oxford-iiit-pet", split=split) |
| self.images = [item["image"].convert("RGB") for item in ds] |
| self.masks = [item["label"] for item in ds] |
| del ds |
| |
| def __len__(self): |
| return len(self.images) |
| |
| def __getitem__(self, idx): |
| image = self.images[idx] |
| mask = self.masks[idx] |
| |
| image = TF.resize(image, self.image_size, TF.InterpolationMode.BILINEAR) |
| mask = TF.resize(mask, self.image_size, TF.InterpolationMode.NEAREST) |
| |
| if self.augment: |
| if random.random() > 0.5: |
| image = TF.hflip(image); mask = TF.hflip(mask) |
| angle = random.uniform(-10, 10) |
| image = TF.rotate(image, angle, TF.InterpolationMode.BILINEAR) |
| mask = TF.rotate(mask, angle, TF.InterpolationMode.NEAREST) |
| |
| image = TF.to_tensor(image) |
| image = TF.normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
| mask = torch.tensor(list(mask.getdata()), dtype=torch.long).view(self.image_size, self.image_size) |
| return image, mask |
| |
| train_dataset = PetDataset(split="train", image_size=128, augment=True) |
| test_dataset = PetDataset(split="test", image_size=128, augment=False) |
| train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True) |
| test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4, pin_memory=True) |
| print(f"Train: {len(train_dataset):,} Test: {len(test_dataset):,}") |
| """) |
|
|
| code("""\ |
| model = UNet(in_channels=3, num_classes=3).to(device) |
| print(f"Parameters: {model.num_params():,}") |
| """) |
|
|
| md("""\ |
| ## 训练 |
| |
| > ⏱ 预估耗时:**30 epoch × ~40s/epoch ≈ 20 分钟**(M4 Max, batch_size=16) |
| > 如果太久,把下面 `NUM_EPOCHS` 改小到 10 先看趋势。 |
| """) |
|
|
| code("""\ |
| NUM_EPOCHS = 30 |
| LR = 1e-3 |
| |
| criterion = nn.CrossEntropyLoss(ignore_index=0) |
| optimizer = optim.Adam(model.parameters(), lr=LR) |
| |
| train_loss_hist, val_loss_hist, pixel_acc_hist = [], [], [] |
| |
| for epoch in range(1, NUM_EPOCHS + 1): |
| model.train() |
| train_loss = 0.0 |
| for images, masks in train_loader: |
| images, masks = images.to(device), masks.to(device) |
| optimizer.zero_grad() |
| logits = model(images) |
| loss = criterion(logits, masks) |
| loss.backward() |
| optimizer.step() |
| train_loss += loss.item() |
| |
| model.eval() |
| val_loss = 0.0 |
| correct = total = 0 |
| with torch.no_grad(): |
| for images, masks in test_loader: |
| images, masks = images.to(device), masks.to(device) |
| logits = model(images) |
| loss = criterion(logits, masks) |
| val_loss += loss.item() |
| preds = torch.argmax(logits, dim=1) |
| valid = masks != 0 |
| correct += (preds[valid] == masks[valid]).sum().item() |
| total += valid.sum().item() |
| |
| avg_train = train_loss / len(train_loader) |
| avg_val = val_loss / len(test_loader) |
| acc = correct / total * 100 |
| train_loss_hist.append(avg_train) |
| val_loss_hist.append(avg_val) |
| pixel_acc_hist.append(acc) |
| print(f"Epoch [{epoch:2d}/{NUM_EPOCHS}] Train: {avg_train:.4f} Val: {avg_val:.4f} Pixel Acc: {acc:.2f}%") |
| """) |
|
|
| md("""## Loss 曲线 & 像素准确率""") |
|
|
| code("""\ |
| import matplotlib.pyplot as plt |
| |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) |
| ax1.plot(train_loss_hist, label='train', marker='o') |
| ax1.plot(val_loss_hist, label='val', marker='o') |
| ax1.set_xlabel("Epoch"); ax1.set_ylabel("Loss"); ax1.legend(); ax1.grid(True) |
| |
| ax2.plot(pixel_acc_hist, marker='o', color='green') |
| ax2.set_xlabel("Epoch"); ax2.set_ylabel("Pixel Acc (%)"); ax2.grid(True) |
| |
| plt.tight_layout(); plt.show() |
| """) |
|
|
| md("""## 分割效果可视化""") |
|
|
| code("""\ |
| import matplotlib.pyplot as plt |
| import numpy as np |
| from utils.device import get_device |
| |
| model.eval() |
| images, masks = next(iter(test_loader)) |
| with torch.no_grad(): |
| preds = torch.argmax(model(images.to(device)), dim=1).cpu() |
| |
| CLASS_CMAP = np.array([[0,0,0], [0,180,60], [180,60,0], [60,0,180]], dtype=np.uint8) |
| |
| fig, axes = plt.subplots(4, 3, figsize=(12, 16)) |
| for i in range(4): |
| img = images[i].permute(1,2,0).numpy() |
| img = img * [0.229,0.224,0.225] + [0.485,0.456,0.406] |
| img = np.clip(img, 0, 1) |
| axes[i,0].imshow(img); axes[i,0].set_title("Input"); axes[i,0].axis("off") |
| |
| mask_true = CLASS_CMAP[masks[i].numpy()] |
| axes[i,1].imshow(mask_true); axes[i,1].set_title("Ground Truth"); axes[i,1].axis("off") |
| |
| mask_pred = CLASS_CMAP[preds[i].numpy()] |
| axes[i,2].imshow(mask_pred); axes[i,2].set_title("Prediction"); axes[i,2].axis("off") |
| |
| plt.tight_layout(); plt.show() |
| """) |
|
|
| md("""\ |
| ## 思考题 |
| |
| 1. 跳跃连接(skip connection)在 UNet 中起到什么作用?没有它会怎样? |
| 2. 上采样为什么用转置卷积而不用双线性插值?各自的优缺点是什么? |
| 3. 如果使用 Dice Loss 代替 CrossEntropy,分割效果会有什么变化? |
| 4. 试试把 `NUM_EPOCHS` 加到 100,观察 IoU 和 pixel accuracy 的变化。 |
| """) |
|
|
| nb.cells = cells |
| out = "cv/unet/unet.ipynb" |
| with open(out, "w") as f: |
| nbf.write(nb, f) |
| print(f"Generated {out}") |
|
|