youngPhilosopher commited on
Commit
b891e61
·
verified ·
1 Parent(s): 381f5f3

Upload folder using huggingface_hub

Browse files
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (158 Bytes). View file
 
src/__pycache__/best_predictions.cpython-311.pyc ADDED
Binary file (8.72 kB). View file
 
src/__pycache__/evaluate.cpython-311.pyc ADDED
Binary file (12.4 kB). View file
 
src/__pycache__/train.cpython-311.pyc ADDED
Binary file (11.4 kB). View file
 
src/best_predictions.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Find best and worst predictions by per-sample IoU and generate showcase figures."""
2
+
3
+ import json
4
+ from pathlib import Path
5
+
6
+ import matplotlib
7
+ matplotlib.use("Agg")
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ from PIL import Image
11
+ from tqdm import tqdm
12
+
13
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
14
+
15
+
16
+ def iou(pred: np.ndarray, gt: np.ndarray) -> float:
17
+ intersection = np.logical_and(pred, gt).sum()
18
+ union = np.logical_or(pred, gt).sum()
19
+ return float(intersection / union) if union > 0 else 0.0
20
+
21
+
22
+ def score_all():
23
+ """Score every test prediction against ground truth. Returns dict of per-class scored lists."""
24
+ with open(PROJECT_ROOT / "data" / "splits" / "test.json") as f:
25
+ test_samples = json.load(f)
26
+
27
+ masks_dir = PROJECT_ROOT / "outputs" / "masks"
28
+ scores = {"taping": [], "cracks": []}
29
+
30
+ for sample in tqdm(test_samples, desc="Scoring predictions"):
31
+ img_stem = Path(sample["image_path"]).stem
32
+ ds = sample["dataset"]
33
+
34
+ candidates = list(masks_dir.glob(f"{img_stem}__*.png"))
35
+ if not candidates:
36
+ continue
37
+
38
+ gt = np.array(Image.open(sample["mask_path"]).convert("L"))
39
+ gt_bin = (gt > 127).astype(np.uint8)
40
+
41
+ best_iou = -1
42
+ best_pred_path = None
43
+ best_prompt = None
44
+ for pred_path in candidates:
45
+ pred = np.array(Image.open(pred_path).convert("L").resize(
46
+ (gt.shape[1], gt.shape[0]), Image.NEAREST))
47
+ pred_bin = (pred > 127).astype(np.uint8)
48
+ score = iou(pred_bin, gt_bin)
49
+ if score > best_iou:
50
+ best_iou = score
51
+ best_pred_path = pred_path
52
+ best_prompt = pred_path.stem.split("__")[1].replace("_", " ")
53
+
54
+ scores[ds].append({
55
+ "image_path": sample["image_path"],
56
+ "mask_path": sample["mask_path"],
57
+ "pred_path": str(best_pred_path),
58
+ "prompt": best_prompt,
59
+ "iou": best_iou,
60
+ "dataset": ds,
61
+ })
62
+
63
+ return scores
64
+
65
+
66
+ def pick_ranked(scores, n_per_class=3, best=True):
67
+ """Pick top-N or bottom-N per class by IoU."""
68
+ result = []
69
+ for ds in ["cracks", "taping"]:
70
+ # Filter out zero-IoU (no prediction found) for worst — keep only actual failures
71
+ pool = [s for s in scores[ds] if s["iou"] > 0] if not best else scores[ds]
72
+ ranked = sorted(pool, key=lambda x: x["iou"], reverse=best)
73
+ selected = ranked[:n_per_class]
74
+ result.extend(selected)
75
+
76
+ label = "best" if best else "worst"
77
+ print(f"\n{ds} {label} {n_per_class}:")
78
+ for r in selected:
79
+ print(f" IoU={r['iou']:.4f} {Path(r['image_path']).name} \"{r['prompt']}\"")
80
+
81
+ return result
82
+
83
+
84
+ def generate_grid(examples, output_path, title=""):
85
+ """Generate original | ground truth | prediction comparison grid."""
86
+ n = len(examples)
87
+ fig, axes = plt.subplots(n, 3, figsize=(14, 4.0 * n))
88
+ if n == 1:
89
+ axes = [axes]
90
+
91
+ if title:
92
+ fig.suptitle(title, fontsize=16, fontweight="bold", y=0.998)
93
+
94
+ for i, ex in enumerate(examples):
95
+ img = Image.open(ex["image_path"]).convert("RGB")
96
+ gt = Image.open(ex["mask_path"]).convert("L")
97
+ pred = Image.open(ex["pred_path"]).convert("L").resize(
98
+ (gt.size[0], gt.size[1]), Image.NEAREST)
99
+
100
+ label = ex["dataset"].capitalize()
101
+
102
+ axes[i][0].imshow(img)
103
+ axes[i][0].set_title(f"Input — {label}", fontsize=11, fontweight="bold")
104
+ axes[i][0].axis("off")
105
+
106
+ axes[i][1].imshow(gt, cmap="gray", vmin=0, vmax=255)
107
+ axes[i][1].set_title("Ground Truth", fontsize=11)
108
+ axes[i][1].axis("off")
109
+
110
+ axes[i][2].imshow(pred, cmap="gray", vmin=0, vmax=255)
111
+ axes[i][2].set_title(f"Predicted — \"{ex['prompt']}\" (IoU {ex['iou']:.2f})", fontsize=11)
112
+ axes[i][2].axis("off")
113
+
114
+ plt.tight_layout()
115
+ plt.savefig(output_path, dpi=150, bbox_inches="tight", facecolor="white")
116
+ plt.close()
117
+ print(f"Saved → {output_path}")
118
+
119
+
120
+ if __name__ == "__main__":
121
+ figures_dir = PROJECT_ROOT / "reports" / "figures"
122
+ scores = score_all()
123
+
124
+ # Best predictions (3 per class)
125
+ best = pick_ranked(scores, n_per_class=3, best=True)
126
+ generate_grid(best, figures_dir / "best_predictions.png",
127
+ title="Best Test-Set Predictions (by IoU)")
128
+
129
+ # Worst predictions (3 per class) — only samples where model actually predicted something
130
+ worst = pick_ranked(scores, n_per_class=3, best=False)
131
+ generate_grid(worst, figures_dir / "failure_cases.png",
132
+ title="Failure Cases — Worst Test-Set Predictions (by IoU)")
src/data/__init__.py ADDED
File without changes
src/data/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (163 Bytes). View file
 
src/data/__pycache__/dataset.cpython-311.pyc ADDED
Binary file (6.37 kB). View file
 
src/data/__pycache__/download.cpython-311.pyc ADDED
Binary file (3.26 kB). View file
 
src/data/__pycache__/preprocess.cpython-311.pyc ADDED
Binary file (15.8 kB). View file
 
src/data/dataset.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch Dataset for CLIPSeg fine-tuning."""
2
+
3
+ import json
4
+ import random
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+ from torch.utils.data import Dataset
11
+ from transformers import CLIPSegProcessor
12
+
13
+
14
+ class DrywallSegDataset(Dataset):
15
+ """Dataset that yields (image, mask, prompt) tuples for CLIPSeg."""
16
+
17
+ def __init__(self, split_json: str, processor: CLIPSegProcessor, image_size: int = 352):
18
+ with open(split_json) as f:
19
+ self.records = json.load(f)
20
+ self.processor = processor
21
+ self.image_size = image_size
22
+
23
+ def __len__(self):
24
+ return len(self.records)
25
+
26
+ def __getitem__(self, idx):
27
+ rec = self.records[idx]
28
+
29
+ # Load image
30
+ image = Image.open(rec["image_path"]).convert("RGB")
31
+
32
+ # Load mask and resize to CLIPSeg resolution
33
+ mask = Image.open(rec["mask_path"]).convert("L")
34
+ mask = mask.resize((self.image_size, self.image_size), Image.NEAREST)
35
+ mask_tensor = torch.from_numpy(np.array(mask)).float() / 255.0 # {0.0, 1.0}
36
+
37
+ # Random prompt synonym
38
+ prompt = random.choice(rec["prompts"])
39
+
40
+ # Process through CLIPSeg processor
41
+ inputs = self.processor(
42
+ text=[prompt],
43
+ images=[image],
44
+ return_tensors="pt",
45
+ padding=True,
46
+ )
47
+
48
+ return {
49
+ "pixel_values": inputs["pixel_values"].squeeze(0),
50
+ "input_ids": inputs["input_ids"].squeeze(0),
51
+ "attention_mask": inputs["attention_mask"].squeeze(0),
52
+ "labels": mask_tensor,
53
+ "dataset": rec["dataset"],
54
+ "image_path": rec["image_path"],
55
+ "mask_path": rec["mask_path"],
56
+ "prompt": prompt,
57
+ "orig_width": rec["width"],
58
+ "orig_height": rec["height"],
59
+ }
60
+
61
+
62
+ def collate_fn(batch):
63
+ """Custom collation: pad input_ids and attention_mask to max length in batch."""
64
+ max_len = max(item["input_ids"].shape[0] for item in batch)
65
+
66
+ pixel_values = torch.stack([item["pixel_values"] for item in batch])
67
+ labels = torch.stack([item["labels"] for item in batch])
68
+
69
+ input_ids = []
70
+ attention_masks = []
71
+ for item in batch:
72
+ ids = item["input_ids"]
73
+ mask = item["attention_mask"]
74
+ pad_len = max_len - ids.shape[0]
75
+ if pad_len > 0:
76
+ ids = torch.cat([ids, torch.zeros(pad_len, dtype=ids.dtype)])
77
+ mask = torch.cat([mask, torch.zeros(pad_len, dtype=mask.dtype)])
78
+ input_ids.append(ids)
79
+ attention_masks.append(mask)
80
+
81
+ return {
82
+ "pixel_values": pixel_values,
83
+ "input_ids": torch.stack(input_ids),
84
+ "attention_mask": torch.stack(attention_masks),
85
+ "labels": labels,
86
+ "dataset": [item["dataset"] for item in batch],
87
+ "image_path": [item["image_path"] for item in batch],
88
+ "mask_path": [item["mask_path"] for item in batch],
89
+ "prompt": [item["prompt"] for item in batch],
90
+ "orig_width": [item["orig_width"] for item in batch],
91
+ "orig_height": [item["orig_height"] for item in batch],
92
+ }
src/data/download.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset download instructions.
2
+
3
+ Both datasets must be downloaded manually from Roboflow Universe in COCO format.
4
+ The Roboflow API cannot be used because the cracks dataset (cracks-3ii36) has
5
+ no generated versions — the owner never created an exportable version.
6
+
7
+ Download locations:
8
+ - Taping: https://universe.roboflow.com/objectdetect-pu6rn/drywall-join-detect
9
+ → Export as COCO, place under data/raw/taping/
10
+ - Cracks: https://universe.roboflow.com/fyp-ny1jt/cracks-3ii36
11
+ → Export as COCO, place under data/raw/cracks/
12
+
13
+ Expected structure after download:
14
+ data/raw/
15
+ ├── taping/
16
+ │ ├── train/
17
+ │ │ ├── _annotations.coco.json
18
+ │ │ └── *.jpg
19
+ │ └── valid/
20
+ │ ├── _annotations.coco.json
21
+ │ └── *.jpg
22
+ └── cracks/
23
+ └── train/
24
+ ├── _annotations.coco.json
25
+ └── *.jpg
26
+ """
src/data/preprocess.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inspect annotations, generate masks, create train/val/test splits."""
2
+
3
+ import json
4
+ import random
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ from PIL import Image
9
+ from pycocotools.coco import COCO
10
+ from pycocotools import mask as mask_utils
11
+
12
+ RAW_DIR = Path(__file__).resolve().parents[2] / "data" / "raw"
13
+ PROCESSED_DIR = Path(__file__).resolve().parents[2] / "data" / "processed"
14
+ SPLITS_DIR = Path(__file__).resolve().parents[2] / "data" / "splits"
15
+
16
+
17
+ def inspect_dataset(coco_json_path: str) -> dict:
18
+ """Check what annotation types exist in a COCO JSON file."""
19
+ with open(coco_json_path) as f:
20
+ data = json.load(f)
21
+
22
+ total = len(data.get("annotations", []))
23
+ has_seg = 0
24
+ has_bbox_only = 0
25
+
26
+ for ann in data.get("annotations", []):
27
+ seg = ann.get("segmentation")
28
+ if seg and isinstance(seg, list) and len(seg) > 0 and len(seg[0]) >= 6:
29
+ has_seg += 1
30
+ elif seg and isinstance(seg, dict): # RLE format
31
+ has_seg += 1
32
+ else:
33
+ has_bbox_only += 1
34
+
35
+ return {
36
+ "total_annotations": total,
37
+ "total_images": len(data.get("images", [])),
38
+ "has_segmentation": has_seg,
39
+ "has_bbox_only": has_bbox_only,
40
+ "annotation_type": "segmentation" if has_seg > has_bbox_only else "bbox_only",
41
+ "categories": [c["name"] for c in data.get("categories", [])],
42
+ }
43
+
44
+
45
+ def render_masks_from_coco(coco_json_path: str, images_dir: str, output_dir: str) -> list[dict]:
46
+ """Render binary masks from COCO polygon/RLE annotations.
47
+
48
+ Returns list of {image_path, mask_path, image_id, width, height}.
49
+ """
50
+ output_dir = Path(output_dir)
51
+ output_dir.mkdir(parents=True, exist_ok=True)
52
+
53
+ coco = COCO(coco_json_path)
54
+ records = []
55
+
56
+ for img_id in sorted(coco.getImgIds()):
57
+ img_info = coco.loadImgs(img_id)[0]
58
+ h, w = img_info["height"], img_info["width"]
59
+
60
+ ann_ids = coco.getAnnIds(imgIds=img_id)
61
+ anns = coco.loadAnns(ann_ids)
62
+
63
+ if not anns:
64
+ continue
65
+
66
+ # Merge all annotations into one binary mask
67
+ combined = np.zeros((h, w), dtype=np.uint8)
68
+ for ann in anns:
69
+ seg = ann.get("segmentation")
70
+ # Skip annotations with empty or invalid segmentation
71
+ if not seg:
72
+ continue
73
+ if isinstance(seg, list) and (len(seg) == 0 or (len(seg) > 0 and isinstance(seg[0], list) and len(seg[0]) < 6)):
74
+ continue
75
+ if isinstance(seg, list) and len(seg) > 0 and not isinstance(seg[0], list) and len(seg) < 6:
76
+ continue
77
+ try:
78
+ rle = coco.annToRLE(ann)
79
+ m = mask_utils.decode(rle)
80
+ combined = np.maximum(combined, m)
81
+ except (IndexError, ValueError):
82
+ # Fall back to bbox if segmentation decode fails
83
+ if "bbox" in ann:
84
+ x, y, bw, bh = [int(v) for v in ann["bbox"]]
85
+ combined[y:y+bh, x:x+bw] = 1
86
+
87
+ mask_img = Image.fromarray(combined * 255, mode="L")
88
+ mask_name = Path(img_info["file_name"]).stem + "_mask.png"
89
+ mask_path = output_dir / mask_name
90
+ mask_img.save(mask_path)
91
+
92
+ image_path = Path(images_dir) / img_info["file_name"]
93
+ records.append({
94
+ "image_path": str(image_path),
95
+ "mask_path": str(mask_path),
96
+ "image_id": img_id,
97
+ "width": w,
98
+ "height": h,
99
+ })
100
+
101
+ return records
102
+
103
+
104
+ def render_masks_from_bboxes(coco_json_path: str, images_dir: str, output_dir: str) -> list[dict]:
105
+ """Create filled-rectangle masks from bounding boxes (fallback when no segmentation)."""
106
+ output_dir = Path(output_dir)
107
+ output_dir.mkdir(parents=True, exist_ok=True)
108
+
109
+ with open(coco_json_path) as f:
110
+ data = json.load(f)
111
+
112
+ img_lookup = {img["id"]: img for img in data["images"]}
113
+ anns_by_img: dict[int, list] = {}
114
+ for ann in data["annotations"]:
115
+ anns_by_img.setdefault(ann["image_id"], []).append(ann)
116
+
117
+ records = []
118
+ for img_id, img_info in sorted(img_lookup.items()):
119
+ anns = anns_by_img.get(img_id, [])
120
+ if not anns:
121
+ continue
122
+
123
+ h, w = img_info["height"], img_info["width"]
124
+ combined = np.zeros((h, w), dtype=np.uint8)
125
+
126
+ for ann in anns:
127
+ x, y, bw, bh = [int(v) for v in ann["bbox"]]
128
+ combined[y:y+bh, x:x+bw] = 1
129
+
130
+ mask_img = Image.fromarray(combined * 255, mode="L")
131
+ mask_name = Path(img_info["file_name"]).stem + "_mask.png"
132
+ mask_path = output_dir / mask_name
133
+ mask_img.save(mask_path)
134
+
135
+ image_path = Path(images_dir) / img_info["file_name"]
136
+ records.append({
137
+ "image_path": str(image_path),
138
+ "mask_path": str(mask_path),
139
+ "image_id": img_id,
140
+ "width": w,
141
+ "height": h,
142
+ })
143
+
144
+ return records
145
+
146
+
147
+ def find_coco_json(dataset_dir: Path) -> tuple[str, str] | None:
148
+ """Find the COCO JSON and images directory in a Roboflow download."""
149
+ for split in ["train", "valid", "test"]:
150
+ json_path = dataset_dir / split / "_annotations.coco.json"
151
+ if json_path.exists():
152
+ return str(json_path), str(dataset_dir / split)
153
+ # Single-folder layout
154
+ for json_path in dataset_dir.rglob("_annotations.coco.json"):
155
+ return str(json_path), str(json_path.parent)
156
+ return None
157
+
158
+
159
+ def process_dataset(name: str, dataset_dir: Path, prompt_synonyms: list[str]) -> list[dict]:
160
+ """Process a single dataset: inspect, render masks, return records with prompts."""
161
+ records = []
162
+ mask_dir = PROCESSED_DIR / name / "masks"
163
+
164
+ # Process each split folder (train/valid/test from Roboflow)
165
+ for split_dir in sorted(dataset_dir.iterdir()):
166
+ if not split_dir.is_dir():
167
+ continue
168
+ json_path = split_dir / "_annotations.coco.json"
169
+ if not json_path.exists():
170
+ continue
171
+
172
+ print(f"\n Processing {name}/{split_dir.name}...")
173
+ info = inspect_dataset(str(json_path))
174
+ print(f" Images: {info['total_images']}, Annotations: {info['total_annotations']}")
175
+ print(f" Type: {info['annotation_type']}, Categories: {info['categories']}")
176
+
177
+ split_mask_dir = mask_dir / split_dir.name
178
+ if info["annotation_type"] == "segmentation":
179
+ split_records = render_masks_from_coco(
180
+ str(json_path), str(split_dir), str(split_mask_dir)
181
+ )
182
+ else:
183
+ print(f" WARNING: bbox-only annotations, using filled rectangles")
184
+ split_records = render_masks_from_bboxes(
185
+ str(json_path), str(split_dir), str(split_mask_dir)
186
+ )
187
+
188
+ for r in split_records:
189
+ r["dataset"] = name
190
+ r["prompts"] = prompt_synonyms
191
+ records.extend(split_records)
192
+
193
+ return records
194
+
195
+
196
+ def create_splits(records: list[dict], ratios: tuple = (0.70, 0.15, 0.15), seed: int = 42):
197
+ """Split records into train/val/test, stratified by dataset."""
198
+ random.seed(seed)
199
+
200
+ by_dataset: dict[str, list] = {}
201
+ for r in records:
202
+ by_dataset.setdefault(r["dataset"], []).append(r)
203
+
204
+ train, val, test = [], [], []
205
+ for name, recs in by_dataset.items():
206
+ random.shuffle(recs)
207
+ n = len(recs)
208
+ n_train = int(n * ratios[0])
209
+ n_val = int(n * ratios[1])
210
+ train.extend(recs[:n_train])
211
+ val.extend(recs[n_train:n_train + n_val])
212
+ test.extend(recs[n_train + n_val:])
213
+
214
+ random.shuffle(train)
215
+ random.shuffle(val)
216
+ random.shuffle(test)
217
+
218
+ SPLITS_DIR.mkdir(parents=True, exist_ok=True)
219
+ for split_name, split_data in [("train", train), ("val", val), ("test", test)]:
220
+ path = SPLITS_DIR / f"{split_name}.json"
221
+ with open(path, "w") as f:
222
+ json.dump(split_data, f, indent=2)
223
+ print(f" {split_name}: {len(split_data)} samples -> {path}")
224
+
225
+ return {"train": train, "val": val, "test": test}
226
+
227
+
228
+ def run(config: dict):
229
+ """Run full preprocessing pipeline."""
230
+ synonyms = config["data"]["prompt_synonyms"]
231
+ ratios = tuple(config["data"]["split_ratios"])
232
+
233
+ all_records = []
234
+ for name in ["taping", "cracks"]:
235
+ dataset_dir = RAW_DIR / name
236
+ if not dataset_dir.exists():
237
+ print(f"WARNING: {dataset_dir} not found, skipping {name}")
238
+ continue
239
+ print(f"\n{'='*60}")
240
+ print(f"Processing dataset: {name}")
241
+ print(f"{'='*60}")
242
+ records = process_dataset(name, dataset_dir, synonyms[name])
243
+ all_records.extend(records)
244
+ print(f" Total records for {name}: {len(records)}")
245
+
246
+ print(f"\n{'='*60}")
247
+ print(f"Creating splits (total: {len(all_records)} records)")
248
+ print(f"{'='*60}")
249
+ splits = create_splits(all_records, ratios=ratios, seed=config["seed"])
250
+ return splits
251
+
252
+
253
+ if __name__ == "__main__":
254
+ import yaml
255
+ config_path = Path(__file__).resolve().parents[2] / "configs" / "train_config.yaml"
256
+ with open(config_path) as f:
257
+ config = yaml.safe_load(f)
258
+ run(config)
src/evaluate.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Evaluate trained CLIPSeg model and generate prediction masks + visuals."""
2
+
3
+ import json
4
+ import time
5
+ from pathlib import Path
6
+
7
+ import matplotlib
8
+ matplotlib.use("Agg")
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ import torch
12
+ import yaml
13
+ from PIL import Image
14
+ from torch.utils.data import DataLoader
15
+ from tqdm import tqdm
16
+
17
+ from src.data.dataset import DrywallSegDataset, collate_fn
18
+ from src.model.clipseg_wrapper import load_model_and_processor
19
+ from src.train import compute_metrics, get_device
20
+
21
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
22
+
23
+
24
+ def evaluate(config_path: str | None = None):
25
+ config_path = config_path or str(PROJECT_ROOT / "configs" / "train_config.yaml")
26
+ with open(config_path) as f:
27
+ config = yaml.safe_load(f)
28
+
29
+ device = get_device()
30
+ threshold = config["evaluation"]["threshold"]
31
+
32
+ # Load model with best checkpoint
33
+ model, processor = load_model_and_processor(config["model"]["name"], config["model"]["freeze_backbone"])
34
+ ckpt_path = PROJECT_ROOT / "outputs" / "checkpoints" / "best_model.pt"
35
+ model.load_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
36
+ model = model.to(device)
37
+ model.eval()
38
+
39
+ # Model size
40
+ model_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024)
41
+
42
+ # Test data
43
+ splits_dir = PROJECT_ROOT / "data" / "splits"
44
+ test_ds = DrywallSegDataset(str(splits_dir / "test.json"), processor, config["data"]["image_size"])
45
+ test_loader = DataLoader(test_ds, batch_size=config["training"]["batch_size"], shuffle=False,
46
+ collate_fn=collate_fn, num_workers=0)
47
+
48
+ # Run evaluation
49
+ masks_dir = PROJECT_ROOT / "outputs" / "masks"
50
+ masks_dir.mkdir(parents=True, exist_ok=True)
51
+
52
+ all_metrics = {"taping": {"miou": [], "dice": []}, "cracks": {"miou": [], "dice": []}}
53
+ inference_times = []
54
+ visual_examples = [] # Collect for visualization
55
+ total_samples = 0
56
+
57
+ with torch.no_grad():
58
+ for batch in tqdm(test_loader, desc="Evaluating"):
59
+ pixel_values = batch["pixel_values"].to(device)
60
+ input_ids = batch["input_ids"].to(device)
61
+ attention_mask = batch["attention_mask"].to(device)
62
+ labels = batch["labels"].to(device)
63
+
64
+ t0 = time.time()
65
+ outputs = model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask)
66
+ inference_times.append((time.time() - t0) / pixel_values.size(0))
67
+
68
+ logits = outputs.logits
69
+ metrics = compute_metrics(logits, labels, threshold)
70
+ preds = (torch.sigmoid(logits) > threshold).cpu().numpy().astype(np.uint8)
71
+
72
+ for i in range(pixel_values.size(0)):
73
+ ds_name = batch["dataset"][i]
74
+ all_metrics[ds_name]["miou"].append(metrics["miou"])
75
+ all_metrics[ds_name]["dice"].append(metrics["dice"])
76
+
77
+ # Save prediction mask at original resolution
78
+ orig_w, orig_h = batch["orig_width"][i], batch["orig_height"][i]
79
+ pred_mask = Image.fromarray(preds[i] * 255, mode="L")
80
+ pred_mask = pred_mask.resize((orig_w, orig_h), Image.NEAREST)
81
+
82
+ prompt_slug = batch["prompt"][i].replace(" ", "_")
83
+ img_stem = Path(batch["image_path"][i]).stem
84
+ mask_filename = f"{img_stem}__{prompt_slug}.png"
85
+ pred_mask.save(masks_dir / mask_filename)
86
+
87
+ total_samples += 1
88
+
89
+ # Collect visual examples
90
+ if len(visual_examples) < config["evaluation"]["num_visual_examples"]:
91
+ visual_examples.append({
92
+ "image_path": batch["image_path"][i],
93
+ "mask_path": batch["mask_path"][i],
94
+ "pred_mask": preds[i],
95
+ "prompt": batch["prompt"][i],
96
+ "dataset": ds_name,
97
+ })
98
+
99
+ # Aggregate metrics
100
+ results = {"per_class": {}, "overall": {}}
101
+ all_miou, all_dice = [], []
102
+ for ds_name in ["taping", "cracks"]:
103
+ m = all_metrics[ds_name]
104
+ if m["miou"]:
105
+ results["per_class"][ds_name] = {
106
+ "miou": round(float(np.mean(m["miou"])), 4),
107
+ "dice": round(float(np.mean(m["dice"])), 4),
108
+ "samples": len(m["miou"]),
109
+ }
110
+ all_miou.extend(m["miou"])
111
+ all_dice.extend(m["dice"])
112
+
113
+ results["overall"] = {
114
+ "miou": round(float(np.mean(all_miou)), 4) if all_miou else 0,
115
+ "dice": round(float(np.mean(all_dice)), 4) if all_dice else 0,
116
+ "total_samples": total_samples,
117
+ }
118
+ results["runtime"] = {
119
+ "avg_inference_ms": round(float(np.mean(inference_times)) * 1000, 1),
120
+ "model_size_mb": round(model_size_mb, 1),
121
+ }
122
+
123
+ # Save results
124
+ log_dir = PROJECT_ROOT / "outputs" / "logs"
125
+ log_dir.mkdir(parents=True, exist_ok=True)
126
+ with open(log_dir / "test_results.json", "w") as f:
127
+ json.dump(results, f, indent=2)
128
+
129
+ print(f"\n{'='*60}")
130
+ print(f"Test Results")
131
+ print(f"{'='*60}")
132
+ for ds_name, m in results["per_class"].items():
133
+ print(f" {ds_name:>10s}: mIoU={m['miou']:.4f} Dice={m['dice']:.4f} (n={m['samples']})")
134
+ print(f" {'overall':>10s}: mIoU={results['overall']['miou']:.4f} Dice={results['overall']['dice']:.4f}")
135
+ print(f" Avg inference: {results['runtime']['avg_inference_ms']:.1f} ms/image")
136
+ print(f" Model size: {results['runtime']['model_size_mb']:.1f} MB")
137
+
138
+ # Generate visual comparison figures
139
+ _generate_visuals(visual_examples, PROJECT_ROOT / "reports" / "figures")
140
+
141
+ return results
142
+
143
+
144
+ def _generate_visuals(examples: list[dict], output_dir: Path):
145
+ """Generate original | GT | prediction comparison figures."""
146
+ output_dir.mkdir(parents=True, exist_ok=True)
147
+
148
+ if not examples:
149
+ return
150
+
151
+ fig, axes = plt.subplots(len(examples), 3, figsize=(12, 4 * len(examples)))
152
+ if len(examples) == 1:
153
+ axes = [axes]
154
+
155
+ for i, ex in enumerate(examples):
156
+ img = Image.open(ex["image_path"]).convert("RGB")
157
+ gt = Image.open(ex["mask_path"]).convert("L")
158
+ pred = Image.fromarray(ex["pred_mask"] * 255, mode="L")
159
+
160
+ axes[i][0].imshow(img)
161
+ axes[i][0].set_title(f"Original ({ex['dataset']})")
162
+ axes[i][0].axis("off")
163
+
164
+ axes[i][1].imshow(gt, cmap="gray", vmin=0, vmax=255)
165
+ axes[i][1].set_title("Ground Truth")
166
+ axes[i][1].axis("off")
167
+
168
+ axes[i][2].imshow(pred, cmap="gray", vmin=0, vmax=255)
169
+ axes[i][2].set_title(f"Prediction: \"{ex['prompt']}\"")
170
+ axes[i][2].axis("off")
171
+
172
+ plt.tight_layout()
173
+ plt.savefig(output_dir / "visual_comparison.png", dpi=150, bbox_inches="tight")
174
+ plt.close()
175
+ print(f"Saved visual comparison to {output_dir / 'visual_comparison.png'}")
176
+
177
+
178
+ if __name__ == "__main__":
179
+ evaluate()
src/model/__init__.py ADDED
File without changes
src/model/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (164 Bytes). View file
 
src/model/__pycache__/clipseg_wrapper.cpython-311.pyc ADDED
Binary file (1.89 kB). View file
 
src/model/__pycache__/losses.cpython-311.pyc ADDED
Binary file (3.34 kB). View file
 
src/model/clipseg_wrapper.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CLIPSeg model loading and freezing utilities."""
2
+
3
+ from transformers import CLIPSegForImageSegmentation, CLIPSegProcessor
4
+
5
+
6
+ def load_model_and_processor(model_name: str = "CIDAS/clipseg-rd64-refined", freeze_backbone: bool = True):
7
+ """Load CLIPSeg model and processor, optionally freezing the backbone."""
8
+ model = CLIPSegForImageSegmentation.from_pretrained(model_name)
9
+ processor = CLIPSegProcessor.from_pretrained(model_name)
10
+
11
+ if freeze_backbone:
12
+ trainable, frozen = 0, 0
13
+ for name, param in model.named_parameters():
14
+ if "decoder" in name:
15
+ param.requires_grad = True
16
+ trainable += param.numel()
17
+ else:
18
+ param.requires_grad = False
19
+ frozen += param.numel()
20
+ print(f"Parameters — trainable (decoder): {trainable:,} | frozen (backbone): {frozen:,}")
21
+ else:
22
+ trainable = sum(p.numel() for p in model.parameters())
23
+ print(f"Parameters — all trainable: {trainable:,}")
24
+
25
+ return model, processor
src/model/losses.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom loss functions for segmentation."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class DiceLoss(nn.Module):
9
+ """Soft Dice loss operating on logits."""
10
+
11
+ def __init__(self, smooth: float = 1.0):
12
+ super().__init__()
13
+ self.smooth = smooth
14
+
15
+ def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
16
+ probs = torch.sigmoid(logits)
17
+ probs_flat = probs.view(probs.size(0), -1)
18
+ targets_flat = targets.view(targets.size(0), -1)
19
+
20
+ intersection = (probs_flat * targets_flat).sum(dim=1)
21
+ union = probs_flat.sum(dim=1) + targets_flat.sum(dim=1)
22
+
23
+ dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
24
+ return 1.0 - dice.mean()
25
+
26
+
27
+ class BCEDiceLoss(nn.Module):
28
+ """Weighted combination of BCE and Dice loss."""
29
+
30
+ def __init__(self, bce_weight: float = 0.5, dice_weight: float = 0.5):
31
+ super().__init__()
32
+ self.bce_weight = bce_weight
33
+ self.dice_weight = dice_weight
34
+ self.bce = nn.BCEWithLogitsLoss()
35
+ self.dice = DiceLoss()
36
+
37
+ def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
38
+ return self.bce_weight * self.bce(logits, targets) + self.dice_weight * self.dice(logits, targets)
src/predict.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Standalone single-image inference for CLIPSeg."""
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import torch
8
+ import yaml
9
+ from PIL import Image
10
+
11
+ from src.model.clipseg_wrapper import load_model_and_processor
12
+ from src.train import get_device
13
+
14
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
15
+
16
+
17
+ def predict(image_path: str, prompt: str, config_path: str | None = None, output_path: str | None = None):
18
+ config_path = config_path or str(PROJECT_ROOT / "configs" / "train_config.yaml")
19
+ with open(config_path) as f:
20
+ config = yaml.safe_load(f)
21
+
22
+ device = get_device()
23
+ model, processor = load_model_and_processor(config["model"]["name"], config["model"]["freeze_backbone"])
24
+ ckpt = PROJECT_ROOT / "outputs" / "checkpoints" / "best_model.pt"
25
+ model.load_state_dict(torch.load(ckpt, map_location="cpu", weights_only=True))
26
+ model = model.to(device).eval()
27
+
28
+ image = Image.open(image_path).convert("RGB")
29
+ orig_w, orig_h = image.size
30
+
31
+ inputs = processor(text=[prompt], images=[image], return_tensors="pt", padding=True)
32
+ inputs = {k: v.to(device) for k, v in inputs.items()}
33
+
34
+ with torch.no_grad():
35
+ logits = model(**inputs).logits
36
+
37
+ pred = (torch.sigmoid(logits[0]) > config["evaluation"]["threshold"]).cpu().numpy().astype(np.uint8)
38
+ mask = Image.fromarray(pred * 255, mode="L").resize((orig_w, orig_h), Image.NEAREST)
39
+
40
+ if output_path is None:
41
+ stem = Path(image_path).stem
42
+ slug = prompt.replace(" ", "_")
43
+ output_path = str(PROJECT_ROOT / "outputs" / "masks" / f"{stem}__{slug}.png")
44
+
45
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
46
+ mask.save(output_path)
47
+ print(f"Saved mask to {output_path}")
48
+ return mask
49
+
50
+
51
+ if __name__ == "__main__":
52
+ parser = argparse.ArgumentParser()
53
+ parser.add_argument("image", help="Path to input image")
54
+ parser.add_argument("prompt", help="Text prompt, e.g. 'segment crack'")
55
+ parser.add_argument("--output", help="Output mask path")
56
+ args = parser.parse_args()
57
+ predict(args.image, args.prompt, output_path=args.output)
src/train.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Training loop for CLIPSeg fine-tuning."""
2
+
3
+ import json
4
+ import time
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import torch
9
+ import yaml
10
+ from torch.optim import AdamW
11
+ from torch.optim.lr_scheduler import CosineAnnealingLR
12
+ from torch.utils.data import DataLoader
13
+ from tqdm import tqdm
14
+
15
+ from src.data.dataset import DrywallSegDataset, collate_fn
16
+ from src.model.clipseg_wrapper import load_model_and_processor
17
+ from src.model.losses import BCEDiceLoss
18
+
19
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
20
+
21
+
22
+ def compute_metrics(logits: torch.Tensor, targets: torch.Tensor, threshold: float = 0.5):
23
+ """Compute mIoU and Dice for a batch."""
24
+ preds = (torch.sigmoid(logits) > threshold).float()
25
+ targets = (targets > 0.5).float()
26
+
27
+ intersection = (preds * targets).sum(dim=(1, 2))
28
+ union = preds.sum(dim=(1, 2)) + targets.sum(dim=(1, 2)) - intersection
29
+ iou = (intersection + 1e-6) / (union + 1e-6)
30
+
31
+ dice = (2 * intersection + 1e-6) / (preds.sum(dim=(1, 2)) + targets.sum(dim=(1, 2)) + 1e-6)
32
+
33
+ return {"miou": iou.mean().item(), "dice": dice.mean().item()}
34
+
35
+
36
+ def get_device():
37
+ """Select best available device."""
38
+ if torch.backends.mps.is_available():
39
+ return torch.device("mps")
40
+ if torch.cuda.is_available():
41
+ return torch.device("cuda")
42
+ return torch.device("cpu")
43
+
44
+
45
+ def train(config_path: str | None = None):
46
+ config_path = config_path or str(PROJECT_ROOT / "configs" / "train_config.yaml")
47
+ with open(config_path) as f:
48
+ config = yaml.safe_load(f)
49
+
50
+ # Seed
51
+ seed = config["seed"]
52
+ torch.manual_seed(seed)
53
+ np.random.seed(seed)
54
+
55
+ device = get_device()
56
+ print(f"Device: {device}")
57
+
58
+ # Model
59
+ model, processor = load_model_and_processor(
60
+ config["model"]["name"],
61
+ config["model"]["freeze_backbone"],
62
+ )
63
+ model = model.to(device)
64
+
65
+ # Data
66
+ splits_dir = PROJECT_ROOT / "data" / "splits"
67
+ train_ds = DrywallSegDataset(str(splits_dir / "train.json"), processor, config["data"]["image_size"])
68
+ val_ds = DrywallSegDataset(str(splits_dir / "val.json"), processor, config["data"]["image_size"])
69
+
70
+ tc = config["training"]
71
+ train_loader = DataLoader(train_ds, batch_size=tc["batch_size"], shuffle=True,
72
+ collate_fn=collate_fn, num_workers=tc["num_workers"])
73
+ val_loader = DataLoader(val_ds, batch_size=tc["batch_size"], shuffle=False,
74
+ collate_fn=collate_fn, num_workers=tc["num_workers"])
75
+
76
+ # Loss, optimizer, scheduler
77
+ criterion = BCEDiceLoss(tc["bce_weight"], tc["dice_weight"])
78
+ optimizer = AdamW(
79
+ [p for p in model.parameters() if p.requires_grad],
80
+ lr=tc["lr"],
81
+ weight_decay=tc["weight_decay"],
82
+ )
83
+ scheduler = CosineAnnealingLR(optimizer, T_max=tc["epochs"])
84
+
85
+ # Training state
86
+ best_miou = 0.0
87
+ patience_counter = 0
88
+ history = {"train_loss": [], "val_loss": [], "val_miou": [], "val_dice": []}
89
+ ckpt_dir = PROJECT_ROOT / "outputs" / "checkpoints"
90
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
91
+ log_dir = PROJECT_ROOT / "outputs" / "logs"
92
+ log_dir.mkdir(parents=True, exist_ok=True)
93
+
94
+ start_time = time.time()
95
+
96
+ for epoch in range(1, tc["epochs"] + 1):
97
+ # ---- Train ----
98
+ model.train()
99
+ train_losses = []
100
+ for batch in tqdm(train_loader, desc=f"Epoch {epoch}/{tc['epochs']} [train]", leave=False):
101
+ pixel_values = batch["pixel_values"].to(device)
102
+ input_ids = batch["input_ids"].to(device)
103
+ attention_mask = batch["attention_mask"].to(device)
104
+ labels = batch["labels"].to(device)
105
+
106
+ outputs = model(
107
+ pixel_values=pixel_values,
108
+ input_ids=input_ids,
109
+ attention_mask=attention_mask,
110
+ )
111
+ logits = outputs.logits
112
+ loss = criterion(logits, labels)
113
+
114
+ optimizer.zero_grad()
115
+ loss.backward()
116
+ optimizer.step()
117
+ train_losses.append(loss.item())
118
+
119
+ scheduler.step()
120
+ avg_train_loss = np.mean(train_losses)
121
+
122
+ # ---- Validate ----
123
+ model.eval()
124
+ val_losses, val_mious, val_dices = [], [], []
125
+ with torch.no_grad():
126
+ for batch in tqdm(val_loader, desc=f"Epoch {epoch}/{tc['epochs']} [val]", leave=False):
127
+ pixel_values = batch["pixel_values"].to(device)
128
+ input_ids = batch["input_ids"].to(device)
129
+ attention_mask = batch["attention_mask"].to(device)
130
+ labels = batch["labels"].to(device)
131
+
132
+ outputs = model(
133
+ pixel_values=pixel_values,
134
+ input_ids=input_ids,
135
+ attention_mask=attention_mask,
136
+ )
137
+ logits = outputs.logits
138
+ loss = criterion(logits, labels)
139
+ metrics = compute_metrics(logits, labels)
140
+
141
+ val_losses.append(loss.item())
142
+ val_mious.append(metrics["miou"])
143
+ val_dices.append(metrics["dice"])
144
+
145
+ avg_val_loss = np.mean(val_losses)
146
+ avg_val_miou = np.mean(val_mious)
147
+ avg_val_dice = np.mean(val_dices)
148
+
149
+ history["train_loss"].append(float(avg_train_loss))
150
+ history["val_loss"].append(float(avg_val_loss))
151
+ history["val_miou"].append(float(avg_val_miou))
152
+ history["val_dice"].append(float(avg_val_dice))
153
+
154
+ print(f"Epoch {epoch:3d} | train_loss={avg_train_loss:.4f} | val_loss={avg_val_loss:.4f} | "
155
+ f"val_mIoU={avg_val_miou:.4f} | val_Dice={avg_val_dice:.4f}")
156
+
157
+ # Checkpoint
158
+ if avg_val_miou > best_miou:
159
+ best_miou = avg_val_miou
160
+ patience_counter = 0
161
+ torch.save(model.state_dict(), ckpt_dir / "best_model.pt")
162
+ print(f" -> New best mIoU: {best_miou:.4f}, saved checkpoint")
163
+ else:
164
+ patience_counter += 1
165
+ if patience_counter >= tc["patience"]:
166
+ print(f" Early stopping at epoch {epoch} (patience={tc['patience']})")
167
+ break
168
+
169
+ total_time = time.time() - start_time
170
+
171
+ # Save history & summary
172
+ with open(log_dir / "training_history.json", "w") as f:
173
+ json.dump(history, f, indent=2)
174
+
175
+ summary = {
176
+ "total_epochs": epoch,
177
+ "best_val_miou": float(best_miou),
178
+ "total_time_seconds": round(total_time, 1),
179
+ "total_time_minutes": round(total_time / 60, 1),
180
+ "device": str(device),
181
+ "train_samples": len(train_ds),
182
+ "val_samples": len(val_ds),
183
+ "seed": seed,
184
+ }
185
+ with open(log_dir / "training_summary.json", "w") as f:
186
+ json.dump(summary, f, indent=2)
187
+
188
+ print(f"\nTraining complete in {summary['total_time_minutes']} min")
189
+ print(f"Best val mIoU: {best_miou:.4f}")
190
+ return model, history
191
+
192
+
193
+ if __name__ == "__main__":
194
+ train()