Image Segmentation
English
CASWiT / utils /bench.py
antoine.carreaud67
Update README benchmark numbers
da8a8ce
import time
import argparse
import torch
import torch.nn as nn
from fvcore.nn import FlopCountAnalysis, flop_count_table, parameter_count_table
try:
from transformers import UperNetForSemanticSegmentation
except ImportError:
UperNetForSemanticSegmentation = None
class ForwardForFlops(torch.nn.Module):
def __init__(self, model: torch.nn.Module, which: str = "logits_hr"):
super().__init__()
self.model = model
self.which = which
def forward(self, x_hr: torch.Tensor, x_lr: torch.Tensor) -> torch.Tensor:
out = self.model(x_hr, x_lr)
return out[self.which] # Tensor
class ForwardForFlopsSingle(torch.nn.Module):
def __init__(self, model: torch.nn.Module):
super().__init__()
self.model = model
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.model(x)
if isinstance(out, dict):
if "logits" in out:
return out["logits"]
for v in out.values():
if torch.is_tensor(v):
return v
raise RuntimeError("Dict output withotu tensor.")
return out
class TinySegNet(nn.Module):
def __init__(self, num_classes: int = 19):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, num_classes, 1),
nn.Upsample(scale_factor=4, mode="bilinear", align_corners=False),
)
def forward(self, x):
return self.net(x)
class UperNetSwinBaseOnly(nn.Module):
def __init__(self, model_name: str = "openmmlab/upernet-swin-large"):
super().__init__()
if UperNetForSemanticSegmentation is None:
raise ImportError("transformers n'est pas installé. pip install transformers")
self.m = UperNetForSemanticSegmentation.from_pretrained(model_name)
def forward(self, x):
out = self.m(pixel_values=x)
return out.logits
@torch.no_grad()
def benchmark_fps(model, make_inputs_fn, iters=100, warmup=20, amp="off"):
"""
amp: "off" | "fp16" | "bf16"
"""
device = next(model.parameters()).device
is_cuda = (device.type == "cuda")
if amp == "fp16":
amp_dtype = torch.float16
elif amp == "bf16":
amp_dtype = torch.bfloat16
else:
amp_dtype = None
model.eval()
# warmup
for _ in range(warmup):
inputs = make_inputs_fn()
if is_cuda and amp_dtype is not None:
with torch.cuda.amp.autocast(dtype=amp_dtype):
_ = model(*inputs) if isinstance(inputs, tuple) else model(inputs)
else:
_ = model(*inputs) if isinstance(inputs, tuple) else model(inputs)
if is_cuda:
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(iters):
inputs = make_inputs_fn()
if is_cuda and amp_dtype is not None:
with torch.cuda.amp.autocast(dtype=amp_dtype):
_ = model(*inputs) if isinstance(inputs, tuple) else model(inputs)
else:
_ = model(*inputs) if isinstance(inputs, tuple) else model(inputs)
if is_cuda:
torch.cuda.synchronize()
t1 = time.perf_counter()
elapsed = t1 - t0
fps = iters / elapsed
return fps, elapsed
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="caswit", choices=["caswit", "upernet", "tiny"])
parser.add_argument("--upernet_name", type=str, default="openmmlab/upernet-swin-large")
parser.add_argument("--which", type=str, default="logits_hr", choices=["logits_hr", "logits_lr"])
parser.add_argument("--batch", type=int, default=1)
parser.add_argument("--h", type=int, default=512)
parser.add_argument("--w", type=int, default=512)
parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"])
parser.add_argument("--iters", type=int, default=100)
parser.add_argument("--warmup", type=int, default=100)
parser.add_argument("--amp", type=str, default="off", choices=["off", "fp16", "bf16"])
parser.add_argument("--max_depth", type=int, default=4)
args = parser.parse_args()
device = args.device
if device == "cuda" and not torch.cuda.is_available():
print("CUDA not available -> CPU")
device = "cpu"
if device == "cuda":
torch.backends.cudnn.benchmark = True
B, H, W = args.batch, args.h, args.w
if args.model == "caswit":
from model.CASWiT_upernet import CASWiT
base_model = CASWiT(num_head_xa=1, num_classes=15, model_name="openmmlab/upernet-swin-base").to(device).eval()
model_for_flops = ForwardForFlops(base_model, which=args.which).to(device).eval()
def make_inputs():
x_hr = torch.randn(B, 3, H, W, device=device)
x_lr = torch.randn(B, 3, H, W, device=device)
return (x_hr, x_lr)
inputs = make_inputs()
model_name = f"CASWiT ({args.which})"
elif args.model == "upernet":
base_model = UperNetSwinBaseOnly(model_name=args.upernet_name).to(device).eval()
model_for_flops = base_model
def make_inputs():
x = torch.randn(B, 3, H, W, device=device)
return x
inputs = make_inputs()
model_name = args.upernet_name
else: # tiny
base_model = TinySegNet(num_classes=19).to(device).eval()
model_for_flops = base_model
def make_inputs():
x = torch.randn(B, 3, H, W, device=device)
return x
inputs = make_inputs()
model_name = "TinySegNet"
# ---- Params ----
print(f"\nModel: {model_name}")
print(parameter_count_table(model_for_flops))
# ---- FLOPs/GFLOPs via fvcore ----
# inputs must be a tuple for FlopCountAnalysis
flops = FlopCountAnalysis(model_for_flops, inputs if isinstance(inputs, tuple) else (inputs,))
total_flops = flops.total()
gflops = total_flops / 1e9
print(f"\nTotal FLOPs: {total_flops:.3e}")
print(f"Total GFLOPs (@B={B}): {gflops:.3f}")
# ---- details per modules ----
print("\n" + flop_count_table(flops, max_depth=args.max_depth))
# ---- FPS benchmark ----
with torch.inference_mode():
fps, elapsed = benchmark_fps(
model_for_flops,
make_inputs_fn=lambda: (make_inputs() if isinstance(make_inputs(), tuple) else make_inputs()),
iters=args.iters,
warmup=args.warmup,
amp=args.amp
)
print(f"\nSpeed: {fps:.2f} FPS (iters={args.iters}, warmup={args.warmup}, amp={args.amp}, device={device})")
print(f"Total timed: {elapsed:.3f} s\n")
if __name__ == "__main__":
main()