(debug) to torch operations in bfloat conversions.
Browse files- configs/default.yaml +5 -5
- infer.py +12 -9
- train.py +3 -3
configs/default.yaml
CHANGED
|
@@ -28,8 +28,8 @@ inference:
|
|
| 28 |
stitch: avg_logits
|
| 29 |
|
| 30 |
eval:
|
| 31 |
-
max_samples:
|
| 32 |
-
fine_batch:
|
| 33 |
|
| 34 |
optim:
|
| 35 |
iters: 2000
|
|
@@ -38,14 +38,14 @@ optim:
|
|
| 38 |
weight_decay: 0.01
|
| 39 |
schedule: poly
|
| 40 |
power: 1.0
|
| 41 |
-
precision:
|
| 42 |
|
| 43 |
# training housekeeping
|
| 44 |
seed: 42
|
| 45 |
out_dir: runs/wireseghr
|
| 46 |
-
eval_interval:
|
| 47 |
ckpt_interval: 300
|
| 48 |
-
resume: runs/wireseghr/ckpt_1800.pt # optional
|
| 49 |
|
| 50 |
# dataset paths (placeholders)
|
| 51 |
data:
|
|
|
|
| 28 |
stitch: avg_logits
|
| 29 |
|
| 30 |
eval:
|
| 31 |
+
max_samples: 12
|
| 32 |
+
fine_batch: 16
|
| 33 |
|
| 34 |
optim:
|
| 35 |
iters: 2000
|
|
|
|
| 38 |
weight_decay: 0.01
|
| 39 |
schedule: poly
|
| 40 |
power: 1.0
|
| 41 |
+
precision: bf16 # one of: fp32, fp16, bf16
|
| 42 |
|
| 43 |
# training housekeeping
|
| 44 |
seed: 42
|
| 45 |
out_dir: runs/wireseghr
|
| 46 |
+
eval_interval: 150
|
| 47 |
ckpt_interval: 300
|
| 48 |
+
# resume: runs/wireseghr/ckpt_1800.pt # optional
|
| 49 |
|
| 50 |
# dataset paths (placeholders)
|
| 51 |
data:
|
infer.py
CHANGED
|
@@ -31,7 +31,7 @@ def _coarse_forward(
|
|
| 31 |
device: torch.device,
|
| 32 |
amp_flag: bool,
|
| 33 |
amp_dtype,
|
| 34 |
-
) -> Tuple[
|
| 35 |
# Convert to tensor on device
|
| 36 |
t_img = (
|
| 37 |
torch.from_numpy(np.transpose(img_rgb, (2, 0, 1)))
|
|
@@ -76,8 +76,8 @@ def _coarse_forward(
|
|
| 76 |
F.interpolate(prob, size=(H, W), mode="bilinear", align_corners=False)[0, 0]
|
| 77 |
.detach()
|
| 78 |
.cpu()
|
| 79 |
-
.
|
| 80 |
-
)
|
| 81 |
return prob_up, cond_map, t_img, y_min_full, y_max_full
|
| 82 |
|
| 83 |
|
|
@@ -94,7 +94,7 @@ def _tiled_fine_forward(
|
|
| 94 |
device: torch.device,
|
| 95 |
amp_flag: bool,
|
| 96 |
amp_dtype,
|
| 97 |
-
) ->
|
| 98 |
H = int(t_img.shape[2])
|
| 99 |
W = int(t_img.shape[3])
|
| 100 |
P = patch_size
|
|
@@ -153,8 +153,8 @@ def _tiled_fine_forward(
|
|
| 153 |
prob_sum_t[y0:y1, x0:x1] += prob_f_up[bi]
|
| 154 |
weight_t[y0:y1, x0:x1] += 1.0
|
| 155 |
|
| 156 |
-
prob_full = (prob_sum_t / weight_t).detach().cpu().
|
| 157 |
-
return prob_full
|
| 158 |
|
| 159 |
|
| 160 |
def _build_model_from_cfg(cfg: dict, device: torch.device) -> WireSegHR:
|
|
@@ -216,7 +216,9 @@ def infer_image(
|
|
| 216 |
amp_dtype,
|
| 217 |
)
|
| 218 |
|
| 219 |
-
|
|
|
|
|
|
|
| 220 |
|
| 221 |
if out_dir is not None:
|
| 222 |
os.makedirs(out_dir, exist_ok=True)
|
|
@@ -225,9 +227,10 @@ def infer_image(
|
|
| 225 |
cv2.imwrite(out_mask, pred)
|
| 226 |
if save_prob:
|
| 227 |
out_prob = os.path.join(out_dir, f"{stem}_prob.npy")
|
| 228 |
-
np.save(out_prob, prob_f.
|
| 229 |
|
| 230 |
-
|
|
|
|
| 231 |
|
| 232 |
|
| 233 |
def main():
|
|
|
|
| 31 |
device: torch.device,
|
| 32 |
amp_flag: bool,
|
| 33 |
amp_dtype,
|
| 34 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 35 |
# Convert to tensor on device
|
| 36 |
t_img = (
|
| 37 |
torch.from_numpy(np.transpose(img_rgb, (2, 0, 1)))
|
|
|
|
| 76 |
F.interpolate(prob, size=(H, W), mode="bilinear", align_corners=False)[0, 0]
|
| 77 |
.detach()
|
| 78 |
.cpu()
|
| 79 |
+
.float()
|
| 80 |
+
) # HxW torch.Tensor on CPU
|
| 81 |
return prob_up, cond_map, t_img, y_min_full, y_max_full
|
| 82 |
|
| 83 |
|
|
|
|
| 94 |
device: torch.device,
|
| 95 |
amp_flag: bool,
|
| 96 |
amp_dtype,
|
| 97 |
+
) -> torch.Tensor:
|
| 98 |
H = int(t_img.shape[2])
|
| 99 |
W = int(t_img.shape[3])
|
| 100 |
P = patch_size
|
|
|
|
| 153 |
prob_sum_t[y0:y1, x0:x1] += prob_f_up[bi]
|
| 154 |
weight_t[y0:y1, x0:x1] += 1.0
|
| 155 |
|
| 156 |
+
prob_full = (prob_sum_t / weight_t).detach().cpu().float()
|
| 157 |
+
return prob_full # HxW torch.Tensor on CPU
|
| 158 |
|
| 159 |
|
| 160 |
def _build_model_from_cfg(cfg: dict, device: torch.device) -> WireSegHR:
|
|
|
|
| 216 |
amp_dtype,
|
| 217 |
)
|
| 218 |
|
| 219 |
+
# Threshold with torch on CPU; convert to numpy only for saving/returning
|
| 220 |
+
pred_t = (prob_f > prob_thresh).to(torch.uint8) * 255 # HxW uint8 torch
|
| 221 |
+
pred = pred_t.detach().cpu().numpy()
|
| 222 |
|
| 223 |
if out_dir is not None:
|
| 224 |
os.makedirs(out_dir, exist_ok=True)
|
|
|
|
| 227 |
cv2.imwrite(out_mask, pred)
|
| 228 |
if save_prob:
|
| 229 |
out_prob = os.path.join(out_dir, f"{stem}_prob.npy")
|
| 230 |
+
np.save(out_prob, prob_f.detach().cpu().float().numpy())
|
| 231 |
|
| 232 |
+
# Return numpy arrays for external consumers, computed via torch
|
| 233 |
+
return pred, prob_f.detach().cpu().numpy()
|
| 234 |
|
| 235 |
|
| 236 |
def main():
|
train.py
CHANGED
|
@@ -635,7 +635,7 @@ def validate(
|
|
| 635 |
amp_dtype,
|
| 636 |
)
|
| 637 |
# Coarse metrics
|
| 638 |
-
pred_coarse = (prob_up > prob_thresh).
|
| 639 |
m_c = compute_metrics(pred_coarse, mask)
|
| 640 |
for k in coarse_sum:
|
| 641 |
coarse_sum[k] += m_c[k]
|
|
@@ -664,7 +664,7 @@ def validate(
|
|
| 664 |
if xs[-1] != (W - P):
|
| 665 |
xs.append(W - P)
|
| 666 |
total_tiles += len(ys) * len(xs)
|
| 667 |
-
pred_fine = (prob_full > prob_thresh).
|
| 668 |
m_f = compute_metrics(pred_fine, mask)
|
| 669 |
for k in metrics_sum:
|
| 670 |
metrics_sum[k] += m_f[k]
|
|
@@ -721,7 +721,7 @@ def save_test_visuals(
|
|
| 721 |
bool(amp_flag),
|
| 722 |
None,
|
| 723 |
)
|
| 724 |
-
pred = (prob_up > prob_thresh).
|
| 725 |
# Save input and prediction
|
| 726 |
img_bgr = (img[..., ::-1] * 255.0).astype(np.uint8)
|
| 727 |
cv2.imwrite(os.path.join(out_dir, f"{i:03d}_input.jpg"), img_bgr)
|
|
|
|
| 635 |
amp_dtype,
|
| 636 |
)
|
| 637 |
# Coarse metrics
|
| 638 |
+
pred_coarse = (prob_up > prob_thresh).to(torch.uint8).cpu().numpy()
|
| 639 |
m_c = compute_metrics(pred_coarse, mask)
|
| 640 |
for k in coarse_sum:
|
| 641 |
coarse_sum[k] += m_c[k]
|
|
|
|
| 664 |
if xs[-1] != (W - P):
|
| 665 |
xs.append(W - P)
|
| 666 |
total_tiles += len(ys) * len(xs)
|
| 667 |
+
pred_fine = (prob_full > prob_thresh).to(torch.uint8).cpu().numpy()
|
| 668 |
m_f = compute_metrics(pred_fine, mask)
|
| 669 |
for k in metrics_sum:
|
| 670 |
metrics_sum[k] += m_f[k]
|
|
|
|
| 721 |
bool(amp_flag),
|
| 722 |
None,
|
| 723 |
)
|
| 724 |
+
pred = ((prob_up > prob_thresh).to(torch.uint8) * 255).cpu().numpy()
|
| 725 |
# Save input and prediction
|
| 726 |
img_bgr = (img[..., ::-1] * 255.0).astype(np.uint8)
|
| 727 |
cv2.imwrite(os.path.join(out_dir, f"{i:03d}_input.jpg"), img_bgr)
|