Added precision training option
Browse files- configs/default.yaml +1 -0
- train.py +30 -22
configs/default.yaml
CHANGED
|
@@ -33,6 +33,7 @@ optim:
|
|
| 33 |
weight_decay: 0.01
|
| 34 |
schedule: poly
|
| 35 |
power: 1.0
|
|
|
|
| 36 |
|
| 37 |
# training housekeeping
|
| 38 |
seed: 42
|
|
|
|
| 33 |
weight_decay: 0.01
|
| 34 |
schedule: poly
|
| 35 |
power: 1.0
|
| 36 |
+
precision: fp32 # one of: fp32, fp16, bf16
|
| 37 |
|
| 38 |
# training housekeeping
|
| 39 |
seed: 42
|
train.py
CHANGED
|
@@ -93,13 +93,30 @@ def main():
|
|
| 93 |
base_lr = float(cfg["optim"]["lr"]) # 6e-5
|
| 94 |
weight_decay = float(cfg["optim"]["weight_decay"]) # 0.01
|
| 95 |
power = float(cfg["optim"]["power"]) # 1.0
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
# Housekeeping
|
| 99 |
seed = int(cfg.get("seed", 42))
|
| 100 |
out_dir = cfg.get("out_dir", "runs/wireseghr")
|
| 101 |
-
eval_interval = int(cfg
|
| 102 |
-
ckpt_interval = int(cfg
|
| 103 |
os.makedirs(out_dir, exist_ok=True)
|
| 104 |
set_seed(seed)
|
| 105 |
|
|
@@ -161,7 +178,7 @@ def main():
|
|
| 161 |
|
| 162 |
# Optimizer and loss
|
| 163 |
optim = torch.optim.AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay)
|
| 164 |
-
scaler = GradScaler("cuda", enabled=(device.type == "cuda" and
|
| 165 |
ce = nn.CrossEntropyLoss()
|
| 166 |
|
| 167 |
# Resume
|
|
@@ -190,9 +207,7 @@ def main():
|
|
| 190 |
imgs, masks, coarse_train, patch_size, sampler, minmax, device
|
| 191 |
)
|
| 192 |
|
| 193 |
-
with autocast(
|
| 194 |
-
device_type=device.type, enabled=(device.type == "cuda" and amp_flag)
|
| 195 |
-
):
|
| 196 |
logits_coarse, cond_map = model.forward_coarse(
|
| 197 |
batch["x_coarse"]
|
| 198 |
) # (B,2,Hc/4,Wc/4) and (B,1,Hc/4,Wc/4)
|
|
@@ -200,9 +215,7 @@ def main():
|
|
| 200 |
# Build fine inputs: crop cond from low-res map to patch, concat with patch RGB+MinMax and loc mask
|
| 201 |
B, _, hc4, wc4 = cond_map.shape
|
| 202 |
x_fine = _build_fine_inputs(batch, cond_map, device)
|
| 203 |
-
with autocast(
|
| 204 |
-
device_type=device.type, enabled=(device.type == "cuda" and amp_flag)
|
| 205 |
-
):
|
| 206 |
logits_fine = model.forward_fine(x_fine)
|
| 207 |
|
| 208 |
# Targets
|
|
@@ -241,7 +254,8 @@ def main():
|
|
| 241 |
dset_val,
|
| 242 |
coarse_train,
|
| 243 |
device,
|
| 244 |
-
|
|
|
|
| 245 |
prob_thresh,
|
| 246 |
mm_enable,
|
| 247 |
mm_kernel,
|
|
@@ -284,7 +298,7 @@ def main():
|
|
| 284 |
coarse_train,
|
| 285 |
device,
|
| 286 |
os.path.join(out_dir, f"test_vis_{step}"),
|
| 287 |
-
|
| 288 |
mm_enable,
|
| 289 |
mm_kernel,
|
| 290 |
prob_thresh,
|
|
@@ -549,6 +563,7 @@ def validate(
|
|
| 549 |
coarse_size: int,
|
| 550 |
device: torch.device,
|
| 551 |
amp_flag: bool,
|
|
|
|
| 552 |
prob_thresh: float,
|
| 553 |
minmax_enable: bool,
|
| 554 |
minmax_kernel: int,
|
|
@@ -604,9 +619,7 @@ def validate(
|
|
| 604 |
)[0]
|
| 605 |
zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
|
| 606 |
x_t = torch.cat([rgb_c, y_min_c, y_max_c, zeros_c], dim=0).unsqueeze(0)
|
| 607 |
-
with autocast(
|
| 608 |
-
device_type=device.type, enabled=(device.type == "cuda" and amp_flag)
|
| 609 |
-
):
|
| 610 |
logits_c, cond_map = model.forward_coarse(x_t)
|
| 611 |
prob = torch.softmax(logits_c, dim=1)[:, 1:2]
|
| 612 |
prob_up = (
|
|
@@ -669,10 +682,7 @@ def validate(
|
|
| 669 |
|
| 670 |
x_f_batch = torch.cat(xs_list, dim=0) # Bx6xPxP
|
| 671 |
|
| 672 |
-
with autocast(
|
| 673 |
-
device_type=device.type,
|
| 674 |
-
enabled=(device.type == "cuda" and amp_flag),
|
| 675 |
-
):
|
| 676 |
logits_f = model.forward_fine(x_f_batch)
|
| 677 |
prob_f = torch.softmax(logits_f, dim=1)[:, 1:2]
|
| 678 |
prob_f_up = F.interpolate(
|
|
@@ -758,9 +768,7 @@ def save_test_visuals(
|
|
| 758 |
)[0]
|
| 759 |
zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
|
| 760 |
x_t = torch.cat([rgb_c, y_min_c, y_max_c, zeros_c], dim=0).unsqueeze(0)
|
| 761 |
-
with autocast(
|
| 762 |
-
device_type=device.type, enabled=(device.type == "cuda" and amp_flag)
|
| 763 |
-
):
|
| 764 |
logits_c, _ = model.forward_coarse(x_t)
|
| 765 |
prob = torch.softmax(logits_c, dim=1)[:, 1:2]
|
| 766 |
prob_up = (
|
|
|
|
| 93 |
base_lr = float(cfg["optim"]["lr"]) # 6e-5
|
| 94 |
weight_decay = float(cfg["optim"]["weight_decay"]) # 0.01
|
| 95 |
power = float(cfg["optim"]["power"]) # 1.0
|
| 96 |
+
precision = str(cfg["optim"].get("precision", "fp32")).lower()
|
| 97 |
+
assert precision in ("fp32", "fp16", "bf16")
|
| 98 |
+
# Enable AMP only when requested and on CUDA
|
| 99 |
+
amp_enabled = (device.type == "cuda") and (precision in ("fp16", "bf16"))
|
| 100 |
+
# Fail fast on unsupported hardware if mixed precision is requested
|
| 101 |
+
if amp_enabled:
|
| 102 |
+
cc_major, cc_minor = torch.cuda.get_device_capability()
|
| 103 |
+
if precision == "fp16":
|
| 104 |
+
assert (
|
| 105 |
+
cc_major >= 7
|
| 106 |
+
), f"fp16 requires Volta (SM 7.0)+; current SM {cc_major}.{cc_minor}"
|
| 107 |
+
elif precision == "bf16":
|
| 108 |
+
assert (
|
| 109 |
+
cc_major >= 8
|
| 110 |
+
), f"bf16 requires Ampere (SM 8.0)+; current SM {cc_major}.{cc_minor}"
|
| 111 |
+
amp_dtype = (
|
| 112 |
+
torch.float16 if precision == "fp16" else (torch.bfloat16 if precision == "bf16" else None)
|
| 113 |
+
)
|
| 114 |
|
| 115 |
# Housekeeping
|
| 116 |
seed = int(cfg.get("seed", 42))
|
| 117 |
out_dir = cfg.get("out_dir", "runs/wireseghr")
|
| 118 |
+
eval_interval = int(cfg["eval_interval"])
|
| 119 |
+
ckpt_interval = int(cfg["ckpt_interval"])
|
| 120 |
os.makedirs(out_dir, exist_ok=True)
|
| 121 |
set_seed(seed)
|
| 122 |
|
|
|
|
| 178 |
|
| 179 |
# Optimizer and loss
|
| 180 |
optim = torch.optim.AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay)
|
| 181 |
+
scaler = GradScaler("cuda", enabled=(device.type == "cuda" and precision == "fp16"))
|
| 182 |
ce = nn.CrossEntropyLoss()
|
| 183 |
|
| 184 |
# Resume
|
|
|
|
| 207 |
imgs, masks, coarse_train, patch_size, sampler, minmax, device
|
| 208 |
)
|
| 209 |
|
| 210 |
+
with autocast(device_type=device.type, dtype=amp_dtype, enabled=amp_enabled):
|
|
|
|
|
|
|
| 211 |
logits_coarse, cond_map = model.forward_coarse(
|
| 212 |
batch["x_coarse"]
|
| 213 |
) # (B,2,Hc/4,Wc/4) and (B,1,Hc/4,Wc/4)
|
|
|
|
| 215 |
# Build fine inputs: crop cond from low-res map to patch, concat with patch RGB+MinMax and loc mask
|
| 216 |
B, _, hc4, wc4 = cond_map.shape
|
| 217 |
x_fine = _build_fine_inputs(batch, cond_map, device)
|
| 218 |
+
with autocast(device_type=device.type, dtype=amp_dtype, enabled=amp_enabled):
|
|
|
|
|
|
|
| 219 |
logits_fine = model.forward_fine(x_fine)
|
| 220 |
|
| 221 |
# Targets
|
|
|
|
| 254 |
dset_val,
|
| 255 |
coarse_train,
|
| 256 |
device,
|
| 257 |
+
amp_enabled,
|
| 258 |
+
amp_dtype,
|
| 259 |
prob_thresh,
|
| 260 |
mm_enable,
|
| 261 |
mm_kernel,
|
|
|
|
| 298 |
coarse_train,
|
| 299 |
device,
|
| 300 |
os.path.join(out_dir, f"test_vis_{step}"),
|
| 301 |
+
amp_enabled,
|
| 302 |
mm_enable,
|
| 303 |
mm_kernel,
|
| 304 |
prob_thresh,
|
|
|
|
| 563 |
coarse_size: int,
|
| 564 |
device: torch.device,
|
| 565 |
amp_flag: bool,
|
| 566 |
+
amp_dtype,
|
| 567 |
prob_thresh: float,
|
| 568 |
minmax_enable: bool,
|
| 569 |
minmax_kernel: int,
|
|
|
|
| 619 |
)[0]
|
| 620 |
zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
|
| 621 |
x_t = torch.cat([rgb_c, y_min_c, y_max_c, zeros_c], dim=0).unsqueeze(0)
|
| 622 |
+
with autocast(device_type=device.type, dtype=amp_dtype, enabled=amp_flag):
|
|
|
|
|
|
|
| 623 |
logits_c, cond_map = model.forward_coarse(x_t)
|
| 624 |
prob = torch.softmax(logits_c, dim=1)[:, 1:2]
|
| 625 |
prob_up = (
|
|
|
|
| 682 |
|
| 683 |
x_f_batch = torch.cat(xs_list, dim=0) # Bx6xPxP
|
| 684 |
|
| 685 |
+
with autocast(device_type=device.type, dtype=amp_dtype, enabled=amp_flag):
|
|
|
|
|
|
|
|
|
|
| 686 |
logits_f = model.forward_fine(x_f_batch)
|
| 687 |
prob_f = torch.softmax(logits_f, dim=1)[:, 1:2]
|
| 688 |
prob_f_up = F.interpolate(
|
|
|
|
| 768 |
)[0]
|
| 769 |
zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
|
| 770 |
x_t = torch.cat([rgb_c, y_min_c, y_max_c, zeros_c], dim=0).unsqueeze(0)
|
| 771 |
+
with autocast(device_type=device.type, dtype=None, enabled=amp_flag):
|
|
|
|
|
|
|
| 772 |
logits_c, _ = model.forward_coarse(x_t)
|
| 773 |
prob = torch.softmax(logits_c, dim=1)[:, 1:2]
|
| 774 |
prob_up = (
|