coderuday21 commited on
Commit
d555eda
·
1 Parent(s): 301b9b6

Optimize notebook for CPU training: MobileNetV2 encoder, 15 epochs, batch 4

Browse files
Files changed (1) hide show
  1. train_change_detection_model.ipynb +31 -23
train_change_detection_model.ipynb CHANGED
@@ -6,13 +6,12 @@
6
  "source": [
7
  "# Satellite Change Detection — Siamese U-Net Training\n",
8
  "\n",
9
- "This notebook trains a **Siamese U-Net** on the **LEVIR-CD** dataset for pixel-level\n",
10
  "satellite image change detection. The exported model plugs directly into the\n",
11
  "AI Change Detection web app.\n",
12
  "\n",
13
- "**Run in Google Colab** with a GPU runtime (Runtime Change runtime type → T4 GPU).\n",
14
- "\n",
15
- "Training takes ~2-3 hours on a free T4."
16
  ]
17
  },
18
  {
@@ -197,10 +196,10 @@
197
  "val_ds = LEVIRCDDataset(DATA_ROOT, \"val\", val_transform)\n",
198
  "test_ds = LEVIRCDDataset(DATA_ROOT, \"test\", val_transform)\n",
199
  "\n",
200
- "BATCH = 8\n",
201
- "train_dl = DataLoader(train_ds, batch_size=BATCH, shuffle=True, num_workers=2, pin_memory=True)\n",
202
- "val_dl = DataLoader(val_ds, batch_size=BATCH, shuffle=False, num_workers=2, pin_memory=True)\n",
203
- "test_dl = DataLoader(test_ds, batch_size=BATCH, shuffle=False, num_workers=2, pin_memory=True)\n",
204
  "\n",
205
  "print(f\"Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}\")"
206
  ],
@@ -214,7 +213,7 @@
214
  "## 4. Siamese U-Net Model\n",
215
  "\n",
216
  "Architecture:\n",
217
- "- **Shared encoder** (ResNet34, ImageNet pretrained) processes both images\n",
218
  "- Feature maps from both branches are **concatenated** at each decoder level\n",
219
  "- Standard U-Net decoder produces a binary change mask"
220
  ]
@@ -228,6 +227,9 @@
228
  "import segmentation_models_pytorch as smp\n",
229
  "\n",
230
  "\n",
 
 
 
231
  "class SiameseUNet(nn.Module):\n",
232
  " \"\"\"\n",
233
  " Siamese U-Net for change detection.\n",
@@ -235,9 +237,8 @@
235
  " concatenated features are decoded into a binary change mask.\n",
236
  " \"\"\"\n",
237
  "\n",
238
- " def __init__(self, encoder_name=\"resnet34\", pretrained=True):\n",
239
  " super().__init__()\n",
240
- " # Build a standard U-Net to reuse its encoder and decoder pieces\n",
241
  " aux = smp.Unet(\n",
242
  " encoder_name=encoder_name,\n",
243
  " encoder_weights=\"imagenet\" if pretrained else None,\n",
@@ -246,8 +247,7 @@
246
  " )\n",
247
  " self.encoder = aux.encoder\n",
248
  "\n",
249
- " # The decoder expects concatenated features (2x channels at each level)\n",
250
- " encoder_channels = self.encoder.out_channels # e.g. (3,64,64,128,256,512)\n",
251
  " doubled = tuple(c * 2 for c in encoder_channels)\n",
252
  "\n",
253
  " self.decoder = smp.decoders.unet.decoder.UnetDecoder(\n",
@@ -261,23 +261,21 @@
261
  " self.head = nn.Conv2d(16, 1, kernel_size=1)\n",
262
  "\n",
263
  " def forward(self, img_a, img_b):\n",
264
- " # Shared encoder for both temporal images\n",
265
  " feats_a = self.encoder(img_a)\n",
266
  " feats_b = self.encoder(img_b)\n",
267
- "\n",
268
- " # Concatenate features at every level\n",
269
  " feats_cat = [torch.cat([fa, fb], dim=1) for fa, fb in zip(feats_a, feats_b)]\n",
270
- "\n",
271
  " decoded = self.decoder(*feats_cat)\n",
272
  " logits = self.head(decoded)\n",
273
  " return logits\n",
274
  "\n",
275
  "\n",
276
  "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
277
- "model = SiameseUNet(encoder_name=\"resnet34\", pretrained=True).to(device)\n",
278
  "\n",
279
  "total_params = sum(p.numel() for p in model.parameters()) / 1e6\n",
280
- "print(f\"Model on {device}, {total_params:.1f}M parameters\")"
 
 
281
  ],
282
  "execution_count": null,
283
  "outputs": []
@@ -337,10 +335,11 @@
337
  "cell_type": "code",
338
  "metadata": {},
339
  "source": [
 
340
  "from tqdm.auto import tqdm\n",
341
  "\n",
342
- "NUM_EPOCHS = 50\n",
343
- "LR = 1e-4\n",
344
  "\n",
345
  "criterion = BCEDiceLoss(bce_weight=0.5)\n",
346
  "optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-4)\n",
@@ -348,8 +347,11 @@
348
  "\n",
349
  "best_f1 = 0.0\n",
350
  "history = {\"train_loss\": [], \"val_loss\": [], \"val_f1\": [], \"val_iou\": []}\n",
 
351
  "\n",
352
  "for epoch in range(1, NUM_EPOCHS + 1):\n",
 
 
353
  " # --- Train ---\n",
354
  " model.train()\n",
355
  " running_loss = 0.0\n",
@@ -395,13 +397,18 @@
395
  " history[\"val_f1\"].append(metrics[\"f1\"])\n",
396
  " history[\"val_iou\"].append(metrics[\"iou\"])\n",
397
  "\n",
 
 
 
 
398
  " print(\n",
399
  " f\"Epoch {epoch:02d} | \"\n",
400
  " f\"train_loss={train_loss:.4f} | \"\n",
401
  " f\"val_loss={val_loss:.4f} | \"\n",
402
  " f\"F1={metrics['f1']:.4f} | \"\n",
403
  " f\"IoU={metrics['iou']:.4f} | \"\n",
404
- " f\"P={metrics['precision']:.4f} R={metrics['recall']:.4f}\"\n",
 
405
  " )\n",
406
  "\n",
407
  " if metrics[\"f1\"] > best_f1:\n",
@@ -409,7 +416,8 @@
409
  " torch.save(model.state_dict(), \"best_siamese_unet.pth\")\n",
410
  " print(f\" >> Saved best model (F1={best_f1:.4f})\")\n",
411
  "\n",
412
- "print(f\"\\nTraining complete. Best val F1: {best_f1:.4f}\")"
 
413
  ],
414
  "execution_count": null,
415
  "outputs": []
 
6
  "source": [
7
  "# Satellite Change Detection — Siamese U-Net Training\n",
8
  "\n",
9
+ "This notebook trains a **Siamese U-Net** on the **LEVIR-CD+** dataset for pixel-level\n",
10
  "satellite image change detection. The exported model plugs directly into the\n",
11
  "AI Change Detection web app.\n",
12
  "\n",
13
+ "**Optimized for CPU** uses a lightweight MobileNetV2 encoder and 15 epochs.\n",
14
+ "Training takes ~3-4 hours on a Colab CPU runtime."
 
15
  ]
16
  },
17
  {
 
196
  "val_ds = LEVIRCDDataset(DATA_ROOT, \"val\", val_transform)\n",
197
  "test_ds = LEVIRCDDataset(DATA_ROOT, \"test\", val_transform)\n",
198
  "\n",
199
+ "BATCH = 4 # smaller batch for CPU\n",
200
+ "train_dl = DataLoader(train_ds, batch_size=BATCH, shuffle=True, num_workers=0, pin_memory=False)\n",
201
+ "val_dl = DataLoader(val_ds, batch_size=BATCH, shuffle=False, num_workers=0, pin_memory=False)\n",
202
+ "test_dl = DataLoader(test_ds, batch_size=BATCH, shuffle=False, num_workers=0, pin_memory=False)\n",
203
  "\n",
204
  "print(f\"Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}\")"
205
  ],
 
213
  "## 4. Siamese U-Net Model\n",
214
  "\n",
215
  "Architecture:\n",
216
+ "- **Shared encoder** (MobileNetV2, ImageNet pretrained) lightweight and fast on CPU\n",
217
  "- Feature maps from both branches are **concatenated** at each decoder level\n",
218
  "- Standard U-Net decoder produces a binary change mask"
219
  ]
 
227
  "import segmentation_models_pytorch as smp\n",
228
  "\n",
229
  "\n",
230
+ "ENCODER_NAME = \"mobilenet_v2\" # lightweight encoder for CPU training\n",
231
+ "\n",
232
+ "\n",
233
  "class SiameseUNet(nn.Module):\n",
234
  " \"\"\"\n",
235
  " Siamese U-Net for change detection.\n",
 
237
  " concatenated features are decoded into a binary change mask.\n",
238
  " \"\"\"\n",
239
  "\n",
240
+ " def __init__(self, encoder_name=ENCODER_NAME, pretrained=True):\n",
241
  " super().__init__()\n",
 
242
  " aux = smp.Unet(\n",
243
  " encoder_name=encoder_name,\n",
244
  " encoder_weights=\"imagenet\" if pretrained else None,\n",
 
247
  " )\n",
248
  " self.encoder = aux.encoder\n",
249
  "\n",
250
+ " encoder_channels = self.encoder.out_channels\n",
 
251
  " doubled = tuple(c * 2 for c in encoder_channels)\n",
252
  "\n",
253
  " self.decoder = smp.decoders.unet.decoder.UnetDecoder(\n",
 
261
  " self.head = nn.Conv2d(16, 1, kernel_size=1)\n",
262
  "\n",
263
  " def forward(self, img_a, img_b):\n",
 
264
  " feats_a = self.encoder(img_a)\n",
265
  " feats_b = self.encoder(img_b)\n",
 
 
266
  " feats_cat = [torch.cat([fa, fb], dim=1) for fa, fb in zip(feats_a, feats_b)]\n",
 
267
  " decoded = self.decoder(*feats_cat)\n",
268
  " logits = self.head(decoded)\n",
269
  " return logits\n",
270
  "\n",
271
  "\n",
272
  "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
273
+ "model = SiameseUNet(encoder_name=ENCODER_NAME, pretrained=True).to(device)\n",
274
  "\n",
275
  "total_params = sum(p.numel() for p in model.parameters()) / 1e6\n",
276
+ "print(f\"Model on {device}, {total_params:.1f}M parameters\")\n",
277
+ "if device.type == \"cpu\":\n",
278
+ " print(\"Running on CPU — training will take ~3-4 hours for 15 epochs\")"
279
  ],
280
  "execution_count": null,
281
  "outputs": []
 
335
  "cell_type": "code",
336
  "metadata": {},
337
  "source": [
338
+ "import time\n",
339
  "from tqdm.auto import tqdm\n",
340
  "\n",
341
+ "NUM_EPOCHS = 15 # fewer epochs for CPU training\n",
342
+ "LR = 3e-4 # slightly higher LR to converge faster\n",
343
  "\n",
344
  "criterion = BCEDiceLoss(bce_weight=0.5)\n",
345
  "optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-4)\n",
 
347
  "\n",
348
  "best_f1 = 0.0\n",
349
  "history = {\"train_loss\": [], \"val_loss\": [], \"val_f1\": [], \"val_iou\": []}\n",
350
+ "train_start = time.time()\n",
351
  "\n",
352
  "for epoch in range(1, NUM_EPOCHS + 1):\n",
353
+ " epoch_start = time.time()\n",
354
+ "\n",
355
  " # --- Train ---\n",
356
  " model.train()\n",
357
  " running_loss = 0.0\n",
 
397
  " history[\"val_f1\"].append(metrics[\"f1\"])\n",
398
  " history[\"val_iou\"].append(metrics[\"iou\"])\n",
399
  "\n",
400
+ " elapsed_min = (time.time() - epoch_start) / 60\n",
401
+ " total_min = (time.time() - train_start) / 60\n",
402
+ " eta_min = elapsed_min * (NUM_EPOCHS - epoch)\n",
403
+ "\n",
404
  " print(\n",
405
  " f\"Epoch {epoch:02d} | \"\n",
406
  " f\"train_loss={train_loss:.4f} | \"\n",
407
  " f\"val_loss={val_loss:.4f} | \"\n",
408
  " f\"F1={metrics['f1']:.4f} | \"\n",
409
  " f\"IoU={metrics['iou']:.4f} | \"\n",
410
+ " f\"P={metrics['precision']:.4f} R={metrics['recall']:.4f} | \"\n",
411
+ " f\"{elapsed_min:.1f}min (ETA: {eta_min:.0f}min)\"\n",
412
  " )\n",
413
  "\n",
414
  " if metrics[\"f1\"] > best_f1:\n",
 
416
  " torch.save(model.state_dict(), \"best_siamese_unet.pth\")\n",
417
  " print(f\" >> Saved best model (F1={best_f1:.4f})\")\n",
418
  "\n",
419
+ "total_time = (time.time() - train_start) / 60\n",
420
+ "print(f\"\\nTraining complete in {total_time:.1f} minutes. Best val F1: {best_f1:.4f}\")"
421
  ],
422
  "execution_count": null,
423
  "outputs": []