smi08's picture
Upload folder using huggingface_hub
188f311 verified
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import random
import numpy as np
from proard.classification.networks import ResNets
__all__ = ["MobileNetArchEncoder", "ResNetArchEncoder"]
class MobileNetArchEncoder:
SPACE_TYPE = "mbv3"
def __init__(
self,
image_size_list=None,
ks_list=None,
expand_list=None,
depth_list=None,
n_stage=None,
):
self.image_size_list = [224] if image_size_list is None else image_size_list
self.ks_list = [3, 5, 7] if ks_list is None else ks_list
self.expand_list = (
[3, 4, 6]
if expand_list is None
else [int(expand) for expand in expand_list]
)
self.depth_list = [2, 3, 4] if depth_list is None else depth_list
if n_stage is not None:
self.n_stage = n_stage
elif self.SPACE_TYPE == "mbv2":
self.n_stage = 6
elif self.SPACE_TYPE == "mbv3":
self.n_stage = 5
else:
raise NotImplementedError
# build info dict
self.n_dim = 0
self.r_info = dict(id2val={}, val2id={}, L=[], R=[])
self._build_info_dict(target="r")
self.k_info = dict(id2val=[], val2id=[], L=[], R=[])
self.e_info = dict(id2val=[], val2id=[], L=[], R=[])
self._build_info_dict(target="k")
self._build_info_dict(target="e")
@property
def max_n_blocks(self):
if self.SPACE_TYPE == "mbv3":
return self.n_stage * max(self.depth_list)
elif self.SPACE_TYPE == "mbv2":
return (self.n_stage - 1) * max(self.depth_list) + 1
else:
raise NotImplementedError
def _build_info_dict(self, target):
if target == "r":
target_dict = self.r_info
target_dict["L"].append(self.n_dim)
for img_size in self.image_size_list:
target_dict["val2id"][img_size] = self.n_dim
target_dict["id2val"][self.n_dim] = img_size
self.n_dim += 1
target_dict["R"].append(self.n_dim)
else:
if target == "k":
target_dict = self.k_info
choices = self.ks_list
elif target == "e":
target_dict = self.e_info
choices = self.expand_list
else:
raise NotImplementedError
for i in range(self.max_n_blocks):
target_dict["val2id"].append({})
target_dict["id2val"].append({})
target_dict["L"].append(self.n_dim)
for k in choices:
target_dict["val2id"][i][k] = self.n_dim
target_dict["id2val"][i][self.n_dim] = k
self.n_dim += 1
target_dict["R"].append(self.n_dim)
def arch2feature(self, arch_dict):
ks, e, d, r = (
arch_dict["ks"],
arch_dict["e"],
arch_dict["d"],
arch_dict["image_size"],
)
feature = np.zeros(self.n_dim)
for i in range(self.max_n_blocks):
nowd = i % max(self.depth_list)
stg = i // max(self.depth_list)
if nowd < d[stg]:
feature[self.k_info["val2id"][i][ks[i]]] = 1
feature[self.e_info["val2id"][i][e[i]]] = 1
feature[self.r_info["val2id"][r[0]]] = 1
return feature
def feature2arch(self, feature):
img_sz = self.r_info["id2val"][
int(np.argmax(feature[self.r_info["L"][0] : self.r_info["R"][0]]))
+ self.r_info["L"][0]
]
assert img_sz in self.image_size_list
arch_dict = {"ks": [], "e": [], "d": [], "image_size": img_sz}
d = 0
for i in range(self.max_n_blocks):
skip = True
for j in range(self.k_info["L"][i], self.k_info["R"][i]):
if feature[j] == 1:
arch_dict["ks"].append(self.k_info["id2val"][i][j])
skip = False
break
for j in range(self.e_info["L"][i], self.e_info["R"][i]):
if feature[j] == 1:
arch_dict["e"].append(self.e_info["id2val"][i][j])
assert not skip
skip = False
break
if skip:
arch_dict["e"].append(0)
arch_dict["ks"].append(0)
else:
d += 1
if (i + 1) % max(self.depth_list) == 0 or (i + 1) == self.max_n_blocks:
arch_dict["d"].append(d)
d = 0
return arch_dict
def random_sample_arch(self):
return {
"ks": random.choices(self.ks_list, k=self.max_n_blocks),
"e": random.choices(self.expand_list, k=self.max_n_blocks),
"d": random.choices(self.depth_list, k=self.n_stage),
"image_size": [random.choice(self.image_size_list)],
}
def mutate_resolution(self, arch_dict, mutate_prob):
if random.random() < mutate_prob:
arch_dict["image_size"] = random.choice(self.image_size_list)
return arch_dict
def mutate_arch(self, arch_dict, mutate_prob):
for i in range(self.max_n_blocks):
if random.random() < mutate_prob:
arch_dict["ks"][i] = random.choice(self.ks_list)
arch_dict["e"][i] = random.choice(self.expand_list)
for i in range(self.n_stage):
if random.random() < mutate_prob:
arch_dict["d"][i] = random.choice(self.depth_list)
return arch_dict
class ResNetArchEncoder:
def __init__(
self,
image_size_list=None,
depth_list=None,
expand_list=None,
width_mult_list=None,
base_depth_list=None,
):
self.image_size_list = [224] if image_size_list is None else image_size_list
self.expand_list = [0.2, 0.25, 0.35] if expand_list is None else expand_list
self.depth_list = [0, 1, 2] if depth_list is None else depth_list
self.width_mult_list = (
[0.65, 0.8, 1.0] if width_mult_list is None else width_mult_list
)
self.base_depth_list = (
ResNets.BASE_DEPTH_LIST if base_depth_list is None else base_depth_list
)
"""" build info dict """
self.n_dim = 0
# resolution
self.r_info = dict(id2val={}, val2id={}, L=[], R=[])
self._build_info_dict(target="r")
# input stem skip
self.input_stem_d_info = dict(id2val={}, val2id={}, L=[], R=[])
self._build_info_dict(target="input_stem_d")
# width_mult
self.width_mult_info = dict(id2val=[], val2id=[], L=[], R=[])
self._build_info_dict(target="width_mult")
# expand ratio
self.e_info = dict(id2val=[], val2id=[], L=[], R=[])
self._build_info_dict(target="e")
@property
def n_stage(self):
return len(self.base_depth_list)
@property
def max_n_blocks(self):
return sum(self.base_depth_list) + self.n_stage * max(self.depth_list)
def _build_info_dict(self, target):
if target == "r":
target_dict = self.r_info
target_dict["L"].append(self.n_dim)
for img_size in self.image_size_list:
target_dict["val2id"][img_size] = self.n_dim
target_dict["id2val"][self.n_dim] = img_size
self.n_dim += 1
target_dict["R"].append(self.n_dim)
elif target == "input_stem_d":
target_dict = self.input_stem_d_info
target_dict["L"].append(self.n_dim)
for skip in [0, 1]:
target_dict["val2id"][skip] = self.n_dim
target_dict["id2val"][self.n_dim] = skip
self.n_dim += 1
target_dict["R"].append(self.n_dim)
elif target == "e":
target_dict = self.e_info
choices = self.expand_list
for i in range(self.max_n_blocks):
target_dict["val2id"].append({})
target_dict["id2val"].append({})
target_dict["L"].append(self.n_dim)
for e in choices:
target_dict["val2id"][i][e] = self.n_dim
target_dict["id2val"][i][self.n_dim] = e
self.n_dim += 1
target_dict["R"].append(self.n_dim)
elif target == "width_mult":
target_dict = self.width_mult_info
choices = list(range(len(self.width_mult_list)))
for i in range(self.n_stage + 2):
target_dict["val2id"].append({})
target_dict["id2val"].append({})
target_dict["L"].append(self.n_dim)
for w in choices:
target_dict["val2id"][i][w] = self.n_dim
target_dict["id2val"][i][self.n_dim] = w
self.n_dim += 1
target_dict["R"].append(self.n_dim)
def arch2feature(self, arch_dict):
d, e, w, r = (
arch_dict["d"],
arch_dict["e"],
arch_dict["w"],
arch_dict["image_size"],
)
input_stem_skip = 1 if d[0] > 0 else 0
d = d[1:]
feature = np.zeros(self.n_dim)
feature[self.r_info["val2id"][r]] = 1
feature[self.input_stem_d_info["val2id"][input_stem_skip]] = 1
for i in range(self.n_stage + 2):
feature[self.width_mult_info["val2id"][i][w[i]]] = 1
start_pt = 0
for i, base_depth in enumerate(self.base_depth_list):
depth = base_depth + d[i]
for j in range(start_pt, start_pt + depth):
feature[self.e_info["val2id"][j][e[j]]] = 1
start_pt += max(self.depth_list) + base_depth
return feature
def feature2arch(self, feature):
img_sz = self.r_info["id2val"][
int(np.argmax(feature[self.r_info["L"][0] : self.r_info["R"][0]]))
+ self.r_info["L"][0]
]
input_stem_skip = (
self.input_stem_d_info["id2val"][
int(
np.argmax(
feature[
self.input_stem_d_info["L"][0] : self.input_stem_d_info[
"R"
][0]
]
)
)
+ self.input_stem_d_info["L"][0]
]
* 2
)
assert img_sz in self.image_size_list
arch_dict = {"d": [input_stem_skip], "e": [], "w": [], "image_size": img_sz}
for i in range(self.n_stage + 2):
arch_dict["w"].append(
self.width_mult_info["id2val"][i][
int(
np.argmax(
feature[
self.width_mult_info["L"][i] : self.width_mult_info[
"R"
][i]
]
)
)
+ self.width_mult_info["L"][i]
]
)
d = 0
skipped = 0
stage_id = 0
for i in range(self.max_n_blocks):
skip = True
for j in range(self.e_info["L"][i], self.e_info["R"][i]):
if feature[j] == 1:
arch_dict["e"].append(self.e_info["id2val"][i][j])
skip = False
break
if skip:
arch_dict["e"].append(0)
skipped += 1
else:
d += 1
if (
i + 1 == self.max_n_blocks
or (skipped + d)
% (max(self.depth_list) + self.base_depth_list[stage_id])
== 0
):
arch_dict["d"].append(d - self.base_depth_list[stage_id])
d, skipped = 0, 0
stage_id += 1
return arch_dict
def random_sample_arch(self):
return {
"d": [random.choice([0, 2])]
+ random.choices(self.depth_list, k=self.n_stage),
"e": random.choices(self.expand_list, k=self.max_n_blocks),
"w": random.choices(
list(range(len(self.width_mult_list))), k=self.n_stage + 2
),
"image_size": random.choice(self.image_size_list),
}
def mutate_resolution(self, arch_dict, mutate_prob):
if random.random() < mutate_prob:
arch_dict["image_size"] = random.choice(self.image_size_list)
return arch_dict
def mutate_arch(self, arch_dict, mutate_prob):
# input stem skip
if random.random() < mutate_prob:
arch_dict["d"][0] = random.choice([0, 2])
# depth
for i in range(1, len(arch_dict["d"])):
if random.random() < mutate_prob:
arch_dict["d"][i] = random.choice(self.depth_list)
# width_mult
for i in range(len(arch_dict["w"])):
if random.random() < mutate_prob:
arch_dict["w"][i] = random.choice(
list(range(len(self.width_mult_list)))
)
# expand ratio
for i in range(len(arch_dict["e"])):
if random.random() < mutate_prob:
arch_dict["e"][i] = random.choice(self.expand_list)