Spaces:
Runtime error
Runtime error
HubHop commited on
Commit ·
bcfa144
1
Parent(s): c8ed6d7
update
Browse files- .idea/.gitignore +8 -0
- __pycache__/datasets.cpython-39.pyc +0 -0
- __pycache__/models_v2.cpython-39.pyc +0 -0
- __pycache__/snnet.cpython-39.pyc +0 -0
- __pycache__/utils.cpython-39.pyc +0 -0
- app.py +411 -4
- datasets.py +109 -0
- demo.jpg +0 -0
- flops_gradio_demo.json +136 -0
- gradio_banner.png +0 -0
- gradio_demo.json +33 -0
- models_v2.py +568 -0
- outputs/deit/20240118_171921.log +1 -0
- outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172124.log +2 -0
- outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172140.log +2 -0
- outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172156.log +5 -0
- outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172250.log +5 -0
- outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172309.log +5 -0
- outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172332.log +5 -0
- requirements.txt +3 -0
- snnet.py +473 -0
- snnetv2_deit3_s_l.pth +3 -0
- stitches_res_s_l.txt +134 -0
- utils.py +408 -0
.idea/.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Default ignored files
|
| 2 |
+
/shelf/
|
| 3 |
+
/workspace.xml
|
| 4 |
+
# Editor-based HTTP Client requests
|
| 5 |
+
/httpRequests/
|
| 6 |
+
# Datasource local storage ignored files
|
| 7 |
+
/dataSources/
|
| 8 |
+
/dataSources.local.xml
|
__pycache__/datasets.cpython-39.pyc
ADDED
|
Binary file (2.97 kB). View file
|
|
|
__pycache__/models_v2.cpython-39.pyc
ADDED
|
Binary file (17.5 kB). View file
|
|
|
__pycache__/snnet.cpython-39.pyc
ADDED
|
Binary file (13.6 kB). View file
|
|
|
__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
app.py
CHANGED
|
@@ -1,7 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
return "Hello " + name + "!!"
|
| 5 |
|
| 6 |
-
|
| 7 |
-
iface.launch()
|
|
|
|
| 1 |
+
# Copyright (c) 2015-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
import argparse
|
| 4 |
+
import datetime
|
| 5 |
+
import numpy as np
|
| 6 |
+
import time
|
| 7 |
+
import torch
|
| 8 |
+
import torch.backends.cudnn as cudnn
|
| 9 |
+
import json
|
| 10 |
+
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from utils import get_root_logger
|
| 13 |
+
from timm.models import create_model
|
| 14 |
+
import models_v2
|
| 15 |
+
import requests
|
| 16 |
+
|
| 17 |
+
import utils
|
| 18 |
+
import time
|
| 19 |
+
import sys
|
| 20 |
+
import datetime
|
| 21 |
+
import os
|
| 22 |
+
from snnet import SNNet, SNNetv2
|
| 23 |
+
import warnings
|
| 24 |
+
|
| 25 |
+
warnings.filterwarnings("ignore")
|
| 26 |
+
from fvcore.nn import FlopCountAnalysis
|
| 27 |
+
|
| 28 |
+
from PIL import Image
|
| 29 |
import gradio as gr
|
| 30 |
+
import plotly.express as px
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_args_parser():
|
| 34 |
+
parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
|
| 35 |
+
parser.add_argument('--batch-size', default=64, type=int)
|
| 36 |
+
parser.add_argument('--epochs', default=300, type=int)
|
| 37 |
+
parser.add_argument('--bce-loss', action='store_true')
|
| 38 |
+
parser.add_argument('--unscale-lr', action='store_true')
|
| 39 |
+
|
| 40 |
+
# Model parameters
|
| 41 |
+
parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL',
|
| 42 |
+
help='Name of model to train')
|
| 43 |
+
parser.add_argument('--input-size', default=224, type=int, help='images input size')
|
| 44 |
+
|
| 45 |
+
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
|
| 46 |
+
help='Dropout rate (default: 0.)')
|
| 47 |
+
parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
|
| 48 |
+
help='Drop path rate (default: 0.1)')
|
| 49 |
+
|
| 50 |
+
parser.add_argument('--model-ema', action='store_true')
|
| 51 |
+
parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
|
| 52 |
+
parser.set_defaults(model_ema=True)
|
| 53 |
+
parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')
|
| 54 |
+
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
|
| 55 |
+
|
| 56 |
+
# Optimizer parameters
|
| 57 |
+
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
|
| 58 |
+
help='Optimizer (default: "adamw"')
|
| 59 |
+
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
|
| 60 |
+
help='Optimizer Epsilon (default: 1e-8)')
|
| 61 |
+
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
|
| 62 |
+
help='Optimizer Betas (default: None, use opt default)')
|
| 63 |
+
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
|
| 64 |
+
help='Clip gradient norm (default: None, no clipping)')
|
| 65 |
+
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
| 66 |
+
help='SGD momentum (default: 0.9)')
|
| 67 |
+
parser.add_argument('--weight-decay', type=float, default=0.05,
|
| 68 |
+
help='weight decay (default: 0.05)')
|
| 69 |
+
# Learning rate schedule parameters
|
| 70 |
+
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
|
| 71 |
+
help='LR scheduler (default: "cosine"')
|
| 72 |
+
parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
|
| 73 |
+
help='learning rate (default: 5e-4)')
|
| 74 |
+
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
|
| 75 |
+
help='learning rate noise on/off epoch percentages')
|
| 76 |
+
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
|
| 77 |
+
help='learning rate noise limit percent (default: 0.67)')
|
| 78 |
+
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
|
| 79 |
+
help='learning rate noise std-dev (default: 1.0)')
|
| 80 |
+
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
|
| 81 |
+
help='warmup learning rate (default: 1e-6)')
|
| 82 |
+
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
|
| 83 |
+
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
|
| 84 |
+
|
| 85 |
+
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
|
| 86 |
+
help='epoch interval to decay LR')
|
| 87 |
+
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
|
| 88 |
+
help='epochs to warmup LR, if scheduler supports')
|
| 89 |
+
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
|
| 90 |
+
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
|
| 91 |
+
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
|
| 92 |
+
help='patience epochs for Plateau LR scheduler (default: 10')
|
| 93 |
+
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
|
| 94 |
+
help='LR decay rate (default: 0.1)')
|
| 95 |
+
|
| 96 |
+
# Augmentation parameters
|
| 97 |
+
parser.add_argument('--color-jitter', type=float, default=0.3, metavar='PCT',
|
| 98 |
+
help='Color jitter factor (default: 0.3)')
|
| 99 |
+
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
|
| 100 |
+
help='Use AutoAugment policy. "v0" or "original". " + \
|
| 101 |
+
"(default: rand-m9-mstd0.5-inc1)'),
|
| 102 |
+
parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
|
| 103 |
+
parser.add_argument('--train-interpolation', type=str, default='bicubic',
|
| 104 |
+
help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
|
| 105 |
+
|
| 106 |
+
parser.add_argument('--repeated-aug', action='store_true')
|
| 107 |
+
parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
|
| 108 |
+
parser.set_defaults(repeated_aug=True)
|
| 109 |
+
|
| 110 |
+
parser.add_argument('--train-mode', action='store_true')
|
| 111 |
+
parser.add_argument('--no-train-mode', action='store_false', dest='train_mode')
|
| 112 |
+
parser.set_defaults(train_mode=True)
|
| 113 |
+
|
| 114 |
+
parser.add_argument('--ThreeAugment', action='store_true') # 3augment
|
| 115 |
+
|
| 116 |
+
parser.add_argument('--src', action='store_true') # simple random crop
|
| 117 |
+
|
| 118 |
+
# * Random Erase params
|
| 119 |
+
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
|
| 120 |
+
help='Random erase prob (default: 0.25)')
|
| 121 |
+
parser.add_argument('--remode', type=str, default='pixel',
|
| 122 |
+
help='Random erase mode (default: "pixel")')
|
| 123 |
+
parser.add_argument('--recount', type=int, default=1,
|
| 124 |
+
help='Random erase count (default: 1)')
|
| 125 |
+
parser.add_argument('--resplit', action='store_true', default=False,
|
| 126 |
+
help='Do not random erase first (clean) augmentation split')
|
| 127 |
+
|
| 128 |
+
# * Mixup params
|
| 129 |
+
parser.add_argument('--mixup', type=float, default=0.8,
|
| 130 |
+
help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
|
| 131 |
+
parser.add_argument('--cutmix', type=float, default=1.0,
|
| 132 |
+
help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
|
| 133 |
+
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
|
| 134 |
+
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
|
| 135 |
+
parser.add_argument('--mixup-prob', type=float, default=1.0,
|
| 136 |
+
help='Probability of performing mixup or cutmix when either/both is enabled')
|
| 137 |
+
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
|
| 138 |
+
help='Probability of switching to cutmix when both mixup and cutmix enabled')
|
| 139 |
+
parser.add_argument('--mixup-mode', type=str, default='batch',
|
| 140 |
+
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
|
| 141 |
+
|
| 142 |
+
# Distillation parameters
|
| 143 |
+
parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL',
|
| 144 |
+
help='Name of teacher model to train (default: "regnety_160"')
|
| 145 |
+
parser.add_argument('--teacher-path', type=str, default='')
|
| 146 |
+
parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
|
| 147 |
+
parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
|
| 148 |
+
parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
|
| 149 |
+
|
| 150 |
+
# * Finetuning params
|
| 151 |
+
parser.add_argument('--finetune', default='', help='finetune from checkpoint')
|
| 152 |
+
parser.add_argument('--attn-only', action='store_true')
|
| 153 |
+
|
| 154 |
+
# Dataset parameters
|
| 155 |
+
parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str,
|
| 156 |
+
help='dataset path')
|
| 157 |
+
parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'],
|
| 158 |
+
type=str, help='Image Net dataset path')
|
| 159 |
+
parser.add_argument('--inat-category', default='name',
|
| 160 |
+
choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
|
| 161 |
+
type=str, help='semantic granularity')
|
| 162 |
+
|
| 163 |
+
parser.add_argument('--output_dir', default='',
|
| 164 |
+
help='path where to save, empty for no saving')
|
| 165 |
+
parser.add_argument('--device', default='cpu',
|
| 166 |
+
help='device to use for training / testing')
|
| 167 |
+
parser.add_argument('--seed', default=0, type=int)
|
| 168 |
+
parser.add_argument('--resume', default='', help='resume from checkpoint')
|
| 169 |
+
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
| 170 |
+
help='start epoch')
|
| 171 |
+
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
|
| 172 |
+
parser.add_argument('--eval-crop-ratio', default=0.875, type=float, help="Crop ratio for evaluation")
|
| 173 |
+
parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
|
| 174 |
+
parser.add_argument('--num_workers', default=10, type=int)
|
| 175 |
+
parser.add_argument('--pin-mem', action='store_true',
|
| 176 |
+
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
| 177 |
+
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
|
| 178 |
+
help='')
|
| 179 |
+
parser.set_defaults(pin_mem=True)
|
| 180 |
+
|
| 181 |
+
# distributed training parameters
|
| 182 |
+
parser.add_argument('--world_size', default=1, type=int,
|
| 183 |
+
help='number of distributed processes')
|
| 184 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
| 185 |
+
|
| 186 |
+
parser.add_argument('--exp_name', default='deit', type=str, help='experiment name')
|
| 187 |
+
parser.add_argument('--config', default=None, type=str, help='configuration')
|
| 188 |
+
parser.add_argument('--scoring', action='store_true', default=False, help='configuration')
|
| 189 |
+
parser.add_argument('--proxy', default='synflow', type=str, help='configuration')
|
| 190 |
+
parser.add_argument('--snnet_name', default='snnetv2', type=str, help='configuration')
|
| 191 |
+
parser.add_argument('--get_flops', action='store_true')
|
| 192 |
+
parser.add_argument('--flops_sampling_k', default=None, type=float, help="Crop ratio for evaluation")
|
| 193 |
+
parser.add_argument('--low_rank', action='store_true', default=False, help='Enabling distributed evaluation')
|
| 194 |
+
parser.add_argument('--lora_r', default=64, type=int,
|
| 195 |
+
help='number of distributed processes')
|
| 196 |
+
parser.add_argument('--flops_gap', default=1.0, type=float,
|
| 197 |
+
help='number of distributed processes')
|
| 198 |
+
|
| 199 |
+
return parser
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def initialize_model_stitching_layer(model, mixup_fn, data_loader, device):
|
| 203 |
+
for samples, targets in data_loader:
|
| 204 |
+
samples = samples.to(device, non_blocking=True)
|
| 205 |
+
targets = targets.to(device, non_blocking=True)
|
| 206 |
+
|
| 207 |
+
if mixup_fn is not None:
|
| 208 |
+
samples, targets = mixup_fn(samples, targets)
|
| 209 |
+
|
| 210 |
+
with torch.cuda.amp.autocast():
|
| 211 |
+
model.initialize_stitching_weights(samples)
|
| 212 |
+
|
| 213 |
+
break
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
@torch.no_grad()
|
| 217 |
+
def analyse_flops_for_all(model, config_name):
|
| 218 |
+
all_cfgs = model.all_cfgs
|
| 219 |
+
stitch_results = {}
|
| 220 |
+
|
| 221 |
+
for cfg_id in all_cfgs:
|
| 222 |
+
model.reset_stitch_id(cfg_id)
|
| 223 |
+
flops = FlopCountAnalysis(model, torch.randn(1, 3, 224, 224).cuda()).total()
|
| 224 |
+
stitch_results[cfg_id] = flops
|
| 225 |
+
|
| 226 |
+
save_dir = './model_flops'
|
| 227 |
+
if not os.path.exists(save_dir):
|
| 228 |
+
os.mkdir(save_dir)
|
| 229 |
+
|
| 230 |
+
with open(os.path.join(save_dir, f'flops_{config_name}.json'), 'w+') as f:
|
| 231 |
+
json.dump(stitch_results, f, indent=4)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def main(args):
|
| 235 |
+
utils.init_distributed_mode(args)
|
| 236 |
+
|
| 237 |
+
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
| 238 |
+
logger = get_root_logger(os.path.join(args.output_dir, f'{timestamp}.log'))
|
| 239 |
+
|
| 240 |
+
logger.info(str(args))
|
| 241 |
+
|
| 242 |
+
if args.distillation_type != 'none' and args.finetune and not args.eval:
|
| 243 |
+
raise NotImplementedError("Finetuning with distillation not yet supported")
|
| 244 |
+
|
| 245 |
+
device = torch.device(args.device)
|
| 246 |
+
|
| 247 |
+
# fix the seed for reproducibility
|
| 248 |
+
seed = args.seed + utils.get_rank()
|
| 249 |
+
torch.manual_seed(seed)
|
| 250 |
+
np.random.seed(seed)
|
| 251 |
+
# random.seed(seed)
|
| 252 |
+
|
| 253 |
+
cudnn.benchmark = True
|
| 254 |
+
|
| 255 |
+
from datasets import build_transform
|
| 256 |
+
|
| 257 |
+
transform = build_transform(False, args)
|
| 258 |
+
|
| 259 |
+
anchors = []
|
| 260 |
+
for i, anchor_name in enumerate(args.anchors):
|
| 261 |
+
logger.info(f"Creating model: {anchor_name}")
|
| 262 |
+
anchor = create_model(
|
| 263 |
+
anchor_name,
|
| 264 |
+
pretrained=False,
|
| 265 |
+
pretrained_deit=None,
|
| 266 |
+
num_classes=1000,
|
| 267 |
+
drop_path_rate=args.anchor_drop_path[i],
|
| 268 |
+
img_size=args.input_size
|
| 269 |
+
)
|
| 270 |
+
anchors.append(anchor)
|
| 271 |
+
|
| 272 |
+
model = SNNetv2(anchors, lora_r=args.lora_r)
|
| 273 |
+
|
| 274 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
| 275 |
+
# torch.save({'model': checkpoint['model']}, './snnetv2_deit3_s_l_50ep.pth')
|
| 276 |
+
|
| 277 |
+
logger.info(f"load checkpoint from {args.resume}")
|
| 278 |
+
model.load_state_dict(checkpoint['model'])
|
| 279 |
+
|
| 280 |
+
model.to(device)
|
| 281 |
+
|
| 282 |
+
config_name = args.config.split('/')[-1].split('.')[0]
|
| 283 |
+
model.eval()
|
| 284 |
+
|
| 285 |
+
eval_res = {}
|
| 286 |
+
flops_res = {}
|
| 287 |
+
with open('stitches_res_s_l.txt', 'r') as f:
|
| 288 |
+
for line in f.readlines():
|
| 289 |
+
epoch_stat = json.loads(line.strip())
|
| 290 |
+
eval_res[epoch_stat['cfg_id']] = epoch_stat['acc1']
|
| 291 |
+
flops_res[epoch_stat['cfg_id']] = epoch_stat['flops'] / 1e9
|
| 292 |
+
|
| 293 |
+
def visualize_stitch_pos(stitch_id):
|
| 294 |
+
if stitch_id == 13:
|
| 295 |
+
# 13 is equivalent to 0
|
| 296 |
+
stitch_id = 0
|
| 297 |
+
|
| 298 |
+
names = [f'ID {key}' for key in flops_res.keys()]
|
| 299 |
+
|
| 300 |
+
fig = px.scatter(x=flops_res.values(), y=eval_res.values(), hover_name=names)
|
| 301 |
+
fig.update_layout(
|
| 302 |
+
title=f"SN-Netv2 - Stitch ID - {stitch_id}",
|
| 303 |
+
title_x=0.5,
|
| 304 |
+
xaxis_title="GFLOPs",
|
| 305 |
+
yaxis_title="mIoU",
|
| 306 |
+
font=dict(
|
| 307 |
+
family="Courier New, monospace",
|
| 308 |
+
size=18,
|
| 309 |
+
color="RebeccaPurple"
|
| 310 |
+
),
|
| 311 |
+
legend=dict(
|
| 312 |
+
yanchor="bottom",
|
| 313 |
+
y=0.99,
|
| 314 |
+
xanchor="left",
|
| 315 |
+
x=0.01),
|
| 316 |
+
)
|
| 317 |
+
# continent, DarkSlateGrey
|
| 318 |
+
fig.update_traces(marker=dict(size=10,
|
| 319 |
+
line=dict(width=2)),
|
| 320 |
+
selector=dict(mode='markers'))
|
| 321 |
+
|
| 322 |
+
fig.add_scatter(x=[flops_res[stitch_id]], y=[eval_res[stitch_id]], mode='markers', marker=dict(size=15),
|
| 323 |
+
name='Current Stitch')
|
| 324 |
+
return fig
|
| 325 |
+
|
| 326 |
+
# Download human-readable labels for ImageNet.
|
| 327 |
+
response = requests.get("https://git.io/JJkYN")
|
| 328 |
+
labels = response.text.split("\n")
|
| 329 |
+
|
| 330 |
+
def process_image(image, stitch_id):
|
| 331 |
+
# inp = torch.from_numpy(image).permute(2, 0, 1).float()
|
| 332 |
+
inp = transform(image).unsqueeze(0).to(device)
|
| 333 |
+
model.reset_stitch_id(stitch_id)
|
| 334 |
+
with torch.no_grad():
|
| 335 |
+
prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
|
| 336 |
+
confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
|
| 337 |
+
fig = visualize_stitch_pos(stitch_id)
|
| 338 |
+
return confidences, fig
|
| 339 |
+
|
| 340 |
+
with gr.Blocks() as main_page:
|
| 341 |
+
with gr.Column():
|
| 342 |
+
gr.HTML("""
|
| 343 |
+
<h1 align="center" style=" display: flex; flex-direction: row; justify-content: center; font-size: 25pt; ">Stitched ViTs are Flexible Vision Backbones</h1>
|
| 344 |
+
<div align="center"> <img align="center" src='file/gradio_banner.png' width="70%"> </div>
|
| 345 |
+
<h3 align="center" >This is the classification demo page of SN-Netv2, an flexible vision backbone that allows for 100+ runtime speed and performance trade-offs.</h3>
|
| 346 |
+
<h3 align="center" >You can also run this gradio demo on your local GPUs at https://github.com/ziplab/SN-Netv2</h3>
|
| 347 |
+
""")
|
| 348 |
+
with gr.Row():
|
| 349 |
+
with gr.Column():
|
| 350 |
+
image_input = gr.Image(type='pil')
|
| 351 |
+
stitch_slider = gr.Slider(minimum=0, maximum=134, step=1, label="Stitch ID")
|
| 352 |
+
with gr.Row():
|
| 353 |
+
clear_button = gr.ClearButton()
|
| 354 |
+
submit_button = gr.Button()
|
| 355 |
+
with gr.Column():
|
| 356 |
+
label_output = gr.Label(num_top_classes=5)
|
| 357 |
+
stitch_plot = gr.Plot(label='Stitch Position')
|
| 358 |
+
|
| 359 |
+
submit_button.click(
|
| 360 |
+
fn=process_image,
|
| 361 |
+
inputs=[image_input, stitch_slider],
|
| 362 |
+
outputs=[label_output, stitch_plot],
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
stitch_slider.change(
|
| 366 |
+
fn=visualize_stitch_pos,
|
| 367 |
+
inputs=[stitch_slider],
|
| 368 |
+
outputs=[stitch_plot],
|
| 369 |
+
show_progress=False
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
clear_button.click(
|
| 373 |
+
lambda: [None, 0, None, None],
|
| 374 |
+
outputs=[image_input, stitch_slider, label_output, stitch_plot],
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
gr.Examples(
|
| 378 |
+
[
|
| 379 |
+
['demo.jpg', 0],
|
| 380 |
+
],
|
| 381 |
+
inputs=[
|
| 382 |
+
image_input,
|
| 383 |
+
stitch_slider
|
| 384 |
+
],
|
| 385 |
+
outputs=[
|
| 386 |
+
label_output,
|
| 387 |
+
stitch_plot
|
| 388 |
+
],
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
main_page.launch(allowed_paths=['./'])
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
if __name__ == '__main__':
|
| 395 |
+
parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()])
|
| 396 |
+
args = parser.parse_args()
|
| 397 |
+
setattr(args, 'config', f'gradio_demo.json')
|
| 398 |
+
if args.config is not None:
|
| 399 |
+
config_args = json.load(open(args.config))
|
| 400 |
+
override_keys = {arg[2:].split('=')[0] for arg in sys.argv[1:]
|
| 401 |
+
if arg.startswith('--')}
|
| 402 |
+
for k, v in config_args.items():
|
| 403 |
+
if k not in override_keys:
|
| 404 |
+
setattr(args, k, v)
|
| 405 |
+
|
| 406 |
+
output_dir = os.path.join('outputs', args.exp_name)
|
| 407 |
+
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
| 408 |
+
checkpoint_path = os.path.join(output_dir, 'checkpoint.pth')
|
| 409 |
+
if os.path.exists(checkpoint_path) and not args.resume:
|
| 410 |
+
setattr(args, 'resume', checkpoint_path)
|
| 411 |
|
| 412 |
+
setattr(args, 'output_dir', output_dir)
|
|
|
|
| 413 |
|
| 414 |
+
main(args)
|
|
|
datasets.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2015-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
|
| 6 |
+
from torchvision import datasets, transforms
|
| 7 |
+
from torchvision.datasets.folder import ImageFolder, default_loader
|
| 8 |
+
|
| 9 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 10 |
+
from timm.data import create_transform
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class INatDataset(ImageFolder):
|
| 14 |
+
def __init__(self, root, train=True, year=2018, transform=None, target_transform=None,
|
| 15 |
+
category='name', loader=default_loader):
|
| 16 |
+
self.transform = transform
|
| 17 |
+
self.loader = loader
|
| 18 |
+
self.target_transform = target_transform
|
| 19 |
+
self.year = year
|
| 20 |
+
# assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name']
|
| 21 |
+
path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json')
|
| 22 |
+
with open(path_json) as json_file:
|
| 23 |
+
data = json.load(json_file)
|
| 24 |
+
|
| 25 |
+
with open(os.path.join(root, 'categories.json')) as json_file:
|
| 26 |
+
data_catg = json.load(json_file)
|
| 27 |
+
|
| 28 |
+
path_json_for_targeter = os.path.join(root, f"train{year}.json")
|
| 29 |
+
|
| 30 |
+
with open(path_json_for_targeter) as json_file:
|
| 31 |
+
data_for_targeter = json.load(json_file)
|
| 32 |
+
|
| 33 |
+
targeter = {}
|
| 34 |
+
indexer = 0
|
| 35 |
+
for elem in data_for_targeter['annotations']:
|
| 36 |
+
king = []
|
| 37 |
+
king.append(data_catg[int(elem['category_id'])][category])
|
| 38 |
+
if king[0] not in targeter.keys():
|
| 39 |
+
targeter[king[0]] = indexer
|
| 40 |
+
indexer += 1
|
| 41 |
+
self.nb_classes = len(targeter)
|
| 42 |
+
|
| 43 |
+
self.samples = []
|
| 44 |
+
for elem in data['images']:
|
| 45 |
+
cut = elem['file_name'].split('/')
|
| 46 |
+
target_current = int(cut[2])
|
| 47 |
+
path_current = os.path.join(root, cut[0], cut[2], cut[3])
|
| 48 |
+
|
| 49 |
+
categors = data_catg[target_current]
|
| 50 |
+
target_current_true = targeter[categors[category]]
|
| 51 |
+
self.samples.append((path_current, target_current_true))
|
| 52 |
+
|
| 53 |
+
# __getitem__ and __len__ inherited from ImageFolder
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def build_dataset(is_train, args):
|
| 57 |
+
transform = build_transform(is_train, args)
|
| 58 |
+
|
| 59 |
+
if args.data_set == 'CIFAR':
|
| 60 |
+
dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform)
|
| 61 |
+
nb_classes = 100
|
| 62 |
+
elif args.data_set == 'IMNET':
|
| 63 |
+
root = os.path.join(args.data_path, 'train' if is_train else 'val')
|
| 64 |
+
dataset = datasets.ImageFolder(root, transform=transform)
|
| 65 |
+
nb_classes = 1000
|
| 66 |
+
elif args.data_set == 'INAT':
|
| 67 |
+
dataset = INatDataset(args.data_path, train=is_train, year=2018,
|
| 68 |
+
category=args.inat_category, transform=transform)
|
| 69 |
+
nb_classes = dataset.nb_classes
|
| 70 |
+
elif args.data_set == 'INAT19':
|
| 71 |
+
dataset = INatDataset(args.data_path, train=is_train, year=2019,
|
| 72 |
+
category=args.inat_category, transform=transform)
|
| 73 |
+
nb_classes = dataset.nb_classes
|
| 74 |
+
|
| 75 |
+
return dataset, nb_classes
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def build_transform(is_train, args):
|
| 79 |
+
resize_im = args.input_size > 32
|
| 80 |
+
if is_train:
|
| 81 |
+
# this should always dispatch to transforms_imagenet_train
|
| 82 |
+
transform = create_transform(
|
| 83 |
+
input_size=args.input_size,
|
| 84 |
+
is_training=True,
|
| 85 |
+
color_jitter=args.color_jitter,
|
| 86 |
+
auto_augment=args.aa,
|
| 87 |
+
interpolation=args.train_interpolation,
|
| 88 |
+
re_prob=args.reprob,
|
| 89 |
+
re_mode=args.remode,
|
| 90 |
+
re_count=args.recount,
|
| 91 |
+
)
|
| 92 |
+
if not resize_im:
|
| 93 |
+
# replace RandomResizedCropAndInterpolation with
|
| 94 |
+
# RandomCrop
|
| 95 |
+
transform.transforms[0] = transforms.RandomCrop(
|
| 96 |
+
args.input_size, padding=4)
|
| 97 |
+
return transform
|
| 98 |
+
|
| 99 |
+
t = []
|
| 100 |
+
if resize_im:
|
| 101 |
+
size = int(args.input_size / args.eval_crop_ratio)
|
| 102 |
+
t.append(
|
| 103 |
+
transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images
|
| 104 |
+
)
|
| 105 |
+
t.append(transforms.CenterCrop(args.input_size))
|
| 106 |
+
|
| 107 |
+
t.append(transforms.ToTensor())
|
| 108 |
+
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
|
| 109 |
+
return transforms.Compose(t)
|
demo.jpg
ADDED
|
flops_gradio_demo.json
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"0": 4608338304,
|
| 3 |
+
"1": 61604135936,
|
| 4 |
+
"2": 56843745792,
|
| 5 |
+
"3": 52102230016,
|
| 6 |
+
"4": 47360714240,
|
| 7 |
+
"5": 42619198464,
|
| 8 |
+
"6": 37877682688,
|
| 9 |
+
"7": 33136166912,
|
| 10 |
+
"8": 28394651136,
|
| 11 |
+
"9": 23653135360,
|
| 12 |
+
"10": 18911619584,
|
| 13 |
+
"11": 14170103808,
|
| 14 |
+
"12": 9428588032,
|
| 15 |
+
"14": 9523655552,
|
| 16 |
+
"15": 14265171328,
|
| 17 |
+
"16": 19006687104,
|
| 18 |
+
"17": 23748202880,
|
| 19 |
+
"18": 28489718656,
|
| 20 |
+
"19": 33231234432,
|
| 21 |
+
"20": 37972750208,
|
| 22 |
+
"21": 42714265984,
|
| 23 |
+
"22": 47455781760,
|
| 24 |
+
"23": 52197297536,
|
| 25 |
+
"24": 56938813312,
|
| 26 |
+
"25": 57017547264,
|
| 27 |
+
"26": 52276031488,
|
| 28 |
+
"27": 47534515712,
|
| 29 |
+
"28": 42792999936,
|
| 30 |
+
"29": 38051484160,
|
| 31 |
+
"30": 33309968384,
|
| 32 |
+
"31": 28568452608,
|
| 33 |
+
"32": 23826936832,
|
| 34 |
+
"33": 19085421056,
|
| 35 |
+
"34": 14343905280,
|
| 36 |
+
"35": 57017547264,
|
| 37 |
+
"36": 52276031488,
|
| 38 |
+
"37": 47534515712,
|
| 39 |
+
"38": 42792999936,
|
| 40 |
+
"39": 38051484160,
|
| 41 |
+
"40": 33309968384,
|
| 42 |
+
"41": 28568452608,
|
| 43 |
+
"42": 23826936832,
|
| 44 |
+
"43": 19085421056,
|
| 45 |
+
"44": 57017547264,
|
| 46 |
+
"45": 52276031488,
|
| 47 |
+
"46": 47534515712,
|
| 48 |
+
"47": 42792999936,
|
| 49 |
+
"48": 38051484160,
|
| 50 |
+
"49": 33309968384,
|
| 51 |
+
"50": 28568452608,
|
| 52 |
+
"51": 23826936832,
|
| 53 |
+
"52": 57017547264,
|
| 54 |
+
"53": 52276031488,
|
| 55 |
+
"54": 47534515712,
|
| 56 |
+
"55": 42792999936,
|
| 57 |
+
"56": 38051484160,
|
| 58 |
+
"57": 33309968384,
|
| 59 |
+
"58": 28568452608,
|
| 60 |
+
"59": 57017547264,
|
| 61 |
+
"60": 52276031488,
|
| 62 |
+
"61": 47534515712,
|
| 63 |
+
"62": 42792999936,
|
| 64 |
+
"63": 38051484160,
|
| 65 |
+
"64": 33309968384,
|
| 66 |
+
"65": 57017547264,
|
| 67 |
+
"66": 52276031488,
|
| 68 |
+
"67": 47534515712,
|
| 69 |
+
"68": 42792999936,
|
| 70 |
+
"69": 38051484160,
|
| 71 |
+
"70": 57017547264,
|
| 72 |
+
"71": 52276031488,
|
| 73 |
+
"72": 47534515712,
|
| 74 |
+
"73": 42792999936,
|
| 75 |
+
"74": 57017547264,
|
| 76 |
+
"75": 52276031488,
|
| 77 |
+
"76": 47534515712,
|
| 78 |
+
"77": 57017547264,
|
| 79 |
+
"78": 52276031488,
|
| 80 |
+
"79": 57017547264,
|
| 81 |
+
"80": 9504781184,
|
| 82 |
+
"81": 14246296960,
|
| 83 |
+
"82": 18987812736,
|
| 84 |
+
"83": 23729328512,
|
| 85 |
+
"84": 28470844288,
|
| 86 |
+
"85": 33212360064,
|
| 87 |
+
"86": 37953875840,
|
| 88 |
+
"87": 42695391616,
|
| 89 |
+
"88": 47436907392,
|
| 90 |
+
"89": 52178423168,
|
| 91 |
+
"90": 9504781184,
|
| 92 |
+
"91": 14246296960,
|
| 93 |
+
"92": 18987812736,
|
| 94 |
+
"93": 23729328512,
|
| 95 |
+
"94": 28470844288,
|
| 96 |
+
"95": 33212360064,
|
| 97 |
+
"96": 37953875840,
|
| 98 |
+
"97": 42695391616,
|
| 99 |
+
"98": 47436907392,
|
| 100 |
+
"99": 9504781184,
|
| 101 |
+
"100": 14246296960,
|
| 102 |
+
"101": 18987812736,
|
| 103 |
+
"102": 23729328512,
|
| 104 |
+
"103": 28470844288,
|
| 105 |
+
"104": 33212360064,
|
| 106 |
+
"105": 37953875840,
|
| 107 |
+
"106": 42695391616,
|
| 108 |
+
"107": 9504781184,
|
| 109 |
+
"108": 14246296960,
|
| 110 |
+
"109": 18987812736,
|
| 111 |
+
"110": 23729328512,
|
| 112 |
+
"111": 28470844288,
|
| 113 |
+
"112": 33212360064,
|
| 114 |
+
"113": 37953875840,
|
| 115 |
+
"114": 9504781184,
|
| 116 |
+
"115": 14246296960,
|
| 117 |
+
"116": 18987812736,
|
| 118 |
+
"117": 23729328512,
|
| 119 |
+
"118": 28470844288,
|
| 120 |
+
"119": 33212360064,
|
| 121 |
+
"120": 9504781184,
|
| 122 |
+
"121": 14246296960,
|
| 123 |
+
"122": 18987812736,
|
| 124 |
+
"123": 23729328512,
|
| 125 |
+
"124": 28470844288,
|
| 126 |
+
"125": 9504781184,
|
| 127 |
+
"126": 14246296960,
|
| 128 |
+
"127": 18987812736,
|
| 129 |
+
"128": 23729328512,
|
| 130 |
+
"129": 9504781184,
|
| 131 |
+
"130": 14246296960,
|
| 132 |
+
"131": 18987812736,
|
| 133 |
+
"132": 9504781184,
|
| 134 |
+
"133": 14246296960,
|
| 135 |
+
"134": 9504781184
|
| 136 |
+
}
|
gradio_banner.png
ADDED
|
gradio_demo.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"anchors": ["deit_small_patch16_LS", "deit_large_patch16_LS"],
|
| 3 |
+
"batch_size": 64,
|
| 4 |
+
"snnet_name": "snnet_v2",
|
| 5 |
+
"data_path": "/data2/datasets/imagenet",
|
| 6 |
+
"data_set": "IMNET",
|
| 7 |
+
"exp_name": "stitch_s_l_v2_lora_r_64_50_ep",
|
| 8 |
+
"input_size": 224,
|
| 9 |
+
"num_workers": 10,
|
| 10 |
+
"lr": 0.00003,
|
| 11 |
+
"warmup_lr": 1e-7,
|
| 12 |
+
"epochs": 50,
|
| 13 |
+
"weight_decay": 0.02,
|
| 14 |
+
"sched": "cosine",
|
| 15 |
+
"eval_crop_ratio": 1.0,
|
| 16 |
+
"reprob": 0.0,
|
| 17 |
+
"smoothing": 0.1,
|
| 18 |
+
"warmup_epochs": 5,
|
| 19 |
+
"drop": 0.0,
|
| 20 |
+
"seed": 0,
|
| 21 |
+
"opt": "fusedlamb",
|
| 22 |
+
"mixup": 0,
|
| 23 |
+
"anchor_drop_path": [0.05, 0.4],
|
| 24 |
+
"cutmix": 1.0,
|
| 25 |
+
"color_jitter": 0.3,
|
| 26 |
+
"unscale_lr": true,
|
| 27 |
+
"no_repeated_aug": true,
|
| 28 |
+
"ThreeAugment": true,
|
| 29 |
+
"src": true,
|
| 30 |
+
"lora_r": 64,
|
| 31 |
+
"pretrained_deit": "../pretrained_weights",
|
| 32 |
+
"resume": "snnetv2_deit3_s_l.pth"
|
| 33 |
+
}
|
models_v2.py
ADDED
|
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
import os.path
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from functools import partial
|
| 8 |
+
|
| 9 |
+
from timm.models.vision_transformer import Mlp, PatchEmbed , _cfg
|
| 10 |
+
|
| 11 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
| 12 |
+
from timm.models.registry import register_model
|
| 13 |
+
# from xformers.ops import memory_efficient_attention
|
| 14 |
+
|
| 15 |
+
class Attention(nn.Module):
|
| 16 |
+
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
| 17 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.num_heads = num_heads
|
| 20 |
+
head_dim = dim // num_heads
|
| 21 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 22 |
+
|
| 23 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 24 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 25 |
+
self.proj = nn.Linear(dim, dim)
|
| 26 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 27 |
+
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
B, N, C = x.shape
|
| 30 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 31 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# x = memory_efficient_attention(q, k, v).transpose(1, 2).reshape(B, N, C)
|
| 35 |
+
|
| 36 |
+
q = q * self.scale
|
| 37 |
+
attn = (q @ k.transpose(-2, -1))
|
| 38 |
+
attn = attn.softmax(dim=-1)
|
| 39 |
+
attn = self.attn_drop(attn)
|
| 40 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 41 |
+
|
| 42 |
+
x = self.proj(x)
|
| 43 |
+
x = self.proj_drop(x)
|
| 44 |
+
return x
|
| 45 |
+
|
| 46 |
+
class Block(nn.Module):
|
| 47 |
+
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
| 48 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 49 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,Attention_block = Attention,Mlp_block=Mlp
|
| 50 |
+
,init_values=1e-4):
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.norm1 = norm_layer(dim)
|
| 53 |
+
self.attn = Attention_block(
|
| 54 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 55 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 56 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 57 |
+
self.norm2 = norm_layer(dim)
|
| 58 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 59 |
+
self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 60 |
+
|
| 61 |
+
def forward(self, x):
|
| 62 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
| 63 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 64 |
+
return x
|
| 65 |
+
|
| 66 |
+
class Layer_scale_init_Block(nn.Module):
|
| 67 |
+
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
| 68 |
+
# with slight modifications
|
| 69 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 70 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,Attention_block = Attention,Mlp_block=Mlp
|
| 71 |
+
,init_values=1e-4):
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.norm1 = norm_layer(dim)
|
| 74 |
+
self.attn = Attention_block(
|
| 75 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 76 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 77 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 78 |
+
self.norm2 = norm_layer(dim)
|
| 79 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 80 |
+
self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 81 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
| 82 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
| 83 |
+
|
| 84 |
+
def forward(self, x):
|
| 85 |
+
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
|
| 86 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
| 87 |
+
return x
|
| 88 |
+
|
| 89 |
+
class Layer_scale_init_Block_paralx2(nn.Module):
|
| 90 |
+
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
| 91 |
+
# with slight modifications
|
| 92 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 93 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,Attention_block = Attention,Mlp_block=Mlp
|
| 94 |
+
,init_values=1e-4):
|
| 95 |
+
super().__init__()
|
| 96 |
+
self.norm1 = norm_layer(dim)
|
| 97 |
+
self.norm11 = norm_layer(dim)
|
| 98 |
+
self.attn = Attention_block(
|
| 99 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 100 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 101 |
+
self.attn1 = Attention_block(
|
| 102 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 103 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 104 |
+
self.norm2 = norm_layer(dim)
|
| 105 |
+
self.norm21 = norm_layer(dim)
|
| 106 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 107 |
+
self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 108 |
+
self.mlp1 = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 109 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
| 110 |
+
self.gamma_1_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
| 111 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
| 112 |
+
self.gamma_2_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
| 113 |
+
|
| 114 |
+
def forward(self, x):
|
| 115 |
+
x = x + self.drop_path(self.gamma_1*self.attn(self.norm1(x))) + self.drop_path(self.gamma_1_1 * self.attn1(self.norm11(x)))
|
| 116 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + self.drop_path(self.gamma_2_1 * self.mlp1(self.norm21(x)))
|
| 117 |
+
return x
|
| 118 |
+
|
| 119 |
+
class Block_paralx2(nn.Module):
|
| 120 |
+
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
| 121 |
+
# with slight modifications
|
| 122 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 123 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,Attention_block = Attention,Mlp_block=Mlp
|
| 124 |
+
,init_values=1e-4):
|
| 125 |
+
super().__init__()
|
| 126 |
+
self.norm1 = norm_layer(dim)
|
| 127 |
+
self.norm11 = norm_layer(dim)
|
| 128 |
+
self.attn = Attention_block(
|
| 129 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 130 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 131 |
+
self.attn1 = Attention_block(
|
| 132 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 133 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 134 |
+
self.norm2 = norm_layer(dim)
|
| 135 |
+
self.norm21 = norm_layer(dim)
|
| 136 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 137 |
+
self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 138 |
+
self.mlp1 = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 139 |
+
|
| 140 |
+
def forward(self, x):
|
| 141 |
+
x = x + self.drop_path(self.attn(self.norm1(x))) + self.drop_path(self.attn1(self.norm11(x)))
|
| 142 |
+
x = x + self.drop_path(self.mlp(self.norm2(x))) + self.drop_path(self.mlp1(self.norm21(x)))
|
| 143 |
+
return x
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class hMLP_stem(nn.Module):
|
| 147 |
+
""" hMLP_stem: https://arxiv.org/pdf/2203.09795.pdf
|
| 148 |
+
taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
| 149 |
+
with slight modifications
|
| 150 |
+
"""
|
| 151 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768,norm_layer=nn.SyncBatchNorm):
|
| 152 |
+
super().__init__()
|
| 153 |
+
img_size = to_2tuple(img_size)
|
| 154 |
+
patch_size = to_2tuple(patch_size)
|
| 155 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
| 156 |
+
self.img_size = img_size
|
| 157 |
+
self.patch_size = patch_size
|
| 158 |
+
self.num_patches = num_patches
|
| 159 |
+
self.proj = torch.nn.Sequential(*[nn.Conv2d(in_chans, embed_dim//4, kernel_size=4, stride=4),
|
| 160 |
+
norm_layer(embed_dim//4),
|
| 161 |
+
nn.GELU(),
|
| 162 |
+
nn.Conv2d(embed_dim//4, embed_dim//4, kernel_size=2, stride=2),
|
| 163 |
+
norm_layer(embed_dim//4),
|
| 164 |
+
nn.GELU(),
|
| 165 |
+
nn.Conv2d(embed_dim//4, embed_dim, kernel_size=2, stride=2),
|
| 166 |
+
norm_layer(embed_dim),
|
| 167 |
+
])
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def forward(self, x):
|
| 171 |
+
B, C, H, W = x.shape
|
| 172 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
| 173 |
+
return x
|
| 174 |
+
|
| 175 |
+
class vit_models(nn.Module):
|
| 176 |
+
""" Vision Transformer with LayerScale (https://arxiv.org/abs/2103.17239) support
|
| 177 |
+
taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
| 178 |
+
with slight modifications
|
| 179 |
+
"""
|
| 180 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
| 181 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
| 182 |
+
drop_path_rate=0., norm_layer=nn.LayerNorm, global_pool=None,
|
| 183 |
+
block_layers = Block,
|
| 184 |
+
Patch_layer=PatchEmbed,act_layer=nn.GELU,
|
| 185 |
+
Attention_block = Attention, Mlp_block=Mlp,
|
| 186 |
+
dpr_constant=True,init_scale=1e-4,
|
| 187 |
+
mlp_ratio_clstk = 4.0):
|
| 188 |
+
super().__init__()
|
| 189 |
+
|
| 190 |
+
self.dropout_rate = drop_rate
|
| 191 |
+
self.depth = depth
|
| 192 |
+
|
| 193 |
+
self.num_classes = num_classes
|
| 194 |
+
self.num_features = self.embed_dim = embed_dim
|
| 195 |
+
|
| 196 |
+
self.patch_embed = Patch_layer(
|
| 197 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 198 |
+
num_patches = self.patch_embed.num_patches
|
| 199 |
+
|
| 200 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 201 |
+
|
| 202 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
| 203 |
+
|
| 204 |
+
dpr = [drop_path_rate for i in range(depth)]
|
| 205 |
+
self.blocks = nn.ModuleList([
|
| 206 |
+
block_layers(
|
| 207 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 208 |
+
drop=0.0, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
| 209 |
+
act_layer=act_layer,Attention_block=Attention_block,Mlp_block=Mlp_block,init_values=init_scale)
|
| 210 |
+
for i in range(depth)])
|
| 211 |
+
|
| 212 |
+
self.norm = norm_layer(embed_dim)
|
| 213 |
+
|
| 214 |
+
self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')]
|
| 215 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 216 |
+
|
| 217 |
+
trunc_normal_(self.pos_embed, std=.02)
|
| 218 |
+
trunc_normal_(self.cls_token, std=.02)
|
| 219 |
+
self.apply(self._init_weights)
|
| 220 |
+
|
| 221 |
+
def _init_weights(self, m):
|
| 222 |
+
if isinstance(m, nn.Linear):
|
| 223 |
+
trunc_normal_(m.weight, std=.02)
|
| 224 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 225 |
+
nn.init.constant_(m.bias, 0)
|
| 226 |
+
elif isinstance(m, nn.LayerNorm):
|
| 227 |
+
nn.init.constant_(m.bias, 0)
|
| 228 |
+
nn.init.constant_(m.weight, 1.0)
|
| 229 |
+
|
| 230 |
+
@torch.jit.ignore
|
| 231 |
+
def no_weight_decay(self):
|
| 232 |
+
return {'pos_embed', 'cls_token'}
|
| 233 |
+
|
| 234 |
+
def get_classifier(self):
|
| 235 |
+
return self.head
|
| 236 |
+
|
| 237 |
+
def get_num_layers(self):
|
| 238 |
+
return len(self.blocks)
|
| 239 |
+
|
| 240 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
| 241 |
+
self.num_classes = num_classes
|
| 242 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def extract_block_features(self, x):
|
| 246 |
+
B = x.shape[0]
|
| 247 |
+
x = self.patch_embed(x)
|
| 248 |
+
|
| 249 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
| 250 |
+
|
| 251 |
+
x = x + self.pos_embed
|
| 252 |
+
|
| 253 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 254 |
+
|
| 255 |
+
outs = {}
|
| 256 |
+
|
| 257 |
+
for i, blk in enumerate(self.blocks):
|
| 258 |
+
x = blk(x)
|
| 259 |
+
outs[i] = x.detach()
|
| 260 |
+
return outs
|
| 261 |
+
|
| 262 |
+
def selective_forward(self, x, begin, end):
|
| 263 |
+
for i, blk in enumerate(self.blocks):
|
| 264 |
+
if i < begin:
|
| 265 |
+
continue
|
| 266 |
+
if i > end:
|
| 267 |
+
break
|
| 268 |
+
x = blk(x)
|
| 269 |
+
return x
|
| 270 |
+
|
| 271 |
+
def forward_until(self, x, blk_id):
|
| 272 |
+
B = x.shape[0]
|
| 273 |
+
x = self.patch_embed(x)
|
| 274 |
+
|
| 275 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
| 276 |
+
x = x + self.pos_embed
|
| 277 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 278 |
+
|
| 279 |
+
for i, blk in enumerate(self.blocks):
|
| 280 |
+
x = blk(x)
|
| 281 |
+
if i == blk_id:
|
| 282 |
+
break
|
| 283 |
+
|
| 284 |
+
return x
|
| 285 |
+
|
| 286 |
+
def forward_from(self, x, blk_id):
|
| 287 |
+
for i, blk in enumerate(self.blocks):
|
| 288 |
+
if i < blk_id:
|
| 289 |
+
continue
|
| 290 |
+
x = blk(x)
|
| 291 |
+
|
| 292 |
+
x = self.norm(x)
|
| 293 |
+
x = self.head(x[:, 0])
|
| 294 |
+
|
| 295 |
+
return x
|
| 296 |
+
|
| 297 |
+
def forward_patch_embed(self, x):
|
| 298 |
+
B = x.shape[0]
|
| 299 |
+
x = self.patch_embed(x)
|
| 300 |
+
|
| 301 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
| 302 |
+
|
| 303 |
+
x = x + self.pos_embed
|
| 304 |
+
|
| 305 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 306 |
+
return x
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def forward_norm_head(self, x):
|
| 310 |
+
x = self.norm(x)
|
| 311 |
+
x = self.head(x[:, 0])
|
| 312 |
+
return x
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def forward_features(self, x):
|
| 316 |
+
B = x.shape[0]
|
| 317 |
+
x = self.patch_embed(x)
|
| 318 |
+
|
| 319 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
| 320 |
+
|
| 321 |
+
x = x + self.pos_embed
|
| 322 |
+
|
| 323 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 324 |
+
|
| 325 |
+
for i , blk in enumerate(self.blocks):
|
| 326 |
+
x = blk(x)
|
| 327 |
+
|
| 328 |
+
x = self.norm(x)
|
| 329 |
+
return x[:, 0]
|
| 330 |
+
|
| 331 |
+
def forward(self, x):
|
| 332 |
+
|
| 333 |
+
x = self.forward_features(x)
|
| 334 |
+
|
| 335 |
+
if self.dropout_rate:
|
| 336 |
+
x = F.dropout(x, p=float(self.dropout_rate), training=self.training)
|
| 337 |
+
x = self.head(x)
|
| 338 |
+
|
| 339 |
+
return x
|
| 340 |
+
|
| 341 |
+
# DeiT III: Revenge of the ViT (https://arxiv.org/abs/2204.07118)
|
| 342 |
+
|
| 343 |
+
@register_model
|
| 344 |
+
def deit_tiny_patch16_LS(pretrained=False, img_size=224, pretrained_21k = False, pretrained_cfg_overlay=None, **kwargs):
|
| 345 |
+
model = vit_models(
|
| 346 |
+
img_size = img_size, patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
|
| 347 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Layer_scale_init_Block, **kwargs)
|
| 348 |
+
|
| 349 |
+
return model
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
@register_model
|
| 353 |
+
def deit_small_patch16_LS(pretrained=False, img_size=224, pretrained_21k = False, pretrained_cfg=None, pretrained_deit=None, pretrained_cfg_overlay=None, **kwargs):
|
| 354 |
+
model = vit_models(
|
| 355 |
+
img_size = img_size, patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
|
| 356 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Layer_scale_init_Block, **kwargs)
|
| 357 |
+
model.default_cfg = _cfg()
|
| 358 |
+
if pretrained:
|
| 359 |
+
# name = 'https://dl.fbaipublicfiles.com/deit/deit_3_small_'+str(img_size)+'_'
|
| 360 |
+
# if pretrained_21k:
|
| 361 |
+
# name+='21k.pth'
|
| 362 |
+
# else:
|
| 363 |
+
# name+='1k.pth'
|
| 364 |
+
|
| 365 |
+
# checkpoint = torch.hub.load_state_dict_from_url(
|
| 366 |
+
# url=name,
|
| 367 |
+
# map_location="cpu", check_hash=True
|
| 368 |
+
# )
|
| 369 |
+
checkpoint = torch.load(os.path.join(pretrained_deit, 'deit_3_small_224_21k.pth'))
|
| 370 |
+
model.load_state_dict(checkpoint["model"])
|
| 371 |
+
|
| 372 |
+
return model
|
| 373 |
+
|
| 374 |
+
@register_model
|
| 375 |
+
def deit_medium_patch16_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
| 376 |
+
model = vit_models(
|
| 377 |
+
patch_size=16, embed_dim=512, depth=12, num_heads=8, mlp_ratio=4, qkv_bias=True,
|
| 378 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers = Layer_scale_init_Block, **kwargs)
|
| 379 |
+
model.default_cfg = _cfg()
|
| 380 |
+
if pretrained:
|
| 381 |
+
name = 'https://dl.fbaipublicfiles.com/deit/deit_3_medium_'+str(img_size)+'_'
|
| 382 |
+
if pretrained_21k:
|
| 383 |
+
name+='21k.pth'
|
| 384 |
+
else:
|
| 385 |
+
name+='1k.pth'
|
| 386 |
+
|
| 387 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 388 |
+
url=name,
|
| 389 |
+
map_location="cpu", check_hash=True
|
| 390 |
+
)
|
| 391 |
+
model.load_state_dict(checkpoint["model"])
|
| 392 |
+
return model
|
| 393 |
+
|
| 394 |
+
@register_model
|
| 395 |
+
def deit_base_patch16_LS(pretrained=False, pretrained_cfg=None, img_size=224, pretrained_21k = False, pretrained_deit=None, pretrained_cfg_overlay=None, **kwargs):
|
| 396 |
+
model = vit_models(
|
| 397 |
+
img_size = img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 398 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Layer_scale_init_Block, **kwargs)
|
| 399 |
+
if pretrained:
|
| 400 |
+
# name = 'https://dl.fbaipublicfiles.com/deit/deit_3_small_'+str(img_size)+'_'
|
| 401 |
+
# if pretrained_21k:
|
| 402 |
+
# name+='21k.pth'
|
| 403 |
+
# else:
|
| 404 |
+
# name+='1k.pth'
|
| 405 |
+
|
| 406 |
+
# checkpoint = torch.hub.load_state_dict_from_url(
|
| 407 |
+
# url=name,
|
| 408 |
+
# map_location="cpu", check_hash=True
|
| 409 |
+
# )
|
| 410 |
+
checkpoint = torch.load(os.path.join(pretrained_deit, 'deit_3_base_224_21k.pth'))
|
| 411 |
+
model.load_state_dict(checkpoint["model"])
|
| 412 |
+
return model
|
| 413 |
+
|
| 414 |
+
@register_model
|
| 415 |
+
def deit_large_patch16_LS(pretrained=False, img_size=224, pretrained_21k = False, pretrained_cfg=None, pretrained_deit=None, pretrained_cfg_overlay=None, **kwargs):
|
| 416 |
+
model = vit_models(
|
| 417 |
+
img_size = img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
| 418 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Layer_scale_init_Block, **kwargs)
|
| 419 |
+
if pretrained:
|
| 420 |
+
# name = 'https://dl.fbaipublicfiles.com/deit/deit_3_large_'+str(img_size)+'_'
|
| 421 |
+
# if pretrained_21k:
|
| 422 |
+
# name+='21k.pth'
|
| 423 |
+
# else:
|
| 424 |
+
# name+='1k.pth'
|
| 425 |
+
#
|
| 426 |
+
# checkpoint = torch.hub.load_state_dict_from_url(
|
| 427 |
+
# url=name,
|
| 428 |
+
# map_location="cpu", check_hash=True
|
| 429 |
+
# )
|
| 430 |
+
checkpoint = torch.load(os.path.join(pretrained_deit, 'deit_3_large_224_21k.pth'))
|
| 431 |
+
model.load_state_dict(checkpoint["model"])
|
| 432 |
+
return model
|
| 433 |
+
|
| 434 |
+
@register_model
|
| 435 |
+
def deit_huge_patch14_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
| 436 |
+
model = vit_models(
|
| 437 |
+
img_size = img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
| 438 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers = Layer_scale_init_Block, **kwargs)
|
| 439 |
+
if pretrained:
|
| 440 |
+
name = 'https://dl.fbaipublicfiles.com/deit/deit_3_huge_'+str(img_size)+'_'
|
| 441 |
+
if pretrained_21k:
|
| 442 |
+
name+='21k_v1.pth'
|
| 443 |
+
else:
|
| 444 |
+
name+='1k_v1.pth'
|
| 445 |
+
|
| 446 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 447 |
+
url=name,
|
| 448 |
+
map_location="cpu", check_hash=True
|
| 449 |
+
)
|
| 450 |
+
model.load_state_dict(checkpoint["model"])
|
| 451 |
+
return model
|
| 452 |
+
|
| 453 |
+
@register_model
|
| 454 |
+
def deit_huge_patch14_52_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
| 455 |
+
model = vit_models(
|
| 456 |
+
img_size = img_size, patch_size=14, embed_dim=1280, depth=52, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
| 457 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers = Layer_scale_init_Block, **kwargs)
|
| 458 |
+
|
| 459 |
+
return model
|
| 460 |
+
|
| 461 |
+
@register_model
|
| 462 |
+
def deit_huge_patch14_26x2_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
| 463 |
+
model = vit_models(
|
| 464 |
+
img_size = img_size, patch_size=14, embed_dim=1280, depth=26, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
| 465 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers = Layer_scale_init_Block_paralx2, **kwargs)
|
| 466 |
+
|
| 467 |
+
return model
|
| 468 |
+
|
| 469 |
+
@register_model
|
| 470 |
+
def deit_Giant_48x2_patch14_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
| 471 |
+
model = vit_models(
|
| 472 |
+
img_size = img_size, patch_size=14, embed_dim=1664, depth=48, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
| 473 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers = Block_paral_LS, **kwargs)
|
| 474 |
+
|
| 475 |
+
return model
|
| 476 |
+
|
| 477 |
+
@register_model
|
| 478 |
+
def deit_giant_40x2_patch14_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
| 479 |
+
model = vit_models(
|
| 480 |
+
img_size = img_size, patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
| 481 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers = Block_paral_LS, **kwargs)
|
| 482 |
+
return model
|
| 483 |
+
|
| 484 |
+
@register_model
|
| 485 |
+
def deit_Giant_48_patch14_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
| 486 |
+
model = vit_models(
|
| 487 |
+
img_size = img_size, patch_size=14, embed_dim=1664, depth=48, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
| 488 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers = Layer_scale_init_Block, **kwargs)
|
| 489 |
+
return model
|
| 490 |
+
|
| 491 |
+
@register_model
|
| 492 |
+
def deit_giant_40_patch14_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
| 493 |
+
model = vit_models(
|
| 494 |
+
img_size = img_size, patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
| 495 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers = Layer_scale_init_Block, **kwargs)
|
| 496 |
+
#model.default_cfg = _cfg()
|
| 497 |
+
|
| 498 |
+
return model
|
| 499 |
+
|
| 500 |
+
# Models from Three things everyone should know about Vision Transformers (https://arxiv.org/pdf/2203.09795.pdf)
|
| 501 |
+
|
| 502 |
+
@register_model
|
| 503 |
+
def deit_small_patch16_36_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
| 504 |
+
model = vit_models(
|
| 505 |
+
img_size = img_size, patch_size=16, embed_dim=384, depth=36, num_heads=6, mlp_ratio=4, qkv_bias=True,
|
| 506 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Layer_scale_init_Block, **kwargs)
|
| 507 |
+
|
| 508 |
+
return model
|
| 509 |
+
|
| 510 |
+
@register_model
|
| 511 |
+
def deit_small_patch16_36(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
| 512 |
+
model = vit_models(
|
| 513 |
+
img_size = img_size, patch_size=16, embed_dim=384, depth=36, num_heads=6, mlp_ratio=4, qkv_bias=True,
|
| 514 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 515 |
+
|
| 516 |
+
return model
|
| 517 |
+
|
| 518 |
+
@register_model
|
| 519 |
+
def deit_small_patch16_18x2_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
| 520 |
+
model = vit_models(
|
| 521 |
+
img_size = img_size, patch_size=16, embed_dim=384, depth=18, num_heads=6, mlp_ratio=4, qkv_bias=True,
|
| 522 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Layer_scale_init_Block_paralx2, **kwargs)
|
| 523 |
+
|
| 524 |
+
return model
|
| 525 |
+
|
| 526 |
+
@register_model
|
| 527 |
+
def deit_small_patch16_18x2(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
| 528 |
+
model = vit_models(
|
| 529 |
+
img_size = img_size, patch_size=16, embed_dim=384, depth=18, num_heads=6, mlp_ratio=4, qkv_bias=True,
|
| 530 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Block_paralx2, **kwargs)
|
| 531 |
+
|
| 532 |
+
return model
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
@register_model
|
| 536 |
+
def deit_base_patch16_18x2_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
| 537 |
+
model = vit_models(
|
| 538 |
+
img_size = img_size, patch_size=16, embed_dim=768, depth=18, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 539 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Layer_scale_init_Block_paralx2, **kwargs)
|
| 540 |
+
|
| 541 |
+
return model
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
@register_model
|
| 545 |
+
def deit_base_patch16_18x2(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
| 546 |
+
model = vit_models(
|
| 547 |
+
img_size = img_size, patch_size=16, embed_dim=768, depth=18, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 548 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Block_paralx2, **kwargs)
|
| 549 |
+
|
| 550 |
+
return model
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
@register_model
|
| 554 |
+
def deit_base_patch16_36x1_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
| 555 |
+
model = vit_models(
|
| 556 |
+
img_size = img_size, patch_size=16, embed_dim=768, depth=36, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 557 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Layer_scale_init_Block, **kwargs)
|
| 558 |
+
|
| 559 |
+
return model
|
| 560 |
+
|
| 561 |
+
@register_model
|
| 562 |
+
def deit_base_patch16_36x1(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
| 563 |
+
model = vit_models(
|
| 564 |
+
img_size = img_size, patch_size=16, embed_dim=768, depth=36, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 565 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 566 |
+
|
| 567 |
+
return model
|
| 568 |
+
|
outputs/deit/20240118_171921.log
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
2024-01-18 17:19:21,866 - snnet - INFO - Namespace(batch_size=64, epochs=300, bce_loss=False, unscale_lr=False, model='deit_base_patch16_224', input_size=224, drop=0.0, drop_path=0.1, model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, opt='adamw', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.05, sched='cosine', lr=0.0005, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, warmup_lr=1e-06, min_lr=1e-05, decay_epochs=30, warmup_epochs=5, cooldown_epochs=10, patience_epochs=10, decay_rate=0.1, color_jitter=0.3, aa='rand-m9-mstd0.5-inc1', smoothing=0.1, train_interpolation='bicubic', repeated_aug=True, train_mode=True, ThreeAugment=False, src=False, reprob=0.25, remode='pixel', recount=1, resplit=False, mixup=0.8, cutmix=1.0, cutmix_minmax=None, mixup_prob=1.0, mixup_switch_prob=0.5, mixup_mode='batch', teacher_model='regnety_160', teacher_path='', distillation_type='none', distillation_alpha=0.5, distillation_tau=1.0, finetune='', attn_only=False, data_path='/datasets01/imagenet_full_size/061417/', data_set='IMNET', inat_category='name', output_dir='outputs/deit', device='cuda', seed=0, resume='', start_epoch=0, eval=False, eval_crop_ratio=0.875, dist_eval=False, num_workers=10, pin_mem=True, world_size=1, dist_url='env://', exp_name='deit', config=None, scoring=False, proxy='synflow', snnet_name='snnetv2', get_flops=False, flops_sampling_k=None, low_rank=False, lora_r=64, flops_gap=1.0, distributed=False)
|
outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172124.log
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2024-01-18 17:21:24,162 - snnet - INFO - Namespace(batch_size=64, epochs=50, bce_loss=False, unscale_lr=True, model='deit_base_patch16_224', input_size=224, drop=0.0, drop_path=0.1, model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, opt='fusedlamb', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.02, sched='cosine', lr=3e-05, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, warmup_lr=1e-07, min_lr=1e-05, decay_epochs=30, warmup_epochs=5, cooldown_epochs=10, patience_epochs=10, decay_rate=0.1, color_jitter=0.3, aa='rand-m9-mstd0.5-inc1', smoothing=0.1, train_interpolation='bicubic', repeated_aug=True, train_mode=True, ThreeAugment=True, src=True, reprob=0.0, remode='pixel', recount=1, resplit=False, mixup=0, cutmix=1.0, cutmix_minmax=None, mixup_prob=1.0, mixup_switch_prob=0.5, mixup_mode='batch', teacher_model='regnety_160', teacher_path='', distillation_type='none', distillation_alpha=0.5, distillation_tau=1.0, finetune='', attn_only=False, data_path='/data2/datasets/imagenet', data_set='IMNET', inat_category='name', output_dir='outputs/stitch_s_l_v2_lora_r_64_50_ep', device='cuda', seed=0, resume='snnetv2_deit3_s_l.pth', start_epoch=0, eval=False, eval_crop_ratio=1.0, dist_eval=False, num_workers=10, pin_mem=True, world_size=1, dist_url='env://', exp_name='stitch_s_l_v2_lora_r_64_50_ep', config='gradio_demo.json', scoring=False, proxy='synflow', snnet_name='snnet_v2', get_flops=False, flops_sampling_k=None, low_rank=False, lora_r=64, flops_gap=1.0, anchors=['deit_small_patch16_LS', 'deit_large_patch16_LS'], anchor_drop_path=[0.05, 0.4], no_repeated_aug=True, pretrained_deit='../pretrained_weights', distributed=False)
|
| 2 |
+
2024-01-18 17:21:24,163 - snnet - INFO - Creating model: deit_small_patch16_LS
|
outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172140.log
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2024-01-18 17:21:40,831 - snnet - INFO - Namespace(batch_size=64, epochs=50, bce_loss=False, unscale_lr=True, model='deit_base_patch16_224', input_size=224, drop=0.0, drop_path=0.1, model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, opt='fusedlamb', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.02, sched='cosine', lr=3e-05, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, warmup_lr=1e-07, min_lr=1e-05, decay_epochs=30, warmup_epochs=5, cooldown_epochs=10, patience_epochs=10, decay_rate=0.1, color_jitter=0.3, aa='rand-m9-mstd0.5-inc1', smoothing=0.1, train_interpolation='bicubic', repeated_aug=True, train_mode=True, ThreeAugment=True, src=True, reprob=0.0, remode='pixel', recount=1, resplit=False, mixup=0, cutmix=1.0, cutmix_minmax=None, mixup_prob=1.0, mixup_switch_prob=0.5, mixup_mode='batch', teacher_model='regnety_160', teacher_path='', distillation_type='none', distillation_alpha=0.5, distillation_tau=1.0, finetune='', attn_only=False, data_path='/data2/datasets/imagenet', data_set='IMNET', inat_category='name', output_dir='outputs/stitch_s_l_v2_lora_r_64_50_ep', device='cuda', seed=0, resume='snnetv2_deit3_s_l.pth', start_epoch=0, eval=False, eval_crop_ratio=1.0, dist_eval=False, num_workers=10, pin_mem=True, world_size=1, dist_url='env://', exp_name='stitch_s_l_v2_lora_r_64_50_ep', config='gradio_demo.json', scoring=False, proxy='synflow', snnet_name='snnet_v2', get_flops=False, flops_sampling_k=None, low_rank=False, lora_r=64, flops_gap=1.0, anchors=['deit_small_patch16_LS', 'deit_large_patch16_LS'], anchor_drop_path=[0.05, 0.4], no_repeated_aug=True, pretrained_deit='../pretrained_weights', distributed=False)
|
| 2 |
+
2024-01-18 17:21:40,832 - snnet - INFO - Creating model: deit_small_patch16_LS
|
outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172156.log
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2024-01-18 17:21:56,859 - snnet - INFO - Namespace(batch_size=64, epochs=50, bce_loss=False, unscale_lr=True, model='deit_base_patch16_224', input_size=224, drop=0.0, drop_path=0.1, model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, opt='fusedlamb', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.02, sched='cosine', lr=3e-05, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, warmup_lr=1e-07, min_lr=1e-05, decay_epochs=30, warmup_epochs=5, cooldown_epochs=10, patience_epochs=10, decay_rate=0.1, color_jitter=0.3, aa='rand-m9-mstd0.5-inc1', smoothing=0.1, train_interpolation='bicubic', repeated_aug=True, train_mode=True, ThreeAugment=True, src=True, reprob=0.0, remode='pixel', recount=1, resplit=False, mixup=0, cutmix=1.0, cutmix_minmax=None, mixup_prob=1.0, mixup_switch_prob=0.5, mixup_mode='batch', teacher_model='regnety_160', teacher_path='', distillation_type='none', distillation_alpha=0.5, distillation_tau=1.0, finetune='', attn_only=False, data_path='/data2/datasets/imagenet', data_set='IMNET', inat_category='name', output_dir='outputs/stitch_s_l_v2_lora_r_64_50_ep', device='cuda', seed=0, resume='snnetv2_deit3_s_l.pth', start_epoch=0, eval=False, eval_crop_ratio=1.0, dist_eval=False, num_workers=10, pin_mem=True, world_size=1, dist_url='env://', exp_name='stitch_s_l_v2_lora_r_64_50_ep', config='gradio_demo.json', scoring=False, proxy='synflow', snnet_name='snnet_v2', get_flops=False, flops_sampling_k=None, low_rank=False, lora_r=64, flops_gap=1.0, anchors=['deit_small_patch16_LS', 'deit_large_patch16_LS'], anchor_drop_path=[0.05, 0.4], no_repeated_aug=True, pretrained_deit='../pretrained_weights', distributed=False)
|
| 2 |
+
2024-01-18 17:21:56,859 - snnet - INFO - Creating model: deit_small_patch16_LS
|
| 3 |
+
2024-01-18 17:21:57,078 - snnet - INFO - Creating model: deit_large_patch16_LS
|
| 4 |
+
2024-01-18 17:21:59,994 - snnet - INFO - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134]
|
| 5 |
+
2024-01-18 17:22:00,521 - snnet - INFO - load checkpoint from snnetv2_deit3_s_l.pth
|
outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172250.log
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2024-01-18 17:22:50,304 - snnet - INFO - Namespace(batch_size=64, epochs=50, bce_loss=False, unscale_lr=True, model='deit_base_patch16_224', input_size=224, drop=0.0, drop_path=0.1, model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, opt='fusedlamb', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.02, sched='cosine', lr=3e-05, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, warmup_lr=1e-07, min_lr=1e-05, decay_epochs=30, warmup_epochs=5, cooldown_epochs=10, patience_epochs=10, decay_rate=0.1, color_jitter=0.3, aa='rand-m9-mstd0.5-inc1', smoothing=0.1, train_interpolation='bicubic', repeated_aug=True, train_mode=True, ThreeAugment=True, src=True, reprob=0.0, remode='pixel', recount=1, resplit=False, mixup=0, cutmix=1.0, cutmix_minmax=None, mixup_prob=1.0, mixup_switch_prob=0.5, mixup_mode='batch', teacher_model='regnety_160', teacher_path='', distillation_type='none', distillation_alpha=0.5, distillation_tau=1.0, finetune='', attn_only=False, data_path='/data2/datasets/imagenet', data_set='IMNET', inat_category='name', output_dir='outputs/stitch_s_l_v2_lora_r_64_50_ep', device='cpu', seed=0, resume='snnetv2_deit3_s_l.pth', start_epoch=0, eval=False, eval_crop_ratio=1.0, dist_eval=False, num_workers=10, pin_mem=True, world_size=1, dist_url='env://', exp_name='stitch_s_l_v2_lora_r_64_50_ep', config='gradio_demo.json', scoring=False, proxy='synflow', snnet_name='snnet_v2', get_flops=False, flops_sampling_k=None, low_rank=False, lora_r=64, flops_gap=1.0, anchors=['deit_small_patch16_LS', 'deit_large_patch16_LS'], anchor_drop_path=[0.05, 0.4], no_repeated_aug=True, pretrained_deit='../pretrained_weights', distributed=False)
|
| 2 |
+
2024-01-18 17:22:50,305 - snnet - INFO - Creating model: deit_small_patch16_LS
|
| 3 |
+
2024-01-18 17:22:50,535 - snnet - INFO - Creating model: deit_large_patch16_LS
|
| 4 |
+
2024-01-18 17:22:53,873 - snnet - INFO - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134]
|
| 5 |
+
2024-01-18 17:22:54,392 - snnet - INFO - load checkpoint from snnetv2_deit3_s_l.pth
|
outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172309.log
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2024-01-18 17:23:09,551 - snnet - INFO - Namespace(batch_size=64, epochs=50, bce_loss=False, unscale_lr=True, model='deit_base_patch16_224', input_size=224, drop=0.0, drop_path=0.1, model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, opt='fusedlamb', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.02, sched='cosine', lr=3e-05, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, warmup_lr=1e-07, min_lr=1e-05, decay_epochs=30, warmup_epochs=5, cooldown_epochs=10, patience_epochs=10, decay_rate=0.1, color_jitter=0.3, aa='rand-m9-mstd0.5-inc1', smoothing=0.1, train_interpolation='bicubic', repeated_aug=True, train_mode=True, ThreeAugment=True, src=True, reprob=0.0, remode='pixel', recount=1, resplit=False, mixup=0, cutmix=1.0, cutmix_minmax=None, mixup_prob=1.0, mixup_switch_prob=0.5, mixup_mode='batch', teacher_model='regnety_160', teacher_path='', distillation_type='none', distillation_alpha=0.5, distillation_tau=1.0, finetune='', attn_only=False, data_path='/data2/datasets/imagenet', data_set='IMNET', inat_category='name', output_dir='outputs/stitch_s_l_v2_lora_r_64_50_ep', device='cpu', seed=0, resume='snnetv2_deit3_s_l.pth', start_epoch=0, eval=False, eval_crop_ratio=1.0, dist_eval=False, num_workers=10, pin_mem=True, world_size=1, dist_url='env://', exp_name='stitch_s_l_v2_lora_r_64_50_ep', config='gradio_demo.json', scoring=False, proxy='synflow', snnet_name='snnet_v2', get_flops=False, flops_sampling_k=None, low_rank=False, lora_r=64, flops_gap=1.0, anchors=['deit_small_patch16_LS', 'deit_large_patch16_LS'], anchor_drop_path=[0.05, 0.4], no_repeated_aug=True, pretrained_deit='../pretrained_weights', distributed=False)
|
| 2 |
+
2024-01-18 17:23:09,553 - snnet - INFO - Creating model: deit_small_patch16_LS
|
| 3 |
+
2024-01-18 17:23:09,778 - snnet - INFO - Creating model: deit_large_patch16_LS
|
| 4 |
+
2024-01-18 17:23:13,077 - snnet - INFO - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134]
|
| 5 |
+
2024-01-18 17:23:13,587 - snnet - INFO - load checkpoint from snnetv2_deit3_s_l.pth
|
outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172332.log
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2024-01-18 17:23:32,357 - snnet - INFO - Namespace(batch_size=64, epochs=50, bce_loss=False, unscale_lr=True, model='deit_base_patch16_224', input_size=224, drop=0.0, drop_path=0.1, model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, opt='fusedlamb', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.02, sched='cosine', lr=3e-05, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, warmup_lr=1e-07, min_lr=1e-05, decay_epochs=30, warmup_epochs=5, cooldown_epochs=10, patience_epochs=10, decay_rate=0.1, color_jitter=0.3, aa='rand-m9-mstd0.5-inc1', smoothing=0.1, train_interpolation='bicubic', repeated_aug=True, train_mode=True, ThreeAugment=True, src=True, reprob=0.0, remode='pixel', recount=1, resplit=False, mixup=0, cutmix=1.0, cutmix_minmax=None, mixup_prob=1.0, mixup_switch_prob=0.5, mixup_mode='batch', teacher_model='regnety_160', teacher_path='', distillation_type='none', distillation_alpha=0.5, distillation_tau=1.0, finetune='', attn_only=False, data_path='/data2/datasets/imagenet', data_set='IMNET', inat_category='name', output_dir='outputs/stitch_s_l_v2_lora_r_64_50_ep', device='cpu', seed=0, resume='snnetv2_deit3_s_l.pth', start_epoch=0, eval=False, eval_crop_ratio=1.0, dist_eval=False, num_workers=10, pin_mem=True, world_size=1, dist_url='env://', exp_name='stitch_s_l_v2_lora_r_64_50_ep', config='gradio_demo.json', scoring=False, proxy='synflow', snnet_name='snnet_v2', get_flops=False, flops_sampling_k=None, low_rank=False, lora_r=64, flops_gap=1.0, anchors=['deit_small_patch16_LS', 'deit_large_patch16_LS'], anchor_drop_path=[0.05, 0.4], no_repeated_aug=True, pretrained_deit='../pretrained_weights', distributed=False)
|
| 2 |
+
2024-01-18 17:23:32,358 - snnet - INFO - Creating model: deit_small_patch16_LS
|
| 3 |
+
2024-01-18 17:23:32,606 - snnet - INFO - Creating model: deit_large_patch16_LS
|
| 4 |
+
2024-01-18 17:23:35,576 - snnet - INFO - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134]
|
| 5 |
+
2024-01-18 17:23:36,120 - snnet - INFO - load checkpoint from snnetv2_deit3_s_l.pth
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
timm==0.6.12
|
| 3 |
+
fvcore
|
snnet.py
ADDED
|
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.import math
|
| 2 |
+
import json
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
from collections import defaultdict
|
| 11 |
+
from utils import get_root_logger
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
def rearrange_activations(activations):
|
| 15 |
+
n_channels = activations.shape[-1]
|
| 16 |
+
activations = activations.reshape(-1, n_channels)
|
| 17 |
+
return activations
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def ps_inv(x1, x2):
|
| 21 |
+
'''Least-squares solver given feature maps from two anchors.
|
| 22 |
+
'''
|
| 23 |
+
x1 = rearrange_activations(x1)
|
| 24 |
+
x2 = rearrange_activations(x2)
|
| 25 |
+
|
| 26 |
+
if not x1.shape[0] == x2.shape[0]:
|
| 27 |
+
raise ValueError('Spatial size of compared neurons must match when ' \
|
| 28 |
+
'calculating psuedo inverse matrix.')
|
| 29 |
+
|
| 30 |
+
# Get transformation matrix shape
|
| 31 |
+
shape = list(x1.shape)
|
| 32 |
+
shape[-1] += 1
|
| 33 |
+
|
| 34 |
+
# Calculate pseudo inverse
|
| 35 |
+
x1_ones = torch.ones(shape)
|
| 36 |
+
x1_ones[:, :-1] = x1
|
| 37 |
+
A_ones = torch.matmul(torch.linalg.pinv(x1_ones), x2.to(x1_ones.device)).T
|
| 38 |
+
|
| 39 |
+
# Get weights and bias
|
| 40 |
+
w = A_ones[..., :-1]
|
| 41 |
+
b = A_ones[..., -1]
|
| 42 |
+
|
| 43 |
+
return w, b
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def reset_out_indices(front_depth=12, end_depth=24, out_indices=(9, 14, 19, 23)):
|
| 47 |
+
block_ids = torch.tensor(list(range(front_depth)))
|
| 48 |
+
block_ids = block_ids[None, None, :].float()
|
| 49 |
+
end_mapping_ids = torch.nn.functional.interpolate(block_ids, end_depth)
|
| 50 |
+
end_mapping_ids = end_mapping_ids.squeeze().long().tolist()
|
| 51 |
+
|
| 52 |
+
small_out_indices = []
|
| 53 |
+
for i, idx in enumerate(end_mapping_ids):
|
| 54 |
+
if i in out_indices:
|
| 55 |
+
small_out_indices.append(idx)
|
| 56 |
+
|
| 57 |
+
return small_out_indices
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def get_stitch_configs_general_unequal(depths):
|
| 61 |
+
depths = sorted(depths)
|
| 62 |
+
|
| 63 |
+
total_configs = []
|
| 64 |
+
|
| 65 |
+
# anchor configurations
|
| 66 |
+
total_configs.append({'comb_id': [1], })
|
| 67 |
+
num_stitches = depths[0]
|
| 68 |
+
for i, blk_id in enumerate(range(num_stitches)):
|
| 69 |
+
total_configs.append({
|
| 70 |
+
'comb_id': (0, 1),
|
| 71 |
+
'stitch_cfgs': (i, (i + 1) * (depths[1] // depths[0]))
|
| 72 |
+
})
|
| 73 |
+
return total_configs, num_stitches
|
| 74 |
+
|
| 75 |
+
def get_stitch_configs_bidirection(depths):
|
| 76 |
+
depths = sorted(depths)
|
| 77 |
+
|
| 78 |
+
total_configs = []
|
| 79 |
+
|
| 80 |
+
# anchor configurations
|
| 81 |
+
total_configs.append({'comb_id': [0], })
|
| 82 |
+
total_configs.append({'comb_id': [1], })
|
| 83 |
+
|
| 84 |
+
num_stitches = depths[0]
|
| 85 |
+
|
| 86 |
+
# small --> large
|
| 87 |
+
sl_configs = []
|
| 88 |
+
for i, blk_id in enumerate(range(num_stitches)):
|
| 89 |
+
sl_configs.append({
|
| 90 |
+
'comb_id': [0, 1],
|
| 91 |
+
'stitch_cfgs': [
|
| 92 |
+
[i, (i + 1) * (depths[1] // depths[0])]
|
| 93 |
+
],
|
| 94 |
+
'stitch_layer_ids': [i]
|
| 95 |
+
})
|
| 96 |
+
|
| 97 |
+
ls_configs = []
|
| 98 |
+
lsl_confgs = []
|
| 99 |
+
block_ids = torch.tensor(list(range(depths[0])))
|
| 100 |
+
block_ids = block_ids[None, None, :].float()
|
| 101 |
+
end_mapping_ids = torch.nn.functional.interpolate(block_ids, depths[1])
|
| 102 |
+
end_mapping_ids = end_mapping_ids.squeeze().long().tolist()
|
| 103 |
+
|
| 104 |
+
# large --> small
|
| 105 |
+
for i in range(depths[1]):
|
| 106 |
+
if depths[1] != depths[0]:
|
| 107 |
+
if i % 2 == 1 and i < (depths[1] - 1):
|
| 108 |
+
ls_configs.append({
|
| 109 |
+
'comb_id': [1, 0],
|
| 110 |
+
'stitch_cfgs': [[i, end_mapping_ids[i] + 1]],
|
| 111 |
+
'stitch_layer_ids': [i // (depths[1] // depths[0])]
|
| 112 |
+
})
|
| 113 |
+
else:
|
| 114 |
+
if i < (depths[1] - 1):
|
| 115 |
+
ls_configs.append({
|
| 116 |
+
'comb_id': [1, 0],
|
| 117 |
+
'stitch_cfgs': [[i, end_mapping_ids[i] + 1]],
|
| 118 |
+
'stitch_layer_ids': [i // (depths[1] // depths[0])]
|
| 119 |
+
})
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# large --> small --> large
|
| 123 |
+
for ls_cfg in ls_configs:
|
| 124 |
+
for sl_cfg in sl_configs:
|
| 125 |
+
if sl_cfg['stitch_layer_ids'][0] == depths[0] - 1:
|
| 126 |
+
continue
|
| 127 |
+
if sl_cfg['stitch_cfgs'][0][0] >= ls_cfg['stitch_cfgs'][0][1]:
|
| 128 |
+
lsl_confgs.append({
|
| 129 |
+
'comb_id': [1, 0, 1],
|
| 130 |
+
'stitch_cfgs': [ls_cfg['stitch_cfgs'][0], sl_cfg['stitch_cfgs'][0]],
|
| 131 |
+
'stitch_layer_ids': ls_cfg['stitch_layer_ids'] + sl_cfg['stitch_layer_ids']
|
| 132 |
+
})
|
| 133 |
+
|
| 134 |
+
# small --> large --> small
|
| 135 |
+
sls_configs = []
|
| 136 |
+
for sl_cfg in sl_configs:
|
| 137 |
+
for ls_cfg in ls_configs:
|
| 138 |
+
if ls_cfg['stitch_cfgs'][0][0] >= sl_cfg['stitch_cfgs'][0][1]:
|
| 139 |
+
sls_configs.append({
|
| 140 |
+
'comb_id': [0, 1, 0],
|
| 141 |
+
'stitch_cfgs': [sl_cfg['stitch_cfgs'][0], ls_cfg['stitch_cfgs'][0]],
|
| 142 |
+
'stitch_layer_ids': sl_cfg['stitch_layer_ids'] + ls_cfg['stitch_layer_ids']
|
| 143 |
+
})
|
| 144 |
+
|
| 145 |
+
total_configs += sl_configs + ls_configs + lsl_confgs + sls_configs
|
| 146 |
+
|
| 147 |
+
anchor_ids = []
|
| 148 |
+
sl_ids = []
|
| 149 |
+
ls_ids = []
|
| 150 |
+
lsl_ids = []
|
| 151 |
+
sls_ids = []
|
| 152 |
+
|
| 153 |
+
for i, cfg in enumerate(total_configs):
|
| 154 |
+
comb_id = cfg['comb_id']
|
| 155 |
+
|
| 156 |
+
if len(comb_id) == 1:
|
| 157 |
+
anchor_ids.append(i)
|
| 158 |
+
continue
|
| 159 |
+
|
| 160 |
+
if len(comb_id) == 2:
|
| 161 |
+
route = []
|
| 162 |
+
front, end = cfg['stitch_cfgs'][0]
|
| 163 |
+
route.append([0, front])
|
| 164 |
+
route.append([end, depths[comb_id[-1]]])
|
| 165 |
+
cfg['route'] = route
|
| 166 |
+
|
| 167 |
+
if comb_id == [0, 1] and front != 11:
|
| 168 |
+
sl_ids.append(i)
|
| 169 |
+
elif comb_id == [1, 0]:
|
| 170 |
+
ls_ids.append(i)
|
| 171 |
+
|
| 172 |
+
if len(comb_id) == 3:
|
| 173 |
+
route = []
|
| 174 |
+
front_1, end_1 = cfg['stitch_cfgs'][0]
|
| 175 |
+
front_2, end_2 = cfg['stitch_cfgs'][1]
|
| 176 |
+
route.append([0, front_1])
|
| 177 |
+
route.append([end_1, front_2])
|
| 178 |
+
route.append([end_2, depths[comb_id[-1]]])
|
| 179 |
+
cfg['route'] = route
|
| 180 |
+
|
| 181 |
+
if comb_id == [1, 0, 1]:
|
| 182 |
+
lsl_ids.append(i)
|
| 183 |
+
elif comb_id == [0, 1, 0]:
|
| 184 |
+
sls_ids.append(i)
|
| 185 |
+
|
| 186 |
+
cfg['stitch_layer_ids'].append(-1)
|
| 187 |
+
|
| 188 |
+
model_combos = [(0, 1), (1, 0)]
|
| 189 |
+
return total_configs, model_combos, [len(sl_configs), len(ls_configs)], anchor_ids, sl_ids, ls_ids, lsl_ids, sls_ids
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def format_out_features(outs, with_cls_token, hw_shape):
|
| 193 |
+
B, _, C = outs[0].shape
|
| 194 |
+
for i in range(len(outs)):
|
| 195 |
+
if with_cls_token:
|
| 196 |
+
# Remove class token and reshape token for decoder head
|
| 197 |
+
outs[i] = outs[i][:, 1:].reshape(B, hw_shape[0], hw_shape[1],
|
| 198 |
+
C).permute(0, 3, 1, 2).contiguous()
|
| 199 |
+
else:
|
| 200 |
+
outs[i] = outs[i].reshape(B, hw_shape[0], hw_shape[1],
|
| 201 |
+
C).permute(0, 3, 1, 2).contiguous()
|
| 202 |
+
return outs
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class LoRALayer():
|
| 206 |
+
def __init__(
|
| 207 |
+
self,
|
| 208 |
+
r: int,
|
| 209 |
+
lora_alpha: int,
|
| 210 |
+
lora_dropout: float,
|
| 211 |
+
merge_weights: bool,
|
| 212 |
+
):
|
| 213 |
+
self.r = r
|
| 214 |
+
self.lora_alpha = lora_alpha
|
| 215 |
+
# Optional dropout
|
| 216 |
+
if lora_dropout > 0.:
|
| 217 |
+
self.lora_dropout = nn.Dropout(p=lora_dropout)
|
| 218 |
+
else:
|
| 219 |
+
self.lora_dropout = lambda x: x
|
| 220 |
+
# Mark the weight as unmerged
|
| 221 |
+
self.merged = False
|
| 222 |
+
self.merge_weights = merge_weights
|
| 223 |
+
|
| 224 |
+
class Linear(nn.Linear, LoRALayer):
|
| 225 |
+
# LoRA implemented in a dense layer
|
| 226 |
+
def __init__(
|
| 227 |
+
self,
|
| 228 |
+
in_features: int,
|
| 229 |
+
out_features: int,
|
| 230 |
+
r: int = 0,
|
| 231 |
+
lora_alpha: int = 1,
|
| 232 |
+
lora_dropout: float = 0.,
|
| 233 |
+
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
| 234 |
+
merge_weights: bool = True,
|
| 235 |
+
**kwargs
|
| 236 |
+
):
|
| 237 |
+
nn.Linear.__init__(self, in_features, out_features, **kwargs)
|
| 238 |
+
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
|
| 239 |
+
merge_weights=merge_weights)
|
| 240 |
+
|
| 241 |
+
self.fan_in_fan_out = fan_in_fan_out
|
| 242 |
+
# Actual trainable parameters
|
| 243 |
+
if r > 0:
|
| 244 |
+
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
|
| 245 |
+
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
|
| 246 |
+
self.scaling = self.lora_alpha / self.r
|
| 247 |
+
# Freezing the pre-trained weight matrix
|
| 248 |
+
self.weight.requires_grad = False
|
| 249 |
+
self.reset_parameters()
|
| 250 |
+
if fan_in_fan_out:
|
| 251 |
+
self.weight.data = self.weight.data.transpose(0, 1)
|
| 252 |
+
|
| 253 |
+
def reset_parameters(self):
|
| 254 |
+
nn.Linear.reset_parameters(self)
|
| 255 |
+
if hasattr(self, 'lora_A'):
|
| 256 |
+
# initialize A the same way as the default for nn.Linear and B to zero
|
| 257 |
+
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
| 258 |
+
nn.init.zeros_(self.lora_B)
|
| 259 |
+
|
| 260 |
+
def train(self, mode: bool = True):
|
| 261 |
+
def T(w):
|
| 262 |
+
return w.transpose(0, 1) if self.fan_in_fan_out else w
|
| 263 |
+
nn.Linear.train(self, mode)
|
| 264 |
+
if mode:
|
| 265 |
+
if self.merge_weights and self.merged:
|
| 266 |
+
# Make sure that the weights are not merged
|
| 267 |
+
if self.r > 0:
|
| 268 |
+
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
|
| 269 |
+
self.merged = False
|
| 270 |
+
else:
|
| 271 |
+
if self.merge_weights and not self.merged:
|
| 272 |
+
# Merge the weights and mark it
|
| 273 |
+
if self.r > 0:
|
| 274 |
+
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
|
| 275 |
+
self.merged = True
|
| 276 |
+
|
| 277 |
+
def forward(self, x: torch.Tensor):
|
| 278 |
+
def T(w):
|
| 279 |
+
return w.transpose(0, 1) if self.fan_in_fan_out else w
|
| 280 |
+
if self.r > 0 and not self.merged:
|
| 281 |
+
result = F.linear(x, T(self.weight), bias=self.bias)
|
| 282 |
+
if self.r > 0:
|
| 283 |
+
result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
|
| 284 |
+
return result
|
| 285 |
+
else:
|
| 286 |
+
return F.linear(x, T(self.weight), bias=self.bias)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
class StitchingLayer(nn.Module):
|
| 290 |
+
def __init__(self, in_features=None, out_features=None, r=0):
|
| 291 |
+
super().__init__()
|
| 292 |
+
self.transform = Linear(in_features, out_features, r=r)
|
| 293 |
+
|
| 294 |
+
def init_stitch_weights_bias(self, weight, bias):
|
| 295 |
+
self.transform.weight.data.copy_(weight)
|
| 296 |
+
self.transform.bias.data.copy_(bias)
|
| 297 |
+
|
| 298 |
+
def forward(self, x):
|
| 299 |
+
out = self.transform(x)
|
| 300 |
+
return out
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class SNNet(nn.Module):
|
| 304 |
+
|
| 305 |
+
def __init__(self, anchors=None):
|
| 306 |
+
super(SNNet, self).__init__()
|
| 307 |
+
self.anchors = nn.ModuleList(anchors)
|
| 308 |
+
|
| 309 |
+
self.depths = [len(anc.blocks) for anc in self.anchors]
|
| 310 |
+
|
| 311 |
+
total_configs, num_stitches = get_stitch_configs_general_unequal(self.depths)
|
| 312 |
+
self.stitch_layers = nn.ModuleList(
|
| 313 |
+
[StitchingLayer(self.anchors[0].embed_dim, self.anchors[1].embed_dim) for _ in range(num_stitches)])
|
| 314 |
+
|
| 315 |
+
self.stitch_configs = {i: cfg for i, cfg in enumerate(total_configs)}
|
| 316 |
+
self.all_cfgs = list(self.stitch_configs.keys())
|
| 317 |
+
self.num_configs = len(self.all_cfgs)
|
| 318 |
+
self.stitch_config_id = 0
|
| 319 |
+
self.is_ranking = False
|
| 320 |
+
|
| 321 |
+
def reset_stitch_id(self, stitch_config_id):
|
| 322 |
+
self.stitch_config_id = stitch_config_id
|
| 323 |
+
|
| 324 |
+
def initialize_stitching_weights(self, x):
|
| 325 |
+
logger = get_root_logger()
|
| 326 |
+
front, end = 0, 1
|
| 327 |
+
with torch.no_grad():
|
| 328 |
+
front_features = self.anchors[front].extract_block_features(x)
|
| 329 |
+
end_features = self.anchors[end].extract_block_features(x)
|
| 330 |
+
|
| 331 |
+
for i, blk_id in enumerate(range(self.depths[0])):
|
| 332 |
+
front_id, end_id = i, (i + 1) * (self.depths[1] // self.depths[0])
|
| 333 |
+
front_blk_feat = front_features[front_id]
|
| 334 |
+
end_blk_feat = end_features[end_id - 1]
|
| 335 |
+
w, b = ps_inv(front_blk_feat, end_blk_feat)
|
| 336 |
+
self.stitch_layers[i].init_stitch_weights_bias(w, b)
|
| 337 |
+
logger.info(f'Initialized Stitching Model {front} to Model {end}, Layer {i}')
|
| 338 |
+
|
| 339 |
+
def init_weights(self):
|
| 340 |
+
for anc in self.anchors:
|
| 341 |
+
anc.init_weights()
|
| 342 |
+
|
| 343 |
+
def sampling_stitch_config(self):
|
| 344 |
+
self.stitch_config_id = np.random.choice(self.all_cfgs)
|
| 345 |
+
|
| 346 |
+
def forward(self, x):
|
| 347 |
+
|
| 348 |
+
stitch_cfg_id = self.stitch_config_id
|
| 349 |
+
comb_id = self.stitch_configs[stitch_cfg_id]['comb_id']
|
| 350 |
+
|
| 351 |
+
if len(comb_id) == 1:
|
| 352 |
+
return self.anchors[comb_id[0]](x)
|
| 353 |
+
|
| 354 |
+
cfg = self.stitch_configs[stitch_cfg_id]['stitch_cfgs']
|
| 355 |
+
|
| 356 |
+
x = self.anchors[comb_id[0]].forward_until(x, blk_id=cfg[0])
|
| 357 |
+
x = self.stitch_layers[cfg[0]](x)
|
| 358 |
+
x = self.anchors[comb_id[1]].forward_from(x, blk_id=cfg[1])
|
| 359 |
+
|
| 360 |
+
return x
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
class SNNetv2(nn.Module):
|
| 364 |
+
|
| 365 |
+
def __init__(self, anchors=None, include_sl=True, include_ls=True, include_lsl=True, include_sls=True, lora_r=0):
|
| 366 |
+
super(SNNetv2, self).__init__()
|
| 367 |
+
self.anchors = nn.ModuleList(anchors)
|
| 368 |
+
|
| 369 |
+
self.lora_r = lora_r
|
| 370 |
+
|
| 371 |
+
self.depths = [len(anc.blocks) for anc in self.anchors]
|
| 372 |
+
|
| 373 |
+
total_configs, model_combos, num_stitches, anchor_ids, sl_ids, ls_ids, lsl_ids, sls_ids = get_stitch_configs_bidirection(self.depths)
|
| 374 |
+
|
| 375 |
+
self.stitch_layers = nn.ModuleList()
|
| 376 |
+
self.stitching_map_id = {}
|
| 377 |
+
|
| 378 |
+
for i, (comb, num_sth) in enumerate(zip(model_combos, num_stitches)):
|
| 379 |
+
front, end = comb
|
| 380 |
+
temp = nn.ModuleList(
|
| 381 |
+
[StitchingLayer(self.anchors[front].embed_dim, self.anchors[end].embed_dim, r=lora_r) for _ in range(num_sth)])
|
| 382 |
+
temp.append(nn.Identity())
|
| 383 |
+
self.stitch_layers.append(temp)
|
| 384 |
+
|
| 385 |
+
self.stitch_configs = {i: cfg for i, cfg in enumerate(total_configs)}
|
| 386 |
+
self.stitch_init_configs = {i: cfg for i, cfg in enumerate(total_configs) if len(cfg['comb_id']) == 2}
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
self.all_cfgs = list(self.stitch_configs.keys())
|
| 390 |
+
logger = get_root_logger()
|
| 391 |
+
logger.info(str(self.all_cfgs))
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
self.all_cfgs = anchor_ids
|
| 395 |
+
|
| 396 |
+
if include_sl:
|
| 397 |
+
self.all_cfgs += sl_ids
|
| 398 |
+
|
| 399 |
+
if include_ls:
|
| 400 |
+
self.all_cfgs += ls_ids
|
| 401 |
+
|
| 402 |
+
if include_lsl:
|
| 403 |
+
self.all_cfgs += lsl_ids
|
| 404 |
+
|
| 405 |
+
if include_sls:
|
| 406 |
+
self.all_cfgs += sls_ids
|
| 407 |
+
|
| 408 |
+
self.num_configs = len(self.stitch_configs)
|
| 409 |
+
self.stitch_config_id = 0
|
| 410 |
+
|
| 411 |
+
def reset_stitch_id(self, stitch_config_id):
|
| 412 |
+
self.stitch_config_id = stitch_config_id
|
| 413 |
+
|
| 414 |
+
def set_ranking_mode(self, ranking_mode):
|
| 415 |
+
self.is_ranking = ranking_mode
|
| 416 |
+
|
| 417 |
+
def initialize_stitching_weights(self, x):
|
| 418 |
+
logger = get_root_logger()
|
| 419 |
+
anchor_features = []
|
| 420 |
+
for anchor in self.anchors:
|
| 421 |
+
with torch.no_grad():
|
| 422 |
+
temp = anchor.extract_block_features(x)
|
| 423 |
+
anchor_features.append(temp)
|
| 424 |
+
|
| 425 |
+
for idx, cfg in self.stitch_init_configs.items():
|
| 426 |
+
comb_id = cfg['comb_id']
|
| 427 |
+
if len(comb_id) == 2:
|
| 428 |
+
front_id, end_id = cfg['stitch_cfgs'][0]
|
| 429 |
+
stitch_layer_id = cfg['stitch_layer_ids'][0]
|
| 430 |
+
front_blk_feat = anchor_features[comb_id[0]][front_id]
|
| 431 |
+
end_blk_feat = anchor_features[comb_id[1]][end_id - 1]
|
| 432 |
+
w, b = ps_inv(front_blk_feat, end_blk_feat)
|
| 433 |
+
self.stitch_layers[comb_id[0]][stitch_layer_id].init_stitch_weights_bias(w, b)
|
| 434 |
+
logger.info(f'Initialized Stitching Layer {cfg}')
|
| 435 |
+
|
| 436 |
+
def init_weights(self):
|
| 437 |
+
for anc in self.anchors:
|
| 438 |
+
anc.init_weights()
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def sampling_stitch_config(self):
|
| 442 |
+
flops_id = np.random.choice(len(self.flops_grouped_cfgs), p=self.flops_sampling_probs)
|
| 443 |
+
stitch_config_id = np.random.choice(self.flops_grouped_cfgs[flops_id])
|
| 444 |
+
return stitch_config_id
|
| 445 |
+
|
| 446 |
+
def forward(self, x):
|
| 447 |
+
|
| 448 |
+
if self.training:
|
| 449 |
+
stitch_cfg_id = self.sampling_stitch_config()
|
| 450 |
+
else:
|
| 451 |
+
stitch_cfg_id = self.stitch_config_id
|
| 452 |
+
|
| 453 |
+
comb_id = self.stitch_configs[stitch_cfg_id]['comb_id']
|
| 454 |
+
|
| 455 |
+
# forward by a single anchor
|
| 456 |
+
if len(comb_id) == 1:
|
| 457 |
+
return self.anchors[comb_id[0]](x)
|
| 458 |
+
|
| 459 |
+
# forward among anchors
|
| 460 |
+
route = self.stitch_configs[stitch_cfg_id]['route']
|
| 461 |
+
stitch_layer_ids = self.stitch_configs[stitch_cfg_id]['stitch_layer_ids']
|
| 462 |
+
|
| 463 |
+
# patch embeding
|
| 464 |
+
x = self.anchors[comb_id[0]].forward_patch_embed(x)
|
| 465 |
+
|
| 466 |
+
for i, (model_id, cfg) in enumerate(zip(comb_id, route)):
|
| 467 |
+
|
| 468 |
+
x = self.anchors[model_id].selective_forward(x, cfg[0], cfg[1])
|
| 469 |
+
x = self.stitch_layers[model_id][stitch_layer_ids[i]](x)
|
| 470 |
+
|
| 471 |
+
x = self.anchors[comb_id[-1]].forward_norm_head(x)
|
| 472 |
+
return x
|
| 473 |
+
|
snnetv2_deit3_s_l.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9d455f17d73f4ed74702076d4cea516194d8c4aa8fbbc63192f85795f79c76b4
|
| 3 |
+
size 1350494458
|
stitches_res_s_l.txt
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"loss": 0.7156664722345092, "acc1": 82.9060024609375, "acc5": 96.73400244140625, "cfg_id": 0, "flops": 4608338304}
|
| 2 |
+
{"loss": 0.5377805712209507, "acc1": 86.97800256835937, "acc5": 98.2540023046875, "cfg_id": 1, "flops": 61604135936}
|
| 3 |
+
{"loss": 0.5598483879796483, "acc1": 86.57800241210937, "acc5": 98.08200240234375, "cfg_id": 2, "flops": 56843745792}
|
| 4 |
+
{"loss": 0.5534007405354218, "acc1": 86.6480025390625, "acc5": 98.1760025390625, "cfg_id": 3, "flops": 52102230016}
|
| 5 |
+
{"loss": 0.5610568577028585, "acc1": 86.49800245117187, "acc5": 98.06600229492187, "cfg_id": 4, "flops": 47360714240}
|
| 6 |
+
{"loss": 0.5747850706067049, "acc1": 86.26000259765625, "acc5": 97.93800240234376, "cfg_id": 5, "flops": 42619198464}
|
| 7 |
+
{"loss": 0.5890085864812136, "acc1": 85.79200244140625, "acc5": 97.80000272460937, "cfg_id": 6, "flops": 37877682688}
|
| 8 |
+
{"loss": 0.6165087098876635, "acc1": 85.08200264648437, "acc5": 97.55600231445312, "cfg_id": 7, "flops": 33136166912}
|
| 9 |
+
{"loss": 0.6652509210574807, "acc1": 83.69200263671875, "acc5": 97.23600259765625, "cfg_id": 8, "flops": 28394651136}
|
| 10 |
+
{"loss": 0.7374675334290122, "acc1": 81.7120026171875, "acc5": 96.53200251953125, "cfg_id": 9, "flops": 23653135360}
|
| 11 |
+
{"loss": 0.7991558508665273, "acc1": 79.50600241210938, "acc5": 95.90200240234375, "cfg_id": 10, "flops": 18911619584}
|
| 12 |
+
{"loss": 0.7554851990531791, "acc1": 80.63600265625, "acc5": 96.09000245117187, "cfg_id": 11, "flops": 14170103808}
|
| 13 |
+
{"loss": 0.7068120487824534, "acc1": 82.25000237304687, "acc5": 96.35600284179688, "cfg_id": 12, "flops": 9428588032}
|
| 14 |
+
{"loss": 0.7329587066038088, "acc1": 82.6600027734375, "acc5": 96.58200264648437, "cfg_id": 14, "flops": 9523655552}
|
| 15 |
+
{"loss": 0.7238117807516546, "acc1": 82.94800252929687, "acc5": 96.68600232421875, "cfg_id": 15, "flops": 14265171328}
|
| 16 |
+
{"loss": 0.7139950410434694, "acc1": 83.0860026953125, "acc5": 96.75800252929687, "cfg_id": 16, "flops": 19006687104}
|
| 17 |
+
{"loss": 0.7004092067028537, "acc1": 83.25400249023437, "acc5": 96.8740026171875, "cfg_id": 17, "flops": 23748202880}
|
| 18 |
+
{"loss": 0.6828147762201049, "acc1": 83.45000244140626, "acc5": 96.9520026171875, "cfg_id": 18, "flops": 28489718656}
|
| 19 |
+
{"loss": 0.6787144099221085, "acc1": 83.56400258789063, "acc5": 97.0600024609375, "cfg_id": 19, "flops": 33231234432}
|
| 20 |
+
{"loss": 0.6765228407175252, "acc1": 83.43400251953125, "acc5": 97.19200266601563, "cfg_id": 20, "flops": 37972750208}
|
| 21 |
+
{"loss": 0.6841061733888857, "acc1": 83.5900022265625, "acc5": 97.20800275390626, "cfg_id": 21, "flops": 42714265984}
|
| 22 |
+
{"loss": 0.6446758140104286, "acc1": 84.8660023828125, "acc5": 97.44400258789062, "cfg_id": 22, "flops": 47455781760}
|
| 23 |
+
{"loss": 0.5939652780917558, "acc1": 86.23000265625, "acc5": 97.69200270507812, "cfg_id": 23, "flops": 52197297536}
|
| 24 |
+
{"loss": 0.5654762382760192, "acc1": 86.43400250976562, "acc5": 97.632002578125, "cfg_id": 24, "flops": 56938813312}
|
| 25 |
+
{"loss": 0.5636055112788172, "acc1": 86.39000270507813, "acc5": 98.04800252929688, "cfg_id": 25, "flops": 57017547264}
|
| 26 |
+
{"loss": 0.5706944450397383, "acc1": 86.234002578125, "acc5": 98.00000237304687, "cfg_id": 26, "flops": 52276031488}
|
| 27 |
+
{"loss": 0.5833309799658529, "acc1": 85.9240025390625, "acc5": 97.9260024609375, "cfg_id": 27, "flops": 47534515712}
|
| 28 |
+
{"loss": 0.5972222860225223, "acc1": 85.57400262695313, "acc5": 97.70800255859375, "cfg_id": 28, "flops": 42792999936}
|
| 29 |
+
{"loss": 0.6253456006560362, "acc1": 84.89800259765624, "acc5": 97.47400255859375, "cfg_id": 29, "flops": 38051484160}
|
| 30 |
+
{"loss": 0.6745385262889393, "acc1": 83.5380026171875, "acc5": 97.07600244140625, "cfg_id": 30, "flops": 33309968384}
|
| 31 |
+
{"loss": 0.7486309014034994, "acc1": 81.42600245117187, "acc5": 96.33600258789062, "cfg_id": 31, "flops": 28568452608}
|
| 32 |
+
{"loss": 0.8134756960877867, "acc1": 79.16000235351562, "acc5": 95.72400271484375, "cfg_id": 32, "flops": 23826936832}
|
| 33 |
+
{"loss": 0.7671100513050051, "acc1": 80.37200240234375, "acc5": 95.98400258789063, "cfg_id": 33, "flops": 19085421056}
|
| 34 |
+
{"loss": 0.7206548866674756, "acc1": 81.91000239257812, "acc5": 96.23800239257812, "cfg_id": 34, "flops": 14343905280}
|
| 35 |
+
{"loss": 0.5626872230998494, "acc1": 86.44600235351562, "acc5": 98.062002421875, "cfg_id": 35, "flops": 57017547264}
|
| 36 |
+
{"loss": 0.5785287711769342, "acc1": 86.06400251953124, "acc5": 97.9420023046875, "cfg_id": 36, "flops": 52276031488}
|
| 37 |
+
{"loss": 0.5930487287202568, "acc1": 85.78400234375, "acc5": 97.79000255859376, "cfg_id": 37, "flops": 47534515712}
|
| 38 |
+
{"loss": 0.6189901619923838, "acc1": 85.10800268554688, "acc5": 97.50400228515625, "cfg_id": 38, "flops": 42792999936}
|
| 39 |
+
{"loss": 0.6674688318462083, "acc1": 83.76600272460938, "acc5": 97.09400264648437, "cfg_id": 39, "flops": 38051484160}
|
| 40 |
+
{"loss": 0.7388352820593299, "acc1": 81.70200266601563, "acc5": 96.47000241210938, "cfg_id": 40, "flops": 33309968384}
|
| 41 |
+
{"loss": 0.803126322613521, "acc1": 79.4560025390625, "acc5": 95.81400245117187, "cfg_id": 41, "flops": 28568452608}
|
| 42 |
+
{"loss": 0.7581946616145697, "acc1": 80.70800243164062, "acc5": 96.08600255859375, "cfg_id": 42, "flops": 23826936832}
|
| 43 |
+
{"loss": 0.7118472667467414, "acc1": 82.22600248046875, "acc5": 96.31000268554688, "cfg_id": 43, "flops": 19085421056}
|
| 44 |
+
{"loss": 0.5727639499713074, "acc1": 86.2180025, "acc5": 97.98200247070312, "cfg_id": 44, "flops": 57017547264}
|
| 45 |
+
{"loss": 0.5866389607615543, "acc1": 85.84400263671876, "acc5": 97.8600024609375, "cfg_id": 45, "flops": 52276031488}
|
| 46 |
+
{"loss": 0.6107792718279542, "acc1": 85.19800255859376, "acc5": 97.61800235351562, "cfg_id": 46, "flops": 47534515712}
|
| 47 |
+
{"loss": 0.6602028349809574, "acc1": 83.92600266601562, "acc5": 97.23800282226563, "cfg_id": 47, "flops": 42792999936}
|
| 48 |
+
{"loss": 0.7285334389431007, "acc1": 82.0040028125, "acc5": 96.52400247070312, "cfg_id": 48, "flops": 38051484160}
|
| 49 |
+
{"loss": 0.7910783413910505, "acc1": 79.69600262695313, "acc5": 95.95800241210938, "cfg_id": 49, "flops": 33309968384}
|
| 50 |
+
{"loss": 0.7478298004152197, "acc1": 80.89400243164063, "acc5": 96.172002421875, "cfg_id": 50, "flops": 28568452608}
|
| 51 |
+
{"loss": 0.7014034044449077, "acc1": 82.45600264648438, "acc5": 96.438002421875, "cfg_id": 51, "flops": 23826936832}
|
| 52 |
+
{"loss": 0.5799332931637764, "acc1": 85.92400239257813, "acc5": 97.94000249023438, "cfg_id": 52, "flops": 57017547264}
|
| 53 |
+
{"loss": 0.6004864230300441, "acc1": 85.43800255859375, "acc5": 97.70800227539063, "cfg_id": 53, "flops": 52276031488}
|
| 54 |
+
{"loss": 0.647012604287628, "acc1": 84.20200264648437, "acc5": 97.30600264648437, "cfg_id": 54, "flops": 47534515712}
|
| 55 |
+
{"loss": 0.7162722434961435, "acc1": 82.29000248046874, "acc5": 96.6640023046875, "cfg_id": 55, "flops": 42792999936}
|
| 56 |
+
{"loss": 0.7757266998065241, "acc1": 79.9760025, "acc5": 96.050002421875, "cfg_id": 56, "flops": 38051484160}
|
| 57 |
+
{"loss": 0.7351311787285588, "acc1": 81.04400232421875, "acc5": 96.2400026953125, "cfg_id": 57, "flops": 33309968384}
|
| 58 |
+
{"loss": 0.6896895027408997, "acc1": 82.6220026171875, "acc5": 96.55400252929688, "cfg_id": 58, "flops": 28568452608}
|
| 59 |
+
{"loss": 0.5911911701727094, "acc1": 85.53000266601562, "acc5": 97.76600241210937, "cfg_id": 59, "flops": 57017547264}
|
| 60 |
+
{"loss": 0.6371258264125297, "acc1": 84.41200249023437, "acc5": 97.44200245117187, "cfg_id": 60, "flops": 52276031488}
|
| 61 |
+
{"loss": 0.7022040815403064, "acc1": 82.49400240234375, "acc5": 96.74400272460937, "cfg_id": 61, "flops": 47534515712}
|
| 62 |
+
{"loss": 0.7612808859257987, "acc1": 80.29200265625, "acc5": 96.15800239257813, "cfg_id": 62, "flops": 42792999936}
|
| 63 |
+
{"loss": 0.7246641330420971, "acc1": 81.20400250976563, "acc5": 96.42400264648437, "cfg_id": 63, "flops": 38051484160}
|
| 64 |
+
{"loss": 0.6782861414619468, "acc1": 82.8040024609375, "acc5": 96.60800270507812, "cfg_id": 64, "flops": 33309968384}
|
| 65 |
+
{"loss": 0.629801401755575, "acc1": 84.65200262695312, "acc5": 97.54200265625, "cfg_id": 65, "flops": 57017547264}
|
| 66 |
+
{"loss": 0.6992729283643492, "acc1": 82.58200259765626, "acc5": 96.85600262695313, "cfg_id": 66, "flops": 52276031488}
|
| 67 |
+
{"loss": 0.7595290538262237, "acc1": 80.35600247070313, "acc5": 96.27000247070312, "cfg_id": 67, "flops": 47534515712}
|
| 68 |
+
{"loss": 0.7238247728709019, "acc1": 81.37600248046876, "acc5": 96.46200267578125, "cfg_id": 68, "flops": 42792999936}
|
| 69 |
+
{"loss": 0.6760879844765771, "acc1": 82.96800264648438, "acc5": 96.69400274414062, "cfg_id": 69, "flops": 38051484160}
|
| 70 |
+
{"loss": 0.68392569430624, "acc1": 83.16200254882813, "acc5": 97.09200258789062, "cfg_id": 70, "flops": 57017547264}
|
| 71 |
+
{"loss": 0.7509645553249301, "acc1": 80.68000260742187, "acc5": 96.3900025, "cfg_id": 71, "flops": 52276031488}
|
| 72 |
+
{"loss": 0.7208586267449639, "acc1": 81.55200274414062, "acc5": 96.62200272460937, "cfg_id": 72, "flops": 47534515712}
|
| 73 |
+
{"loss": 0.6785354860352747, "acc1": 82.86000262695312, "acc5": 96.80000244140625, "cfg_id": 73, "flops": 42792999936}
|
| 74 |
+
{"loss": 0.7184764705598354, "acc1": 81.61200241210938, "acc5": 96.63200258789063, "cfg_id": 74, "flops": 57017547264}
|
| 75 |
+
{"loss": 0.7229886900520686, "acc1": 81.45000249023437, "acc5": 96.62200250976562, "cfg_id": 75, "flops": 52276031488}
|
| 76 |
+
{"loss": 0.6883746685855316, "acc1": 83.01600272460938, "acc5": 96.83200240234375, "cfg_id": 76, "flops": 47534515712}
|
| 77 |
+
{"loss": 0.6293963799535325, "acc1": 83.90800231445313, "acc5": 97.27400245117188, "cfg_id": 77, "flops": 57017547264}
|
| 78 |
+
{"loss": 0.642419446824175, "acc1": 84.31400258789063, "acc5": 97.19200287109375, "cfg_id": 78, "flops": 52276031488}
|
| 79 |
+
{"loss": 0.5880116202275861, "acc1": 85.90600231445312, "acc5": 97.58400263671875, "cfg_id": 79, "flops": 57017547264}
|
| 80 |
+
{"loss": 0.750676692096573, "acc1": 82.14200271484376, "acc5": 96.36600260742188, "cfg_id": 80, "flops": 9504781184}
|
| 81 |
+
{"loss": 0.7431871895537232, "acc1": 82.234002578125, "acc5": 96.39200241210938, "cfg_id": 81, "flops": 14246296960}
|
| 82 |
+
{"loss": 0.7236298957105839, "acc1": 82.62600249023437, "acc5": 96.57600243164063, "cfg_id": 82, "flops": 18987812736}
|
| 83 |
+
{"loss": 0.7074674766397837, "acc1": 82.84600237304687, "acc5": 96.68800247070313, "cfg_id": 83, "flops": 23729328512}
|
| 84 |
+
{"loss": 0.7014015182062532, "acc1": 82.99000265625, "acc5": 96.7740025, "cfg_id": 84, "flops": 28470844288}
|
| 85 |
+
{"loss": 0.6996880258348855, "acc1": 82.98000252929687, "acc5": 96.90200263671875, "cfg_id": 85, "flops": 33212360064}
|
| 86 |
+
{"loss": 0.7077699161953095, "acc1": 82.96200270507812, "acc5": 96.98600258789062, "cfg_id": 86, "flops": 37953875840}
|
| 87 |
+
{"loss": 0.6674120087515224, "acc1": 84.33000274414063, "acc5": 97.2640025, "cfg_id": 87, "flops": 42695391616}
|
| 88 |
+
{"loss": 0.6169534720141779, "acc1": 85.86200280273438, "acc5": 97.49200272460938, "cfg_id": 88, "flops": 47436907392}
|
| 89 |
+
{"loss": 0.5848360503600403, "acc1": 86.02600271484376, "acc5": 97.4480026171875, "cfg_id": 89, "flops": 52178423168}
|
| 90 |
+
{"loss": 0.7346750153510859, "acc1": 82.5540027734375, "acc5": 96.52400241210937, "cfg_id": 90, "flops": 9504781184}
|
| 91 |
+
{"loss": 0.7158081559182117, "acc1": 82.82200255859375, "acc5": 96.65600255859376, "cfg_id": 91, "flops": 14246296960}
|
| 92 |
+
{"loss": 0.6994372372600165, "acc1": 83.03600239257813, "acc5": 96.74600252929687, "cfg_id": 92, "flops": 18987812736}
|
| 93 |
+
{"loss": 0.6947964186582601, "acc1": 83.07400255859375, "acc5": 96.88200241210937, "cfg_id": 93, "flops": 23729328512}
|
| 94 |
+
{"loss": 0.6946824553112189, "acc1": 83.0200026171875, "acc5": 96.99400251953125, "cfg_id": 94, "flops": 28470844288}
|
| 95 |
+
{"loss": 0.7041463901599249, "acc1": 83.21600236328125, "acc5": 97.00400250976563, "cfg_id": 95, "flops": 33212360064}
|
| 96 |
+
{"loss": 0.6699620116163384, "acc1": 84.54600258789063, "acc5": 97.31600244140625, "cfg_id": 96, "flops": 37953875840}
|
| 97 |
+
{"loss": 0.6176637105192199, "acc1": 85.91800228515625, "acc5": 97.57000255859376, "cfg_id": 97, "flops": 42695391616}
|
| 98 |
+
{"loss": 0.5765587539045196, "acc1": 86.1880023046875, "acc5": 97.54000249023437, "cfg_id": 98, "flops": 47436907392}
|
| 99 |
+
{"loss": 0.7319535712401072, "acc1": 82.49200258789062, "acc5": 96.46600271484375, "cfg_id": 99, "flops": 9504781184}
|
| 100 |
+
{"loss": 0.710809505516381, "acc1": 82.7860023046875, "acc5": 96.6260026171875, "cfg_id": 100, "flops": 14246296960}
|
| 101 |
+
{"loss": 0.7044268037107858, "acc1": 82.93000239257813, "acc5": 96.79800258789062, "cfg_id": 101, "flops": 18987812736}
|
| 102 |
+
{"loss": 0.7076575808001287, "acc1": 82.83200243164063, "acc5": 96.8820025390625, "cfg_id": 102, "flops": 23729328512}
|
| 103 |
+
{"loss": 0.7188302328189214, "acc1": 82.88200259765625, "acc5": 96.93600249023437, "cfg_id": 103, "flops": 28470844288}
|
| 104 |
+
{"loss": 0.6856357377361167, "acc1": 84.2520023046875, "acc5": 97.2200026171875, "cfg_id": 104, "flops": 33212360064}
|
| 105 |
+
{"loss": 0.6273381847210906, "acc1": 85.758002734375, "acc5": 97.44800275390625, "cfg_id": 105, "flops": 37953875840}
|
| 106 |
+
{"loss": 0.5830204013847944, "acc1": 86.01000260742188, "acc5": 97.4780026171875, "cfg_id": 106, "flops": 42695391616}
|
| 107 |
+
{"loss": 0.7305513945492831, "acc1": 82.3180023828125, "acc5": 96.47200256835937, "cfg_id": 107, "flops": 9504781184}
|
| 108 |
+
{"loss": 0.7206297208639708, "acc1": 82.37200228515626, "acc5": 96.61600234375, "cfg_id": 108, "flops": 14246296960}
|
| 109 |
+
{"loss": 0.7241401795975186, "acc1": 82.33800244140625, "acc5": 96.74400267578125, "cfg_id": 109, "flops": 18987812736}
|
| 110 |
+
{"loss": 0.7350799917723193, "acc1": 82.57800231445313, "acc5": 96.78000252929688, "cfg_id": 110, "flops": 23729328512}
|
| 111 |
+
{"loss": 0.7013292148935072, "acc1": 83.95000255859375, "acc5": 97.17600254882812, "cfg_id": 111, "flops": 28470844288}
|
| 112 |
+
{"loss": 0.64130035031474, "acc1": 85.54600244140624, "acc5": 97.36800262695313, "cfg_id": 112, "flops": 33212360064}
|
| 113 |
+
{"loss": 0.5961506485826138, "acc1": 85.76800252929688, "acc5": 97.33400268554688, "cfg_id": 113, "flops": 37953875840}
|
| 114 |
+
{"loss": 0.7443677056580782, "acc1": 81.86200244140625, "acc5": 96.29200241210937, "cfg_id": 114, "flops": 9504781184}
|
| 115 |
+
{"loss": 0.7442678388659701, "acc1": 81.82000262695313, "acc5": 96.3900026953125, "cfg_id": 115, "flops": 14246296960}
|
| 116 |
+
{"loss": 0.749958168037913, "acc1": 81.99400229492187, "acc5": 96.49600264648437, "cfg_id": 116, "flops": 18987812736}
|
| 117 |
+
{"loss": 0.7073916116672935, "acc1": 83.43200267578125, "acc5": 96.99200250976563, "cfg_id": 117, "flops": 23729328512}
|
| 118 |
+
{"loss": 0.6501240834706661, "acc1": 85.1480025, "acc5": 97.212002734375, "cfg_id": 118, "flops": 28470844288}
|
| 119 |
+
{"loss": 0.6135486943477934, "acc1": 85.35400266601563, "acc5": 97.16000249023438, "cfg_id": 119, "flops": 33212360064}
|
| 120 |
+
{"loss": 0.7824224890633062, "acc1": 81.16600250976562, "acc5": 96.11200255859374, "cfg_id": 120, "flops": 9504781184}
|
| 121 |
+
{"loss": 0.7932735298844901, "acc1": 80.93600233398438, "acc5": 96.12600254882813, "cfg_id": 121, "flops": 14246296960}
|
| 122 |
+
{"loss": 0.7476008046757091, "acc1": 82.49800259765625, "acc5": 96.65000265625, "cfg_id": 122, "flops": 18987812736}
|
| 123 |
+
{"loss": 0.6842290776019747, "acc1": 84.35800275390625, "acc5": 96.94000270507813, "cfg_id": 123, "flops": 23729328512}
|
| 124 |
+
{"loss": 0.6409159135073423, "acc1": 84.520002578125, "acc5": 96.91800252929687, "cfg_id": 124, "flops": 28470844288}
|
| 125 |
+
{"loss": 0.8394820786109476, "acc1": 79.96600270507813, "acc5": 95.5980024609375, "cfg_id": 125, "flops": 9504781184}
|
| 126 |
+
{"loss": 0.7924092794683847, "acc1": 81.1100024609375, "acc5": 96.12600264648438, "cfg_id": 126, "flops": 14246296960}
|
| 127 |
+
{"loss": 0.724013783997207, "acc1": 83.04000248046874, "acc5": 96.46600255859374, "cfg_id": 127, "flops": 18987812736}
|
| 128 |
+
{"loss": 0.6948224610338608, "acc1": 83.016002421875, "acc5": 96.44000282226563, "cfg_id": 128, "flops": 23729328512}
|
| 129 |
+
{"loss": 0.8688964597655066, "acc1": 79.15800264160156, "acc5": 95.2460025, "cfg_id": 129, "flops": 9504781184}
|
| 130 |
+
{"loss": 0.8095247168658357, "acc1": 80.57800265625, "acc5": 95.63000274414063, "cfg_id": 130, "flops": 14246296960}
|
| 131 |
+
{"loss": 0.7750140779059042, "acc1": 80.9680026171875, "acc5": 95.68200263671875, "cfg_id": 131, "flops": 18987812736}
|
| 132 |
+
{"loss": 0.9017180648039688, "acc1": 77.88400251953125, "acc5": 94.78400234375, "cfg_id": 132, "flops": 9504781184}
|
| 133 |
+
{"loss": 0.8799216277671583, "acc1": 77.84800252929688, "acc5": 94.79800247070312, "cfg_id": 133, "flops": 14246296960}
|
| 134 |
+
{"loss": 0.8987371790589709, "acc1": 78.12000266601562, "acc5": 94.92200256835937, "cfg_id": 134, "flops": 9504781184}
|
utils.py
ADDED
|
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2015-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
"""
|
| 4 |
+
Misc functions, including distributed helpers.
|
| 5 |
+
|
| 6 |
+
Mostly copy-paste from torchvision references.
|
| 7 |
+
"""
|
| 8 |
+
import io
|
| 9 |
+
import os
|
| 10 |
+
import time
|
| 11 |
+
from collections import defaultdict, deque
|
| 12 |
+
import datetime
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
import logging
|
| 17 |
+
|
| 18 |
+
logger_initialized = {}
|
| 19 |
+
|
| 20 |
+
def group_subnets_by_flops(data, flops_gap=1.0):
|
| 21 |
+
sorted_data = {k: v for k, v in sorted(data.items(), key=lambda item: item[1])}
|
| 22 |
+
candidate_idx = []
|
| 23 |
+
grouped_cands = []
|
| 24 |
+
last_flops = 0
|
| 25 |
+
for cfg_id, flops in sorted_data.items():
|
| 26 |
+
flops = flops / 1e9
|
| 27 |
+
if abs(last_flops - flops) > flops_gap:
|
| 28 |
+
if len(candidate_idx) > 0:
|
| 29 |
+
grouped_cands.append(sorted(candidate_idx))
|
| 30 |
+
candidate_idx = [int(cfg_id)]
|
| 31 |
+
last_flops = flops
|
| 32 |
+
else:
|
| 33 |
+
candidate_idx.append(int(cfg_id))
|
| 34 |
+
|
| 35 |
+
if len(candidate_idx) > 0:
|
| 36 |
+
grouped_cands.append(sorted(candidate_idx))
|
| 37 |
+
|
| 38 |
+
return grouped_cands
|
| 39 |
+
|
| 40 |
+
def find_best_candidates(data):
|
| 41 |
+
sorted_data = {k: v for k, v in sorted(data.items(), key=lambda item: item[1])}
|
| 42 |
+
candidate_idx = []
|
| 43 |
+
last_flops = 0
|
| 44 |
+
for cfg_id, values in sorted_data.items():
|
| 45 |
+
flops, score = values
|
| 46 |
+
if abs(last_flops - flops) > 1:
|
| 47 |
+
candidate_idx.append(cfg_id)
|
| 48 |
+
last_flops = flops
|
| 49 |
+
else:
|
| 50 |
+
if score > data[candidate_idx[-1]][1]:
|
| 51 |
+
candidate_idx[-1] = cfg_id
|
| 52 |
+
|
| 53 |
+
return candidate_idx
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def find_top_candidates(data, ratio=0.9):
|
| 58 |
+
sorted_data = {k: v for k, v in sorted(data.items(), key=lambda item: item[1])}
|
| 59 |
+
candidate_idx = []
|
| 60 |
+
grouped_cands = []
|
| 61 |
+
last_flops = 0
|
| 62 |
+
for cfg_id, values in sorted_data.items():
|
| 63 |
+
flops, score = values
|
| 64 |
+
if abs(last_flops - flops) > 3:
|
| 65 |
+
if len(candidate_idx) > 0:
|
| 66 |
+
grouped_cands.append(candidate_idx)
|
| 67 |
+
candidate_idx = [cfg_id]
|
| 68 |
+
last_flops = flops
|
| 69 |
+
else:
|
| 70 |
+
candidate_idx.append(cfg_id)
|
| 71 |
+
|
| 72 |
+
if len(candidate_idx) > 0:
|
| 73 |
+
grouped_cands.append(candidate_idx)
|
| 74 |
+
|
| 75 |
+
final_list = []
|
| 76 |
+
for group in grouped_cands:
|
| 77 |
+
if len(group) == 1:
|
| 78 |
+
final_list += list(map(int, group))
|
| 79 |
+
continue
|
| 80 |
+
scores = torch.tensor([sorted_data[cfg_id][-1] for cfg_id in group])
|
| 81 |
+
|
| 82 |
+
indices = torch.argsort(scores, descending=True)
|
| 83 |
+
num_selected = int(ratio*len(group)) if int(ratio*len(group)) > 0 else 1
|
| 84 |
+
|
| 85 |
+
top_ids = indices[:num_selected].tolist()
|
| 86 |
+
selected = [group[idx] for idx in top_ids]
|
| 87 |
+
final_list += list(map(int, selected))
|
| 88 |
+
|
| 89 |
+
return final_list
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
|
| 94 |
+
"""Initialize and get a logger by name.
|
| 95 |
+
|
| 96 |
+
If the logger has not been initialized, this method will initialize the
|
| 97 |
+
logger by adding one or two handlers, otherwise the initialized logger will
|
| 98 |
+
be directly returned. During initialization, a StreamHandler will always be
|
| 99 |
+
added. If `log_file` is specified and the process rank is 0, a FileHandler
|
| 100 |
+
will also be added.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
name (str): Logger name.
|
| 104 |
+
log_file (str | None): The log filename. If specified, a FileHandler
|
| 105 |
+
will be added to the logger.
|
| 106 |
+
log_level (int): The logger level. Note that only the process of
|
| 107 |
+
rank 0 is affected, and other processes will set the level to
|
| 108 |
+
"Error" thus be silent most of the time.
|
| 109 |
+
file_mode (str): The file mode used in opening log file.
|
| 110 |
+
Defaults to 'w'.
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
logging.Logger: The expected logger.
|
| 114 |
+
"""
|
| 115 |
+
logger = logging.getLogger(name)
|
| 116 |
+
if name in logger_initialized:
|
| 117 |
+
return logger
|
| 118 |
+
# handle hierarchical names
|
| 119 |
+
# e.g., logger "a" is initialized, then logger "a.b" will skip the
|
| 120 |
+
# initialization since it is a child of "a".
|
| 121 |
+
for logger_name in logger_initialized:
|
| 122 |
+
if name.startswith(logger_name):
|
| 123 |
+
return logger
|
| 124 |
+
|
| 125 |
+
stream_handler = logging.StreamHandler()
|
| 126 |
+
handlers = [stream_handler]
|
| 127 |
+
|
| 128 |
+
if dist.is_available() and dist.is_initialized():
|
| 129 |
+
rank = dist.get_rank()
|
| 130 |
+
else:
|
| 131 |
+
rank = 0
|
| 132 |
+
|
| 133 |
+
# only rank 0 will add a FileHandler
|
| 134 |
+
if rank == 0 and log_file is not None:
|
| 135 |
+
# Here, the default behaviour of the official logger is 'a'. Thus, we
|
| 136 |
+
# provide an interface to change the file mode to the default
|
| 137 |
+
# behaviour.
|
| 138 |
+
file_handler = logging.FileHandler(log_file, file_mode)
|
| 139 |
+
handlers.append(file_handler)
|
| 140 |
+
|
| 141 |
+
formatter = logging.Formatter(
|
| 142 |
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
| 143 |
+
for handler in handlers:
|
| 144 |
+
handler.setFormatter(formatter)
|
| 145 |
+
handler.setLevel(log_level)
|
| 146 |
+
logger.addHandler(handler)
|
| 147 |
+
|
| 148 |
+
if rank == 0:
|
| 149 |
+
logger.setLevel(log_level)
|
| 150 |
+
else:
|
| 151 |
+
logger.setLevel(logging.ERROR)
|
| 152 |
+
|
| 153 |
+
logger_initialized[name] = True
|
| 154 |
+
|
| 155 |
+
return logger
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def get_root_logger(log_file=None, log_level=logging.INFO):
|
| 159 |
+
"""Get the root logger.
|
| 160 |
+
|
| 161 |
+
The logger will be initialized if it has not been initialized. By default a
|
| 162 |
+
StreamHandler will be added. If `log_file` is specified, a FileHandler will
|
| 163 |
+
also be added. The name of the root logger is the top-level package name,
|
| 164 |
+
e.g., "mmseg".
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
log_file (str | None): The log filename. If specified, a FileHandler
|
| 168 |
+
will be added to the root logger.
|
| 169 |
+
log_level (int): The root logger level. Note that only the process of
|
| 170 |
+
rank 0 is affected, while other processes will set the level to
|
| 171 |
+
"Error" and be silent most of the time.
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
logging.Logger: The root logger.
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
logger = get_logger(name='snnet', log_file=log_file, log_level=log_level)
|
| 178 |
+
|
| 179 |
+
return logger
|
| 180 |
+
|
| 181 |
+
class SmoothedValue(object):
|
| 182 |
+
"""Track a series of values and provide access to smoothed values over a
|
| 183 |
+
window or the global series average.
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
def __init__(self, window_size=20, fmt=None):
|
| 187 |
+
if fmt is None:
|
| 188 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
| 189 |
+
self.deque = deque(maxlen=window_size)
|
| 190 |
+
self.total = 0.0
|
| 191 |
+
self.count = 0
|
| 192 |
+
self.fmt = fmt
|
| 193 |
+
|
| 194 |
+
def update(self, value, n=1):
|
| 195 |
+
self.deque.append(value)
|
| 196 |
+
self.count += n
|
| 197 |
+
self.total += value * n
|
| 198 |
+
|
| 199 |
+
def synchronize_between_processes(self):
|
| 200 |
+
"""
|
| 201 |
+
Warning: does not synchronize the deque!
|
| 202 |
+
"""
|
| 203 |
+
if not is_dist_avail_and_initialized():
|
| 204 |
+
return
|
| 205 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
| 206 |
+
dist.barrier()
|
| 207 |
+
dist.all_reduce(t)
|
| 208 |
+
t = t.tolist()
|
| 209 |
+
self.count = int(t[0])
|
| 210 |
+
self.total = t[1]
|
| 211 |
+
|
| 212 |
+
@property
|
| 213 |
+
def median(self):
|
| 214 |
+
d = torch.tensor(list(self.deque))
|
| 215 |
+
return d.median().item()
|
| 216 |
+
|
| 217 |
+
@property
|
| 218 |
+
def avg(self):
|
| 219 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
| 220 |
+
return d.mean().item()
|
| 221 |
+
|
| 222 |
+
@property
|
| 223 |
+
def global_avg(self):
|
| 224 |
+
return self.total / self.count
|
| 225 |
+
|
| 226 |
+
@property
|
| 227 |
+
def max(self):
|
| 228 |
+
return max(self.deque)
|
| 229 |
+
|
| 230 |
+
@property
|
| 231 |
+
def value(self):
|
| 232 |
+
return self.deque[-1]
|
| 233 |
+
|
| 234 |
+
def __str__(self):
|
| 235 |
+
return self.fmt.format(
|
| 236 |
+
median=self.median,
|
| 237 |
+
avg=self.avg,
|
| 238 |
+
global_avg=self.global_avg,
|
| 239 |
+
max=self.max,
|
| 240 |
+
value=self.value)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class MetricLogger(object):
|
| 244 |
+
def __init__(self, delimiter="\t", logger=None):
|
| 245 |
+
self.meters = defaultdict(SmoothedValue)
|
| 246 |
+
self.delimiter = delimiter
|
| 247 |
+
self.logger = logger
|
| 248 |
+
|
| 249 |
+
def update(self, **kwargs):
|
| 250 |
+
for k, v in kwargs.items():
|
| 251 |
+
if isinstance(v, torch.Tensor):
|
| 252 |
+
v = v.item()
|
| 253 |
+
assert isinstance(v, (float, int))
|
| 254 |
+
self.meters[k].update(v)
|
| 255 |
+
|
| 256 |
+
def __getattr__(self, attr):
|
| 257 |
+
if attr in self.meters:
|
| 258 |
+
return self.meters[attr]
|
| 259 |
+
if attr in self.__dict__:
|
| 260 |
+
return self.__dict__[attr]
|
| 261 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
| 262 |
+
type(self).__name__, attr))
|
| 263 |
+
|
| 264 |
+
def __str__(self):
|
| 265 |
+
loss_str = []
|
| 266 |
+
for name, meter in self.meters.items():
|
| 267 |
+
loss_str.append(
|
| 268 |
+
"{}: {}".format(name, str(meter))
|
| 269 |
+
)
|
| 270 |
+
return self.delimiter.join(loss_str)
|
| 271 |
+
|
| 272 |
+
def synchronize_between_processes(self):
|
| 273 |
+
for meter in self.meters.values():
|
| 274 |
+
meter.synchronize_between_processes()
|
| 275 |
+
|
| 276 |
+
def add_meter(self, name, meter):
|
| 277 |
+
self.meters[name] = meter
|
| 278 |
+
|
| 279 |
+
def log_every(self, iterable, print_freq, header=None):
|
| 280 |
+
i = 0
|
| 281 |
+
if not header:
|
| 282 |
+
header = ''
|
| 283 |
+
start_time = time.time()
|
| 284 |
+
end = time.time()
|
| 285 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
| 286 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
| 287 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
| 288 |
+
log_msg = [
|
| 289 |
+
header,
|
| 290 |
+
'[{0' + space_fmt + '}/{1}]',
|
| 291 |
+
'eta: {eta}',
|
| 292 |
+
'{meters}',
|
| 293 |
+
'time: {time}',
|
| 294 |
+
'data: {data}'
|
| 295 |
+
]
|
| 296 |
+
if torch.cuda.is_available():
|
| 297 |
+
log_msg.append('max mem: {memory:.0f}')
|
| 298 |
+
log_msg = self.delimiter.join(log_msg)
|
| 299 |
+
MB = 1024.0 * 1024.0
|
| 300 |
+
for obj in iterable:
|
| 301 |
+
data_time.update(time.time() - end)
|
| 302 |
+
yield obj
|
| 303 |
+
iter_time.update(time.time() - end)
|
| 304 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
| 305 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
| 306 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
| 307 |
+
if torch.cuda.is_available():
|
| 308 |
+
self.logger.info(log_msg.format(
|
| 309 |
+
i, len(iterable), eta=eta_string,
|
| 310 |
+
meters=str(self),
|
| 311 |
+
time=str(iter_time), data=str(data_time),
|
| 312 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
| 313 |
+
else:
|
| 314 |
+
self.logger.info(log_msg.format(
|
| 315 |
+
i, len(iterable), eta=eta_string,
|
| 316 |
+
meters=str(self),
|
| 317 |
+
time=str(iter_time), data=str(data_time)))
|
| 318 |
+
i += 1
|
| 319 |
+
end = time.time()
|
| 320 |
+
total_time = time.time() - start_time
|
| 321 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 322 |
+
self.logger.info('{} Total time: {} ({:.4f} s / it)'.format(
|
| 323 |
+
header, total_time_str, total_time / len(iterable)))
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def _load_checkpoint_for_ema(model_ema, checkpoint):
|
| 327 |
+
"""
|
| 328 |
+
Workaround for ModelEma._load_checkpoint to accept an already-loaded object
|
| 329 |
+
"""
|
| 330 |
+
mem_file = io.BytesIO()
|
| 331 |
+
torch.save({'state_dict_ema':checkpoint}, mem_file)
|
| 332 |
+
mem_file.seek(0)
|
| 333 |
+
model_ema._load_checkpoint(mem_file)
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def setup_for_distributed(is_master):
|
| 337 |
+
"""
|
| 338 |
+
This function disables printing when not in master process
|
| 339 |
+
"""
|
| 340 |
+
import builtins as __builtin__
|
| 341 |
+
builtin_print = __builtin__.print
|
| 342 |
+
|
| 343 |
+
def print(*args, **kwargs):
|
| 344 |
+
force = kwargs.pop('force', False)
|
| 345 |
+
if is_master or force:
|
| 346 |
+
builtin_print(*args, **kwargs)
|
| 347 |
+
|
| 348 |
+
__builtin__.print = print
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def is_dist_avail_and_initialized():
|
| 352 |
+
if not dist.is_available():
|
| 353 |
+
return False
|
| 354 |
+
if not dist.is_initialized():
|
| 355 |
+
return False
|
| 356 |
+
return True
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def get_world_size():
|
| 360 |
+
if not is_dist_avail_and_initialized():
|
| 361 |
+
return 1
|
| 362 |
+
return dist.get_world_size()
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def get_rank():
|
| 366 |
+
if not is_dist_avail_and_initialized():
|
| 367 |
+
return 0
|
| 368 |
+
return dist.get_rank()
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def is_main_process():
|
| 372 |
+
return get_rank() == 0
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def save_on_master(*args, **kwargs):
|
| 376 |
+
if is_main_process():
|
| 377 |
+
torch.save(*args, **kwargs)
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def init_distributed_mode(args):
|
| 381 |
+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
| 382 |
+
args.rank = int(os.environ["RANK"])
|
| 383 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
| 384 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
| 385 |
+
elif 'SLURM_PROCID' in os.environ:
|
| 386 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
| 387 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
| 388 |
+
else:
|
| 389 |
+
print('Not using distributed mode')
|
| 390 |
+
args.distributed = False
|
| 391 |
+
return
|
| 392 |
+
|
| 393 |
+
args.distributed = True
|
| 394 |
+
|
| 395 |
+
torch.cuda.set_device(args.gpu)
|
| 396 |
+
args.dist_backend = 'nccl'
|
| 397 |
+
print('| distributed init (rank {}): {}'.format(
|
| 398 |
+
args.rank, args.dist_url), flush=True)
|
| 399 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
| 400 |
+
world_size=args.world_size, rank=args.rank)
|
| 401 |
+
torch.distributed.barrier()
|
| 402 |
+
setup_for_distributed(args.rank == 0)
|
| 403 |
+
|
| 404 |
+
import json
|
| 405 |
+
def save_on_master_eval_res(log_stats, output_dir):
|
| 406 |
+
if is_main_process():
|
| 407 |
+
with open(output_dir, 'a') as f:
|
| 408 |
+
f.write(json.dumps(log_stats) + "\n")
|