import time import argparse import torch import torch.nn as nn from fvcore.nn import FlopCountAnalysis, flop_count_table, parameter_count_table try: from transformers import UperNetForSemanticSegmentation except ImportError: UperNetForSemanticSegmentation = None class ForwardForFlops(torch.nn.Module): def __init__(self, model: torch.nn.Module, which: str = "logits_hr"): super().__init__() self.model = model self.which = which def forward(self, x_hr: torch.Tensor, x_lr: torch.Tensor) -> torch.Tensor: out = self.model(x_hr, x_lr) return out[self.which] # Tensor class ForwardForFlopsSingle(torch.nn.Module): def __init__(self, model: torch.nn.Module): super().__init__() self.model = model def forward(self, x: torch.Tensor) -> torch.Tensor: out = self.model(x) if isinstance(out, dict): if "logits" in out: return out["logits"] for v in out.values(): if torch.is_tensor(v): return v raise RuntimeError("Dict output withotu tensor.") return out class TinySegNet(nn.Module): def __init__(self, num_classes: int = 19): super().__init__() self.net = nn.Sequential( nn.Conv2d(3, 32, 3, padding=1, bias=False), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.Conv2d(32, 64, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, num_classes, 1), nn.Upsample(scale_factor=4, mode="bilinear", align_corners=False), ) def forward(self, x): return self.net(x) class UperNetSwinBaseOnly(nn.Module): def __init__(self, model_name: str = "openmmlab/upernet-swin-large"): super().__init__() if UperNetForSemanticSegmentation is None: raise ImportError("transformers n'est pas installé. pip install transformers") self.m = UperNetForSemanticSegmentation.from_pretrained(model_name) def forward(self, x): out = self.m(pixel_values=x) return out.logits @torch.no_grad() def benchmark_fps(model, make_inputs_fn, iters=100, warmup=20, amp="off"): """ amp: "off" | "fp16" | "bf16" """ device = next(model.parameters()).device is_cuda = (device.type == "cuda") if amp == "fp16": amp_dtype = torch.float16 elif amp == "bf16": amp_dtype = torch.bfloat16 else: amp_dtype = None model.eval() # warmup for _ in range(warmup): inputs = make_inputs_fn() if is_cuda and amp_dtype is not None: with torch.cuda.amp.autocast(dtype=amp_dtype): _ = model(*inputs) if isinstance(inputs, tuple) else model(inputs) else: _ = model(*inputs) if isinstance(inputs, tuple) else model(inputs) if is_cuda: torch.cuda.synchronize() t0 = time.perf_counter() for _ in range(iters): inputs = make_inputs_fn() if is_cuda and amp_dtype is not None: with torch.cuda.amp.autocast(dtype=amp_dtype): _ = model(*inputs) if isinstance(inputs, tuple) else model(inputs) else: _ = model(*inputs) if isinstance(inputs, tuple) else model(inputs) if is_cuda: torch.cuda.synchronize() t1 = time.perf_counter() elapsed = t1 - t0 fps = iters / elapsed return fps, elapsed def main(): parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="caswit", choices=["caswit", "upernet", "tiny"]) parser.add_argument("--upernet_name", type=str, default="openmmlab/upernet-swin-large") parser.add_argument("--which", type=str, default="logits_hr", choices=["logits_hr", "logits_lr"]) parser.add_argument("--batch", type=int, default=1) parser.add_argument("--h", type=int, default=512) parser.add_argument("--w", type=int, default=512) parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"]) parser.add_argument("--iters", type=int, default=100) parser.add_argument("--warmup", type=int, default=100) parser.add_argument("--amp", type=str, default="off", choices=["off", "fp16", "bf16"]) parser.add_argument("--max_depth", type=int, default=4) args = parser.parse_args() device = args.device if device == "cuda" and not torch.cuda.is_available(): print("CUDA not available -> CPU") device = "cpu" if device == "cuda": torch.backends.cudnn.benchmark = True B, H, W = args.batch, args.h, args.w if args.model == "caswit": from model.CASWiT_upernet import CASWiT base_model = CASWiT(num_head_xa=1, num_classes=15, model_name="openmmlab/upernet-swin-base").to(device).eval() model_for_flops = ForwardForFlops(base_model, which=args.which).to(device).eval() def make_inputs(): x_hr = torch.randn(B, 3, H, W, device=device) x_lr = torch.randn(B, 3, H, W, device=device) return (x_hr, x_lr) inputs = make_inputs() model_name = f"CASWiT ({args.which})" elif args.model == "upernet": base_model = UperNetSwinBaseOnly(model_name=args.upernet_name).to(device).eval() model_for_flops = base_model def make_inputs(): x = torch.randn(B, 3, H, W, device=device) return x inputs = make_inputs() model_name = args.upernet_name else: # tiny base_model = TinySegNet(num_classes=19).to(device).eval() model_for_flops = base_model def make_inputs(): x = torch.randn(B, 3, H, W, device=device) return x inputs = make_inputs() model_name = "TinySegNet" # ---- Params ---- print(f"\nModel: {model_name}") print(parameter_count_table(model_for_flops)) # ---- FLOPs/GFLOPs via fvcore ---- # inputs must be a tuple for FlopCountAnalysis flops = FlopCountAnalysis(model_for_flops, inputs if isinstance(inputs, tuple) else (inputs,)) total_flops = flops.total() gflops = total_flops / 1e9 print(f"\nTotal FLOPs: {total_flops:.3e}") print(f"Total GFLOPs (@B={B}): {gflops:.3f}") # ---- details per modules ---- print("\n" + flop_count_table(flops, max_depth=args.max_depth)) # ---- FPS benchmark ---- with torch.inference_mode(): fps, elapsed = benchmark_fps( model_for_flops, make_inputs_fn=lambda: (make_inputs() if isinstance(make_inputs(), tuple) else make_inputs()), iters=args.iters, warmup=args.warmup, amp=args.amp ) print(f"\nSpeed: {fps:.2f} FPS (iters={args.iters}, warmup={args.warmup}, amp={args.amp}, device={device})") print(f"Total timed: {elapsed:.3f} s\n") if __name__ == "__main__": main()