update to torch2
Browse files- app.py +35 -0
- config.py +4 -1
- data/build.py +6 -15
- data/dataset_fg.py +52 -9
- inference.py +106 -29
- lr_scheduler.py +0 -1
- main.py +74 -19
- models/MetaFG.py +2 -1
app.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from inference import Inference
|
| 2 |
+
import argparse
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import glob
|
| 5 |
+
|
| 6 |
+
def parse_option():
|
| 7 |
+
parser = argparse.ArgumentParser('MetaFG Inference script', add_help=False)
|
| 8 |
+
parser.add_argument('--cfg', type=str, metavar="FILE", help='path to config file', default="configs/MetaFG_2_224.yaml")
|
| 9 |
+
# easy config modification
|
| 10 |
+
parser.add_argument('--model-path', type=str, help="path to model data", default="./ckpt_4_mf2.pth")
|
| 11 |
+
parser.add_argument('--img-size', type=int, default=384, help='path to image')
|
| 12 |
+
parser.add_argument('--meta-path', default="meta.txt", type=str, help='path to meta data')
|
| 13 |
+
parser.add_argument('--names-path', default="names_mf2.txt", type=str, help='path to meta data')
|
| 14 |
+
args = parser.parse_args()
|
| 15 |
+
return args
|
| 16 |
+
|
| 17 |
+
if __name__ == '__main__':
|
| 18 |
+
args = parse_option()
|
| 19 |
+
|
| 20 |
+
model = Inference(config_path=args.cfg,
|
| 21 |
+
model_path=args.model_path,
|
| 22 |
+
names_path=args.names_path)
|
| 23 |
+
|
| 24 |
+
def classify(image):
|
| 25 |
+
preds = model.infer(img_path=image, meta_data_path="meta.txt").squeeze()
|
| 26 |
+
print(len(model.classes))
|
| 27 |
+
print(model.classes)
|
| 28 |
+
confidences = {c: float(preds[i]) for i,c in enumerate(model.classes)}
|
| 29 |
+
|
| 30 |
+
return confidences
|
| 31 |
+
|
| 32 |
+
gr.Interface(pfn=classify,
|
| 33 |
+
inputs=gr.Image(shape=(args.img_size, args.img_size), type="pil"),
|
| 34 |
+
outputs=gr.Label(num_top_classes=10),
|
| 35 |
+
examples=glob.glob("./example_images/*")).launch()
|
config.py
CHANGED
|
@@ -24,6 +24,8 @@ _C.DATA.BATCH_SIZE = 32
|
|
| 24 |
_C.DATA.DATA_PATH = ''
|
| 25 |
# Dataset name
|
| 26 |
_C.DATA.DATASET = 'imagenet'
|
|
|
|
|
|
|
| 27 |
# Input image size
|
| 28 |
_C.DATA.IMG_SIZE = 224
|
| 29 |
# Interpolation to resize image (random, bilinear, bicubic)
|
|
@@ -74,6 +76,7 @@ _C.MODEL.LABEL_SMOOTHING = 0.1
|
|
| 74 |
_C.MODEL.PRETRAINED = None
|
| 75 |
_C.MODEL.DORP_HEAD = True
|
| 76 |
_C.MODEL.DORP_META = True
|
|
|
|
| 77 |
|
| 78 |
_C.MODEL.ONLY_LAST_CLS = False
|
| 79 |
_C.MODEL.EXTRA_TOKEN_NUM = 1
|
|
@@ -255,7 +258,7 @@ def update_config(config, args):
|
|
| 255 |
config.MODEL.PRETRAINED = args.pretrain
|
| 256 |
|
| 257 |
# set local rank for distributed training
|
| 258 |
-
config.LOCAL_RANK =
|
| 259 |
|
| 260 |
# output folder
|
| 261 |
config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)
|
|
|
|
| 24 |
_C.DATA.DATA_PATH = ''
|
| 25 |
# Dataset name
|
| 26 |
_C.DATA.DATASET = 'imagenet'
|
| 27 |
+
# Dataset root folder
|
| 28 |
+
_C.DATA.DATASET_ROOT = None
|
| 29 |
# Input image size
|
| 30 |
_C.DATA.IMG_SIZE = 224
|
| 31 |
# Interpolation to resize image (random, bilinear, bicubic)
|
|
|
|
| 76 |
_C.MODEL.PRETRAINED = None
|
| 77 |
_C.MODEL.DORP_HEAD = True
|
| 78 |
_C.MODEL.DORP_META = True
|
| 79 |
+
_C.MODEL.FREEZE_BACKBONE = True
|
| 80 |
|
| 81 |
_C.MODEL.ONLY_LAST_CLS = False
|
| 82 |
_C.MODEL.EXTRA_TOKEN_NUM = 1
|
|
|
|
| 258 |
config.MODEL.PRETRAINED = args.pretrain
|
| 259 |
|
| 260 |
# set local rank for distributed training
|
| 261 |
+
config.LOCAL_RANK = os.environ['LOCAL_RANK']
|
| 262 |
|
| 263 |
# output folder
|
| 264 |
config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)
|
data/build.py
CHANGED
|
@@ -13,7 +13,7 @@ from torchvision import datasets, transforms
|
|
| 13 |
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 14 |
from timm.data import Mixup
|
| 15 |
from timm.data import create_transform
|
| 16 |
-
from timm.data.transforms import
|
| 17 |
|
| 18 |
from .cached_image_folder import CachedImageFolder
|
| 19 |
from .samplers import SubsetRandomSampler
|
|
@@ -81,50 +81,41 @@ def build_dataset(is_train, config):
|
|
| 81 |
# root = os.path.join(config.DATA.DATA_PATH, prefix)
|
| 82 |
root = './datasets/imagenet'
|
| 83 |
dataset = datasets.ImageFolder(root, transform=transform)
|
| 84 |
-
nb_classes = 1000
|
| 85 |
elif config.DATA.DATASET == 'inaturelist2021':
|
| 86 |
root = './datasets/inaturelist2021'
|
| 87 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
| 88 |
-
nb_classes = 10000
|
| 89 |
elif config.DATA.DATASET == 'inaturelist2021_mini':
|
| 90 |
root = './datasets/inaturelist2021_mini'
|
| 91 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
| 92 |
-
nb_classes = 10000
|
| 93 |
elif config.DATA.DATASET == 'inaturelist2017':
|
| 94 |
root = './datasets/inaturelist2017'
|
| 95 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
| 96 |
-
nb_classes = 5089
|
| 97 |
elif config.DATA.DATASET == 'inaturelist2018':
|
| 98 |
root = './datasets/inaturelist2018'
|
| 99 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
| 100 |
-
nb_classes = 8142
|
| 101 |
elif config.DATA.DATASET == 'cub-200':
|
| 102 |
root = './datasets/cub-200'
|
| 103 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
| 104 |
-
nb_classes = 200
|
| 105 |
elif config.DATA.DATASET == 'stanfordcars':
|
| 106 |
root = './datasets/stanfordcars'
|
| 107 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
| 108 |
-
nb_classes = 196
|
| 109 |
elif config.DATA.DATASET == 'oxfordflower':
|
| 110 |
root = './datasets/oxfordflower'
|
| 111 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
| 112 |
-
nb_classes = 102
|
| 113 |
elif config.DATA.DATASET == 'stanforddogs':
|
| 114 |
root = './datasets/stanforddogs'
|
| 115 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
| 116 |
-
nb_classes = 120
|
| 117 |
elif config.DATA.DATASET == 'nabirds':
|
| 118 |
root = './datasets/nabirds'
|
| 119 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
| 120 |
-
nb_classes = 555
|
| 121 |
elif config.DATA.DATASET == 'aircraft':
|
| 122 |
root = './datasets/aircraft'
|
| 123 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
| 124 |
-
nb_classes = 100
|
| 125 |
else:
|
| 126 |
-
|
|
|
|
| 127 |
|
|
|
|
| 128 |
return dataset, nb_classes
|
| 129 |
|
| 130 |
|
|
@@ -153,14 +144,14 @@ def build_transform(is_train, config):
|
|
| 153 |
if config.TEST.CROP:
|
| 154 |
size = int((256 / 224) * config.DATA.IMG_SIZE)
|
| 155 |
t.append(
|
| 156 |
-
transforms.Resize(size, interpolation=
|
| 157 |
# to maintain same ratio w.r.t. 224 images
|
| 158 |
)
|
| 159 |
t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
|
| 160 |
else:
|
| 161 |
t.append(
|
| 162 |
transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
|
| 163 |
-
interpolation=
|
| 164 |
)
|
| 165 |
|
| 166 |
t.append(transforms.ToTensor())
|
|
|
|
| 13 |
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 14 |
from timm.data import Mixup
|
| 15 |
from timm.data import create_transform
|
| 16 |
+
from timm.data.transforms import str_to_interp_mode
|
| 17 |
|
| 18 |
from .cached_image_folder import CachedImageFolder
|
| 19 |
from .samplers import SubsetRandomSampler
|
|
|
|
| 81 |
# root = os.path.join(config.DATA.DATA_PATH, prefix)
|
| 82 |
root = './datasets/imagenet'
|
| 83 |
dataset = datasets.ImageFolder(root, transform=transform)
|
|
|
|
| 84 |
elif config.DATA.DATASET == 'inaturelist2021':
|
| 85 |
root = './datasets/inaturelist2021'
|
| 86 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
|
|
|
| 87 |
elif config.DATA.DATASET == 'inaturelist2021_mini':
|
| 88 |
root = './datasets/inaturelist2021_mini'
|
| 89 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
|
|
|
| 90 |
elif config.DATA.DATASET == 'inaturelist2017':
|
| 91 |
root = './datasets/inaturelist2017'
|
| 92 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
|
|
|
| 93 |
elif config.DATA.DATASET == 'inaturelist2018':
|
| 94 |
root = './datasets/inaturelist2018'
|
| 95 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
|
|
|
| 96 |
elif config.DATA.DATASET == 'cub-200':
|
| 97 |
root = './datasets/cub-200'
|
| 98 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
|
|
|
| 99 |
elif config.DATA.DATASET == 'stanfordcars':
|
| 100 |
root = './datasets/stanfordcars'
|
| 101 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
|
|
|
| 102 |
elif config.DATA.DATASET == 'oxfordflower':
|
| 103 |
root = './datasets/oxfordflower'
|
| 104 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
|
|
|
| 105 |
elif config.DATA.DATASET == 'stanforddogs':
|
| 106 |
root = './datasets/stanforddogs'
|
| 107 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
|
|
|
| 108 |
elif config.DATA.DATASET == 'nabirds':
|
| 109 |
root = './datasets/nabirds'
|
| 110 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
|
|
|
| 111 |
elif config.DATA.DATASET == 'aircraft':
|
| 112 |
root = './datasets/aircraft'
|
| 113 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
|
|
|
| 114 |
else:
|
| 115 |
+
root = config.DATA.DATASET_ROOT
|
| 116 |
+
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
| 117 |
|
| 118 |
+
nb_classes = len(dataset.class_to_idx)
|
| 119 |
return dataset, nb_classes
|
| 120 |
|
| 121 |
|
|
|
|
| 144 |
if config.TEST.CROP:
|
| 145 |
size = int((256 / 224) * config.DATA.IMG_SIZE)
|
| 146 |
t.append(
|
| 147 |
+
transforms.Resize(size, interpolation=str_to_interp_mode(config.DATA.INTERPOLATION)),
|
| 148 |
# to maintain same ratio w.r.t. 224 images
|
| 149 |
)
|
| 150 |
t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
|
| 151 |
else:
|
| 152 |
t.append(
|
| 153 |
transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
|
| 154 |
+
interpolation=str_to_interp_mode(config.DATA.INTERPOLATION))
|
| 155 |
)
|
| 156 |
|
| 157 |
t.append(transforms.ToTensor())
|
data/dataset_fg.py
CHANGED
|
@@ -10,6 +10,7 @@ import pickle
|
|
| 10 |
import numpy as np
|
| 11 |
import pandas as pd
|
| 12 |
import random
|
|
|
|
| 13 |
random.seed(2021)
|
| 14 |
from PIL import Image
|
| 15 |
from scipy import io as scio
|
|
@@ -335,7 +336,7 @@ def find_images_and_targets_2017_2018(root,dataset,istrain=False,aux_info=False)
|
|
| 335 |
else:
|
| 336 |
images_and_targets.append((file_path,target))
|
| 337 |
return images_and_targets,class_to_idx,images_info
|
| 338 |
-
def find_images_and_targets(root,istrain=False,aux_info=False):
|
| 339 |
if os.path.exists(os.path.join(root,'train.json')):
|
| 340 |
with open(os.path.join(root,'train.json'),'r') as f:
|
| 341 |
train_class_info = json.load(f)
|
|
@@ -343,24 +344,59 @@ def find_images_and_targets(root,istrain=False,aux_info=False):
|
|
| 343 |
with open(os.path.join(root,'train_mini.json'),'r') as f:
|
| 344 |
train_class_info = json.load(f)
|
| 345 |
else:
|
| 346 |
-
raise ValueError(f'
|
|
|
|
| 347 |
with open(os.path.join(root,'val.json'),'r') as f:
|
| 348 |
val_class_info = json.load(f)
|
| 349 |
-
|
| 350 |
-
|
|
|
|
| 351 |
id2label = dict()
|
| 352 |
for categorie in train_class_info['categories']:
|
| 353 |
id2label[int(categorie['id'])] = categorie['name'].strip().lower()
|
| 354 |
class_info = train_class_info if istrain else val_class_info
|
| 355 |
-
|
| 356 |
images_and_targets = []
|
| 357 |
images_info = []
|
| 358 |
if aux_info:
|
| 359 |
temporal_info = []
|
| 360 |
spatial_info = []
|
| 361 |
|
| 362 |
-
|
| 363 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
id_name = id2label[int(annotation['category_id'])]
|
| 365 |
target = class_to_idx[id_name]
|
| 366 |
date = image['date']
|
|
@@ -389,13 +425,15 @@ class DatasetMeta(data.Dataset):
|
|
| 389 |
transform=None,
|
| 390 |
train=False,
|
| 391 |
aux_info=False,
|
| 392 |
-
dataset='
|
| 393 |
class_ratio=1.0,
|
| 394 |
per_sample=1.0):
|
| 395 |
self.aux_info = aux_info
|
| 396 |
self.dataset = dataset
|
| 397 |
if dataset in ['inaturelist2021','inaturelist2021_mini']:
|
| 398 |
images, class_to_idx,images_info = find_images_and_targets(root,train,aux_info)
|
|
|
|
|
|
|
| 399 |
elif dataset in ['inaturelist2017','inaturelist2018']:
|
| 400 |
images, class_to_idx,images_info = find_images_and_targets_2017_2018(root,dataset,train,aux_info)
|
| 401 |
elif dataset == 'cub-200':
|
|
@@ -427,7 +465,12 @@ class DatasetMeta(data.Dataset):
|
|
| 427 |
path, target,aux_info = self.samples[index]
|
| 428 |
else:
|
| 429 |
path, target = self.samples[index]
|
| 430 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
if self.transform is not None:
|
| 432 |
img = self.transform(img)
|
| 433 |
if self.aux_info:
|
|
|
|
| 10 |
import numpy as np
|
| 11 |
import pandas as pd
|
| 12 |
import random
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
random.seed(2021)
|
| 15 |
from PIL import Image
|
| 16 |
from scipy import io as scio
|
|
|
|
| 336 |
else:
|
| 337 |
images_and_targets.append((file_path,target))
|
| 338 |
return images_and_targets,class_to_idx,images_info
|
| 339 |
+
def find_images_and_targets(root,istrain=False,aux_info=False, integrity_check=False):
|
| 340 |
if os.path.exists(os.path.join(root,'train.json')):
|
| 341 |
with open(os.path.join(root,'train.json'),'r') as f:
|
| 342 |
train_class_info = json.load(f)
|
|
|
|
| 344 |
with open(os.path.join(root,'train_mini.json'),'r') as f:
|
| 345 |
train_class_info = json.load(f)
|
| 346 |
else:
|
| 347 |
+
raise ValueError(f'{root}/train.json or {root}/train_mini.json doesn\'t exist')
|
| 348 |
+
|
| 349 |
with open(os.path.join(root,'val.json'),'r') as f:
|
| 350 |
val_class_info = json.load(f)
|
| 351 |
+
|
| 352 |
+
categories = [x['name'].strip().lower() for x in val_class_info['categories']]
|
| 353 |
+
class_to_idx = {c: idx for idx, c in enumerate(categories)}
|
| 354 |
id2label = dict()
|
| 355 |
for categorie in train_class_info['categories']:
|
| 356 |
id2label[int(categorie['id'])] = categorie['name'].strip().lower()
|
| 357 |
class_info = train_class_info if istrain else val_class_info
|
| 358 |
+
image_subdir = "train" if istrain else "val"
|
| 359 |
images_and_targets = []
|
| 360 |
images_info = []
|
| 361 |
if aux_info:
|
| 362 |
temporal_info = []
|
| 363 |
spatial_info = []
|
| 364 |
|
| 365 |
+
ann2im = {}
|
| 366 |
+
for ann in class_info['annotations']:
|
| 367 |
+
ann2im[ann['id']] = ann['image_id']
|
| 368 |
+
|
| 369 |
+
ims = {}
|
| 370 |
+
for image in class_info['images']:
|
| 371 |
+
ims[image['id']] = image
|
| 372 |
+
|
| 373 |
+
print("Found", len(train_class_info['categories']))
|
| 374 |
+
print("Loading images and targets, checking image integrity")
|
| 375 |
+
|
| 376 |
+
for annotation in tqdm(class_info['annotations']):
|
| 377 |
+
|
| 378 |
+
image = ims[annotation['image_id']]
|
| 379 |
+
dir = train_class_info['categories'][annotation['category_id']]['image_dir_name']
|
| 380 |
+
|
| 381 |
+
file_path = os.path.join(root,image_subdir,dir,image['file_name'])
|
| 382 |
+
|
| 383 |
+
if not os.path.exists(file_path):
|
| 384 |
+
|
| 385 |
+
continue
|
| 386 |
+
|
| 387 |
+
print(f"Download {file_path}")
|
| 388 |
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
| 389 |
+
import requests
|
| 390 |
+
with open(file_path, 'wb') as fp:
|
| 391 |
+
fp.write(requests.get(image['inaturalist_url']).content)
|
| 392 |
+
|
| 393 |
+
if integrity_check:
|
| 394 |
+
try:
|
| 395 |
+
_ = np.array(Image.open(file_path))
|
| 396 |
+
except:
|
| 397 |
+
print(f"Failed to open {file_path}")
|
| 398 |
+
continue
|
| 399 |
+
|
| 400 |
id_name = id2label[int(annotation['category_id'])]
|
| 401 |
target = class_to_idx[id_name]
|
| 402 |
date = image['date']
|
|
|
|
| 425 |
transform=None,
|
| 426 |
train=False,
|
| 427 |
aux_info=False,
|
| 428 |
+
dataset='coco_generic',
|
| 429 |
class_ratio=1.0,
|
| 430 |
per_sample=1.0):
|
| 431 |
self.aux_info = aux_info
|
| 432 |
self.dataset = dataset
|
| 433 |
if dataset in ['inaturelist2021','inaturelist2021_mini']:
|
| 434 |
images, class_to_idx,images_info = find_images_and_targets(root,train,aux_info)
|
| 435 |
+
elif dataset in ['coco_generic']:
|
| 436 |
+
images, class_to_idx,images_info = find_images_and_targets(root,train,aux_info)
|
| 437 |
elif dataset in ['inaturelist2017','inaturelist2018']:
|
| 438 |
images, class_to_idx,images_info = find_images_and_targets_2017_2018(root,dataset,train,aux_info)
|
| 439 |
elif dataset == 'cub-200':
|
|
|
|
| 465 |
path, target,aux_info = self.samples[index]
|
| 466 |
else:
|
| 467 |
path, target = self.samples[index]
|
| 468 |
+
|
| 469 |
+
try:
|
| 470 |
+
img = open(path, 'rb').read() if self.load_bytes else Image.open(path).convert('RGB')
|
| 471 |
+
except:
|
| 472 |
+
img = Image.fromarray(np.zeros((224,224,3), dtype=np.uint8))
|
| 473 |
+
|
| 474 |
if self.transform is not None:
|
| 475 |
img = self.transform(img)
|
| 476 |
if self.aux_info:
|
inference.py
CHANGED
|
@@ -7,6 +7,10 @@ from torch.autograd import Variable
|
|
| 7 |
from torchvision.transforms import transforms
|
| 8 |
import numpy as np
|
| 9 |
import argparse
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
try:
|
| 12 |
from apex import amp
|
|
@@ -34,24 +38,32 @@ def read_class_names(file_path):
|
|
| 34 |
class_list = []
|
| 35 |
|
| 36 |
for l in lines:
|
| 37 |
-
line = l.strip()
|
| 38 |
# class_list.append(line[0])
|
| 39 |
-
class_list.append(line
|
| 40 |
|
| 41 |
classes = tuple(class_list)
|
| 42 |
return classes
|
| 43 |
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
| 50 |
self.model = AutoModel.from_pretrained("bert-base-uncased")
|
| 51 |
|
| 52 |
-
def generate(self):
|
| 53 |
text_list = []
|
| 54 |
-
with open(
|
| 55 |
for line in f_text:
|
| 56 |
line = line.encode(encoding='UTF-8', errors='strict')
|
| 57 |
line = line.replace(b'\xef\xbf\xbd\xef\xbf\xbd', b' ')
|
|
@@ -69,57 +81,122 @@ class GenerateEmbedding:
|
|
| 69 |
|
| 70 |
|
| 71 |
class Inference:
|
| 72 |
-
def __init__(self, config_path, model_path):
|
|
|
|
| 73 |
self.config_path = config_path
|
| 74 |
self.model_path = model_path
|
| 75 |
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 76 |
-
|
| 77 |
-
self.classes = read_class_names(r"D:\dataset\CUB_200_2011\CUB_200_2011\classes_custom.txt")
|
| 78 |
|
| 79 |
self.config = model_config(self.config_path)
|
|
|
|
| 80 |
self.model = build_model(self.config)
|
| 81 |
self.checkpoint = torch.load(self.model_path, map_location='cpu')
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
self.model.eval()
|
| 84 |
-
self.model.
|
|
|
|
|
|
|
| 85 |
|
| 86 |
self.transform_img = transforms.Compose([
|
| 87 |
-
transforms.Resize((
|
| 88 |
transforms.ToTensor(), # transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
| 89 |
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
| 90 |
])
|
| 91 |
|
| 92 |
-
def infer(self, img_path, meta_data_path):
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
img = self.transform_img(img)
|
| 97 |
img.unsqueeze_(0)
|
| 98 |
-
img = img.
|
| 99 |
img = Variable(img).to(self.device)
|
| 100 |
out = self.model(img, meta)
|
| 101 |
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
|
| 108 |
def parse_option():
|
| 109 |
parser = argparse.ArgumentParser('MetaFG Inference script', add_help=False)
|
| 110 |
-
parser.add_argument('--cfg', type=str,
|
| 111 |
# easy config modification
|
| 112 |
-
parser.add_argument('--model-path',
|
| 113 |
-
parser.add_argument('--img-path',
|
| 114 |
-
parser.add_argument('--
|
|
|
|
|
|
|
| 115 |
args = parser.parse_args()
|
| 116 |
return args
|
| 117 |
|
| 118 |
|
| 119 |
if __name__ == '__main__':
|
| 120 |
args = parse_option()
|
| 121 |
-
|
| 122 |
-
model_path=args.model_path
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
# Usage: python inference.py --cfg 'path/to/cfg' --model_path 'path/to/model' --img-path 'path/to/img' --meta-path 'path/to/meta'
|
|
|
|
| 7 |
from torchvision.transforms import transforms
|
| 8 |
import numpy as np
|
| 9 |
import argparse
|
| 10 |
+
from pycocotools.coco import COCO
|
| 11 |
+
import requests
|
| 12 |
+
import os
|
| 13 |
+
from tqdm.auto import tqdm
|
| 14 |
|
| 15 |
try:
|
| 16 |
from apex import amp
|
|
|
|
| 38 |
class_list = []
|
| 39 |
|
| 40 |
for l in lines:
|
| 41 |
+
line = l.strip()
|
| 42 |
# class_list.append(line[0])
|
| 43 |
+
class_list.append(line)
|
| 44 |
|
| 45 |
classes = tuple(class_list)
|
| 46 |
return classes
|
| 47 |
|
| 48 |
|
| 49 |
+
def read_class_names_coco(file_path):
|
| 50 |
+
dataset = COCO(file_path)
|
| 51 |
+
classes = [dataset.cats[k]['name'] for k in sorted(dataset.cats.keys())]
|
| 52 |
+
|
| 53 |
+
with open("names.txt", 'w') as fp:
|
| 54 |
+
for c in classes:
|
| 55 |
+
fp.write(f"{c}\n")
|
| 56 |
|
| 57 |
+
return classes
|
| 58 |
+
|
| 59 |
+
class GenerateEmbedding:
|
| 60 |
+
def __init__(self):
|
| 61 |
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
| 62 |
self.model = AutoModel.from_pretrained("bert-base-uncased")
|
| 63 |
|
| 64 |
+
def generate(self, text_file):
|
| 65 |
text_list = []
|
| 66 |
+
with open(text_file, 'r') as f_text:
|
| 67 |
for line in f_text:
|
| 68 |
line = line.encode(encoding='UTF-8', errors='strict')
|
| 69 |
line = line.replace(b'\xef\xbf\xbd\xef\xbf\xbd', b' ')
|
|
|
|
| 81 |
|
| 82 |
|
| 83 |
class Inference:
|
| 84 |
+
def __init__(self, config_path, model_path, names_path):
|
| 85 |
+
|
| 86 |
self.config_path = config_path
|
| 87 |
self.model_path = model_path
|
| 88 |
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 89 |
+
self.classes = read_class_names(names_path)
|
|
|
|
| 90 |
|
| 91 |
self.config = model_config(self.config_path)
|
| 92 |
+
|
| 93 |
self.model = build_model(self.config)
|
| 94 |
self.checkpoint = torch.load(self.model_path, map_location='cpu')
|
| 95 |
+
|
| 96 |
+
if 'model' in self.checkpoint:
|
| 97 |
+
self.model.load_state_dict(self.checkpoint['model'], strict=False)
|
| 98 |
+
else:
|
| 99 |
+
self.model.load_state_dict(self.checkpoint, strict=False)
|
| 100 |
+
|
| 101 |
self.model.eval()
|
| 102 |
+
self.model.to(self.device)
|
| 103 |
+
self.topk = 10
|
| 104 |
+
self.embedding_gen = GenerateEmbedding()
|
| 105 |
|
| 106 |
self.transform_img = transforms.Compose([
|
| 107 |
+
transforms.Resize((self.config.DATA.IMG_SIZE, self.config.DATA.IMG_SIZE), interpolation=Image.BILINEAR),
|
| 108 |
transforms.ToTensor(), # transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
| 109 |
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
| 110 |
])
|
| 111 |
|
| 112 |
+
def infer(self, img_path, meta_data_path, topk=None):
|
| 113 |
+
|
| 114 |
+
if isinstance(img_path, str):
|
| 115 |
+
if img_path.startswith("http"):
|
| 116 |
+
img = Image.open(requests.get(img_path, stream=True).raw).convert('RGB')
|
| 117 |
+
else:
|
| 118 |
+
img = Image.open(img_path).convert('RGB')
|
| 119 |
+
else:
|
| 120 |
+
img = img_path
|
| 121 |
+
|
| 122 |
+
"""
|
| 123 |
+
_, _, meta = self.embedding_gen(meta_data_path)
|
| 124 |
+
meta = meta.to(self.device)
|
| 125 |
+
"""
|
| 126 |
+
meta = None
|
| 127 |
+
|
| 128 |
img = self.transform_img(img)
|
| 129 |
img.unsqueeze_(0)
|
| 130 |
+
img = img.to(self.device)
|
| 131 |
img = Variable(img).to(self.device)
|
| 132 |
out = self.model(img, meta)
|
| 133 |
|
| 134 |
+
f = torch.nn.Softmax(dim=1)
|
| 135 |
+
y_pred = f(out)
|
| 136 |
+
indices = reversed(torch.argsort(y_pred, dim=1).squeeze().tolist())
|
| 137 |
+
|
| 138 |
+
if topk is not None:
|
| 139 |
+
predict = [{self.classes[idx] : y_pred.squeeze()[idx].cpu().item() for idx in indices[:topk]}]
|
| 140 |
+
return predict
|
| 141 |
+
else:
|
| 142 |
+
return {self.classes[idx] : y_pred.squeeze()[idx].cpu().item() for idx in indices}
|
| 143 |
|
| 144 |
|
| 145 |
def parse_option():
|
| 146 |
parser = argparse.ArgumentParser('MetaFG Inference script', add_help=False)
|
| 147 |
+
parser.add_argument('--cfg', type=str, metavar="FILE", help='path to config file', default="configs/MetaFG_2_224.yaml")
|
| 148 |
# easy config modification
|
| 149 |
+
parser.add_argument('--model-path', type=str, help="path to model data", default="ckpt_epoch_12.pth")
|
| 150 |
+
parser.add_argument('--img-path', type=str, help='path to image')
|
| 151 |
+
parser.add_argument('--img-folder', type=str, help='path to image')
|
| 152 |
+
parser.add_argument('--meta-path', default="meta.txt", type=str, help='path to meta data')
|
| 153 |
+
parser.add_argument('--names-path', default="names_mf2.txt", type=str, help='path to meta data')
|
| 154 |
args = parser.parse_args()
|
| 155 |
return args
|
| 156 |
|
| 157 |
|
| 158 |
if __name__ == '__main__':
|
| 159 |
args = parse_option()
|
| 160 |
+
model = Inference(config_path=args.cfg,
|
| 161 |
+
model_path=args.model_path,
|
| 162 |
+
names_path=args.names_path)
|
| 163 |
+
|
| 164 |
+
from glob import glob
|
| 165 |
+
glob_imgs = glob(os.path.join(args.img_folder, "*.jpg"))
|
| 166 |
+
out_dir = f"results_{os.path.splitext(os.path.basename(args.model_path))[0]}"
|
| 167 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 168 |
+
|
| 169 |
+
for img in tqdm(glob_imgs):
|
| 170 |
+
try:
|
| 171 |
+
res = model.infer(img_path=img, meta_data_path=args.meta_path)
|
| 172 |
+
except KeyboardInterrupt:
|
| 173 |
+
break
|
| 174 |
+
except Exception as e:
|
| 175 |
+
print(e)
|
| 176 |
+
continue
|
| 177 |
+
|
| 178 |
+
out = {}
|
| 179 |
+
out['preds'] = res
|
| 180 |
+
|
| 181 |
+
"""
|
| 182 |
+
# Out is a list of (class, score). Return true/false if the top1 class is correct
|
| 183 |
+
out['top1_correct'] = '_'.join(res[0][1].split(' ')).lower() in os.path.basename(img).lower()
|
| 184 |
+
|
| 185 |
+
out['top5_correct'] = False
|
| 186 |
+
print(os.path.basename(img).lower())
|
| 187 |
+
for i in range(5):
|
| 188 |
+
out['top5_correct'] |= '_'.join(res[i][1].split(' ')).lower() in os.path.basename(img).lower()
|
| 189 |
+
print('_'.join(res[i][1].split(' ')).lower())
|
| 190 |
+
|
| 191 |
+
out['top10_correct'] = False
|
| 192 |
+
for i in range(10):
|
| 193 |
+
out['top10_correct'] |= '_'.join(res[i][1].split(' ')).lower() in os.path.basename(img).lower()
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
# output json with inference results, use image basename
|
| 197 |
+
# as filename
|
| 198 |
+
import json
|
| 199 |
+
with open(os.path.join(out_dir, os.path.splitext(os.path.basename(img))[0]+".json"), 'w') as fp:
|
| 200 |
+
json.dump(out, fp, indent=1)
|
| 201 |
|
| 202 |
# Usage: python inference.py --cfg 'path/to/cfg' --model_path 'path/to/model' --img-path 'path/to/img' --meta-path 'path/to/meta'
|
lr_scheduler.py
CHANGED
|
@@ -21,7 +21,6 @@ def build_scheduler(config, optimizer, n_iter_per_epoch):
|
|
| 21 |
lr_scheduler = CosineLRScheduler(
|
| 22 |
optimizer,
|
| 23 |
t_initial=num_steps,
|
| 24 |
-
t_mul=1.,
|
| 25 |
lr_min=config.TRAIN.MIN_LR,
|
| 26 |
warmup_lr_init=config.TRAIN.WARMUP_LR,
|
| 27 |
warmup_t=warmup_steps,
|
|
|
|
| 21 |
lr_scheduler = CosineLRScheduler(
|
| 22 |
optimizer,
|
| 23 |
t_initial=num_steps,
|
|
|
|
| 24 |
lr_min=config.TRAIN.MIN_LR,
|
| 25 |
warmup_lr_init=config.TRAIN.WARMUP_LR,
|
| 26 |
warmup_t=warmup_steps,
|
main.py
CHANGED
|
@@ -2,7 +2,9 @@ import os
|
|
| 2 |
import time
|
| 3 |
import argparse
|
| 4 |
import datetime
|
|
|
|
| 5 |
import numpy as np
|
|
|
|
| 6 |
|
| 7 |
import torch
|
| 8 |
import torch.backends.cudnn as cudnn
|
|
@@ -18,13 +20,23 @@ from lr_scheduler import build_scheduler
|
|
| 18 |
from optimizer import build_optimizer
|
| 19 |
from logger import create_logger
|
| 20 |
from utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor,load_pretained
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
try:
|
| 23 |
# noinspection PyUnresolvedReferences
|
| 24 |
from apex import amp
|
| 25 |
except ImportError:
|
| 26 |
amp = None
|
| 27 |
|
|
|
|
|
|
|
| 28 |
|
| 29 |
def parse_option():
|
| 30 |
parser = argparse.ArgumentParser('MetaFG training and evaluation script', add_help=False)
|
|
@@ -77,20 +89,19 @@ def parse_option():
|
|
| 77 |
help='dataset')
|
| 78 |
parser.add_argument('--lr-scheduler-name', type=str,
|
| 79 |
help='lr scheduler name,cosin linear,step')
|
| 80 |
-
|
| 81 |
parser.add_argument('--pretrain', type=str,
|
| 82 |
help='pretrain')
|
| 83 |
|
| 84 |
-
parser.add_argument('--
|
| 85 |
-
|
| 86 |
|
| 87 |
-
# distributed training
|
| 88 |
-
parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
|
| 89 |
-
|
| 90 |
args, unparsed = parser.parse_known_args()
|
| 91 |
|
| 92 |
config = get_config(args)
|
| 93 |
|
|
|
|
|
|
|
|
|
|
| 94 |
return args, config
|
| 95 |
|
| 96 |
|
|
@@ -98,14 +109,20 @@ def main(config):
|
|
| 98 |
dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)
|
| 99 |
logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
|
| 100 |
model = build_model(config)
|
|
|
|
|
|
|
| 101 |
model.cuda()
|
| 102 |
logger.info(str(model))
|
| 103 |
|
| 104 |
optimizer = build_optimizer(config, model)
|
| 105 |
if config.AMP_OPT_LEVEL != "O0":
|
| 106 |
model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL)
|
| 107 |
-
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
|
| 108 |
model_without_ddp = model.module
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 110 |
logger.info(f"number of params: {n_parameters}")
|
| 111 |
if hasattr(model_without_ddp, 'flops'):
|
|
@@ -123,10 +140,15 @@ def main(config):
|
|
| 123 |
max_accuracy = 0.0
|
| 124 |
if config.MODEL.PRETRAINED:
|
| 125 |
load_pretained(config,model_without_ddp,logger)
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
if config.TRAIN.AUTO_RESUME:
|
| 132 |
resume_file = auto_resume_helper(config.OUTPUT)
|
|
@@ -143,11 +165,11 @@ def main(config):
|
|
| 143 |
if config.MODEL.RESUME:
|
| 144 |
logger.info(f"**********normal test***********")
|
| 145 |
max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger)
|
| 146 |
-
acc1, acc5, loss = validate(config, data_loader_val, model)
|
| 147 |
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
|
| 148 |
if config.DATA.ADD_META:
|
| 149 |
logger.info(f"**********mask meta test***********")
|
| 150 |
-
acc1, acc5, loss = validate(config, data_loader_val, model,mask_meta=True)
|
| 151 |
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
|
| 152 |
if config.EVAL_MODE:
|
| 153 |
return
|
|
@@ -165,18 +187,37 @@ def main(config):
|
|
| 165 |
save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger)
|
| 166 |
|
| 167 |
logger.info(f"**********normal test***********")
|
| 168 |
-
acc1, acc5, loss = validate(config, data_loader_val, model)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
|
| 170 |
max_accuracy = max(max_accuracy, acc1)
|
| 171 |
logger.info(f'Max accuracy: {max_accuracy:.2f}%')
|
| 172 |
if config.DATA.ADD_META:
|
| 173 |
logger.info(f"**********mask meta test***********")
|
| 174 |
-
acc1, acc5, loss = validate(config, data_loader_val, model,mask_meta=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
|
| 176 |
# data_loader_train.terminate()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
total_time = time.time() - start_time
|
| 178 |
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 179 |
logger.info('Training time {}'.format(total_time_str))
|
|
|
|
| 180 |
def train_one_epoch_local_data(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler,tb_logger=None):
|
| 181 |
model.train()
|
| 182 |
if hasattr(model.module,'cur_epoch'):
|
|
@@ -261,6 +302,8 @@ def train_one_epoch_local_data(config, model, criterion, data_loader, optimizer,
|
|
| 261 |
lr = optimizer.param_groups[0]['lr']
|
| 262 |
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
|
| 263 |
etas = batch_time.avg * (num_steps - idx)
|
|
|
|
|
|
|
| 264 |
logger.info(
|
| 265 |
f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
|
| 266 |
f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
|
|
@@ -271,7 +314,7 @@ def train_one_epoch_local_data(config, model, criterion, data_loader, optimizer,
|
|
| 271 |
epoch_time = time.time() - start
|
| 272 |
logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
|
| 273 |
@torch.no_grad()
|
| 274 |
-
def validate(config, data_loader, model, mask_meta=False):
|
| 275 |
criterion = torch.nn.CrossEntropyLoss()
|
| 276 |
model.eval()
|
| 277 |
|
|
@@ -280,8 +323,16 @@ def validate(config, data_loader, model, mask_meta=False):
|
|
| 280 |
acc1_meter = AverageMeter()
|
| 281 |
acc5_meter = AverageMeter()
|
| 282 |
|
|
|
|
|
|
|
| 283 |
end = time.time()
|
|
|
|
| 284 |
for idx, data in enumerate(data_loader):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
if config.DATA.ADD_META:
|
| 286 |
images,target,meta = data
|
| 287 |
meta = [m.float() for m in meta]
|
|
@@ -314,6 +365,9 @@ def validate(config, data_loader, model, mask_meta=False):
|
|
| 314 |
acc1_meter.update(acc1.item(), target.size(0))
|
| 315 |
acc5_meter.update(acc5.item(), target.size(0))
|
| 316 |
|
|
|
|
|
|
|
|
|
|
| 317 |
# measure elapsed time
|
| 318 |
batch_time.update(time.time() - end)
|
| 319 |
end = time.time()
|
|
@@ -328,7 +382,8 @@ def validate(config, data_loader, model, mask_meta=False):
|
|
| 328 |
f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
|
| 329 |
f'Mem {memory_used:.0f}MB')
|
| 330 |
logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
|
| 331 |
-
|
|
|
|
| 332 |
|
| 333 |
|
| 334 |
@torch.no_grad()
|
|
@@ -364,7 +419,7 @@ if __name__ == '__main__':
|
|
| 364 |
else:
|
| 365 |
rank = -1
|
| 366 |
world_size = -1
|
| 367 |
-
torch.cuda.set_device(config.LOCAL_RANK)
|
| 368 |
torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
|
| 369 |
torch.distributed.barrier()
|
| 370 |
|
|
|
|
| 2 |
import time
|
| 3 |
import argparse
|
| 4 |
import datetime
|
| 5 |
+
import json
|
| 6 |
import numpy as np
|
| 7 |
+
from collections import defaultdict
|
| 8 |
|
| 9 |
import torch
|
| 10 |
import torch.backends.cudnn as cudnn
|
|
|
|
| 20 |
from optimizer import build_optimizer
|
| 21 |
from logger import create_logger
|
| 22 |
from utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor,load_pretained
|
| 23 |
+
|
| 24 |
+
have_wandb = False
|
| 25 |
+
try:
|
| 26 |
+
import wandb
|
| 27 |
+
have_wandb = True
|
| 28 |
+
except:
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
# TODO use torch.amp
|
| 32 |
try:
|
| 33 |
# noinspection PyUnresolvedReferences
|
| 34 |
from apex import amp
|
| 35 |
except ImportError:
|
| 36 |
amp = None
|
| 37 |
|
| 38 |
+
import logging
|
| 39 |
+
logging.basicConfig(level=logging.INFO)
|
| 40 |
|
| 41 |
def parse_option():
|
| 42 |
parser = argparse.ArgumentParser('MetaFG training and evaluation script', add_help=False)
|
|
|
|
| 89 |
help='dataset')
|
| 90 |
parser.add_argument('--lr-scheduler-name', type=str,
|
| 91 |
help='lr scheduler name,cosin linear,step')
|
| 92 |
+
|
| 93 |
parser.add_argument('--pretrain', type=str,
|
| 94 |
help='pretrain')
|
| 95 |
|
| 96 |
+
parser.add_argument('--wandb_job', type=str)
|
|
|
|
| 97 |
|
|
|
|
|
|
|
|
|
|
| 98 |
args, unparsed = parser.parse_known_args()
|
| 99 |
|
| 100 |
config = get_config(args)
|
| 101 |
|
| 102 |
+
if have_wandb and int(config.LOCAL_RANK) == 0:
|
| 103 |
+
wandb.init(name = args.wandb_job, config=args)
|
| 104 |
+
|
| 105 |
return args, config
|
| 106 |
|
| 107 |
|
|
|
|
| 109 |
dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)
|
| 110 |
logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
|
| 111 |
model = build_model(config)
|
| 112 |
+
if have_wandb and int(config.LOCAL_RANK) == 0:
|
| 113 |
+
wandb.config['model_config'] = config
|
| 114 |
model.cuda()
|
| 115 |
logger.info(str(model))
|
| 116 |
|
| 117 |
optimizer = build_optimizer(config, model)
|
| 118 |
if config.AMP_OPT_LEVEL != "O0":
|
| 119 |
model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL)
|
| 120 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[int(config.LOCAL_RANK)], broadcast_buffers=False)
|
| 121 |
model_without_ddp = model.module
|
| 122 |
+
|
| 123 |
+
if have_wandb and int(config.LOCAL_RANK) == 0:
|
| 124 |
+
wandb.watch(model, log_freq=100)
|
| 125 |
+
|
| 126 |
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 127 |
logger.info(f"number of params: {n_parameters}")
|
| 128 |
if hasattr(model_without_ddp, 'flops'):
|
|
|
|
| 140 |
max_accuracy = 0.0
|
| 141 |
if config.MODEL.PRETRAINED:
|
| 142 |
load_pretained(config,model_without_ddp,logger)
|
| 143 |
+
|
| 144 |
+
# Run initial validation
|
| 145 |
+
logger.info("Start validation (on init)")
|
| 146 |
+
acc1, acc5, loss, stats = validate(config, data_loader_val, model, limit=10)
|
| 147 |
+
|
| 148 |
+
with open(os.path.join(config.OUTPUT, f'val_init.json'), 'w') as fp:
|
| 149 |
+
json.dump(stats, fp, indent=1)
|
| 150 |
+
|
| 151 |
+
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
|
| 152 |
|
| 153 |
if config.TRAIN.AUTO_RESUME:
|
| 154 |
resume_file = auto_resume_helper(config.OUTPUT)
|
|
|
|
| 165 |
if config.MODEL.RESUME:
|
| 166 |
logger.info(f"**********normal test***********")
|
| 167 |
max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger)
|
| 168 |
+
acc1, acc5, loss, stats = validate(config, data_loader_val, model)
|
| 169 |
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
|
| 170 |
if config.DATA.ADD_META:
|
| 171 |
logger.info(f"**********mask meta test***********")
|
| 172 |
+
acc1, acc5, loss, stats = validate(config, data_loader_val, model,mask_meta=True)
|
| 173 |
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
|
| 174 |
if config.EVAL_MODE:
|
| 175 |
return
|
|
|
|
| 187 |
save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger)
|
| 188 |
|
| 189 |
logger.info(f"**********normal test***********")
|
| 190 |
+
acc1, acc5, loss, stats = validate(config, data_loader_val, model)
|
| 191 |
+
|
| 192 |
+
with open(os.path.join(config.OUTPUT, f'val_{epoch}.json'), 'w') as fp:
|
| 193 |
+
json.dump(stats, fp, indent=1)
|
| 194 |
+
|
| 195 |
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
|
| 196 |
max_accuracy = max(max_accuracy, acc1)
|
| 197 |
logger.info(f'Max accuracy: {max_accuracy:.2f}%')
|
| 198 |
if config.DATA.ADD_META:
|
| 199 |
logger.info(f"**********mask meta test***********")
|
| 200 |
+
acc1, acc5, loss, stats = validate(config, data_loader_val, model,mask_meta=True)
|
| 201 |
+
|
| 202 |
+
with open(os.path.join(config.OUTPUT, f'val_{epoch}_meta.json'), 'w') as fp:
|
| 203 |
+
json.dump(stats, fp, indent=1)
|
| 204 |
+
|
| 205 |
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
|
| 206 |
# data_loader_train.terminate()
|
| 207 |
+
|
| 208 |
+
if have_wandb and int(config.LOCAL_RANK) == 0:
|
| 209 |
+
wandb.run.summary["acc_top_1"] = acc1
|
| 210 |
+
wandb.run.summary["acc_top_5"] = acc5
|
| 211 |
+
wandb.run.summary["val_loss"] = loss
|
| 212 |
+
|
| 213 |
+
wandb.log({'val/acc1': acc1})
|
| 214 |
+
wandb.log({'val/acc5': acc5})
|
| 215 |
+
wandb.log({'val/loss': acc5})
|
| 216 |
+
|
| 217 |
total_time = time.time() - start_time
|
| 218 |
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 219 |
logger.info('Training time {}'.format(total_time_str))
|
| 220 |
+
|
| 221 |
def train_one_epoch_local_data(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler,tb_logger=None):
|
| 222 |
model.train()
|
| 223 |
if hasattr(model.module,'cur_epoch'):
|
|
|
|
| 302 |
lr = optimizer.param_groups[0]['lr']
|
| 303 |
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
|
| 304 |
etas = batch_time.avg * (num_steps - idx)
|
| 305 |
+
if have_wandb and int(config.LOCAL_RANK) == 0 and idx % 100 == 0:
|
| 306 |
+
wandb.log({"train/loss": loss_meter.val})
|
| 307 |
logger.info(
|
| 308 |
f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
|
| 309 |
f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
|
|
|
|
| 314 |
epoch_time = time.time() - start
|
| 315 |
logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
|
| 316 |
@torch.no_grad()
|
| 317 |
+
def validate(config, data_loader, model, mask_meta=False, limit=None):
|
| 318 |
criterion = torch.nn.CrossEntropyLoss()
|
| 319 |
model.eval()
|
| 320 |
|
|
|
|
| 323 |
acc1_meter = AverageMeter()
|
| 324 |
acc5_meter = AverageMeter()
|
| 325 |
|
| 326 |
+
stats = defaultdict(list)
|
| 327 |
+
|
| 328 |
end = time.time()
|
| 329 |
+
|
| 330 |
for idx, data in enumerate(data_loader):
|
| 331 |
+
|
| 332 |
+
if limit:
|
| 333 |
+
if idx > limit:
|
| 334 |
+
break
|
| 335 |
+
|
| 336 |
if config.DATA.ADD_META:
|
| 337 |
images,target,meta = data
|
| 338 |
meta = [m.float() for m in meta]
|
|
|
|
| 365 |
acc1_meter.update(acc1.item(), target.size(0))
|
| 366 |
acc5_meter.update(acc5.item(), target.size(0))
|
| 367 |
|
| 368 |
+
for t in target:
|
| 369 |
+
stats[int(t.item())].append((acc1.item(), acc5.item(), loss.item()))
|
| 370 |
+
|
| 371 |
# measure elapsed time
|
| 372 |
batch_time.update(time.time() - end)
|
| 373 |
end = time.time()
|
|
|
|
| 382 |
f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
|
| 383 |
f'Mem {memory_used:.0f}MB')
|
| 384 |
logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
|
| 385 |
+
|
| 386 |
+
return acc1_meter.avg, acc5_meter.avg, loss_meter.avg, stats
|
| 387 |
|
| 388 |
|
| 389 |
@torch.no_grad()
|
|
|
|
| 419 |
else:
|
| 420 |
rank = -1
|
| 421 |
world_size = -1
|
| 422 |
+
torch.cuda.set_device(f'cuda:{config.LOCAL_RANK}')
|
| 423 |
torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
|
| 424 |
torch.distributed.barrier()
|
| 425 |
|
models/MetaFG.py
CHANGED
|
@@ -54,7 +54,8 @@ class MetaFG(nn.Module):
|
|
| 54 |
qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,drop_path_rate=0.,
|
| 55 |
meta_dims=[],
|
| 56 |
only_last_cls=False,
|
| 57 |
-
use_checkpoint=False
|
|
|
|
| 58 |
super().__init__()
|
| 59 |
self.only_last_cls = only_last_cls
|
| 60 |
self.img_size = img_size
|
|
|
|
| 54 |
qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,drop_path_rate=0.,
|
| 55 |
meta_dims=[],
|
| 56 |
only_last_cls=False,
|
| 57 |
+
use_checkpoint=False,
|
| 58 |
+
**kwargs):
|
| 59 |
super().__init__()
|
| 60 |
self.only_last_cls = only_last_cls
|
| 61 |
self.img_size = img_size
|