dl-from-scratch / scripts /gen_unet_notebook.py
yusiwen's picture
refactor: restructure into domain-grounded directories (ml/ cv/ gen/ graph/ rl/ nlp/)
73526b1 unverified
Raw
History Blame Contribute Delete
7.54 kB
#!/usr/bin/env python3
"""Generate UNet notebook."""
import nbformat as nbf
nb = nbf.v4.new_notebook()
nb.metadata = {
"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
"language_info": {"name": "python", "version": "3.12.0"},
}
cells = []
def md(s): cells.append(nbf.v4.new_markdown_cell(s))
def code(s): cells.append(nbf.v4.new_code_cell(s))
md("""\
# UNet: Semantic Segmentation
Encoder-decoder with skip connections for pixel-wise classification.
""")
md("""\
## 背景
UNet 最初为医学图像分割设计,核心创新是 **跳跃连接(skip connections)**:
编码器逐步下采样提取高层语义,解码器逐步上采样恢复空间分辨率,
跳跃连接将编码器每层的特征直接拼接到对应解码器层,保留细节信息。
数据集:Oxford-IIIT Pet — 图像 + 逐像素标注(3 类:前景、背景、轮廓)。
""")
md("""\
## 数学原理
### 逐像素分类
每个像素被分类为 $C$ 类之一:
$$\\hat{y}_{i,j} = \\arg\\max_c \\, \\text{logits}_{i,j,c}$$
### 损失函数
$$\\mathcal{L} = -\\frac{1}{N} \\sum_{i,j} \\sum_c y_{i,j,c} \\log(\\hat{y}_{i,j,c})$$
忽略未标注像素(`ignore_index=0`)。
### 架构
```
Input → Conv+ReLU → ← skip ─── UpConv → Conv+ReLU → Conv1×1 → C classes
↓ MaxPool ↑
Conv+ReLU → ← skip ─── UpConv → Conv+ReLU
↓ MaxPool ↑
Conv+ReLU → ← skip ─── UpConv → Conv+ReLU
↓ MaxPool ↑
Conv+ReLU → ← skip ─── UpConv → Conv+ReLU
↓ MaxPool ↑
Conv+ReLU × 2 (bottleneck)
```
""")
code("""\
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.transforms import functional as TF
from datasets import load_dataset
import random
from cv.unet.model import UNet
from utils.device import get_device
device = get_device()
print(f"Device: {device}")
""")
code("""\
class PetDataset(torch.utils.data.Dataset):
def __init__(self, split="train", image_size=128, augment=False):
self.image_size = image_size
self.augment = augment and split == "train"
ds = load_dataset("tchevrou/oxford-iiit-pet", split=split)
self.images = [item["image"].convert("RGB") for item in ds]
self.masks = [item["label"] for item in ds]
del ds
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = self.images[idx]
mask = self.masks[idx]
image = TF.resize(image, self.image_size, TF.InterpolationMode.BILINEAR)
mask = TF.resize(mask, self.image_size, TF.InterpolationMode.NEAREST)
if self.augment:
if random.random() > 0.5:
image = TF.hflip(image); mask = TF.hflip(mask)
angle = random.uniform(-10, 10)
image = TF.rotate(image, angle, TF.InterpolationMode.BILINEAR)
mask = TF.rotate(mask, angle, TF.InterpolationMode.NEAREST)
image = TF.to_tensor(image)
image = TF.normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
mask = torch.tensor(list(mask.getdata()), dtype=torch.long).view(self.image_size, self.image_size)
return image, mask
train_dataset = PetDataset(split="train", image_size=128, augment=True)
test_dataset = PetDataset(split="test", image_size=128, augment=False)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4, pin_memory=True)
print(f"Train: {len(train_dataset):,} Test: {len(test_dataset):,}")
""")
code("""\
model = UNet(in_channels=3, num_classes=3).to(device)
print(f"Parameters: {model.num_params():,}")
""")
md("""\
## 训练
> ⏱ 预估耗时:**30 epoch × ~40s/epoch ≈ 20 分钟**(M4 Max, batch_size=16)
> 如果太久,把下面 `NUM_EPOCHS` 改小到 10 先看趋势。
""")
code("""\
NUM_EPOCHS = 30
LR = 1e-3
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=LR)
train_loss_hist, val_loss_hist, pixel_acc_hist = [], [], []
for epoch in range(1, NUM_EPOCHS + 1):
model.train()
train_loss = 0.0
for images, masks in train_loader:
images, masks = images.to(device), masks.to(device)
optimizer.zero_grad()
logits = model(images)
loss = criterion(logits, masks)
loss.backward()
optimizer.step()
train_loss += loss.item()
model.eval()
val_loss = 0.0
correct = total = 0
with torch.no_grad():
for images, masks in test_loader:
images, masks = images.to(device), masks.to(device)
logits = model(images)
loss = criterion(logits, masks)
val_loss += loss.item()
preds = torch.argmax(logits, dim=1)
valid = masks != 0
correct += (preds[valid] == masks[valid]).sum().item()
total += valid.sum().item()
avg_train = train_loss / len(train_loader)
avg_val = val_loss / len(test_loader)
acc = correct / total * 100
train_loss_hist.append(avg_train)
val_loss_hist.append(avg_val)
pixel_acc_hist.append(acc)
print(f"Epoch [{epoch:2d}/{NUM_EPOCHS}] Train: {avg_train:.4f} Val: {avg_val:.4f} Pixel Acc: {acc:.2f}%")
""")
md("""## Loss 曲线 & 像素准确率""")
code("""\
import matplotlib.pyplot as plt
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(train_loss_hist, label='train', marker='o')
ax1.plot(val_loss_hist, label='val', marker='o')
ax1.set_xlabel("Epoch"); ax1.set_ylabel("Loss"); ax1.legend(); ax1.grid(True)
ax2.plot(pixel_acc_hist, marker='o', color='green')
ax2.set_xlabel("Epoch"); ax2.set_ylabel("Pixel Acc (%)"); ax2.grid(True)
plt.tight_layout(); plt.show()
""")
md("""## 分割效果可视化""")
code("""\
import matplotlib.pyplot as plt
import numpy as np
from utils.device import get_device
model.eval()
images, masks = next(iter(test_loader))
with torch.no_grad():
preds = torch.argmax(model(images.to(device)), dim=1).cpu()
CLASS_CMAP = np.array([[0,0,0], [0,180,60], [180,60,0], [60,0,180]], dtype=np.uint8)
fig, axes = plt.subplots(4, 3, figsize=(12, 16))
for i in range(4):
img = images[i].permute(1,2,0).numpy()
img = img * [0.229,0.224,0.225] + [0.485,0.456,0.406]
img = np.clip(img, 0, 1)
axes[i,0].imshow(img); axes[i,0].set_title("Input"); axes[i,0].axis("off")
mask_true = CLASS_CMAP[masks[i].numpy()]
axes[i,1].imshow(mask_true); axes[i,1].set_title("Ground Truth"); axes[i,1].axis("off")
mask_pred = CLASS_CMAP[preds[i].numpy()]
axes[i,2].imshow(mask_pred); axes[i,2].set_title("Prediction"); axes[i,2].axis("off")
plt.tight_layout(); plt.show()
""")
md("""\
## 思考题
1. 跳跃连接(skip connection)在 UNet 中起到什么作用?没有它会怎样?
2. 上采样为什么用转置卷积而不用双线性插值?各自的优缺点是什么?
3. 如果使用 Dice Loss 代替 CrossEntropy,分割效果会有什么变化?
4. 试试把 `NUM_EPOCHS` 加到 100,观察 IoU 和 pixel accuracy 的变化。
""")
nb.cells = cells
out = "cv/unet/unet.ipynb"
with open(out, "w") as f:
nbf.write(nb, f)
print(f"Generated {out}")