Spaces:
Sleeping
Sleeping
Commit ·
d555eda
1
Parent(s): 301b9b6
Optimize notebook for CPU training: MobileNetV2 encoder, 15 epochs, batch 4
Browse files
train_change_detection_model.ipynb
CHANGED
|
@@ -6,13 +6,12 @@
|
|
| 6 |
"source": [
|
| 7 |
"# Satellite Change Detection — Siamese U-Net Training\n",
|
| 8 |
"\n",
|
| 9 |
-
"This notebook trains a **Siamese U-Net** on the **LEVIR-CD** dataset for pixel-level\n",
|
| 10 |
"satellite image change detection. The exported model plugs directly into the\n",
|
| 11 |
"AI Change Detection web app.\n",
|
| 12 |
"\n",
|
| 13 |
-
"**
|
| 14 |
-
"
|
| 15 |
-
"Training takes ~2-3 hours on a free T4."
|
| 16 |
]
|
| 17 |
},
|
| 18 |
{
|
|
@@ -197,10 +196,10 @@
|
|
| 197 |
"val_ds = LEVIRCDDataset(DATA_ROOT, \"val\", val_transform)\n",
|
| 198 |
"test_ds = LEVIRCDDataset(DATA_ROOT, \"test\", val_transform)\n",
|
| 199 |
"\n",
|
| 200 |
-
"BATCH =
|
| 201 |
-
"train_dl = DataLoader(train_ds, batch_size=BATCH, shuffle=True, num_workers=
|
| 202 |
-
"val_dl = DataLoader(val_ds, batch_size=BATCH, shuffle=False, num_workers=
|
| 203 |
-
"test_dl = DataLoader(test_ds, batch_size=BATCH, shuffle=False, num_workers=
|
| 204 |
"\n",
|
| 205 |
"print(f\"Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}\")"
|
| 206 |
],
|
|
@@ -214,7 +213,7 @@
|
|
| 214 |
"## 4. Siamese U-Net Model\n",
|
| 215 |
"\n",
|
| 216 |
"Architecture:\n",
|
| 217 |
-
"- **Shared encoder** (
|
| 218 |
"- Feature maps from both branches are **concatenated** at each decoder level\n",
|
| 219 |
"- Standard U-Net decoder produces a binary change mask"
|
| 220 |
]
|
|
@@ -228,6 +227,9 @@
|
|
| 228 |
"import segmentation_models_pytorch as smp\n",
|
| 229 |
"\n",
|
| 230 |
"\n",
|
|
|
|
|
|
|
|
|
|
| 231 |
"class SiameseUNet(nn.Module):\n",
|
| 232 |
" \"\"\"\n",
|
| 233 |
" Siamese U-Net for change detection.\n",
|
|
@@ -235,9 +237,8 @@
|
|
| 235 |
" concatenated features are decoded into a binary change mask.\n",
|
| 236 |
" \"\"\"\n",
|
| 237 |
"\n",
|
| 238 |
-
" def __init__(self, encoder_name=
|
| 239 |
" super().__init__()\n",
|
| 240 |
-
" # Build a standard U-Net to reuse its encoder and decoder pieces\n",
|
| 241 |
" aux = smp.Unet(\n",
|
| 242 |
" encoder_name=encoder_name,\n",
|
| 243 |
" encoder_weights=\"imagenet\" if pretrained else None,\n",
|
|
@@ -246,8 +247,7 @@
|
|
| 246 |
" )\n",
|
| 247 |
" self.encoder = aux.encoder\n",
|
| 248 |
"\n",
|
| 249 |
-
"
|
| 250 |
-
" encoder_channels = self.encoder.out_channels # e.g. (3,64,64,128,256,512)\n",
|
| 251 |
" doubled = tuple(c * 2 for c in encoder_channels)\n",
|
| 252 |
"\n",
|
| 253 |
" self.decoder = smp.decoders.unet.decoder.UnetDecoder(\n",
|
|
@@ -261,23 +261,21 @@
|
|
| 261 |
" self.head = nn.Conv2d(16, 1, kernel_size=1)\n",
|
| 262 |
"\n",
|
| 263 |
" def forward(self, img_a, img_b):\n",
|
| 264 |
-
" # Shared encoder for both temporal images\n",
|
| 265 |
" feats_a = self.encoder(img_a)\n",
|
| 266 |
" feats_b = self.encoder(img_b)\n",
|
| 267 |
-
"\n",
|
| 268 |
-
" # Concatenate features at every level\n",
|
| 269 |
" feats_cat = [torch.cat([fa, fb], dim=1) for fa, fb in zip(feats_a, feats_b)]\n",
|
| 270 |
-
"\n",
|
| 271 |
" decoded = self.decoder(*feats_cat)\n",
|
| 272 |
" logits = self.head(decoded)\n",
|
| 273 |
" return logits\n",
|
| 274 |
"\n",
|
| 275 |
"\n",
|
| 276 |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 277 |
-
"model = SiameseUNet(encoder_name=
|
| 278 |
"\n",
|
| 279 |
"total_params = sum(p.numel() for p in model.parameters()) / 1e6\n",
|
| 280 |
-
"print(f\"Model on {device}, {total_params:.1f}M parameters\")"
|
|
|
|
|
|
|
| 281 |
],
|
| 282 |
"execution_count": null,
|
| 283 |
"outputs": []
|
|
@@ -337,10 +335,11 @@
|
|
| 337 |
"cell_type": "code",
|
| 338 |
"metadata": {},
|
| 339 |
"source": [
|
|
|
|
| 340 |
"from tqdm.auto import tqdm\n",
|
| 341 |
"\n",
|
| 342 |
-
"NUM_EPOCHS =
|
| 343 |
-
"LR =
|
| 344 |
"\n",
|
| 345 |
"criterion = BCEDiceLoss(bce_weight=0.5)\n",
|
| 346 |
"optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-4)\n",
|
|
@@ -348,8 +347,11 @@
|
|
| 348 |
"\n",
|
| 349 |
"best_f1 = 0.0\n",
|
| 350 |
"history = {\"train_loss\": [], \"val_loss\": [], \"val_f1\": [], \"val_iou\": []}\n",
|
|
|
|
| 351 |
"\n",
|
| 352 |
"for epoch in range(1, NUM_EPOCHS + 1):\n",
|
|
|
|
|
|
|
| 353 |
" # --- Train ---\n",
|
| 354 |
" model.train()\n",
|
| 355 |
" running_loss = 0.0\n",
|
|
@@ -395,13 +397,18 @@
|
|
| 395 |
" history[\"val_f1\"].append(metrics[\"f1\"])\n",
|
| 396 |
" history[\"val_iou\"].append(metrics[\"iou\"])\n",
|
| 397 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 398 |
" print(\n",
|
| 399 |
" f\"Epoch {epoch:02d} | \"\n",
|
| 400 |
" f\"train_loss={train_loss:.4f} | \"\n",
|
| 401 |
" f\"val_loss={val_loss:.4f} | \"\n",
|
| 402 |
" f\"F1={metrics['f1']:.4f} | \"\n",
|
| 403 |
" f\"IoU={metrics['iou']:.4f} | \"\n",
|
| 404 |
-
" f\"P={metrics['precision']:.4f} R={metrics['recall']:.4f}\"\n",
|
|
|
|
| 405 |
" )\n",
|
| 406 |
"\n",
|
| 407 |
" if metrics[\"f1\"] > best_f1:\n",
|
|
@@ -409,7 +416,8 @@
|
|
| 409 |
" torch.save(model.state_dict(), \"best_siamese_unet.pth\")\n",
|
| 410 |
" print(f\" >> Saved best model (F1={best_f1:.4f})\")\n",
|
| 411 |
"\n",
|
| 412 |
-
"
|
|
|
|
| 413 |
],
|
| 414 |
"execution_count": null,
|
| 415 |
"outputs": []
|
|
|
|
| 6 |
"source": [
|
| 7 |
"# Satellite Change Detection — Siamese U-Net Training\n",
|
| 8 |
"\n",
|
| 9 |
+
"This notebook trains a **Siamese U-Net** on the **LEVIR-CD+** dataset for pixel-level\n",
|
| 10 |
"satellite image change detection. The exported model plugs directly into the\n",
|
| 11 |
"AI Change Detection web app.\n",
|
| 12 |
"\n",
|
| 13 |
+
"**Optimized for CPU** — uses a lightweight MobileNetV2 encoder and 15 epochs.\n",
|
| 14 |
+
"Training takes ~3-4 hours on a Colab CPU runtime."
|
|
|
|
| 15 |
]
|
| 16 |
},
|
| 17 |
{
|
|
|
|
| 196 |
"val_ds = LEVIRCDDataset(DATA_ROOT, \"val\", val_transform)\n",
|
| 197 |
"test_ds = LEVIRCDDataset(DATA_ROOT, \"test\", val_transform)\n",
|
| 198 |
"\n",
|
| 199 |
+
"BATCH = 4 # smaller batch for CPU\n",
|
| 200 |
+
"train_dl = DataLoader(train_ds, batch_size=BATCH, shuffle=True, num_workers=0, pin_memory=False)\n",
|
| 201 |
+
"val_dl = DataLoader(val_ds, batch_size=BATCH, shuffle=False, num_workers=0, pin_memory=False)\n",
|
| 202 |
+
"test_dl = DataLoader(test_ds, batch_size=BATCH, shuffle=False, num_workers=0, pin_memory=False)\n",
|
| 203 |
"\n",
|
| 204 |
"print(f\"Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}\")"
|
| 205 |
],
|
|
|
|
| 213 |
"## 4. Siamese U-Net Model\n",
|
| 214 |
"\n",
|
| 215 |
"Architecture:\n",
|
| 216 |
+
"- **Shared encoder** (MobileNetV2, ImageNet pretrained) — lightweight and fast on CPU\n",
|
| 217 |
"- Feature maps from both branches are **concatenated** at each decoder level\n",
|
| 218 |
"- Standard U-Net decoder produces a binary change mask"
|
| 219 |
]
|
|
|
|
| 227 |
"import segmentation_models_pytorch as smp\n",
|
| 228 |
"\n",
|
| 229 |
"\n",
|
| 230 |
+
"ENCODER_NAME = \"mobilenet_v2\" # lightweight encoder for CPU training\n",
|
| 231 |
+
"\n",
|
| 232 |
+
"\n",
|
| 233 |
"class SiameseUNet(nn.Module):\n",
|
| 234 |
" \"\"\"\n",
|
| 235 |
" Siamese U-Net for change detection.\n",
|
|
|
|
| 237 |
" concatenated features are decoded into a binary change mask.\n",
|
| 238 |
" \"\"\"\n",
|
| 239 |
"\n",
|
| 240 |
+
" def __init__(self, encoder_name=ENCODER_NAME, pretrained=True):\n",
|
| 241 |
" super().__init__()\n",
|
|
|
|
| 242 |
" aux = smp.Unet(\n",
|
| 243 |
" encoder_name=encoder_name,\n",
|
| 244 |
" encoder_weights=\"imagenet\" if pretrained else None,\n",
|
|
|
|
| 247 |
" )\n",
|
| 248 |
" self.encoder = aux.encoder\n",
|
| 249 |
"\n",
|
| 250 |
+
" encoder_channels = self.encoder.out_channels\n",
|
|
|
|
| 251 |
" doubled = tuple(c * 2 for c in encoder_channels)\n",
|
| 252 |
"\n",
|
| 253 |
" self.decoder = smp.decoders.unet.decoder.UnetDecoder(\n",
|
|
|
|
| 261 |
" self.head = nn.Conv2d(16, 1, kernel_size=1)\n",
|
| 262 |
"\n",
|
| 263 |
" def forward(self, img_a, img_b):\n",
|
|
|
|
| 264 |
" feats_a = self.encoder(img_a)\n",
|
| 265 |
" feats_b = self.encoder(img_b)\n",
|
|
|
|
|
|
|
| 266 |
" feats_cat = [torch.cat([fa, fb], dim=1) for fa, fb in zip(feats_a, feats_b)]\n",
|
|
|
|
| 267 |
" decoded = self.decoder(*feats_cat)\n",
|
| 268 |
" logits = self.head(decoded)\n",
|
| 269 |
" return logits\n",
|
| 270 |
"\n",
|
| 271 |
"\n",
|
| 272 |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 273 |
+
"model = SiameseUNet(encoder_name=ENCODER_NAME, pretrained=True).to(device)\n",
|
| 274 |
"\n",
|
| 275 |
"total_params = sum(p.numel() for p in model.parameters()) / 1e6\n",
|
| 276 |
+
"print(f\"Model on {device}, {total_params:.1f}M parameters\")\n",
|
| 277 |
+
"if device.type == \"cpu\":\n",
|
| 278 |
+
" print(\"Running on CPU — training will take ~3-4 hours for 15 epochs\")"
|
| 279 |
],
|
| 280 |
"execution_count": null,
|
| 281 |
"outputs": []
|
|
|
|
| 335 |
"cell_type": "code",
|
| 336 |
"metadata": {},
|
| 337 |
"source": [
|
| 338 |
+
"import time\n",
|
| 339 |
"from tqdm.auto import tqdm\n",
|
| 340 |
"\n",
|
| 341 |
+
"NUM_EPOCHS = 15 # fewer epochs for CPU training\n",
|
| 342 |
+
"LR = 3e-4 # slightly higher LR to converge faster\n",
|
| 343 |
"\n",
|
| 344 |
"criterion = BCEDiceLoss(bce_weight=0.5)\n",
|
| 345 |
"optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-4)\n",
|
|
|
|
| 347 |
"\n",
|
| 348 |
"best_f1 = 0.0\n",
|
| 349 |
"history = {\"train_loss\": [], \"val_loss\": [], \"val_f1\": [], \"val_iou\": []}\n",
|
| 350 |
+
"train_start = time.time()\n",
|
| 351 |
"\n",
|
| 352 |
"for epoch in range(1, NUM_EPOCHS + 1):\n",
|
| 353 |
+
" epoch_start = time.time()\n",
|
| 354 |
+
"\n",
|
| 355 |
" # --- Train ---\n",
|
| 356 |
" model.train()\n",
|
| 357 |
" running_loss = 0.0\n",
|
|
|
|
| 397 |
" history[\"val_f1\"].append(metrics[\"f1\"])\n",
|
| 398 |
" history[\"val_iou\"].append(metrics[\"iou\"])\n",
|
| 399 |
"\n",
|
| 400 |
+
" elapsed_min = (time.time() - epoch_start) / 60\n",
|
| 401 |
+
" total_min = (time.time() - train_start) / 60\n",
|
| 402 |
+
" eta_min = elapsed_min * (NUM_EPOCHS - epoch)\n",
|
| 403 |
+
"\n",
|
| 404 |
" print(\n",
|
| 405 |
" f\"Epoch {epoch:02d} | \"\n",
|
| 406 |
" f\"train_loss={train_loss:.4f} | \"\n",
|
| 407 |
" f\"val_loss={val_loss:.4f} | \"\n",
|
| 408 |
" f\"F1={metrics['f1']:.4f} | \"\n",
|
| 409 |
" f\"IoU={metrics['iou']:.4f} | \"\n",
|
| 410 |
+
" f\"P={metrics['precision']:.4f} R={metrics['recall']:.4f} | \"\n",
|
| 411 |
+
" f\"{elapsed_min:.1f}min (ETA: {eta_min:.0f}min)\"\n",
|
| 412 |
" )\n",
|
| 413 |
"\n",
|
| 414 |
" if metrics[\"f1\"] > best_f1:\n",
|
|
|
|
| 416 |
" torch.save(model.state_dict(), \"best_siamese_unet.pth\")\n",
|
| 417 |
" print(f\" >> Saved best model (F1={best_f1:.4f})\")\n",
|
| 418 |
"\n",
|
| 419 |
+
"total_time = (time.time() - train_start) / 60\n",
|
| 420 |
+
"print(f\"\\nTraining complete in {total_time:.1f} minutes. Best val F1: {best_f1:.4f}\")"
|
| 421 |
],
|
| 422 |
"execution_count": null,
|
| 423 |
"outputs": []
|