antoine.carreaud67 commited on
Commit Β·
da8a8ce
1
Parent(s): a6df640
Update README benchmark numbers
Browse files- README.md +3 -3
- utils/bench.py +1 -2
README.md
CHANGED
|
@@ -312,9 +312,9 @@ python main.py eval --config configs/config_CRAG_aug.yaml --checkpoint weights/
|
|
| 312 |
| Model | mIoU (%) β | mF1 (%) β | mBIoU (%) β | GFLOPs β | FPS β |
|
| 313 |
|----------------------------------------|-----------:|----------:|------------:|---------:|------:|
|
| 314 |
| *RGB Baselines (official)* | | | | | |
|
| 315 |
-
| Swin-T + UPerNet | 62.01 | 75.27 | β |
|
| 316 |
-
| Swin-S + UPerNet | 61.87 | 75.11 | β |
|
| 317 |
-
| Swin-B + UPerNet | 64.05 | 76.88 | β |
|
| 318 |
| Swin-B + UPerNet (retrained) | 64.02 | 76.64 | 32.57 | β | β |
|
| 319 |
| Swin-L + UPerNet | 63.36 | 76.35 | β | 420 | 27.8 |
|
| 320 |
| *Dual-branch baselines* | | | | | |
|
|
|
|
| 312 |
| Model | mIoU (%) β | mF1 (%) β | mBIoU (%) β | GFLOPs β | FPS β |
|
| 313 |
|----------------------------------------|-----------:|----------:|------------:|---------:|------:|
|
| 314 |
| *RGB Baselines (official)* | | | | | |
|
| 315 |
+
| Swin-T + UPerNet | 62.01 | 75.27 | β | 237 | 69.2 |
|
| 316 |
+
| Swin-S + UPerNet | 61.87 | 75.11 | β | 261 | 41.5 |
|
| 317 |
+
| Swin-B + UPerNet | 64.05 | 76.88 | β | 306 | 36.3 |
|
| 318 |
| Swin-B + UPerNet (retrained) | 64.02 | 76.64 | 32.57 | β | β |
|
| 319 |
| Swin-L + UPerNet | 63.36 | 76.35 | β | 420 | 27.8 |
|
| 320 |
| *Dual-branch baselines* | | | | | |
|
utils/bench.py
CHANGED
|
@@ -4,8 +4,6 @@ import torch
|
|
| 4 |
import torch.nn as nn
|
| 5 |
from fvcore.nn import FlopCountAnalysis, flop_count_table, parameter_count_table
|
| 6 |
|
| 7 |
-
from CASWiT import CASWiT
|
| 8 |
-
|
| 9 |
try:
|
| 10 |
from transformers import UperNetForSemanticSegmentation
|
| 11 |
except ImportError:
|
|
@@ -145,6 +143,7 @@ def main():
|
|
| 145 |
B, H, W = args.batch, args.h, args.w
|
| 146 |
|
| 147 |
if args.model == "caswit":
|
|
|
|
| 148 |
base_model = CASWiT(num_head_xa=1, num_classes=15, model_name="openmmlab/upernet-swin-base").to(device).eval()
|
| 149 |
model_for_flops = ForwardForFlops(base_model, which=args.which).to(device).eval()
|
| 150 |
|
|
|
|
| 4 |
import torch.nn as nn
|
| 5 |
from fvcore.nn import FlopCountAnalysis, flop_count_table, parameter_count_table
|
| 6 |
|
|
|
|
|
|
|
| 7 |
try:
|
| 8 |
from transformers import UperNetForSemanticSegmentation
|
| 9 |
except ImportError:
|
|
|
|
| 143 |
B, H, W = args.batch, args.h, args.w
|
| 144 |
|
| 145 |
if args.model == "caswit":
|
| 146 |
+
from model.CASWiT_upernet import CASWiT
|
| 147 |
base_model = CASWiT(num_head_xa=1, num_classes=15, model_name="openmmlab/upernet-swin-base").to(device).eval()
|
| 148 |
model_for_flops = ForwardForFlops(base_model, which=args.which).to(device).eval()
|
| 149 |
|