| import time |
| import argparse |
| import torch |
| import torch.nn as nn |
| from fvcore.nn import FlopCountAnalysis, flop_count_table, parameter_count_table |
|
|
| try: |
| from transformers import UperNetForSemanticSegmentation |
| except ImportError: |
| UperNetForSemanticSegmentation = None |
|
|
| class ForwardForFlops(torch.nn.Module): |
| def __init__(self, model: torch.nn.Module, which: str = "logits_hr"): |
| super().__init__() |
| self.model = model |
| self.which = which |
|
|
| def forward(self, x_hr: torch.Tensor, x_lr: torch.Tensor) -> torch.Tensor: |
| out = self.model(x_hr, x_lr) |
| return out[self.which] |
|
|
|
|
| class ForwardForFlopsSingle(torch.nn.Module): |
| def __init__(self, model: torch.nn.Module): |
| super().__init__() |
| self.model = model |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| out = self.model(x) |
| if isinstance(out, dict): |
| if "logits" in out: |
| return out["logits"] |
| for v in out.values(): |
| if torch.is_tensor(v): |
| return v |
| raise RuntimeError("Dict output withotu tensor.") |
| return out |
|
|
| class TinySegNet(nn.Module): |
| def __init__(self, num_classes: int = 19): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Conv2d(3, 32, 3, padding=1, bias=False), |
| nn.BatchNorm2d(32), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(32, 64, 3, stride=2, padding=1, bias=False), |
| nn.BatchNorm2d(64), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False), |
| nn.BatchNorm2d(128), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(128, num_classes, 1), |
| nn.Upsample(scale_factor=4, mode="bilinear", align_corners=False), |
| ) |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
|
|
| class UperNetSwinBaseOnly(nn.Module): |
| def __init__(self, model_name: str = "openmmlab/upernet-swin-large"): |
| super().__init__() |
| if UperNetForSemanticSegmentation is None: |
| raise ImportError("transformers n'est pas installé. pip install transformers") |
| self.m = UperNetForSemanticSegmentation.from_pretrained(model_name) |
|
|
| def forward(self, x): |
| out = self.m(pixel_values=x) |
| return out.logits |
|
|
|
|
| @torch.no_grad() |
| def benchmark_fps(model, make_inputs_fn, iters=100, warmup=20, amp="off"): |
| """ |
| amp: "off" | "fp16" | "bf16" |
| """ |
| device = next(model.parameters()).device |
| is_cuda = (device.type == "cuda") |
|
|
| if amp == "fp16": |
| amp_dtype = torch.float16 |
| elif amp == "bf16": |
| amp_dtype = torch.bfloat16 |
| else: |
| amp_dtype = None |
|
|
| model.eval() |
|
|
| |
| for _ in range(warmup): |
| inputs = make_inputs_fn() |
| if is_cuda and amp_dtype is not None: |
| with torch.cuda.amp.autocast(dtype=amp_dtype): |
| _ = model(*inputs) if isinstance(inputs, tuple) else model(inputs) |
| else: |
| _ = model(*inputs) if isinstance(inputs, tuple) else model(inputs) |
|
|
| if is_cuda: |
| torch.cuda.synchronize() |
|
|
| t0 = time.perf_counter() |
|
|
| for _ in range(iters): |
| inputs = make_inputs_fn() |
| if is_cuda and amp_dtype is not None: |
| with torch.cuda.amp.autocast(dtype=amp_dtype): |
| _ = model(*inputs) if isinstance(inputs, tuple) else model(inputs) |
| else: |
| _ = model(*inputs) if isinstance(inputs, tuple) else model(inputs) |
|
|
| if is_cuda: |
| torch.cuda.synchronize() |
|
|
| t1 = time.perf_counter() |
| elapsed = t1 - t0 |
| fps = iters / elapsed |
| return fps, elapsed |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model", type=str, default="caswit", choices=["caswit", "upernet", "tiny"]) |
| parser.add_argument("--upernet_name", type=str, default="openmmlab/upernet-swin-large") |
| parser.add_argument("--which", type=str, default="logits_hr", choices=["logits_hr", "logits_lr"]) |
| parser.add_argument("--batch", type=int, default=1) |
| parser.add_argument("--h", type=int, default=512) |
| parser.add_argument("--w", type=int, default=512) |
| parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"]) |
| parser.add_argument("--iters", type=int, default=100) |
| parser.add_argument("--warmup", type=int, default=100) |
| parser.add_argument("--amp", type=str, default="off", choices=["off", "fp16", "bf16"]) |
| parser.add_argument("--max_depth", type=int, default=4) |
| args = parser.parse_args() |
|
|
| device = args.device |
| if device == "cuda" and not torch.cuda.is_available(): |
| print("CUDA not available -> CPU") |
| device = "cpu" |
|
|
| if device == "cuda": |
| torch.backends.cudnn.benchmark = True |
|
|
| B, H, W = args.batch, args.h, args.w |
|
|
| if args.model == "caswit": |
| from model.CASWiT_upernet import CASWiT |
| base_model = CASWiT(num_head_xa=1, num_classes=15, model_name="openmmlab/upernet-swin-base").to(device).eval() |
| model_for_flops = ForwardForFlops(base_model, which=args.which).to(device).eval() |
|
|
| def make_inputs(): |
| x_hr = torch.randn(B, 3, H, W, device=device) |
| x_lr = torch.randn(B, 3, H, W, device=device) |
| return (x_hr, x_lr) |
|
|
| inputs = make_inputs() |
| model_name = f"CASWiT ({args.which})" |
|
|
| elif args.model == "upernet": |
| base_model = UperNetSwinBaseOnly(model_name=args.upernet_name).to(device).eval() |
| model_for_flops = base_model |
|
|
| def make_inputs(): |
| x = torch.randn(B, 3, H, W, device=device) |
| return x |
|
|
| inputs = make_inputs() |
| model_name = args.upernet_name |
|
|
| else: |
| base_model = TinySegNet(num_classes=19).to(device).eval() |
| model_for_flops = base_model |
|
|
| def make_inputs(): |
| x = torch.randn(B, 3, H, W, device=device) |
| return x |
|
|
| inputs = make_inputs() |
| model_name = "TinySegNet" |
|
|
| |
| print(f"\nModel: {model_name}") |
| print(parameter_count_table(model_for_flops)) |
|
|
| |
| |
| flops = FlopCountAnalysis(model_for_flops, inputs if isinstance(inputs, tuple) else (inputs,)) |
| total_flops = flops.total() |
| gflops = total_flops / 1e9 |
| print(f"\nTotal FLOPs: {total_flops:.3e}") |
| print(f"Total GFLOPs (@B={B}): {gflops:.3f}") |
|
|
| |
| print("\n" + flop_count_table(flops, max_depth=args.max_depth)) |
|
|
| |
| with torch.inference_mode(): |
| fps, elapsed = benchmark_fps( |
| model_for_flops, |
| make_inputs_fn=lambda: (make_inputs() if isinstance(make_inputs(), tuple) else make_inputs()), |
| iters=args.iters, |
| warmup=args.warmup, |
| amp=args.amp |
| ) |
| print(f"\nSpeed: {fps:.2f} FPS (iters={args.iters}, warmup={args.warmup}, amp={args.amp}, device={device})") |
| print(f"Total timed: {elapsed:.3f} s\n") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|