to mit_b2 as per paper
Browse files- SEGMENTATION_PLAN.md +2 -2
- configs/default.yaml +2 -2
- src/wireseghr/model/encoder.py +3 -3
- src/wireseghr/model/model.py +1 -1
- tests/test_model_forward.py +1 -1
- train.py +112 -11
SEGMENTATION_PLAN.md
CHANGED
|
@@ -21,7 +21,7 @@ Focus: segmentation only (no dataset collection or inpainting).
|
|
| 21 |
|
| 22 |
## Project Structure
|
| 23 |
- `configs/`
|
| 24 |
-
- `default.yaml` (backbone=
|
| 25 |
- `src/wireseghr/`
|
| 26 |
- `model/`
|
| 27 |
- `encoder.py` (SegFormer MiT-B3, N_in channels expansion)
|
|
@@ -106,7 +106,7 @@ Focus: segmentation only (no dataset collection or inpainting).
|
|
| 106 |
- Ablations: MinMax on/off, MaxPool on/off, conditioning variant (Table `tables/logit.tex`).
|
| 107 |
|
| 108 |
## Configuration Surface (key)
|
| 109 |
-
- Backbone/weights: `
|
| 110 |
- Sizes: `p=768`, `coarse_train=512`, `coarse_test=1024`, `overlap=128`.
|
| 111 |
- Conditioning: `use_binary_location=true`, `cond_from='coarse_logits_1x1'`, `cond_crop='patch'`.
|
| 112 |
- MinMax: `enable=true`, `kernel=6`.
|
|
|
|
| 21 |
|
| 22 |
## Project Structure
|
| 23 |
- `configs/`
|
| 24 |
+
- `default.yaml` (backbone=mit_b2, p=768, coarse_train=512, coarse_test=1024, alpha=0.01, minmax=true, kernel=6, maxpool_label=true, cond_variant=global+binary_mask)
|
| 25 |
- `src/wireseghr/`
|
| 26 |
- `model/`
|
| 27 |
- `encoder.py` (SegFormer MiT-B3, N_in channels expansion)
|
|
|
|
| 106 |
- Ablations: MinMax on/off, MaxPool on/off, conditioning variant (Table `tables/logit.tex`).
|
| 107 |
|
| 108 |
## Configuration Surface (key)
|
| 109 |
+
- Backbone/weights: `mit_b2` (pretrained ImageNet-1K).
|
| 110 |
- Sizes: `p=768`, `coarse_train=512`, `coarse_test=1024`, `overlap=128`.
|
| 111 |
- Conditioning: `use_binary_location=true`, `cond_from='coarse_logits_1x1'`, `cond_crop='patch'`.
|
| 112 |
- MinMax: `enable=true`, `kernel=6`.
|
configs/default.yaml
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
# Default configuration for WireSegHR (segmentation-only)
|
| 2 |
-
backbone:
|
| 3 |
pretrained: true # Uses HF SegFormer weights if available; else timm or tiny fallback
|
| 4 |
|
| 5 |
coarse:
|
|
@@ -24,7 +24,7 @@ label:
|
|
| 24 |
|
| 25 |
inference:
|
| 26 |
alpha: 0.01
|
| 27 |
-
prob_threshold: 0.5
|
| 28 |
stitch: avg_logits
|
| 29 |
|
| 30 |
optim:
|
|
|
|
| 1 |
# Default configuration for WireSegHR (segmentation-only)
|
| 2 |
+
backbone: mit_b2
|
| 3 |
pretrained: true # Uses HF SegFormer weights if available; else timm or tiny fallback
|
| 4 |
|
| 5 |
coarse:
|
|
|
|
| 24 |
|
| 25 |
inference:
|
| 26 |
alpha: 0.01
|
| 27 |
+
prob_threshold: 0.3 # was 0.5, not actually mentioned in the paper.
|
| 28 |
stitch: avg_logits
|
| 29 |
|
| 30 |
optim:
|
src/wireseghr/model/encoder.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
"""SegFormer MiT encoder wrapper with adjustable input channels.
|
| 2 |
|
| 3 |
-
Uses timm to instantiate MiT (e.g.,
|
| 4 |
features [C1, C2, C3, C4].
|
| 5 |
"""
|
| 6 |
|
|
@@ -14,7 +14,7 @@ import timm
|
|
| 14 |
class SegFormerEncoder(nn.Module):
|
| 15 |
def __init__(
|
| 16 |
self,
|
| 17 |
-
backbone: str = "
|
| 18 |
in_channels: int = 7,
|
| 19 |
pretrained: bool = True,
|
| 20 |
out_indices: Tuple[int, int, int, int] = (0, 1, 2, 3),
|
|
@@ -125,7 +125,7 @@ class _HFEncoderWrapper(nn.Module):
|
|
| 125 |
"mit_b0": "nvidia/mit-b0",
|
| 126 |
"mit_b1": "nvidia/mit-b1",
|
| 127 |
"mit_b2": "nvidia/mit-b2",
|
| 128 |
-
"
|
| 129 |
"mit_b4": "nvidia/mit-b4",
|
| 130 |
"mit_b5": "nvidia/mit-b5",
|
| 131 |
}
|
|
|
|
| 1 |
"""SegFormer MiT encoder wrapper with adjustable input channels.
|
| 2 |
|
| 3 |
+
Uses timm to instantiate MiT (e.g., mit_b2) and returns a list of multi-scale
|
| 4 |
features [C1, C2, C3, C4].
|
| 5 |
"""
|
| 6 |
|
|
|
|
| 14 |
class SegFormerEncoder(nn.Module):
|
| 15 |
def __init__(
|
| 16 |
self,
|
| 17 |
+
backbone: str = "mit_b2",
|
| 18 |
in_channels: int = 7,
|
| 19 |
pretrained: bool = True,
|
| 20 |
out_indices: Tuple[int, int, int, int] = (0, 1, 2, 3),
|
|
|
|
| 125 |
"mit_b0": "nvidia/mit-b0",
|
| 126 |
"mit_b1": "nvidia/mit-b1",
|
| 127 |
"mit_b2": "nvidia/mit-b2",
|
| 128 |
+
"mit_b2": "nvidia/mit-b3",
|
| 129 |
"mit_b4": "nvidia/mit-b4",
|
| 130 |
"mit_b5": "nvidia/mit-b5",
|
| 131 |
}
|
src/wireseghr/model/model.py
CHANGED
|
@@ -20,7 +20,7 @@ class WireSegHR(nn.Module):
|
|
| 20 |
"""
|
| 21 |
|
| 22 |
def __init__(
|
| 23 |
-
self, backbone: str = "
|
| 24 |
):
|
| 25 |
super().__init__()
|
| 26 |
self.encoder = SegFormerEncoder(
|
|
|
|
| 20 |
"""
|
| 21 |
|
| 22 |
def __init__(
|
| 23 |
+
self, backbone: str = "mit_b2", in_channels: int = 7, pretrained: bool = True
|
| 24 |
):
|
| 25 |
super().__init__()
|
| 26 |
self.encoder = SegFormerEncoder(
|
tests/test_model_forward.py
CHANGED
|
@@ -5,7 +5,7 @@ from wireseghr.model import WireSegHR
|
|
| 5 |
|
| 6 |
def test_wireseghr_forward_shapes():
|
| 7 |
# Use small input to keep test light and avoid downloading weights
|
| 8 |
-
model = WireSegHR(backbone="
|
| 9 |
|
| 10 |
x = torch.randn(1, 3, 64, 64)
|
| 11 |
logits_coarse, cond = model.forward_coarse(x)
|
|
|
|
| 5 |
|
| 6 |
def test_wireseghr_forward_shapes():
|
| 7 |
# Use small input to keep test light and avoid downloading weights
|
| 8 |
+
model = WireSegHR(backbone="mit_b2", in_channels=3, pretrained=False)
|
| 9 |
|
| 10 |
x = torch.randn(1, 3, 64, 64)
|
| 11 |
logits_coarse, cond = model.forward_coarse(x)
|
train.py
CHANGED
|
@@ -8,7 +8,7 @@ import numpy as np
|
|
| 8 |
import torch
|
| 9 |
import torch.nn as nn
|
| 10 |
import torch.nn.functional as F
|
| 11 |
-
from torch.
|
| 12 |
from torch.amp import GradScaler
|
| 13 |
from tqdm import tqdm
|
| 14 |
import random
|
|
@@ -46,6 +46,7 @@ def main():
|
|
| 46 |
# Config
|
| 47 |
coarse_train = int(cfg["coarse"]["train_size"]) # 512
|
| 48 |
patch_size = int(cfg["fine"]["patch_size"]) # 768
|
|
|
|
| 49 |
iters = int(cfg["optim"]["iters"]) # 40000
|
| 50 |
batch_size = int(cfg["optim"]["batch_size"]) # 8
|
| 51 |
base_lr = float(cfg["optim"]["lr"]) # 6e-5
|
|
@@ -126,7 +127,9 @@ def main():
|
|
| 126 |
imgs, masks, coarse_train, patch_size, sampler, minmax, device
|
| 127 |
)
|
| 128 |
|
| 129 |
-
with autocast(
|
|
|
|
|
|
|
| 130 |
logits_coarse, cond_map = model.forward_coarse(
|
| 131 |
batch["x_coarse"]
|
| 132 |
) # (B,2,Hc/4,Wc/4) and (B,1,Hc/4,Wc/4)
|
|
@@ -134,7 +137,9 @@ def main():
|
|
| 134 |
# Build fine inputs: crop cond from low-res map to patch, concat with patch RGB+MinMax and loc mask
|
| 135 |
B, _, hc4, wc4 = cond_map.shape
|
| 136 |
x_fine = _build_fine_inputs(batch, cond_map, device)
|
| 137 |
-
with autocast(
|
|
|
|
|
|
|
| 138 |
logits_fine = model.forward_fine(x_fine)
|
| 139 |
|
| 140 |
# Targets
|
|
@@ -174,9 +179,14 @@ def main():
|
|
| 174 |
prob_thresh,
|
| 175 |
mm_enable,
|
| 176 |
mm_kernel,
|
|
|
|
|
|
|
| 177 |
)
|
| 178 |
print(
|
| 179 |
-
f"[Val @ {step}]
|
|
|
|
|
|
|
|
|
|
| 180 |
)
|
| 181 |
# Save best
|
| 182 |
if val_stats["f1"] > best_f1:
|
|
@@ -481,10 +491,13 @@ def validate(
|
|
| 481 |
prob_thresh: float,
|
| 482 |
minmax_enable: bool,
|
| 483 |
minmax_kernel: int,
|
|
|
|
|
|
|
| 484 |
) -> Dict[str, float]:
|
| 485 |
# Coarse-only validation: resize image to coarse_size, predict coarse logits, upsample to full and compute metrics
|
| 486 |
model = model.to(device)
|
| 487 |
metrics_sum = {"iou": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0}
|
|
|
|
| 488 |
n = 0
|
| 489 |
for i in range(len(dset_val)):
|
| 490 |
item = dset_val[i]
|
|
@@ -529,8 +542,10 @@ def validate(
|
|
| 529 |
)[0]
|
| 530 |
zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
|
| 531 |
x_t = torch.cat([rgb_c, y_min_c, y_max_c, zeros_c, zeros_c], dim=0).unsqueeze(0)
|
| 532 |
-
with autocast(
|
| 533 |
-
|
|
|
|
|
|
|
| 534 |
prob = torch.softmax(logits_c, dim=1)[:, 1:2]
|
| 535 |
prob_up = (
|
| 536 |
F.interpolate(prob, size=(H, W), mode="bilinear", align_corners=False)[0, 0]
|
|
@@ -538,14 +553,98 @@ def validate(
|
|
| 538 |
.cpu()
|
| 539 |
.numpy()
|
| 540 |
)
|
| 541 |
-
|
| 542 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 543 |
for k in metrics_sum:
|
| 544 |
-
metrics_sum[k] +=
|
| 545 |
n += 1
|
| 546 |
if n == 0:
|
| 547 |
return {k: 0.0 for k in metrics_sum}
|
| 548 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 549 |
|
| 550 |
|
| 551 |
@torch.no_grad()
|
|
@@ -602,7 +701,9 @@ def save_test_visuals(
|
|
| 602 |
)[0]
|
| 603 |
zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
|
| 604 |
x_t = torch.cat([rgb_c, y_min_c, y_max_c, zeros_c, zeros_c], dim=0).unsqueeze(0)
|
| 605 |
-
with autocast(
|
|
|
|
|
|
|
| 606 |
logits_c, _ = model.forward_coarse(x_t)
|
| 607 |
prob = torch.softmax(logits_c, dim=1)[:, 1:2]
|
| 608 |
prob_up = (
|
|
|
|
| 8 |
import torch
|
| 9 |
import torch.nn as nn
|
| 10 |
import torch.nn.functional as F
|
| 11 |
+
from torch.amp import autocast
|
| 12 |
from torch.amp import GradScaler
|
| 13 |
from tqdm import tqdm
|
| 14 |
import random
|
|
|
|
| 46 |
# Config
|
| 47 |
coarse_train = int(cfg["coarse"]["train_size"]) # 512
|
| 48 |
patch_size = int(cfg["fine"]["patch_size"]) # 768
|
| 49 |
+
overlap = int(cfg["fine"]["overlap"]) # e.g., 128
|
| 50 |
iters = int(cfg["optim"]["iters"]) # 40000
|
| 51 |
batch_size = int(cfg["optim"]["batch_size"]) # 8
|
| 52 |
base_lr = float(cfg["optim"]["lr"]) # 6e-5
|
|
|
|
| 127 |
imgs, masks, coarse_train, patch_size, sampler, minmax, device
|
| 128 |
)
|
| 129 |
|
| 130 |
+
with autocast(
|
| 131 |
+
device_type=device.type, enabled=(device.type == "cuda" and amp_flag)
|
| 132 |
+
):
|
| 133 |
logits_coarse, cond_map = model.forward_coarse(
|
| 134 |
batch["x_coarse"]
|
| 135 |
) # (B,2,Hc/4,Wc/4) and (B,1,Hc/4,Wc/4)
|
|
|
|
| 137 |
# Build fine inputs: crop cond from low-res map to patch, concat with patch RGB+MinMax and loc mask
|
| 138 |
B, _, hc4, wc4 = cond_map.shape
|
| 139 |
x_fine = _build_fine_inputs(batch, cond_map, device)
|
| 140 |
+
with autocast(
|
| 141 |
+
device_type=device.type, enabled=(device.type == "cuda" and amp_flag)
|
| 142 |
+
):
|
| 143 |
logits_fine = model.forward_fine(x_fine)
|
| 144 |
|
| 145 |
# Targets
|
|
|
|
| 179 |
prob_thresh,
|
| 180 |
mm_enable,
|
| 181 |
mm_kernel,
|
| 182 |
+
patch_size,
|
| 183 |
+
overlap,
|
| 184 |
)
|
| 185 |
print(
|
| 186 |
+
f"[Val @ {step}][Fine] IoU={val_stats['iou']:.4f} F1={val_stats['f1']:.4f} P={val_stats['precision']:.4f} R={val_stats['recall']:.4f}"
|
| 187 |
+
)
|
| 188 |
+
print(
|
| 189 |
+
f"[Val @ {step}][Coarse] IoU={val_stats['iou_coarse']:.4f} F1={val_stats['f1_coarse']:.4f} P={val_stats['precision_coarse']:.4f} R={val_stats['recall_coarse']:.4f}"
|
| 190 |
)
|
| 191 |
# Save best
|
| 192 |
if val_stats["f1"] > best_f1:
|
|
|
|
| 491 |
prob_thresh: float,
|
| 492 |
minmax_enable: bool,
|
| 493 |
minmax_kernel: int,
|
| 494 |
+
fine_patch_size: int,
|
| 495 |
+
fine_overlap: int,
|
| 496 |
) -> Dict[str, float]:
|
| 497 |
# Coarse-only validation: resize image to coarse_size, predict coarse logits, upsample to full and compute metrics
|
| 498 |
model = model.to(device)
|
| 499 |
metrics_sum = {"iou": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0}
|
| 500 |
+
coarse_sum = {"iou": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0}
|
| 501 |
n = 0
|
| 502 |
for i in range(len(dset_val)):
|
| 503 |
item = dset_val[i]
|
|
|
|
| 542 |
)[0]
|
| 543 |
zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
|
| 544 |
x_t = torch.cat([rgb_c, y_min_c, y_max_c, zeros_c, zeros_c], dim=0).unsqueeze(0)
|
| 545 |
+
with autocast(
|
| 546 |
+
device_type=device.type, enabled=(device.type == "cuda" and amp_flag)
|
| 547 |
+
):
|
| 548 |
+
logits_c, cond_map = model.forward_coarse(x_t)
|
| 549 |
prob = torch.softmax(logits_c, dim=1)[:, 1:2]
|
| 550 |
prob_up = (
|
| 551 |
F.interpolate(prob, size=(H, W), mode="bilinear", align_corners=False)[0, 0]
|
|
|
|
| 553 |
.cpu()
|
| 554 |
.numpy()
|
| 555 |
)
|
| 556 |
+
# Coarse metrics
|
| 557 |
+
pred_coarse = (prob_up > prob_thresh).astype(np.uint8)
|
| 558 |
+
m_c = compute_metrics(pred_coarse, mask)
|
| 559 |
+
for k in coarse_sum:
|
| 560 |
+
coarse_sum[k] += m_c[k]
|
| 561 |
+
|
| 562 |
+
# Fine-stage tiled inference and stitching
|
| 563 |
+
P = fine_patch_size
|
| 564 |
+
stride = P - fine_overlap
|
| 565 |
+
assert stride > 0
|
| 566 |
+
assert H >= P and W >= P
|
| 567 |
+
prob_sum = np.zeros((H, W), dtype=np.float32)
|
| 568 |
+
weight = np.zeros((H, W), dtype=np.float32)
|
| 569 |
+
|
| 570 |
+
# Prepare min/max on full-res (already computed above as y_min_full/y_max_full)
|
| 571 |
+
# y_min_full, y_max_full exist in this scope from above branch
|
| 572 |
+
hc4, wc4 = cond_map.shape[2], cond_map.shape[3]
|
| 573 |
+
|
| 574 |
+
ys = list(range(0, max(H - P, 0) + 1, stride))
|
| 575 |
+
if ys[-1] != (H - P):
|
| 576 |
+
ys.append(H - P)
|
| 577 |
+
xs = list(range(0, max(W - P, 0) + 1, stride))
|
| 578 |
+
if xs[-1] != (W - P):
|
| 579 |
+
xs.append(W - P)
|
| 580 |
+
|
| 581 |
+
for y0 in ys:
|
| 582 |
+
for x0 in xs:
|
| 583 |
+
y1, x1 = y0 + P, x0 + P
|
| 584 |
+
# Build fine input 1x7xP x P
|
| 585 |
+
patch_rgb = img[y0:y1, x0:x1, :]
|
| 586 |
+
ymin_patch = y_min_full[0, 0, y0:y1, x0:x1].detach().cpu().numpy()
|
| 587 |
+
ymax_patch = y_max_full[0, 0, y0:y1, x0:x1].detach().cpu().numpy()
|
| 588 |
+
loc = np.ones((P, P), dtype=np.float32)
|
| 589 |
+
|
| 590 |
+
# Cond crop mapping (same as training _build_fine_inputs)
|
| 591 |
+
y0c = (y0 * hc4) // H
|
| 592 |
+
y1c = ((y1 * hc4) + H - 1) // H
|
| 593 |
+
x0c = (x0 * wc4) // W
|
| 594 |
+
x1c = ((x1 * wc4) + W - 1) // W
|
| 595 |
+
cond_sub = cond_map[:, :, y0c:y1c, x0c:x1c].float()
|
| 596 |
+
cond_patch = F.interpolate(
|
| 597 |
+
cond_sub, size=(P, P), mode="bilinear", align_corners=False
|
| 598 |
+
).squeeze(1) # 1xPxP
|
| 599 |
+
|
| 600 |
+
rgb_t = (
|
| 601 |
+
torch.from_numpy(np.transpose(patch_rgb, (2, 0, 1)))
|
| 602 |
+
.to(device)
|
| 603 |
+
.float()
|
| 604 |
+
) # 3xPxP
|
| 605 |
+
ymin_t = torch.from_numpy(ymin_patch)[None, ...].to(device).float()
|
| 606 |
+
ymax_t = torch.from_numpy(ymax_patch)[None, ...].to(device).float()
|
| 607 |
+
loc_t = torch.from_numpy(loc)[None, ...].to(device).float()
|
| 608 |
+
x_f = torch.cat(
|
| 609 |
+
[rgb_t, ymin_t, ymax_t, cond_patch, loc_t], dim=0
|
| 610 |
+
).unsqueeze(0)
|
| 611 |
+
|
| 612 |
+
with autocast(
|
| 613 |
+
device_type=device.type,
|
| 614 |
+
enabled=(device.type == "cuda" and amp_flag),
|
| 615 |
+
):
|
| 616 |
+
logits_f = model.forward_fine(x_f)
|
| 617 |
+
prob_f = torch.softmax(logits_f, dim=1)[:, 1:2]
|
| 618 |
+
prob_f_up = (
|
| 619 |
+
F.interpolate(
|
| 620 |
+
prob_f, size=(P, P), mode="bilinear", align_corners=False
|
| 621 |
+
)[0, 0]
|
| 622 |
+
.detach()
|
| 623 |
+
.cpu()
|
| 624 |
+
.numpy()
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
prob_sum[y0:y1, x0:x1] += prob_f_up
|
| 628 |
+
weight[y0:y1, x0:x1] += 1.0
|
| 629 |
+
|
| 630 |
+
prob_full = prob_sum / weight
|
| 631 |
+
pred_fine = (prob_full > prob_thresh).astype(np.uint8)
|
| 632 |
+
m_f = compute_metrics(pred_fine, mask)
|
| 633 |
for k in metrics_sum:
|
| 634 |
+
metrics_sum[k] += m_f[k]
|
| 635 |
n += 1
|
| 636 |
if n == 0:
|
| 637 |
return {k: 0.0 for k in metrics_sum}
|
| 638 |
+
out = {k: v / float(n) for k, v in metrics_sum.items()}
|
| 639 |
+
out.update(
|
| 640 |
+
{
|
| 641 |
+
"iou_coarse": coarse_sum["iou"] / float(n),
|
| 642 |
+
"f1_coarse": coarse_sum["f1"] / float(n),
|
| 643 |
+
"precision_coarse": coarse_sum["precision"] / float(n),
|
| 644 |
+
"recall_coarse": coarse_sum["recall"] / float(n),
|
| 645 |
+
}
|
| 646 |
+
)
|
| 647 |
+
return out
|
| 648 |
|
| 649 |
|
| 650 |
@torch.no_grad()
|
|
|
|
| 701 |
)[0]
|
| 702 |
zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
|
| 703 |
x_t = torch.cat([rgb_c, y_min_c, y_max_c, zeros_c, zeros_c], dim=0).unsqueeze(0)
|
| 704 |
+
with autocast(
|
| 705 |
+
device_type=device.type, enabled=(device.type == "cuda" and amp_flag)
|
| 706 |
+
):
|
| 707 |
logits_c, _ = model.forward_coarse(x_t)
|
| 708 |
prob = torch.softmax(logits_c, dim=1)[:, 1:2]
|
| 709 |
prob_up = (
|