Image Segmentation
English
antoine.carreaud67 commited on
Commit
da8a8ce
Β·
1 Parent(s): a6df640

Update README benchmark numbers

Browse files
Files changed (2) hide show
  1. README.md +3 -3
  2. utils/bench.py +1 -2
README.md CHANGED
@@ -312,9 +312,9 @@ python main.py eval --config configs/config_CRAG_aug.yaml --checkpoint weights/
312
  | Model | mIoU (%) ↑ | mF1 (%) ↑ | mBIoU (%) ↑ | GFLOPs ↓ | FPS ↑ |
313
  |----------------------------------------|-----------:|----------:|------------:|---------:|------:|
314
  | *RGB Baselines (official)* | | | | | |
315
- | Swin-T + UPerNet | 62.01 | 75.27 | – | – | – |
316
- | Swin-S + UPerNet | 61.87 | 75.11 | – | – | – |
317
- | Swin-B + UPerNet | 64.05 | 76.88 | – | – | – |
318
  | Swin-B + UPerNet (retrained) | 64.02 | 76.64 | 32.57 | – | – |
319
  | Swin-L + UPerNet | 63.36 | 76.35 | – | 420 | 27.8 |
320
  | *Dual-branch baselines* | | | | | |
 
312
  | Model | mIoU (%) ↑ | mF1 (%) ↑ | mBIoU (%) ↑ | GFLOPs ↓ | FPS ↑ |
313
  |----------------------------------------|-----------:|----------:|------------:|---------:|------:|
314
  | *RGB Baselines (official)* | | | | | |
315
+ | Swin-T + UPerNet | 62.01 | 75.27 | – | 237 | 69.2 |
316
+ | Swin-S + UPerNet | 61.87 | 75.11 | – | 261 | 41.5 |
317
+ | Swin-B + UPerNet | 64.05 | 76.88 | – | 306 | 36.3 |
318
  | Swin-B + UPerNet (retrained) | 64.02 | 76.64 | 32.57 | – | – |
319
  | Swin-L + UPerNet | 63.36 | 76.35 | – | 420 | 27.8 |
320
  | *Dual-branch baselines* | | | | | |
utils/bench.py CHANGED
@@ -4,8 +4,6 @@ import torch
4
  import torch.nn as nn
5
  from fvcore.nn import FlopCountAnalysis, flop_count_table, parameter_count_table
6
 
7
- from CASWiT import CASWiT
8
-
9
  try:
10
  from transformers import UperNetForSemanticSegmentation
11
  except ImportError:
@@ -145,6 +143,7 @@ def main():
145
  B, H, W = args.batch, args.h, args.w
146
 
147
  if args.model == "caswit":
 
148
  base_model = CASWiT(num_head_xa=1, num_classes=15, model_name="openmmlab/upernet-swin-base").to(device).eval()
149
  model_for_flops = ForwardForFlops(base_model, which=args.which).to(device).eval()
150
 
 
4
  import torch.nn as nn
5
  from fvcore.nn import FlopCountAnalysis, flop_count_table, parameter_count_table
6
 
 
 
7
  try:
8
  from transformers import UperNetForSemanticSegmentation
9
  except ImportError:
 
143
  B, H, W = args.batch, args.h, args.w
144
 
145
  if args.model == "caswit":
146
+ from model.CASWiT_upernet import CASWiT
147
  base_model = CASWiT(num_head_xa=1, num_classes=15, model_name="openmmlab/upernet-swin-base").to(device).eval()
148
  model_for_flops = ForwardForFlops(base_model, which=args.which).to(device).eval()
149