coderuday21 commited on
Commit
1b91ffc
·
1 Parent(s): 7d0f836

Add Siamese U-Net training notebook and model inference integration

Browse files
app/detection_engine.py CHANGED
@@ -584,14 +584,27 @@ def _ai_fusion_core(img1, img2, sensitivity=0.5):
584
 
585
  def ai_deep_learning_method(img1, img2, sensitivity=0.5):
586
  """
587
- Single-pass AI fusion with CVA, SNR weighting, and vegetation/shadow
588
- suppression. The suppression maps make the reverse pass unnecessary,
589
- halving computation and eliminating OR-induced false positives.
590
  """
591
- change_mask, core_debug = _ai_fusion_core(img1, img2, sensitivity=sensitivity)
 
 
 
 
 
 
 
 
 
 
 
 
592
 
 
 
593
  debug = {
594
- "method": "AI-Based Deep Learning",
595
  "threshold_used": core_debug.get("threshold_used"),
596
  "sensitivity": float(sensitivity),
597
  "core": core_debug,
 
584
 
585
  def ai_deep_learning_method(img1, img2, sensitivity=0.5):
586
  """
587
+ Uses the trained Siamese U-Net when available; falls back to the
588
+ rule-based multi-channel fusion otherwise.
 
589
  """
590
+ from .model_inference import is_model_available, predict_change_mask
591
+
592
+ if is_model_available():
593
+ threshold = 0.35 + (1.0 - sensitivity) * 0.3
594
+ change_mask, score_map = predict_change_mask(img1, img2, threshold=threshold)
595
+ change_mask = _clean_mask(change_mask, sensitivity=sensitivity)
596
+ debug = {
597
+ "method": "AI-Based Deep Learning (Siamese U-Net)",
598
+ "model": "siamese_unet",
599
+ "threshold_used": int(threshold * 255),
600
+ "sensitivity": float(sensitivity),
601
+ }
602
+ return change_mask, debug
603
 
604
+ # Fallback: rule-based fusion
605
+ change_mask, core_debug = _ai_fusion_core(img1, img2, sensitivity=sensitivity)
606
  debug = {
607
+ "method": "AI-Based Deep Learning (rule-based fallback)",
608
  "threshold_used": core_debug.get("threshold_used"),
609
  "sensitivity": float(sensitivity),
610
  "core": core_debug,
app/model_inference.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Siamese U-Net inference for satellite change detection.
3
+
4
+ Loads a TorchScript model exported from the training notebook and runs
5
+ tile-based inference on arbitrary-size image pairs, producing a binary
6
+ change mask compatible with the rest of the detection pipeline.
7
+
8
+ Set CHANGE_MODEL_PATH env var to the .pt file location.
9
+ Falls back to the rule-based AI fusion when no model is available.
10
+ """
11
+ import logging
12
+ import os
13
+ from pathlib import Path
14
+
15
+ import cv2
16
+ import numpy as np
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ _MODEL = None
21
+ _MODEL_PATH = os.environ.get("CHANGE_MODEL_PATH", "data/siamese_unet.pt")
22
+ _TILE_SIZE = 256
23
+ _MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
24
+ _STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
25
+
26
+
27
+ def _get_torch():
28
+ """Lazy import torch — only when model exists."""
29
+ try:
30
+ import torch
31
+ return torch
32
+ except ImportError:
33
+ return None
34
+
35
+
36
+ def is_model_available():
37
+ """Check if a trained model file exists and torch is installed."""
38
+ return Path(_MODEL_PATH).is_file() and _get_torch() is not None
39
+
40
+
41
+ def _load_model():
42
+ global _MODEL
43
+ if _MODEL is not None:
44
+ return _MODEL
45
+ torch = _get_torch()
46
+ if torch is None:
47
+ raise RuntimeError("PyTorch is not installed")
48
+ path = Path(_MODEL_PATH)
49
+ if not path.is_file():
50
+ raise FileNotFoundError(f"Model not found at {path}")
51
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
+ _MODEL = torch.jit.load(str(path), map_location=device)
53
+ _MODEL.eval()
54
+ logger.info("Loaded Siamese U-Net from %s on %s", path, device)
55
+ return _MODEL
56
+
57
+
58
+ def _preprocess_tile(tile):
59
+ """Normalize a (H, W, 3) uint8 RGB tile to (1, 3, H, W) float tensor."""
60
+ torch = _get_torch()
61
+ img = tile.astype(np.float32) / 255.0
62
+ img = (img - _MEAN) / _STD
63
+ tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0)
64
+ return tensor
65
+
66
+
67
+ def predict_change_mask(img1, img2, threshold=0.5):
68
+ """
69
+ Run Siamese U-Net inference on two RGB numpy arrays (H, W, 3).
70
+ Images are split into overlapping tiles, predicted individually,
71
+ and stitched back into a full-resolution binary mask.
72
+
73
+ Returns a uint8 mask (0 or 255) at the input resolution.
74
+ """
75
+ torch = _get_torch()
76
+ model = _load_model()
77
+ device = next(model.parameters()).device
78
+
79
+ if img1.shape != img2.shape:
80
+ img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0]))
81
+
82
+ h, w = img1.shape[:2]
83
+ tile = _TILE_SIZE
84
+ stride = tile * 3 // 4 # 75% overlap for smoother stitching
85
+
86
+ # Pad to make dimensions divisible by tile size
87
+ pad_h = (tile - h % tile) % tile
88
+ pad_w = (tile - w % tile) % tile
89
+ if pad_h or pad_w:
90
+ img1 = np.pad(img1, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")
91
+ img2 = np.pad(img2, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")
92
+
93
+ ph, pw = img1.shape[:2]
94
+ score_sum = np.zeros((ph, pw), dtype=np.float32)
95
+ count = np.zeros((ph, pw), dtype=np.float32)
96
+
97
+ with torch.no_grad():
98
+ for y0 in range(0, ph - tile + 1, stride):
99
+ for x0 in range(0, pw - tile + 1, stride):
100
+ t1 = _preprocess_tile(img1[y0:y0+tile, x0:x0+tile])
101
+ t2 = _preprocess_tile(img2[y0:y0+tile, x0:x0+tile])
102
+ logits = model(t1.to(device), t2.to(device))
103
+ prob = torch.sigmoid(logits).squeeze().cpu().numpy()
104
+ score_sum[y0:y0+tile, x0:x0+tile] += prob
105
+ count[y0:y0+tile, x0:x0+tile] += 1.0
106
+
107
+ count = np.maximum(count, 1.0)
108
+ avg_score = score_sum / count
109
+
110
+ # Crop back to original size
111
+ avg_score = avg_score[:h, :w]
112
+
113
+ mask = (avg_score >= threshold).astype(np.uint8) * 255
114
+ return mask, avg_score
train_change_detection_model.ipynb ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Satellite Change Detection — Siamese U-Net Training\n",
8
+ "\n",
9
+ "This notebook trains a **Siamese U-Net** on the **LEVIR-CD** dataset for pixel-level\n",
10
+ "satellite image change detection. The exported model plugs directly into the\n",
11
+ "AI Change Detection web app.\n",
12
+ "\n",
13
+ "**Run in Google Colab** with a GPU runtime (Runtime → Change runtime type → T4 GPU).\n",
14
+ "\n",
15
+ "Training takes ~2-3 hours on a free T4."
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "markdown",
20
+ "metadata": {},
21
+ "source": [
22
+ "## 1. Install Dependencies"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": null,
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "!pip install -q torch torchvision segmentation-models-pytorch albumentations gdown tqdm matplotlib"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "markdown",
36
+ "metadata": {},
37
+ "source": [
38
+ "## 2. Download LEVIR-CD Dataset\n",
39
+ "\n",
40
+ "LEVIR-CD contains 637 pairs of 1024×1024 Google Earth images with pixel-level\n",
41
+ "building change annotations. We download the pre-cut 256×256 patch version."
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": null,
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "import os, zipfile, gdown\n",
51
+ "\n",
52
+ "DATA_ROOT = \"./levir_cd_256\"\n",
53
+ "\n",
54
+ "if not os.path.isdir(DATA_ROOT):\n",
55
+ " # LEVIR-CD 256x256 patches (hosted mirror)\n",
56
+ " url = \"https://drive.google.com/uc?id=1RUFY5Z4Bf5LAoYXC8Iq9k5GloJmVDmfF\"\n",
57
+ " zip_path = \"levir_cd_256.zip\"\n",
58
+ " print(\"Downloading LEVIR-CD 256×256 patches (~540 MB)...\")\n",
59
+ " gdown.download(url, zip_path, quiet=False)\n",
60
+ " print(\"Extracting...\")\n",
61
+ " with zipfile.ZipFile(zip_path, \"r\") as z:\n",
62
+ " z.extractall(\".\")\n",
63
+ " os.remove(zip_path)\n",
64
+ " print(\"Done. Dataset at:\", DATA_ROOT)\n",
65
+ "else:\n",
66
+ " print(\"Dataset already present at\", DATA_ROOT)"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": null,
72
+ "metadata": {},
73
+ "outputs": [],
74
+ "source": [
75
+ "# Verify structure — adjust paths if your zip extracts differently\n",
76
+ "for split in [\"train\", \"val\", \"test\"]:\n",
77
+ " for sub in [\"A\", \"B\", \"label\"]:\n",
78
+ " p = os.path.join(DATA_ROOT, split, sub)\n",
79
+ " if os.path.isdir(p):\n",
80
+ " n = len(os.listdir(p))\n",
81
+ " print(f\"{split}/{sub}: {n} files\")\n",
82
+ " else:\n",
83
+ " print(f\"WARNING: {p} not found — check extracted folder name\")"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "markdown",
88
+ "metadata": {},
89
+ "source": [
90
+ "## 3. Dataset & DataLoader"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": null,
96
+ "metadata": {},
97
+ "outputs": [],
98
+ "source": [
99
+ "import numpy as np\n",
100
+ "from PIL import Image\n",
101
+ "from torch.utils.data import Dataset, DataLoader\n",
102
+ "import albumentations as A\n",
103
+ "from albumentations.pytorch import ToTensorV2\n",
104
+ "\n",
105
+ "\n",
106
+ "class LEVIRCDDataset(Dataset):\n",
107
+ " \"\"\"LEVIR-CD dataset: before (A), after (B), binary label.\"\"\"\n",
108
+ "\n",
109
+ " def __init__(self, root, split=\"train\", transform=None):\n",
110
+ " self.dir_a = os.path.join(root, split, \"A\")\n",
111
+ " self.dir_b = os.path.join(root, split, \"B\")\n",
112
+ " self.dir_label = os.path.join(root, split, \"label\")\n",
113
+ " self.fnames = sorted(os.listdir(self.dir_a))\n",
114
+ " self.transform = transform\n",
115
+ "\n",
116
+ " def __len__(self):\n",
117
+ " return len(self.fnames)\n",
118
+ "\n",
119
+ " def __getitem__(self, idx):\n",
120
+ " name = self.fnames[idx]\n",
121
+ " img_a = np.array(Image.open(os.path.join(self.dir_a, name)).convert(\"RGB\"))\n",
122
+ " img_b = np.array(Image.open(os.path.join(self.dir_b, name)).convert(\"RGB\"))\n",
123
+ " label = np.array(Image.open(os.path.join(self.dir_label, name)).convert(\"L\"))\n",
124
+ " label = (label > 127).astype(np.float32)\n",
125
+ "\n",
126
+ " if self.transform:\n",
127
+ " # Apply same spatial transform to all three\n",
128
+ " aug = self.transform(\n",
129
+ " image=img_a,\n",
130
+ " image_b=img_b,\n",
131
+ " mask=label,\n",
132
+ " )\n",
133
+ " img_a = aug[\"image\"] # (3, H, W) tensor\n",
134
+ " img_b = aug[\"image_b\"] # (3, H, W) tensor\n",
135
+ " label = aug[\"mask\"].unsqueeze(0) # (1, H, W)\n",
136
+ " return img_a, img_b, label\n",
137
+ "\n",
138
+ "\n",
139
+ "train_transform = A.Compose(\n",
140
+ " [\n",
141
+ " A.HorizontalFlip(p=0.5),\n",
142
+ " A.VerticalFlip(p=0.5),\n",
143
+ " A.RandomRotate90(p=0.5),\n",
144
+ " A.RandomBrightnessContrast(p=0.3, brightness_limit=0.15, contrast_limit=0.15),\n",
145
+ " A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),\n",
146
+ " ToTensorV2(),\n",
147
+ " ],\n",
148
+ " additional_targets={\"image_b\": \"image\"},\n",
149
+ ")\n",
150
+ "\n",
151
+ "val_transform = A.Compose(\n",
152
+ " [\n",
153
+ " A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),\n",
154
+ " ToTensorV2(),\n",
155
+ " ],\n",
156
+ " additional_targets={\"image_b\": \"image\"},\n",
157
+ ")\n",
158
+ "\n",
159
+ "train_ds = LEVIRCDDataset(DATA_ROOT, \"train\", train_transform)\n",
160
+ "val_ds = LEVIRCDDataset(DATA_ROOT, \"val\", val_transform)\n",
161
+ "test_ds = LEVIRCDDataset(DATA_ROOT, \"test\", val_transform)\n",
162
+ "\n",
163
+ "BATCH = 8\n",
164
+ "train_dl = DataLoader(train_ds, batch_size=BATCH, shuffle=True, num_workers=2, pin_memory=True)\n",
165
+ "val_dl = DataLoader(val_ds, batch_size=BATCH, shuffle=False, num_workers=2, pin_memory=True)\n",
166
+ "test_dl = DataLoader(test_ds, batch_size=BATCH, shuffle=False, num_workers=2, pin_memory=True)\n",
167
+ "\n",
168
+ "print(f\"Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}\")"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "markdown",
173
+ "metadata": {},
174
+ "source": [
175
+ "## 4. Siamese U-Net Model\n",
176
+ "\n",
177
+ "Architecture:\n",
178
+ "- **Shared encoder** (ResNet34, ImageNet pretrained) processes both images\n",
179
+ "- Feature maps from both branches are **concatenated** at each decoder level\n",
180
+ "- Standard U-Net decoder produces a binary change mask"
181
+ ]
182
+ },
183
+ {
184
+ "cell_type": "code",
185
+ "execution_count": null,
186
+ "metadata": {},
187
+ "outputs": [],
188
+ "source": [
189
+ "import torch\n",
190
+ "import torch.nn as nn\n",
191
+ "import segmentation_models_pytorch as smp\n",
192
+ "\n",
193
+ "\n",
194
+ "class SiameseUNet(nn.Module):\n",
195
+ " \"\"\"\n",
196
+ " Siamese U-Net for change detection.\n",
197
+ " Shared encoder extracts features from both images;\n",
198
+ " concatenated features are decoded into a binary change mask.\n",
199
+ " \"\"\"\n",
200
+ "\n",
201
+ " def __init__(self, encoder_name=\"resnet34\", pretrained=True):\n",
202
+ " super().__init__()\n",
203
+ " # Build a standard U-Net to reuse its encoder and decoder pieces\n",
204
+ " aux = smp.Unet(\n",
205
+ " encoder_name=encoder_name,\n",
206
+ " encoder_weights=\"imagenet\" if pretrained else None,\n",
207
+ " in_channels=3,\n",
208
+ " classes=1,\n",
209
+ " )\n",
210
+ " self.encoder = aux.encoder\n",
211
+ "\n",
212
+ " # The decoder expects concatenated features (2x channels at each level)\n",
213
+ " encoder_channels = self.encoder.out_channels # e.g. (3,64,64,128,256,512)\n",
214
+ " doubled = tuple(c * 2 for c in encoder_channels)\n",
215
+ "\n",
216
+ " self.decoder = smp.decoders.unet.decoder.UnetDecoder(\n",
217
+ " encoder_channels=doubled,\n",
218
+ " decoder_channels=(256, 128, 64, 32, 16),\n",
219
+ " n_blocks=5,\n",
220
+ " use_batchnorm=True,\n",
221
+ " attention_type=None,\n",
222
+ " )\n",
223
+ "\n",
224
+ " self.head = nn.Conv2d(16, 1, kernel_size=1)\n",
225
+ "\n",
226
+ " def forward(self, img_a, img_b):\n",
227
+ " # Shared encoder for both temporal images\n",
228
+ " feats_a = self.encoder(img_a)\n",
229
+ " feats_b = self.encoder(img_b)\n",
230
+ "\n",
231
+ " # Concatenate features at every level\n",
232
+ " feats_cat = [torch.cat([fa, fb], dim=1) for fa, fb in zip(feats_a, feats_b)]\n",
233
+ "\n",
234
+ " decoded = self.decoder(*feats_cat)\n",
235
+ " logits = self.head(decoded)\n",
236
+ " return logits\n",
237
+ "\n",
238
+ "\n",
239
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
240
+ "model = SiameseUNet(encoder_name=\"resnet34\", pretrained=True).to(device)\n",
241
+ "\n",
242
+ "total_params = sum(p.numel() for p in model.parameters()) / 1e6\n",
243
+ "print(f\"Model on {device}, {total_params:.1f}M parameters\")"
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "markdown",
248
+ "metadata": {},
249
+ "source": [
250
+ "## 5. Loss Function & Metrics\n",
251
+ "\n",
252
+ "Combined **BCE + Dice** loss handles class imbalance (most pixels are unchanged)."
253
+ ]
254
+ },
255
+ {
256
+ "cell_type": "code",
257
+ "execution_count": null,
258
+ "metadata": {},
259
+ "outputs": [],
260
+ "source": [
261
+ "class BCEDiceLoss(nn.Module):\n",
262
+ " def __init__(self, bce_weight=0.5):\n",
263
+ " super().__init__()\n",
264
+ " self.bce = nn.BCEWithLogitsLoss()\n",
265
+ " self.bce_weight = bce_weight\n",
266
+ "\n",
267
+ " def forward(self, logits, targets):\n",
268
+ " bce_loss = self.bce(logits, targets)\n",
269
+ " probs = torch.sigmoid(logits)\n",
270
+ " smooth = 1.0\n",
271
+ " intersection = (probs * targets).sum()\n",
272
+ " dice = (2.0 * intersection + smooth) / (probs.sum() + targets.sum() + smooth)\n",
273
+ " dice_loss = 1.0 - dice\n",
274
+ " return self.bce_weight * bce_loss + (1 - self.bce_weight) * dice_loss\n",
275
+ "\n",
276
+ "\n",
277
+ "def compute_metrics(preds, targets, threshold=0.5):\n",
278
+ " \"\"\"Compute precision, recall, F1, and IoU.\"\"\"\n",
279
+ " preds_bin = (preds > threshold).float()\n",
280
+ " tp = (preds_bin * targets).sum().item()\n",
281
+ " fp = (preds_bin * (1 - targets)).sum().item()\n",
282
+ " fn = ((1 - preds_bin) * targets).sum().item()\n",
283
+ " precision = tp / (tp + fp + 1e-8)\n",
284
+ " recall = tp / (tp + fn + 1e-8)\n",
285
+ " f1 = 2 * precision * recall / (precision + recall + 1e-8)\n",
286
+ " iou = tp / (tp + fp + fn + 1e-8)\n",
287
+ " return {\"precision\": precision, \"recall\": recall, \"f1\": f1, \"iou\": iou}"
288
+ ]
289
+ },
290
+ {
291
+ "cell_type": "markdown",
292
+ "metadata": {},
293
+ "source": [
294
+ "## 6. Training Loop"
295
+ ]
296
+ },
297
+ {
298
+ "cell_type": "code",
299
+ "execution_count": null,
300
+ "metadata": {},
301
+ "outputs": [],
302
+ "source": [
303
+ "from tqdm.auto import tqdm\n",
304
+ "\n",
305
+ "NUM_EPOCHS = 50\n",
306
+ "LR = 1e-4\n",
307
+ "\n",
308
+ "criterion = BCEDiceLoss(bce_weight=0.5)\n",
309
+ "optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-4)\n",
310
+ "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=1e-6)\n",
311
+ "\n",
312
+ "best_f1 = 0.0\n",
313
+ "history = {\"train_loss\": [], \"val_loss\": [], \"val_f1\": [], \"val_iou\": []}\n",
314
+ "\n",
315
+ "for epoch in range(1, NUM_EPOCHS + 1):\n",
316
+ " # --- Train ---\n",
317
+ " model.train()\n",
318
+ " running_loss = 0.0\n",
319
+ " for img_a, img_b, label in tqdm(train_dl, desc=f\"Epoch {epoch}/{NUM_EPOCHS} [train]\", leave=False):\n",
320
+ " img_a = img_a.to(device)\n",
321
+ " img_b = img_b.to(device)\n",
322
+ " label = label.to(device)\n",
323
+ "\n",
324
+ " logits = model(img_a, img_b)\n",
325
+ " loss = criterion(logits, label)\n",
326
+ "\n",
327
+ " optimizer.zero_grad()\n",
328
+ " loss.backward()\n",
329
+ " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
330
+ " optimizer.step()\n",
331
+ " running_loss += loss.item() * img_a.size(0)\n",
332
+ "\n",
333
+ " train_loss = running_loss / len(train_ds)\n",
334
+ " scheduler.step()\n",
335
+ "\n",
336
+ " # --- Validate ---\n",
337
+ " model.eval()\n",
338
+ " val_loss_sum = 0.0\n",
339
+ " all_preds, all_targets = [], []\n",
340
+ " with torch.no_grad():\n",
341
+ " for img_a, img_b, label in val_dl:\n",
342
+ " img_a = img_a.to(device)\n",
343
+ " img_b = img_b.to(device)\n",
344
+ " label = label.to(device)\n",
345
+ "\n",
346
+ " logits = model(img_a, img_b)\n",
347
+ " val_loss_sum += criterion(logits, label).item() * img_a.size(0)\n",
348
+ " all_preds.append(torch.sigmoid(logits).cpu())\n",
349
+ " all_targets.append(label.cpu())\n",
350
+ "\n",
351
+ " val_loss = val_loss_sum / len(val_ds)\n",
352
+ " preds_cat = torch.cat(all_preds)\n",
353
+ " targets_cat = torch.cat(all_targets)\n",
354
+ " metrics = compute_metrics(preds_cat, targets_cat)\n",
355
+ "\n",
356
+ " history[\"train_loss\"].append(train_loss)\n",
357
+ " history[\"val_loss\"].append(val_loss)\n",
358
+ " history[\"val_f1\"].append(metrics[\"f1\"])\n",
359
+ " history[\"val_iou\"].append(metrics[\"iou\"])\n",
360
+ "\n",
361
+ " print(\n",
362
+ " f\"Epoch {epoch:02d} | \"\n",
363
+ " f\"train_loss={train_loss:.4f} | \"\n",
364
+ " f\"val_loss={val_loss:.4f} | \"\n",
365
+ " f\"F1={metrics['f1']:.4f} | \"\n",
366
+ " f\"IoU={metrics['iou']:.4f} | \"\n",
367
+ " f\"P={metrics['precision']:.4f} R={metrics['recall']:.4f}\"\n",
368
+ " )\n",
369
+ "\n",
370
+ " if metrics[\"f1\"] > best_f1:\n",
371
+ " best_f1 = metrics[\"f1\"]\n",
372
+ " torch.save(model.state_dict(), \"best_siamese_unet.pth\")\n",
373
+ " print(f\" >> Saved best model (F1={best_f1:.4f})\")\n",
374
+ "\n",
375
+ "print(f\"\\nTraining complete. Best val F1: {best_f1:.4f}\")"
376
+ ]
377
+ },
378
+ {
379
+ "cell_type": "markdown",
380
+ "metadata": {},
381
+ "source": [
382
+ "## 7. Training Curves"
383
+ ]
384
+ },
385
+ {
386
+ "cell_type": "code",
387
+ "execution_count": null,
388
+ "metadata": {},
389
+ "outputs": [],
390
+ "source": [
391
+ "import matplotlib.pyplot as plt\n",
392
+ "\n",
393
+ "fig, axes = plt.subplots(1, 3, figsize=(15, 4))\n",
394
+ "\n",
395
+ "axes[0].plot(history[\"train_loss\"], label=\"Train\")\n",
396
+ "axes[0].plot(history[\"val_loss\"], label=\"Val\")\n",
397
+ "axes[0].set_title(\"Loss\")\n",
398
+ "axes[0].legend()\n",
399
+ "\n",
400
+ "axes[1].plot(history[\"val_f1\"])\n",
401
+ "axes[1].set_title(\"Val F1 Score\")\n",
402
+ "\n",
403
+ "axes[2].plot(history[\"val_iou\"])\n",
404
+ "axes[2].set_title(\"Val IoU\")\n",
405
+ "\n",
406
+ "for ax in axes:\n",
407
+ " ax.set_xlabel(\"Epoch\")\n",
408
+ " ax.grid(True, alpha=0.3)\n",
409
+ "\n",
410
+ "plt.tight_layout()\n",
411
+ "plt.show()"
412
+ ]
413
+ },
414
+ {
415
+ "cell_type": "markdown",
416
+ "metadata": {},
417
+ "source": [
418
+ "## 8. Evaluate on Test Set"
419
+ ]
420
+ },
421
+ {
422
+ "cell_type": "code",
423
+ "execution_count": null,
424
+ "metadata": {},
425
+ "outputs": [],
426
+ "source": [
427
+ "# Load best checkpoint\n",
428
+ "model.load_state_dict(torch.load(\"best_siamese_unet.pth\", map_location=device))\n",
429
+ "model.eval()\n",
430
+ "\n",
431
+ "all_preds, all_targets = [], []\n",
432
+ "with torch.no_grad():\n",
433
+ " for img_a, img_b, label in tqdm(test_dl, desc=\"Testing\"):\n",
434
+ " logits = model(img_a.to(device), img_b.to(device))\n",
435
+ " all_preds.append(torch.sigmoid(logits).cpu())\n",
436
+ " all_targets.append(label)\n",
437
+ "\n",
438
+ "preds = torch.cat(all_preds)\n",
439
+ "targets = torch.cat(all_targets)\n",
440
+ "test_metrics = compute_metrics(preds, targets)\n",
441
+ "\n",
442
+ "print(f\"\\nTest Results:\")\n",
443
+ "print(f\" F1 Score: {test_metrics['f1']:.4f}\")\n",
444
+ "print(f\" IoU: {test_metrics['iou']:.4f}\")\n",
445
+ "print(f\" Precision: {test_metrics['precision']:.4f}\")\n",
446
+ "print(f\" Recall: {test_metrics['recall']:.4f}\")"
447
+ ]
448
+ },
449
+ {
450
+ "cell_type": "markdown",
451
+ "metadata": {},
452
+ "source": [
453
+ "## 9. Visualize Predictions"
454
+ ]
455
+ },
456
+ {
457
+ "cell_type": "code",
458
+ "execution_count": null,
459
+ "metadata": {},
460
+ "outputs": [],
461
+ "source": [
462
+ "MEAN = np.array([0.485, 0.456, 0.406])\n",
463
+ "STD = np.array([0.229, 0.224, 0.225])\n",
464
+ "\n",
465
+ "def denorm(tensor):\n",
466
+ " img = tensor.permute(1, 2, 0).numpy()\n",
467
+ " img = img * STD + MEAN\n",
468
+ " return np.clip(img, 0, 1)\n",
469
+ "\n",
470
+ "fig, axes = plt.subplots(4, 4, figsize=(16, 16))\n",
471
+ "sample_indices = np.random.choice(len(test_ds), 4, replace=False)\n",
472
+ "\n",
473
+ "for row, idx in enumerate(sample_indices):\n",
474
+ " img_a, img_b, label = test_ds[idx]\n",
475
+ " with torch.no_grad():\n",
476
+ " logit = model(img_a.unsqueeze(0).to(device), img_b.unsqueeze(0).to(device))\n",
477
+ " pred = (torch.sigmoid(logit) > 0.5).squeeze().cpu().numpy()\n",
478
+ "\n",
479
+ " axes[row, 0].imshow(denorm(img_a))\n",
480
+ " axes[row, 0].set_title(\"Before\")\n",
481
+ " axes[row, 1].imshow(denorm(img_b))\n",
482
+ " axes[row, 1].set_title(\"After\")\n",
483
+ " axes[row, 2].imshow(label.squeeze(), cmap=\"gray\")\n",
484
+ " axes[row, 2].set_title(\"Ground Truth\")\n",
485
+ " axes[row, 3].imshow(pred, cmap=\"gray\")\n",
486
+ " axes[row, 3].set_title(\"Prediction\")\n",
487
+ "\n",
488
+ "for ax in axes.flat:\n",
489
+ " ax.axis(\"off\")\n",
490
+ "plt.tight_layout()\n",
491
+ "plt.show()"
492
+ ]
493
+ },
494
+ {
495
+ "cell_type": "markdown",
496
+ "metadata": {},
497
+ "source": [
498
+ "## 10. Export Model for Deployment\n",
499
+ "\n",
500
+ "Export as TorchScript for the web app. Download the `.pt` file and place it in\n",
501
+ "your app's `data/` folder, then set the environment variable:\n",
502
+ "\n",
503
+ "```\n",
504
+ "CHANGE_MODEL_PATH=data/siamese_unet.pt\n",
505
+ "```"
506
+ ]
507
+ },
508
+ {
509
+ "cell_type": "code",
510
+ "execution_count": null,
511
+ "metadata": {},
512
+ "outputs": [],
513
+ "source": [
514
+ "model.eval()\n",
515
+ "model_cpu = model.cpu()\n",
516
+ "\n",
517
+ "# Trace with example inputs\n",
518
+ "example_a = torch.randn(1, 3, 256, 256)\n",
519
+ "example_b = torch.randn(1, 3, 256, 256)\n",
520
+ "traced = torch.jit.trace(model_cpu, (example_a, example_b))\n",
521
+ "\n",
522
+ "export_path = \"siamese_unet.pt\"\n",
523
+ "traced.save(export_path)\n",
524
+ "size_mb = os.path.getsize(export_path) / 1e6\n",
525
+ "print(f\"Exported TorchScript model: {export_path} ({size_mb:.1f} MB)\")\n",
526
+ "print(\"\\nDownload this file and place it in your app's data/ directory.\")\n",
527
+ "print('Then set: CHANGE_MODEL_PATH=data/siamese_unet.pt')"
528
+ ]
529
+ },
530
+ {
531
+ "cell_type": "code",
532
+ "execution_count": null,
533
+ "metadata": {},
534
+ "outputs": [],
535
+ "source": [
536
+ "# Quick sanity check: verify exported model produces same output\n",
537
+ "loaded = torch.jit.load(export_path)\n",
538
+ "with torch.no_grad():\n",
539
+ " out_orig = model_cpu(example_a, example_b)\n",
540
+ " out_loaded = loaded(example_a, example_b)\n",
541
+ " diff = (out_orig - out_loaded).abs().max().item()\n",
542
+ " print(f\"Max diff between original and exported: {diff:.8f}\")\n",
543
+ " assert diff < 1e-5, \"Export verification failed!\"\n",
544
+ " print(\"Export verified successfully.\")"
545
+ ]
546
+ },
547
+ {
548
+ "cell_type": "markdown",
549
+ "metadata": {},
550
+ "source": [
551
+ "## 11. Download from Colab\n",
552
+ "\n",
553
+ "Run this cell to trigger a browser download of the model file."
554
+ ]
555
+ },
556
+ {
557
+ "cell_type": "code",
558
+ "execution_count": null,
559
+ "metadata": {},
560
+ "outputs": [],
561
+ "source": [
562
+ "try:\n",
563
+ " from google.colab import files\n",
564
+ " files.download(\"siamese_unet.pt\")\n",
565
+ " files.download(\"best_siamese_unet.pth\")\n",
566
+ "except ImportError:\n",
567
+ " print(\"Not running in Colab. Files saved locally:\")\n",
568
+ " print(f\" - {export_path}\")\n",
569
+ " print(f\" - best_siamese_unet.pth\")"
570
+ ]
571
+ }
572
+ ],
573
+ "metadata": {
574
+ "kernelspec": {
575
+ "display_name": "Python 3",
576
+ "language": "python",
577
+ "name": "python3"
578
+ },
579
+ "language_info": {
580
+ "name": "python",
581
+ "version": "3.11.0"
582
+ }
583
+ },
584
+ "nbformat": 4,
585
+ "nbformat_minor": 4
586
+ }