MRiabov commited on
Commit
8743e1d
·
1 Parent(s): efdb6a6

Added precision training option

Browse files
Files changed (2) hide show
  1. configs/default.yaml +1 -0
  2. train.py +30 -22
configs/default.yaml CHANGED
@@ -33,6 +33,7 @@ optim:
33
  weight_decay: 0.01
34
  schedule: poly
35
  power: 1.0
 
36
 
37
  # training housekeeping
38
  seed: 42
 
33
  weight_decay: 0.01
34
  schedule: poly
35
  power: 1.0
36
+ precision: fp32 # one of: fp32, fp16, bf16
37
 
38
  # training housekeeping
39
  seed: 42
train.py CHANGED
@@ -93,13 +93,30 @@ def main():
93
  base_lr = float(cfg["optim"]["lr"]) # 6e-5
94
  weight_decay = float(cfg["optim"]["weight_decay"]) # 0.01
95
  power = float(cfg["optim"]["power"]) # 1.0
96
- amp_flag = bool(cfg["optim"].get("amp", True))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  # Housekeeping
99
  seed = int(cfg.get("seed", 42))
100
  out_dir = cfg.get("out_dir", "runs/wireseghr")
101
- eval_interval = int(cfg.get("eval_interval", 500))
102
- ckpt_interval = int(cfg.get("ckpt_interval", 1000))
103
  os.makedirs(out_dir, exist_ok=True)
104
  set_seed(seed)
105
 
@@ -161,7 +178,7 @@ def main():
161
 
162
  # Optimizer and loss
163
  optim = torch.optim.AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay)
164
- scaler = GradScaler("cuda", enabled=(device.type == "cuda" and amp_flag))
165
  ce = nn.CrossEntropyLoss()
166
 
167
  # Resume
@@ -190,9 +207,7 @@ def main():
190
  imgs, masks, coarse_train, patch_size, sampler, minmax, device
191
  )
192
 
193
- with autocast(
194
- device_type=device.type, enabled=(device.type == "cuda" and amp_flag)
195
- ):
196
  logits_coarse, cond_map = model.forward_coarse(
197
  batch["x_coarse"]
198
  ) # (B,2,Hc/4,Wc/4) and (B,1,Hc/4,Wc/4)
@@ -200,9 +215,7 @@ def main():
200
  # Build fine inputs: crop cond from low-res map to patch, concat with patch RGB+MinMax and loc mask
201
  B, _, hc4, wc4 = cond_map.shape
202
  x_fine = _build_fine_inputs(batch, cond_map, device)
203
- with autocast(
204
- device_type=device.type, enabled=(device.type == "cuda" and amp_flag)
205
- ):
206
  logits_fine = model.forward_fine(x_fine)
207
 
208
  # Targets
@@ -241,7 +254,8 @@ def main():
241
  dset_val,
242
  coarse_train,
243
  device,
244
- amp_flag,
 
245
  prob_thresh,
246
  mm_enable,
247
  mm_kernel,
@@ -284,7 +298,7 @@ def main():
284
  coarse_train,
285
  device,
286
  os.path.join(out_dir, f"test_vis_{step}"),
287
- amp_flag,
288
  mm_enable,
289
  mm_kernel,
290
  prob_thresh,
@@ -549,6 +563,7 @@ def validate(
549
  coarse_size: int,
550
  device: torch.device,
551
  amp_flag: bool,
 
552
  prob_thresh: float,
553
  minmax_enable: bool,
554
  minmax_kernel: int,
@@ -604,9 +619,7 @@ def validate(
604
  )[0]
605
  zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
606
  x_t = torch.cat([rgb_c, y_min_c, y_max_c, zeros_c], dim=0).unsqueeze(0)
607
- with autocast(
608
- device_type=device.type, enabled=(device.type == "cuda" and amp_flag)
609
- ):
610
  logits_c, cond_map = model.forward_coarse(x_t)
611
  prob = torch.softmax(logits_c, dim=1)[:, 1:2]
612
  prob_up = (
@@ -669,10 +682,7 @@ def validate(
669
 
670
  x_f_batch = torch.cat(xs_list, dim=0) # Bx6xPxP
671
 
672
- with autocast(
673
- device_type=device.type,
674
- enabled=(device.type == "cuda" and amp_flag),
675
- ):
676
  logits_f = model.forward_fine(x_f_batch)
677
  prob_f = torch.softmax(logits_f, dim=1)[:, 1:2]
678
  prob_f_up = F.interpolate(
@@ -758,9 +768,7 @@ def save_test_visuals(
758
  )[0]
759
  zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
760
  x_t = torch.cat([rgb_c, y_min_c, y_max_c, zeros_c], dim=0).unsqueeze(0)
761
- with autocast(
762
- device_type=device.type, enabled=(device.type == "cuda" and amp_flag)
763
- ):
764
  logits_c, _ = model.forward_coarse(x_t)
765
  prob = torch.softmax(logits_c, dim=1)[:, 1:2]
766
  prob_up = (
 
93
  base_lr = float(cfg["optim"]["lr"]) # 6e-5
94
  weight_decay = float(cfg["optim"]["weight_decay"]) # 0.01
95
  power = float(cfg["optim"]["power"]) # 1.0
96
+ precision = str(cfg["optim"].get("precision", "fp32")).lower()
97
+ assert precision in ("fp32", "fp16", "bf16")
98
+ # Enable AMP only when requested and on CUDA
99
+ amp_enabled = (device.type == "cuda") and (precision in ("fp16", "bf16"))
100
+ # Fail fast on unsupported hardware if mixed precision is requested
101
+ if amp_enabled:
102
+ cc_major, cc_minor = torch.cuda.get_device_capability()
103
+ if precision == "fp16":
104
+ assert (
105
+ cc_major >= 7
106
+ ), f"fp16 requires Volta (SM 7.0)+; current SM {cc_major}.{cc_minor}"
107
+ elif precision == "bf16":
108
+ assert (
109
+ cc_major >= 8
110
+ ), f"bf16 requires Ampere (SM 8.0)+; current SM {cc_major}.{cc_minor}"
111
+ amp_dtype = (
112
+ torch.float16 if precision == "fp16" else (torch.bfloat16 if precision == "bf16" else None)
113
+ )
114
 
115
  # Housekeeping
116
  seed = int(cfg.get("seed", 42))
117
  out_dir = cfg.get("out_dir", "runs/wireseghr")
118
+ eval_interval = int(cfg["eval_interval"])
119
+ ckpt_interval = int(cfg["ckpt_interval"])
120
  os.makedirs(out_dir, exist_ok=True)
121
  set_seed(seed)
122
 
 
178
 
179
  # Optimizer and loss
180
  optim = torch.optim.AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay)
181
+ scaler = GradScaler("cuda", enabled=(device.type == "cuda" and precision == "fp16"))
182
  ce = nn.CrossEntropyLoss()
183
 
184
  # Resume
 
207
  imgs, masks, coarse_train, patch_size, sampler, minmax, device
208
  )
209
 
210
+ with autocast(device_type=device.type, dtype=amp_dtype, enabled=amp_enabled):
 
 
211
  logits_coarse, cond_map = model.forward_coarse(
212
  batch["x_coarse"]
213
  ) # (B,2,Hc/4,Wc/4) and (B,1,Hc/4,Wc/4)
 
215
  # Build fine inputs: crop cond from low-res map to patch, concat with patch RGB+MinMax and loc mask
216
  B, _, hc4, wc4 = cond_map.shape
217
  x_fine = _build_fine_inputs(batch, cond_map, device)
218
+ with autocast(device_type=device.type, dtype=amp_dtype, enabled=amp_enabled):
 
 
219
  logits_fine = model.forward_fine(x_fine)
220
 
221
  # Targets
 
254
  dset_val,
255
  coarse_train,
256
  device,
257
+ amp_enabled,
258
+ amp_dtype,
259
  prob_thresh,
260
  mm_enable,
261
  mm_kernel,
 
298
  coarse_train,
299
  device,
300
  os.path.join(out_dir, f"test_vis_{step}"),
301
+ amp_enabled,
302
  mm_enable,
303
  mm_kernel,
304
  prob_thresh,
 
563
  coarse_size: int,
564
  device: torch.device,
565
  amp_flag: bool,
566
+ amp_dtype,
567
  prob_thresh: float,
568
  minmax_enable: bool,
569
  minmax_kernel: int,
 
619
  )[0]
620
  zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
621
  x_t = torch.cat([rgb_c, y_min_c, y_max_c, zeros_c], dim=0).unsqueeze(0)
622
+ with autocast(device_type=device.type, dtype=amp_dtype, enabled=amp_flag):
 
 
623
  logits_c, cond_map = model.forward_coarse(x_t)
624
  prob = torch.softmax(logits_c, dim=1)[:, 1:2]
625
  prob_up = (
 
682
 
683
  x_f_batch = torch.cat(xs_list, dim=0) # Bx6xPxP
684
 
685
+ with autocast(device_type=device.type, dtype=amp_dtype, enabled=amp_flag):
 
 
 
686
  logits_f = model.forward_fine(x_f_batch)
687
  prob_f = torch.softmax(logits_f, dim=1)[:, 1:2]
688
  prob_f_up = F.interpolate(
 
768
  )[0]
769
  zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
770
  x_t = torch.cat([rgb_c, y_min_c, y_max_c, zeros_c], dim=0).unsqueeze(0)
771
+ with autocast(device_type=device.type, dtype=None, enabled=amp_flag):
 
 
772
  logits_c, _ = model.forward_coarse(x_t)
773
  prob = torch.softmax(logits_c, dim=1)[:, 1:2]
774
  prob_up = (