Upload folder using huggingface_hub
Browse files- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-311.pyc +0 -0
- src/__pycache__/best_predictions.cpython-311.pyc +0 -0
- src/__pycache__/evaluate.cpython-311.pyc +0 -0
- src/__pycache__/train.cpython-311.pyc +0 -0
- src/best_predictions.py +132 -0
- src/data/__init__.py +0 -0
- src/data/__pycache__/__init__.cpython-311.pyc +0 -0
- src/data/__pycache__/dataset.cpython-311.pyc +0 -0
- src/data/__pycache__/download.cpython-311.pyc +0 -0
- src/data/__pycache__/preprocess.cpython-311.pyc +0 -0
- src/data/dataset.py +92 -0
- src/data/download.py +26 -0
- src/data/preprocess.py +258 -0
- src/evaluate.py +179 -0
- src/model/__init__.py +0 -0
- src/model/__pycache__/__init__.cpython-311.pyc +0 -0
- src/model/__pycache__/clipseg_wrapper.cpython-311.pyc +0 -0
- src/model/__pycache__/losses.cpython-311.pyc +0 -0
- src/model/clipseg_wrapper.py +25 -0
- src/model/losses.py +38 -0
- src/predict.py +57 -0
- src/train.py +194 -0
src/__init__.py
ADDED
|
File without changes
|
src/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (158 Bytes). View file
|
|
|
src/__pycache__/best_predictions.cpython-311.pyc
ADDED
|
Binary file (8.72 kB). View file
|
|
|
src/__pycache__/evaluate.cpython-311.pyc
ADDED
|
Binary file (12.4 kB). View file
|
|
|
src/__pycache__/train.cpython-311.pyc
ADDED
|
Binary file (11.4 kB). View file
|
|
|
src/best_predictions.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Find best and worst predictions by per-sample IoU and generate showcase figures."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import matplotlib
|
| 7 |
+
matplotlib.use("Agg")
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import numpy as np
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def iou(pred: np.ndarray, gt: np.ndarray) -> float:
|
| 17 |
+
intersection = np.logical_and(pred, gt).sum()
|
| 18 |
+
union = np.logical_or(pred, gt).sum()
|
| 19 |
+
return float(intersection / union) if union > 0 else 0.0
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def score_all():
|
| 23 |
+
"""Score every test prediction against ground truth. Returns dict of per-class scored lists."""
|
| 24 |
+
with open(PROJECT_ROOT / "data" / "splits" / "test.json") as f:
|
| 25 |
+
test_samples = json.load(f)
|
| 26 |
+
|
| 27 |
+
masks_dir = PROJECT_ROOT / "outputs" / "masks"
|
| 28 |
+
scores = {"taping": [], "cracks": []}
|
| 29 |
+
|
| 30 |
+
for sample in tqdm(test_samples, desc="Scoring predictions"):
|
| 31 |
+
img_stem = Path(sample["image_path"]).stem
|
| 32 |
+
ds = sample["dataset"]
|
| 33 |
+
|
| 34 |
+
candidates = list(masks_dir.glob(f"{img_stem}__*.png"))
|
| 35 |
+
if not candidates:
|
| 36 |
+
continue
|
| 37 |
+
|
| 38 |
+
gt = np.array(Image.open(sample["mask_path"]).convert("L"))
|
| 39 |
+
gt_bin = (gt > 127).astype(np.uint8)
|
| 40 |
+
|
| 41 |
+
best_iou = -1
|
| 42 |
+
best_pred_path = None
|
| 43 |
+
best_prompt = None
|
| 44 |
+
for pred_path in candidates:
|
| 45 |
+
pred = np.array(Image.open(pred_path).convert("L").resize(
|
| 46 |
+
(gt.shape[1], gt.shape[0]), Image.NEAREST))
|
| 47 |
+
pred_bin = (pred > 127).astype(np.uint8)
|
| 48 |
+
score = iou(pred_bin, gt_bin)
|
| 49 |
+
if score > best_iou:
|
| 50 |
+
best_iou = score
|
| 51 |
+
best_pred_path = pred_path
|
| 52 |
+
best_prompt = pred_path.stem.split("__")[1].replace("_", " ")
|
| 53 |
+
|
| 54 |
+
scores[ds].append({
|
| 55 |
+
"image_path": sample["image_path"],
|
| 56 |
+
"mask_path": sample["mask_path"],
|
| 57 |
+
"pred_path": str(best_pred_path),
|
| 58 |
+
"prompt": best_prompt,
|
| 59 |
+
"iou": best_iou,
|
| 60 |
+
"dataset": ds,
|
| 61 |
+
})
|
| 62 |
+
|
| 63 |
+
return scores
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def pick_ranked(scores, n_per_class=3, best=True):
|
| 67 |
+
"""Pick top-N or bottom-N per class by IoU."""
|
| 68 |
+
result = []
|
| 69 |
+
for ds in ["cracks", "taping"]:
|
| 70 |
+
# Filter out zero-IoU (no prediction found) for worst — keep only actual failures
|
| 71 |
+
pool = [s for s in scores[ds] if s["iou"] > 0] if not best else scores[ds]
|
| 72 |
+
ranked = sorted(pool, key=lambda x: x["iou"], reverse=best)
|
| 73 |
+
selected = ranked[:n_per_class]
|
| 74 |
+
result.extend(selected)
|
| 75 |
+
|
| 76 |
+
label = "best" if best else "worst"
|
| 77 |
+
print(f"\n{ds} {label} {n_per_class}:")
|
| 78 |
+
for r in selected:
|
| 79 |
+
print(f" IoU={r['iou']:.4f} {Path(r['image_path']).name} \"{r['prompt']}\"")
|
| 80 |
+
|
| 81 |
+
return result
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def generate_grid(examples, output_path, title=""):
|
| 85 |
+
"""Generate original | ground truth | prediction comparison grid."""
|
| 86 |
+
n = len(examples)
|
| 87 |
+
fig, axes = plt.subplots(n, 3, figsize=(14, 4.0 * n))
|
| 88 |
+
if n == 1:
|
| 89 |
+
axes = [axes]
|
| 90 |
+
|
| 91 |
+
if title:
|
| 92 |
+
fig.suptitle(title, fontsize=16, fontweight="bold", y=0.998)
|
| 93 |
+
|
| 94 |
+
for i, ex in enumerate(examples):
|
| 95 |
+
img = Image.open(ex["image_path"]).convert("RGB")
|
| 96 |
+
gt = Image.open(ex["mask_path"]).convert("L")
|
| 97 |
+
pred = Image.open(ex["pred_path"]).convert("L").resize(
|
| 98 |
+
(gt.size[0], gt.size[1]), Image.NEAREST)
|
| 99 |
+
|
| 100 |
+
label = ex["dataset"].capitalize()
|
| 101 |
+
|
| 102 |
+
axes[i][0].imshow(img)
|
| 103 |
+
axes[i][0].set_title(f"Input — {label}", fontsize=11, fontweight="bold")
|
| 104 |
+
axes[i][0].axis("off")
|
| 105 |
+
|
| 106 |
+
axes[i][1].imshow(gt, cmap="gray", vmin=0, vmax=255)
|
| 107 |
+
axes[i][1].set_title("Ground Truth", fontsize=11)
|
| 108 |
+
axes[i][1].axis("off")
|
| 109 |
+
|
| 110 |
+
axes[i][2].imshow(pred, cmap="gray", vmin=0, vmax=255)
|
| 111 |
+
axes[i][2].set_title(f"Predicted — \"{ex['prompt']}\" (IoU {ex['iou']:.2f})", fontsize=11)
|
| 112 |
+
axes[i][2].axis("off")
|
| 113 |
+
|
| 114 |
+
plt.tight_layout()
|
| 115 |
+
plt.savefig(output_path, dpi=150, bbox_inches="tight", facecolor="white")
|
| 116 |
+
plt.close()
|
| 117 |
+
print(f"Saved → {output_path}")
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
if __name__ == "__main__":
|
| 121 |
+
figures_dir = PROJECT_ROOT / "reports" / "figures"
|
| 122 |
+
scores = score_all()
|
| 123 |
+
|
| 124 |
+
# Best predictions (3 per class)
|
| 125 |
+
best = pick_ranked(scores, n_per_class=3, best=True)
|
| 126 |
+
generate_grid(best, figures_dir / "best_predictions.png",
|
| 127 |
+
title="Best Test-Set Predictions (by IoU)")
|
| 128 |
+
|
| 129 |
+
# Worst predictions (3 per class) — only samples where model actually predicted something
|
| 130 |
+
worst = pick_ranked(scores, n_per_class=3, best=False)
|
| 131 |
+
generate_grid(worst, figures_dir / "failure_cases.png",
|
| 132 |
+
title="Failure Cases — Worst Test-Set Predictions (by IoU)")
|
src/data/__init__.py
ADDED
|
File without changes
|
src/data/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (163 Bytes). View file
|
|
|
src/data/__pycache__/dataset.cpython-311.pyc
ADDED
|
Binary file (6.37 kB). View file
|
|
|
src/data/__pycache__/download.cpython-311.pyc
ADDED
|
Binary file (3.26 kB). View file
|
|
|
src/data/__pycache__/preprocess.cpython-311.pyc
ADDED
|
Binary file (15.8 kB). View file
|
|
|
src/data/dataset.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PyTorch Dataset for CLIPSeg fine-tuning."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import random
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from torch.utils.data import Dataset
|
| 11 |
+
from transformers import CLIPSegProcessor
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class DrywallSegDataset(Dataset):
|
| 15 |
+
"""Dataset that yields (image, mask, prompt) tuples for CLIPSeg."""
|
| 16 |
+
|
| 17 |
+
def __init__(self, split_json: str, processor: CLIPSegProcessor, image_size: int = 352):
|
| 18 |
+
with open(split_json) as f:
|
| 19 |
+
self.records = json.load(f)
|
| 20 |
+
self.processor = processor
|
| 21 |
+
self.image_size = image_size
|
| 22 |
+
|
| 23 |
+
def __len__(self):
|
| 24 |
+
return len(self.records)
|
| 25 |
+
|
| 26 |
+
def __getitem__(self, idx):
|
| 27 |
+
rec = self.records[idx]
|
| 28 |
+
|
| 29 |
+
# Load image
|
| 30 |
+
image = Image.open(rec["image_path"]).convert("RGB")
|
| 31 |
+
|
| 32 |
+
# Load mask and resize to CLIPSeg resolution
|
| 33 |
+
mask = Image.open(rec["mask_path"]).convert("L")
|
| 34 |
+
mask = mask.resize((self.image_size, self.image_size), Image.NEAREST)
|
| 35 |
+
mask_tensor = torch.from_numpy(np.array(mask)).float() / 255.0 # {0.0, 1.0}
|
| 36 |
+
|
| 37 |
+
# Random prompt synonym
|
| 38 |
+
prompt = random.choice(rec["prompts"])
|
| 39 |
+
|
| 40 |
+
# Process through CLIPSeg processor
|
| 41 |
+
inputs = self.processor(
|
| 42 |
+
text=[prompt],
|
| 43 |
+
images=[image],
|
| 44 |
+
return_tensors="pt",
|
| 45 |
+
padding=True,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
return {
|
| 49 |
+
"pixel_values": inputs["pixel_values"].squeeze(0),
|
| 50 |
+
"input_ids": inputs["input_ids"].squeeze(0),
|
| 51 |
+
"attention_mask": inputs["attention_mask"].squeeze(0),
|
| 52 |
+
"labels": mask_tensor,
|
| 53 |
+
"dataset": rec["dataset"],
|
| 54 |
+
"image_path": rec["image_path"],
|
| 55 |
+
"mask_path": rec["mask_path"],
|
| 56 |
+
"prompt": prompt,
|
| 57 |
+
"orig_width": rec["width"],
|
| 58 |
+
"orig_height": rec["height"],
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def collate_fn(batch):
|
| 63 |
+
"""Custom collation: pad input_ids and attention_mask to max length in batch."""
|
| 64 |
+
max_len = max(item["input_ids"].shape[0] for item in batch)
|
| 65 |
+
|
| 66 |
+
pixel_values = torch.stack([item["pixel_values"] for item in batch])
|
| 67 |
+
labels = torch.stack([item["labels"] for item in batch])
|
| 68 |
+
|
| 69 |
+
input_ids = []
|
| 70 |
+
attention_masks = []
|
| 71 |
+
for item in batch:
|
| 72 |
+
ids = item["input_ids"]
|
| 73 |
+
mask = item["attention_mask"]
|
| 74 |
+
pad_len = max_len - ids.shape[0]
|
| 75 |
+
if pad_len > 0:
|
| 76 |
+
ids = torch.cat([ids, torch.zeros(pad_len, dtype=ids.dtype)])
|
| 77 |
+
mask = torch.cat([mask, torch.zeros(pad_len, dtype=mask.dtype)])
|
| 78 |
+
input_ids.append(ids)
|
| 79 |
+
attention_masks.append(mask)
|
| 80 |
+
|
| 81 |
+
return {
|
| 82 |
+
"pixel_values": pixel_values,
|
| 83 |
+
"input_ids": torch.stack(input_ids),
|
| 84 |
+
"attention_mask": torch.stack(attention_masks),
|
| 85 |
+
"labels": labels,
|
| 86 |
+
"dataset": [item["dataset"] for item in batch],
|
| 87 |
+
"image_path": [item["image_path"] for item in batch],
|
| 88 |
+
"mask_path": [item["mask_path"] for item in batch],
|
| 89 |
+
"prompt": [item["prompt"] for item in batch],
|
| 90 |
+
"orig_width": [item["orig_width"] for item in batch],
|
| 91 |
+
"orig_height": [item["orig_height"] for item in batch],
|
| 92 |
+
}
|
src/data/download.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dataset download instructions.
|
| 2 |
+
|
| 3 |
+
Both datasets must be downloaded manually from Roboflow Universe in COCO format.
|
| 4 |
+
The Roboflow API cannot be used because the cracks dataset (cracks-3ii36) has
|
| 5 |
+
no generated versions — the owner never created an exportable version.
|
| 6 |
+
|
| 7 |
+
Download locations:
|
| 8 |
+
- Taping: https://universe.roboflow.com/objectdetect-pu6rn/drywall-join-detect
|
| 9 |
+
→ Export as COCO, place under data/raw/taping/
|
| 10 |
+
- Cracks: https://universe.roboflow.com/fyp-ny1jt/cracks-3ii36
|
| 11 |
+
→ Export as COCO, place under data/raw/cracks/
|
| 12 |
+
|
| 13 |
+
Expected structure after download:
|
| 14 |
+
data/raw/
|
| 15 |
+
├── taping/
|
| 16 |
+
│ ├── train/
|
| 17 |
+
│ │ ├── _annotations.coco.json
|
| 18 |
+
│ │ └── *.jpg
|
| 19 |
+
│ └── valid/
|
| 20 |
+
│ ├── _annotations.coco.json
|
| 21 |
+
│ └── *.jpg
|
| 22 |
+
└── cracks/
|
| 23 |
+
└── train/
|
| 24 |
+
├── _annotations.coco.json
|
| 25 |
+
└── *.jpg
|
| 26 |
+
"""
|
src/data/preprocess.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Inspect annotations, generate masks, create train/val/test splits."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import random
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from pycocotools.coco import COCO
|
| 10 |
+
from pycocotools import mask as mask_utils
|
| 11 |
+
|
| 12 |
+
RAW_DIR = Path(__file__).resolve().parents[2] / "data" / "raw"
|
| 13 |
+
PROCESSED_DIR = Path(__file__).resolve().parents[2] / "data" / "processed"
|
| 14 |
+
SPLITS_DIR = Path(__file__).resolve().parents[2] / "data" / "splits"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def inspect_dataset(coco_json_path: str) -> dict:
|
| 18 |
+
"""Check what annotation types exist in a COCO JSON file."""
|
| 19 |
+
with open(coco_json_path) as f:
|
| 20 |
+
data = json.load(f)
|
| 21 |
+
|
| 22 |
+
total = len(data.get("annotations", []))
|
| 23 |
+
has_seg = 0
|
| 24 |
+
has_bbox_only = 0
|
| 25 |
+
|
| 26 |
+
for ann in data.get("annotations", []):
|
| 27 |
+
seg = ann.get("segmentation")
|
| 28 |
+
if seg and isinstance(seg, list) and len(seg) > 0 and len(seg[0]) >= 6:
|
| 29 |
+
has_seg += 1
|
| 30 |
+
elif seg and isinstance(seg, dict): # RLE format
|
| 31 |
+
has_seg += 1
|
| 32 |
+
else:
|
| 33 |
+
has_bbox_only += 1
|
| 34 |
+
|
| 35 |
+
return {
|
| 36 |
+
"total_annotations": total,
|
| 37 |
+
"total_images": len(data.get("images", [])),
|
| 38 |
+
"has_segmentation": has_seg,
|
| 39 |
+
"has_bbox_only": has_bbox_only,
|
| 40 |
+
"annotation_type": "segmentation" if has_seg > has_bbox_only else "bbox_only",
|
| 41 |
+
"categories": [c["name"] for c in data.get("categories", [])],
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def render_masks_from_coco(coco_json_path: str, images_dir: str, output_dir: str) -> list[dict]:
|
| 46 |
+
"""Render binary masks from COCO polygon/RLE annotations.
|
| 47 |
+
|
| 48 |
+
Returns list of {image_path, mask_path, image_id, width, height}.
|
| 49 |
+
"""
|
| 50 |
+
output_dir = Path(output_dir)
|
| 51 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 52 |
+
|
| 53 |
+
coco = COCO(coco_json_path)
|
| 54 |
+
records = []
|
| 55 |
+
|
| 56 |
+
for img_id in sorted(coco.getImgIds()):
|
| 57 |
+
img_info = coco.loadImgs(img_id)[0]
|
| 58 |
+
h, w = img_info["height"], img_info["width"]
|
| 59 |
+
|
| 60 |
+
ann_ids = coco.getAnnIds(imgIds=img_id)
|
| 61 |
+
anns = coco.loadAnns(ann_ids)
|
| 62 |
+
|
| 63 |
+
if not anns:
|
| 64 |
+
continue
|
| 65 |
+
|
| 66 |
+
# Merge all annotations into one binary mask
|
| 67 |
+
combined = np.zeros((h, w), dtype=np.uint8)
|
| 68 |
+
for ann in anns:
|
| 69 |
+
seg = ann.get("segmentation")
|
| 70 |
+
# Skip annotations with empty or invalid segmentation
|
| 71 |
+
if not seg:
|
| 72 |
+
continue
|
| 73 |
+
if isinstance(seg, list) and (len(seg) == 0 or (len(seg) > 0 and isinstance(seg[0], list) and len(seg[0]) < 6)):
|
| 74 |
+
continue
|
| 75 |
+
if isinstance(seg, list) and len(seg) > 0 and not isinstance(seg[0], list) and len(seg) < 6:
|
| 76 |
+
continue
|
| 77 |
+
try:
|
| 78 |
+
rle = coco.annToRLE(ann)
|
| 79 |
+
m = mask_utils.decode(rle)
|
| 80 |
+
combined = np.maximum(combined, m)
|
| 81 |
+
except (IndexError, ValueError):
|
| 82 |
+
# Fall back to bbox if segmentation decode fails
|
| 83 |
+
if "bbox" in ann:
|
| 84 |
+
x, y, bw, bh = [int(v) for v in ann["bbox"]]
|
| 85 |
+
combined[y:y+bh, x:x+bw] = 1
|
| 86 |
+
|
| 87 |
+
mask_img = Image.fromarray(combined * 255, mode="L")
|
| 88 |
+
mask_name = Path(img_info["file_name"]).stem + "_mask.png"
|
| 89 |
+
mask_path = output_dir / mask_name
|
| 90 |
+
mask_img.save(mask_path)
|
| 91 |
+
|
| 92 |
+
image_path = Path(images_dir) / img_info["file_name"]
|
| 93 |
+
records.append({
|
| 94 |
+
"image_path": str(image_path),
|
| 95 |
+
"mask_path": str(mask_path),
|
| 96 |
+
"image_id": img_id,
|
| 97 |
+
"width": w,
|
| 98 |
+
"height": h,
|
| 99 |
+
})
|
| 100 |
+
|
| 101 |
+
return records
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def render_masks_from_bboxes(coco_json_path: str, images_dir: str, output_dir: str) -> list[dict]:
|
| 105 |
+
"""Create filled-rectangle masks from bounding boxes (fallback when no segmentation)."""
|
| 106 |
+
output_dir = Path(output_dir)
|
| 107 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 108 |
+
|
| 109 |
+
with open(coco_json_path) as f:
|
| 110 |
+
data = json.load(f)
|
| 111 |
+
|
| 112 |
+
img_lookup = {img["id"]: img for img in data["images"]}
|
| 113 |
+
anns_by_img: dict[int, list] = {}
|
| 114 |
+
for ann in data["annotations"]:
|
| 115 |
+
anns_by_img.setdefault(ann["image_id"], []).append(ann)
|
| 116 |
+
|
| 117 |
+
records = []
|
| 118 |
+
for img_id, img_info in sorted(img_lookup.items()):
|
| 119 |
+
anns = anns_by_img.get(img_id, [])
|
| 120 |
+
if not anns:
|
| 121 |
+
continue
|
| 122 |
+
|
| 123 |
+
h, w = img_info["height"], img_info["width"]
|
| 124 |
+
combined = np.zeros((h, w), dtype=np.uint8)
|
| 125 |
+
|
| 126 |
+
for ann in anns:
|
| 127 |
+
x, y, bw, bh = [int(v) for v in ann["bbox"]]
|
| 128 |
+
combined[y:y+bh, x:x+bw] = 1
|
| 129 |
+
|
| 130 |
+
mask_img = Image.fromarray(combined * 255, mode="L")
|
| 131 |
+
mask_name = Path(img_info["file_name"]).stem + "_mask.png"
|
| 132 |
+
mask_path = output_dir / mask_name
|
| 133 |
+
mask_img.save(mask_path)
|
| 134 |
+
|
| 135 |
+
image_path = Path(images_dir) / img_info["file_name"]
|
| 136 |
+
records.append({
|
| 137 |
+
"image_path": str(image_path),
|
| 138 |
+
"mask_path": str(mask_path),
|
| 139 |
+
"image_id": img_id,
|
| 140 |
+
"width": w,
|
| 141 |
+
"height": h,
|
| 142 |
+
})
|
| 143 |
+
|
| 144 |
+
return records
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def find_coco_json(dataset_dir: Path) -> tuple[str, str] | None:
|
| 148 |
+
"""Find the COCO JSON and images directory in a Roboflow download."""
|
| 149 |
+
for split in ["train", "valid", "test"]:
|
| 150 |
+
json_path = dataset_dir / split / "_annotations.coco.json"
|
| 151 |
+
if json_path.exists():
|
| 152 |
+
return str(json_path), str(dataset_dir / split)
|
| 153 |
+
# Single-folder layout
|
| 154 |
+
for json_path in dataset_dir.rglob("_annotations.coco.json"):
|
| 155 |
+
return str(json_path), str(json_path.parent)
|
| 156 |
+
return None
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def process_dataset(name: str, dataset_dir: Path, prompt_synonyms: list[str]) -> list[dict]:
|
| 160 |
+
"""Process a single dataset: inspect, render masks, return records with prompts."""
|
| 161 |
+
records = []
|
| 162 |
+
mask_dir = PROCESSED_DIR / name / "masks"
|
| 163 |
+
|
| 164 |
+
# Process each split folder (train/valid/test from Roboflow)
|
| 165 |
+
for split_dir in sorted(dataset_dir.iterdir()):
|
| 166 |
+
if not split_dir.is_dir():
|
| 167 |
+
continue
|
| 168 |
+
json_path = split_dir / "_annotations.coco.json"
|
| 169 |
+
if not json_path.exists():
|
| 170 |
+
continue
|
| 171 |
+
|
| 172 |
+
print(f"\n Processing {name}/{split_dir.name}...")
|
| 173 |
+
info = inspect_dataset(str(json_path))
|
| 174 |
+
print(f" Images: {info['total_images']}, Annotations: {info['total_annotations']}")
|
| 175 |
+
print(f" Type: {info['annotation_type']}, Categories: {info['categories']}")
|
| 176 |
+
|
| 177 |
+
split_mask_dir = mask_dir / split_dir.name
|
| 178 |
+
if info["annotation_type"] == "segmentation":
|
| 179 |
+
split_records = render_masks_from_coco(
|
| 180 |
+
str(json_path), str(split_dir), str(split_mask_dir)
|
| 181 |
+
)
|
| 182 |
+
else:
|
| 183 |
+
print(f" WARNING: bbox-only annotations, using filled rectangles")
|
| 184 |
+
split_records = render_masks_from_bboxes(
|
| 185 |
+
str(json_path), str(split_dir), str(split_mask_dir)
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
for r in split_records:
|
| 189 |
+
r["dataset"] = name
|
| 190 |
+
r["prompts"] = prompt_synonyms
|
| 191 |
+
records.extend(split_records)
|
| 192 |
+
|
| 193 |
+
return records
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def create_splits(records: list[dict], ratios: tuple = (0.70, 0.15, 0.15), seed: int = 42):
|
| 197 |
+
"""Split records into train/val/test, stratified by dataset."""
|
| 198 |
+
random.seed(seed)
|
| 199 |
+
|
| 200 |
+
by_dataset: dict[str, list] = {}
|
| 201 |
+
for r in records:
|
| 202 |
+
by_dataset.setdefault(r["dataset"], []).append(r)
|
| 203 |
+
|
| 204 |
+
train, val, test = [], [], []
|
| 205 |
+
for name, recs in by_dataset.items():
|
| 206 |
+
random.shuffle(recs)
|
| 207 |
+
n = len(recs)
|
| 208 |
+
n_train = int(n * ratios[0])
|
| 209 |
+
n_val = int(n * ratios[1])
|
| 210 |
+
train.extend(recs[:n_train])
|
| 211 |
+
val.extend(recs[n_train:n_train + n_val])
|
| 212 |
+
test.extend(recs[n_train + n_val:])
|
| 213 |
+
|
| 214 |
+
random.shuffle(train)
|
| 215 |
+
random.shuffle(val)
|
| 216 |
+
random.shuffle(test)
|
| 217 |
+
|
| 218 |
+
SPLITS_DIR.mkdir(parents=True, exist_ok=True)
|
| 219 |
+
for split_name, split_data in [("train", train), ("val", val), ("test", test)]:
|
| 220 |
+
path = SPLITS_DIR / f"{split_name}.json"
|
| 221 |
+
with open(path, "w") as f:
|
| 222 |
+
json.dump(split_data, f, indent=2)
|
| 223 |
+
print(f" {split_name}: {len(split_data)} samples -> {path}")
|
| 224 |
+
|
| 225 |
+
return {"train": train, "val": val, "test": test}
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def run(config: dict):
|
| 229 |
+
"""Run full preprocessing pipeline."""
|
| 230 |
+
synonyms = config["data"]["prompt_synonyms"]
|
| 231 |
+
ratios = tuple(config["data"]["split_ratios"])
|
| 232 |
+
|
| 233 |
+
all_records = []
|
| 234 |
+
for name in ["taping", "cracks"]:
|
| 235 |
+
dataset_dir = RAW_DIR / name
|
| 236 |
+
if not dataset_dir.exists():
|
| 237 |
+
print(f"WARNING: {dataset_dir} not found, skipping {name}")
|
| 238 |
+
continue
|
| 239 |
+
print(f"\n{'='*60}")
|
| 240 |
+
print(f"Processing dataset: {name}")
|
| 241 |
+
print(f"{'='*60}")
|
| 242 |
+
records = process_dataset(name, dataset_dir, synonyms[name])
|
| 243 |
+
all_records.extend(records)
|
| 244 |
+
print(f" Total records for {name}: {len(records)}")
|
| 245 |
+
|
| 246 |
+
print(f"\n{'='*60}")
|
| 247 |
+
print(f"Creating splits (total: {len(all_records)} records)")
|
| 248 |
+
print(f"{'='*60}")
|
| 249 |
+
splits = create_splits(all_records, ratios=ratios, seed=config["seed"])
|
| 250 |
+
return splits
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
if __name__ == "__main__":
|
| 254 |
+
import yaml
|
| 255 |
+
config_path = Path(__file__).resolve().parents[2] / "configs" / "train_config.yaml"
|
| 256 |
+
with open(config_path) as f:
|
| 257 |
+
config = yaml.safe_load(f)
|
| 258 |
+
run(config)
|
src/evaluate.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Evaluate trained CLIPSeg model and generate prediction masks + visuals."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import time
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import matplotlib
|
| 8 |
+
matplotlib.use("Agg")
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import yaml
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from torch.utils.data import DataLoader
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
from src.data.dataset import DrywallSegDataset, collate_fn
|
| 18 |
+
from src.model.clipseg_wrapper import load_model_and_processor
|
| 19 |
+
from src.train import compute_metrics, get_device
|
| 20 |
+
|
| 21 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def evaluate(config_path: str | None = None):
|
| 25 |
+
config_path = config_path or str(PROJECT_ROOT / "configs" / "train_config.yaml")
|
| 26 |
+
with open(config_path) as f:
|
| 27 |
+
config = yaml.safe_load(f)
|
| 28 |
+
|
| 29 |
+
device = get_device()
|
| 30 |
+
threshold = config["evaluation"]["threshold"]
|
| 31 |
+
|
| 32 |
+
# Load model with best checkpoint
|
| 33 |
+
model, processor = load_model_and_processor(config["model"]["name"], config["model"]["freeze_backbone"])
|
| 34 |
+
ckpt_path = PROJECT_ROOT / "outputs" / "checkpoints" / "best_model.pt"
|
| 35 |
+
model.load_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
|
| 36 |
+
model = model.to(device)
|
| 37 |
+
model.eval()
|
| 38 |
+
|
| 39 |
+
# Model size
|
| 40 |
+
model_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024)
|
| 41 |
+
|
| 42 |
+
# Test data
|
| 43 |
+
splits_dir = PROJECT_ROOT / "data" / "splits"
|
| 44 |
+
test_ds = DrywallSegDataset(str(splits_dir / "test.json"), processor, config["data"]["image_size"])
|
| 45 |
+
test_loader = DataLoader(test_ds, batch_size=config["training"]["batch_size"], shuffle=False,
|
| 46 |
+
collate_fn=collate_fn, num_workers=0)
|
| 47 |
+
|
| 48 |
+
# Run evaluation
|
| 49 |
+
masks_dir = PROJECT_ROOT / "outputs" / "masks"
|
| 50 |
+
masks_dir.mkdir(parents=True, exist_ok=True)
|
| 51 |
+
|
| 52 |
+
all_metrics = {"taping": {"miou": [], "dice": []}, "cracks": {"miou": [], "dice": []}}
|
| 53 |
+
inference_times = []
|
| 54 |
+
visual_examples = [] # Collect for visualization
|
| 55 |
+
total_samples = 0
|
| 56 |
+
|
| 57 |
+
with torch.no_grad():
|
| 58 |
+
for batch in tqdm(test_loader, desc="Evaluating"):
|
| 59 |
+
pixel_values = batch["pixel_values"].to(device)
|
| 60 |
+
input_ids = batch["input_ids"].to(device)
|
| 61 |
+
attention_mask = batch["attention_mask"].to(device)
|
| 62 |
+
labels = batch["labels"].to(device)
|
| 63 |
+
|
| 64 |
+
t0 = time.time()
|
| 65 |
+
outputs = model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask)
|
| 66 |
+
inference_times.append((time.time() - t0) / pixel_values.size(0))
|
| 67 |
+
|
| 68 |
+
logits = outputs.logits
|
| 69 |
+
metrics = compute_metrics(logits, labels, threshold)
|
| 70 |
+
preds = (torch.sigmoid(logits) > threshold).cpu().numpy().astype(np.uint8)
|
| 71 |
+
|
| 72 |
+
for i in range(pixel_values.size(0)):
|
| 73 |
+
ds_name = batch["dataset"][i]
|
| 74 |
+
all_metrics[ds_name]["miou"].append(metrics["miou"])
|
| 75 |
+
all_metrics[ds_name]["dice"].append(metrics["dice"])
|
| 76 |
+
|
| 77 |
+
# Save prediction mask at original resolution
|
| 78 |
+
orig_w, orig_h = batch["orig_width"][i], batch["orig_height"][i]
|
| 79 |
+
pred_mask = Image.fromarray(preds[i] * 255, mode="L")
|
| 80 |
+
pred_mask = pred_mask.resize((orig_w, orig_h), Image.NEAREST)
|
| 81 |
+
|
| 82 |
+
prompt_slug = batch["prompt"][i].replace(" ", "_")
|
| 83 |
+
img_stem = Path(batch["image_path"][i]).stem
|
| 84 |
+
mask_filename = f"{img_stem}__{prompt_slug}.png"
|
| 85 |
+
pred_mask.save(masks_dir / mask_filename)
|
| 86 |
+
|
| 87 |
+
total_samples += 1
|
| 88 |
+
|
| 89 |
+
# Collect visual examples
|
| 90 |
+
if len(visual_examples) < config["evaluation"]["num_visual_examples"]:
|
| 91 |
+
visual_examples.append({
|
| 92 |
+
"image_path": batch["image_path"][i],
|
| 93 |
+
"mask_path": batch["mask_path"][i],
|
| 94 |
+
"pred_mask": preds[i],
|
| 95 |
+
"prompt": batch["prompt"][i],
|
| 96 |
+
"dataset": ds_name,
|
| 97 |
+
})
|
| 98 |
+
|
| 99 |
+
# Aggregate metrics
|
| 100 |
+
results = {"per_class": {}, "overall": {}}
|
| 101 |
+
all_miou, all_dice = [], []
|
| 102 |
+
for ds_name in ["taping", "cracks"]:
|
| 103 |
+
m = all_metrics[ds_name]
|
| 104 |
+
if m["miou"]:
|
| 105 |
+
results["per_class"][ds_name] = {
|
| 106 |
+
"miou": round(float(np.mean(m["miou"])), 4),
|
| 107 |
+
"dice": round(float(np.mean(m["dice"])), 4),
|
| 108 |
+
"samples": len(m["miou"]),
|
| 109 |
+
}
|
| 110 |
+
all_miou.extend(m["miou"])
|
| 111 |
+
all_dice.extend(m["dice"])
|
| 112 |
+
|
| 113 |
+
results["overall"] = {
|
| 114 |
+
"miou": round(float(np.mean(all_miou)), 4) if all_miou else 0,
|
| 115 |
+
"dice": round(float(np.mean(all_dice)), 4) if all_dice else 0,
|
| 116 |
+
"total_samples": total_samples,
|
| 117 |
+
}
|
| 118 |
+
results["runtime"] = {
|
| 119 |
+
"avg_inference_ms": round(float(np.mean(inference_times)) * 1000, 1),
|
| 120 |
+
"model_size_mb": round(model_size_mb, 1),
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
# Save results
|
| 124 |
+
log_dir = PROJECT_ROOT / "outputs" / "logs"
|
| 125 |
+
log_dir.mkdir(parents=True, exist_ok=True)
|
| 126 |
+
with open(log_dir / "test_results.json", "w") as f:
|
| 127 |
+
json.dump(results, f, indent=2)
|
| 128 |
+
|
| 129 |
+
print(f"\n{'='*60}")
|
| 130 |
+
print(f"Test Results")
|
| 131 |
+
print(f"{'='*60}")
|
| 132 |
+
for ds_name, m in results["per_class"].items():
|
| 133 |
+
print(f" {ds_name:>10s}: mIoU={m['miou']:.4f} Dice={m['dice']:.4f} (n={m['samples']})")
|
| 134 |
+
print(f" {'overall':>10s}: mIoU={results['overall']['miou']:.4f} Dice={results['overall']['dice']:.4f}")
|
| 135 |
+
print(f" Avg inference: {results['runtime']['avg_inference_ms']:.1f} ms/image")
|
| 136 |
+
print(f" Model size: {results['runtime']['model_size_mb']:.1f} MB")
|
| 137 |
+
|
| 138 |
+
# Generate visual comparison figures
|
| 139 |
+
_generate_visuals(visual_examples, PROJECT_ROOT / "reports" / "figures")
|
| 140 |
+
|
| 141 |
+
return results
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def _generate_visuals(examples: list[dict], output_dir: Path):
|
| 145 |
+
"""Generate original | GT | prediction comparison figures."""
|
| 146 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 147 |
+
|
| 148 |
+
if not examples:
|
| 149 |
+
return
|
| 150 |
+
|
| 151 |
+
fig, axes = plt.subplots(len(examples), 3, figsize=(12, 4 * len(examples)))
|
| 152 |
+
if len(examples) == 1:
|
| 153 |
+
axes = [axes]
|
| 154 |
+
|
| 155 |
+
for i, ex in enumerate(examples):
|
| 156 |
+
img = Image.open(ex["image_path"]).convert("RGB")
|
| 157 |
+
gt = Image.open(ex["mask_path"]).convert("L")
|
| 158 |
+
pred = Image.fromarray(ex["pred_mask"] * 255, mode="L")
|
| 159 |
+
|
| 160 |
+
axes[i][0].imshow(img)
|
| 161 |
+
axes[i][0].set_title(f"Original ({ex['dataset']})")
|
| 162 |
+
axes[i][0].axis("off")
|
| 163 |
+
|
| 164 |
+
axes[i][1].imshow(gt, cmap="gray", vmin=0, vmax=255)
|
| 165 |
+
axes[i][1].set_title("Ground Truth")
|
| 166 |
+
axes[i][1].axis("off")
|
| 167 |
+
|
| 168 |
+
axes[i][2].imshow(pred, cmap="gray", vmin=0, vmax=255)
|
| 169 |
+
axes[i][2].set_title(f"Prediction: \"{ex['prompt']}\"")
|
| 170 |
+
axes[i][2].axis("off")
|
| 171 |
+
|
| 172 |
+
plt.tight_layout()
|
| 173 |
+
plt.savefig(output_dir / "visual_comparison.png", dpi=150, bbox_inches="tight")
|
| 174 |
+
plt.close()
|
| 175 |
+
print(f"Saved visual comparison to {output_dir / 'visual_comparison.png'}")
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
if __name__ == "__main__":
|
| 179 |
+
evaluate()
|
src/model/__init__.py
ADDED
|
File without changes
|
src/model/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (164 Bytes). View file
|
|
|
src/model/__pycache__/clipseg_wrapper.cpython-311.pyc
ADDED
|
Binary file (1.89 kB). View file
|
|
|
src/model/__pycache__/losses.cpython-311.pyc
ADDED
|
Binary file (3.34 kB). View file
|
|
|
src/model/clipseg_wrapper.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CLIPSeg model loading and freezing utilities."""
|
| 2 |
+
|
| 3 |
+
from transformers import CLIPSegForImageSegmentation, CLIPSegProcessor
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def load_model_and_processor(model_name: str = "CIDAS/clipseg-rd64-refined", freeze_backbone: bool = True):
|
| 7 |
+
"""Load CLIPSeg model and processor, optionally freezing the backbone."""
|
| 8 |
+
model = CLIPSegForImageSegmentation.from_pretrained(model_name)
|
| 9 |
+
processor = CLIPSegProcessor.from_pretrained(model_name)
|
| 10 |
+
|
| 11 |
+
if freeze_backbone:
|
| 12 |
+
trainable, frozen = 0, 0
|
| 13 |
+
for name, param in model.named_parameters():
|
| 14 |
+
if "decoder" in name:
|
| 15 |
+
param.requires_grad = True
|
| 16 |
+
trainable += param.numel()
|
| 17 |
+
else:
|
| 18 |
+
param.requires_grad = False
|
| 19 |
+
frozen += param.numel()
|
| 20 |
+
print(f"Parameters — trainable (decoder): {trainable:,} | frozen (backbone): {frozen:,}")
|
| 21 |
+
else:
|
| 22 |
+
trainable = sum(p.numel() for p in model.parameters())
|
| 23 |
+
print(f"Parameters — all trainable: {trainable:,}")
|
| 24 |
+
|
| 25 |
+
return model, processor
|
src/model/losses.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Custom loss functions for segmentation."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DiceLoss(nn.Module):
|
| 9 |
+
"""Soft Dice loss operating on logits."""
|
| 10 |
+
|
| 11 |
+
def __init__(self, smooth: float = 1.0):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.smooth = smooth
|
| 14 |
+
|
| 15 |
+
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
| 16 |
+
probs = torch.sigmoid(logits)
|
| 17 |
+
probs_flat = probs.view(probs.size(0), -1)
|
| 18 |
+
targets_flat = targets.view(targets.size(0), -1)
|
| 19 |
+
|
| 20 |
+
intersection = (probs_flat * targets_flat).sum(dim=1)
|
| 21 |
+
union = probs_flat.sum(dim=1) + targets_flat.sum(dim=1)
|
| 22 |
+
|
| 23 |
+
dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
|
| 24 |
+
return 1.0 - dice.mean()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class BCEDiceLoss(nn.Module):
|
| 28 |
+
"""Weighted combination of BCE and Dice loss."""
|
| 29 |
+
|
| 30 |
+
def __init__(self, bce_weight: float = 0.5, dice_weight: float = 0.5):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.bce_weight = bce_weight
|
| 33 |
+
self.dice_weight = dice_weight
|
| 34 |
+
self.bce = nn.BCEWithLogitsLoss()
|
| 35 |
+
self.dice = DiceLoss()
|
| 36 |
+
|
| 37 |
+
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
| 38 |
+
return self.bce_weight * self.bce(logits, targets) + self.dice_weight * self.dice(logits, targets)
|
src/predict.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Standalone single-image inference for CLIPSeg."""
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import yaml
|
| 9 |
+
from PIL import Image
|
| 10 |
+
|
| 11 |
+
from src.model.clipseg_wrapper import load_model_and_processor
|
| 12 |
+
from src.train import get_device
|
| 13 |
+
|
| 14 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def predict(image_path: str, prompt: str, config_path: str | None = None, output_path: str | None = None):
|
| 18 |
+
config_path = config_path or str(PROJECT_ROOT / "configs" / "train_config.yaml")
|
| 19 |
+
with open(config_path) as f:
|
| 20 |
+
config = yaml.safe_load(f)
|
| 21 |
+
|
| 22 |
+
device = get_device()
|
| 23 |
+
model, processor = load_model_and_processor(config["model"]["name"], config["model"]["freeze_backbone"])
|
| 24 |
+
ckpt = PROJECT_ROOT / "outputs" / "checkpoints" / "best_model.pt"
|
| 25 |
+
model.load_state_dict(torch.load(ckpt, map_location="cpu", weights_only=True))
|
| 26 |
+
model = model.to(device).eval()
|
| 27 |
+
|
| 28 |
+
image = Image.open(image_path).convert("RGB")
|
| 29 |
+
orig_w, orig_h = image.size
|
| 30 |
+
|
| 31 |
+
inputs = processor(text=[prompt], images=[image], return_tensors="pt", padding=True)
|
| 32 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 33 |
+
|
| 34 |
+
with torch.no_grad():
|
| 35 |
+
logits = model(**inputs).logits
|
| 36 |
+
|
| 37 |
+
pred = (torch.sigmoid(logits[0]) > config["evaluation"]["threshold"]).cpu().numpy().astype(np.uint8)
|
| 38 |
+
mask = Image.fromarray(pred * 255, mode="L").resize((orig_w, orig_h), Image.NEAREST)
|
| 39 |
+
|
| 40 |
+
if output_path is None:
|
| 41 |
+
stem = Path(image_path).stem
|
| 42 |
+
slug = prompt.replace(" ", "_")
|
| 43 |
+
output_path = str(PROJECT_ROOT / "outputs" / "masks" / f"{stem}__{slug}.png")
|
| 44 |
+
|
| 45 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 46 |
+
mask.save(output_path)
|
| 47 |
+
print(f"Saved mask to {output_path}")
|
| 48 |
+
return mask
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
if __name__ == "__main__":
|
| 52 |
+
parser = argparse.ArgumentParser()
|
| 53 |
+
parser.add_argument("image", help="Path to input image")
|
| 54 |
+
parser.add_argument("prompt", help="Text prompt, e.g. 'segment crack'")
|
| 55 |
+
parser.add_argument("--output", help="Output mask path")
|
| 56 |
+
args = parser.parse_args()
|
| 57 |
+
predict(args.image, args.prompt, output_path=args.output)
|
src/train.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Training loop for CLIPSeg fine-tuning."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import time
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import yaml
|
| 10 |
+
from torch.optim import AdamW
|
| 11 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 12 |
+
from torch.utils.data import DataLoader
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
from src.data.dataset import DrywallSegDataset, collate_fn
|
| 16 |
+
from src.model.clipseg_wrapper import load_model_and_processor
|
| 17 |
+
from src.model.losses import BCEDiceLoss
|
| 18 |
+
|
| 19 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def compute_metrics(logits: torch.Tensor, targets: torch.Tensor, threshold: float = 0.5):
|
| 23 |
+
"""Compute mIoU and Dice for a batch."""
|
| 24 |
+
preds = (torch.sigmoid(logits) > threshold).float()
|
| 25 |
+
targets = (targets > 0.5).float()
|
| 26 |
+
|
| 27 |
+
intersection = (preds * targets).sum(dim=(1, 2))
|
| 28 |
+
union = preds.sum(dim=(1, 2)) + targets.sum(dim=(1, 2)) - intersection
|
| 29 |
+
iou = (intersection + 1e-6) / (union + 1e-6)
|
| 30 |
+
|
| 31 |
+
dice = (2 * intersection + 1e-6) / (preds.sum(dim=(1, 2)) + targets.sum(dim=(1, 2)) + 1e-6)
|
| 32 |
+
|
| 33 |
+
return {"miou": iou.mean().item(), "dice": dice.mean().item()}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_device():
|
| 37 |
+
"""Select best available device."""
|
| 38 |
+
if torch.backends.mps.is_available():
|
| 39 |
+
return torch.device("mps")
|
| 40 |
+
if torch.cuda.is_available():
|
| 41 |
+
return torch.device("cuda")
|
| 42 |
+
return torch.device("cpu")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def train(config_path: str | None = None):
|
| 46 |
+
config_path = config_path or str(PROJECT_ROOT / "configs" / "train_config.yaml")
|
| 47 |
+
with open(config_path) as f:
|
| 48 |
+
config = yaml.safe_load(f)
|
| 49 |
+
|
| 50 |
+
# Seed
|
| 51 |
+
seed = config["seed"]
|
| 52 |
+
torch.manual_seed(seed)
|
| 53 |
+
np.random.seed(seed)
|
| 54 |
+
|
| 55 |
+
device = get_device()
|
| 56 |
+
print(f"Device: {device}")
|
| 57 |
+
|
| 58 |
+
# Model
|
| 59 |
+
model, processor = load_model_and_processor(
|
| 60 |
+
config["model"]["name"],
|
| 61 |
+
config["model"]["freeze_backbone"],
|
| 62 |
+
)
|
| 63 |
+
model = model.to(device)
|
| 64 |
+
|
| 65 |
+
# Data
|
| 66 |
+
splits_dir = PROJECT_ROOT / "data" / "splits"
|
| 67 |
+
train_ds = DrywallSegDataset(str(splits_dir / "train.json"), processor, config["data"]["image_size"])
|
| 68 |
+
val_ds = DrywallSegDataset(str(splits_dir / "val.json"), processor, config["data"]["image_size"])
|
| 69 |
+
|
| 70 |
+
tc = config["training"]
|
| 71 |
+
train_loader = DataLoader(train_ds, batch_size=tc["batch_size"], shuffle=True,
|
| 72 |
+
collate_fn=collate_fn, num_workers=tc["num_workers"])
|
| 73 |
+
val_loader = DataLoader(val_ds, batch_size=tc["batch_size"], shuffle=False,
|
| 74 |
+
collate_fn=collate_fn, num_workers=tc["num_workers"])
|
| 75 |
+
|
| 76 |
+
# Loss, optimizer, scheduler
|
| 77 |
+
criterion = BCEDiceLoss(tc["bce_weight"], tc["dice_weight"])
|
| 78 |
+
optimizer = AdamW(
|
| 79 |
+
[p for p in model.parameters() if p.requires_grad],
|
| 80 |
+
lr=tc["lr"],
|
| 81 |
+
weight_decay=tc["weight_decay"],
|
| 82 |
+
)
|
| 83 |
+
scheduler = CosineAnnealingLR(optimizer, T_max=tc["epochs"])
|
| 84 |
+
|
| 85 |
+
# Training state
|
| 86 |
+
best_miou = 0.0
|
| 87 |
+
patience_counter = 0
|
| 88 |
+
history = {"train_loss": [], "val_loss": [], "val_miou": [], "val_dice": []}
|
| 89 |
+
ckpt_dir = PROJECT_ROOT / "outputs" / "checkpoints"
|
| 90 |
+
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
| 91 |
+
log_dir = PROJECT_ROOT / "outputs" / "logs"
|
| 92 |
+
log_dir.mkdir(parents=True, exist_ok=True)
|
| 93 |
+
|
| 94 |
+
start_time = time.time()
|
| 95 |
+
|
| 96 |
+
for epoch in range(1, tc["epochs"] + 1):
|
| 97 |
+
# ---- Train ----
|
| 98 |
+
model.train()
|
| 99 |
+
train_losses = []
|
| 100 |
+
for batch in tqdm(train_loader, desc=f"Epoch {epoch}/{tc['epochs']} [train]", leave=False):
|
| 101 |
+
pixel_values = batch["pixel_values"].to(device)
|
| 102 |
+
input_ids = batch["input_ids"].to(device)
|
| 103 |
+
attention_mask = batch["attention_mask"].to(device)
|
| 104 |
+
labels = batch["labels"].to(device)
|
| 105 |
+
|
| 106 |
+
outputs = model(
|
| 107 |
+
pixel_values=pixel_values,
|
| 108 |
+
input_ids=input_ids,
|
| 109 |
+
attention_mask=attention_mask,
|
| 110 |
+
)
|
| 111 |
+
logits = outputs.logits
|
| 112 |
+
loss = criterion(logits, labels)
|
| 113 |
+
|
| 114 |
+
optimizer.zero_grad()
|
| 115 |
+
loss.backward()
|
| 116 |
+
optimizer.step()
|
| 117 |
+
train_losses.append(loss.item())
|
| 118 |
+
|
| 119 |
+
scheduler.step()
|
| 120 |
+
avg_train_loss = np.mean(train_losses)
|
| 121 |
+
|
| 122 |
+
# ---- Validate ----
|
| 123 |
+
model.eval()
|
| 124 |
+
val_losses, val_mious, val_dices = [], [], []
|
| 125 |
+
with torch.no_grad():
|
| 126 |
+
for batch in tqdm(val_loader, desc=f"Epoch {epoch}/{tc['epochs']} [val]", leave=False):
|
| 127 |
+
pixel_values = batch["pixel_values"].to(device)
|
| 128 |
+
input_ids = batch["input_ids"].to(device)
|
| 129 |
+
attention_mask = batch["attention_mask"].to(device)
|
| 130 |
+
labels = batch["labels"].to(device)
|
| 131 |
+
|
| 132 |
+
outputs = model(
|
| 133 |
+
pixel_values=pixel_values,
|
| 134 |
+
input_ids=input_ids,
|
| 135 |
+
attention_mask=attention_mask,
|
| 136 |
+
)
|
| 137 |
+
logits = outputs.logits
|
| 138 |
+
loss = criterion(logits, labels)
|
| 139 |
+
metrics = compute_metrics(logits, labels)
|
| 140 |
+
|
| 141 |
+
val_losses.append(loss.item())
|
| 142 |
+
val_mious.append(metrics["miou"])
|
| 143 |
+
val_dices.append(metrics["dice"])
|
| 144 |
+
|
| 145 |
+
avg_val_loss = np.mean(val_losses)
|
| 146 |
+
avg_val_miou = np.mean(val_mious)
|
| 147 |
+
avg_val_dice = np.mean(val_dices)
|
| 148 |
+
|
| 149 |
+
history["train_loss"].append(float(avg_train_loss))
|
| 150 |
+
history["val_loss"].append(float(avg_val_loss))
|
| 151 |
+
history["val_miou"].append(float(avg_val_miou))
|
| 152 |
+
history["val_dice"].append(float(avg_val_dice))
|
| 153 |
+
|
| 154 |
+
print(f"Epoch {epoch:3d} | train_loss={avg_train_loss:.4f} | val_loss={avg_val_loss:.4f} | "
|
| 155 |
+
f"val_mIoU={avg_val_miou:.4f} | val_Dice={avg_val_dice:.4f}")
|
| 156 |
+
|
| 157 |
+
# Checkpoint
|
| 158 |
+
if avg_val_miou > best_miou:
|
| 159 |
+
best_miou = avg_val_miou
|
| 160 |
+
patience_counter = 0
|
| 161 |
+
torch.save(model.state_dict(), ckpt_dir / "best_model.pt")
|
| 162 |
+
print(f" -> New best mIoU: {best_miou:.4f}, saved checkpoint")
|
| 163 |
+
else:
|
| 164 |
+
patience_counter += 1
|
| 165 |
+
if patience_counter >= tc["patience"]:
|
| 166 |
+
print(f" Early stopping at epoch {epoch} (patience={tc['patience']})")
|
| 167 |
+
break
|
| 168 |
+
|
| 169 |
+
total_time = time.time() - start_time
|
| 170 |
+
|
| 171 |
+
# Save history & summary
|
| 172 |
+
with open(log_dir / "training_history.json", "w") as f:
|
| 173 |
+
json.dump(history, f, indent=2)
|
| 174 |
+
|
| 175 |
+
summary = {
|
| 176 |
+
"total_epochs": epoch,
|
| 177 |
+
"best_val_miou": float(best_miou),
|
| 178 |
+
"total_time_seconds": round(total_time, 1),
|
| 179 |
+
"total_time_minutes": round(total_time / 60, 1),
|
| 180 |
+
"device": str(device),
|
| 181 |
+
"train_samples": len(train_ds),
|
| 182 |
+
"val_samples": len(val_ds),
|
| 183 |
+
"seed": seed,
|
| 184 |
+
}
|
| 185 |
+
with open(log_dir / "training_summary.json", "w") as f:
|
| 186 |
+
json.dump(summary, f, indent=2)
|
| 187 |
+
|
| 188 |
+
print(f"\nTraining complete in {summary['total_time_minutes']} min")
|
| 189 |
+
print(f"Best val mIoU: {best_miou:.4f}")
|
| 190 |
+
return model, history
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
if __name__ == "__main__":
|
| 194 |
+
train()
|