File size: 3,054 Bytes
188f311 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 | # Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import os
import copy
from .latency_lookup_table import *
class BaseEfficiencyModel:
def __init__(self, dyn_net):
self.dyn_net = dyn_net
def get_active_subnet_config(self, arch_dict):
arch_dict = copy.deepcopy(arch_dict)
image_size = arch_dict.pop("image_size")
self.dyn_net.set_active_subnet(**arch_dict)
active_net_config = self.dyn_net.get_active_net_config()
return active_net_config, image_size
def get_efficiency(self, arch_dict):
raise NotImplementedError
class ProxylessNASFLOPsModel(BaseEfficiencyModel):
def get_efficiency(self, arch_dict):
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
return ProxylessNASLatencyTable.count_flops_given_config(
active_net_config, image_size
)
class Mbv3FLOPsModel(BaseEfficiencyModel):
def get_efficiency(self, arch_dict):
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
return MBv3LatencyTable.count_flops_given_config(active_net_config, image_size[0])
class ResNet50FLOPsModel(BaseEfficiencyModel):
def get_efficiency(self, arch_dict):
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
return ResNet50LatencyTable.count_flops_given_config(
active_net_config, image_size
)
class ProxylessNASLatencyModel(BaseEfficiencyModel):
def __init__(self, dyn_net, lookup_table_path_dict):
super(ProxylessNASLatencyModel, self).__init__(dyn_net)
self.latency_tables = {}
for image_size, path in lookup_table_path_dict.items():
self.latency_tables[image_size] = ProxylessNASLatencyTable(
local_dir="/tmp/.dyn_latency_tools/",
url=os.path.join(path, "%d_lookup_table.yaml" % image_size),
)
def get_efficiency(self, arch_dict):
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
return self.latency_tables[image_size].predict_network_latency_given_config(
active_net_config, image_size
)
class Mbv3LatencyModel(BaseEfficiencyModel):
def __init__(self, dyn_net, lookup_table_path_dict):
super(Mbv3LatencyModel, self).__init__(dyn_net)
self.latency_tables = {}
for image_size, path in lookup_table_path_dict.items():
self.latency_tables[image_size] = MBv3LatencyTable(
local_dir="/tmp/.dyn_latency_tools/",
url=os.path.join(path, "%d_lookup_table.yaml" % image_size),
)
def get_efficiency(self, arch_dict):
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
return self.latency_tables[image_size].predict_network_latency_given_config(
active_net_config, image_size
)
|