smi08's picture
Upload folder using huggingface_hub
188f311 verified
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import os
import copy
from .latency_lookup_table import *
class BaseEfficiencyModel:
def __init__(self, dyn_net):
self.dyn_net = dyn_net
def get_active_subnet_config(self, arch_dict):
arch_dict = copy.deepcopy(arch_dict)
image_size = arch_dict.pop("image_size")
self.dyn_net.set_active_subnet(**arch_dict)
active_net_config = self.dyn_net.get_active_net_config()
return active_net_config, image_size
def get_efficiency(self, arch_dict):
raise NotImplementedError
class ProxylessNASFLOPsModel(BaseEfficiencyModel):
def get_efficiency(self, arch_dict):
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
return ProxylessNASLatencyTable.count_flops_given_config(
active_net_config, image_size
)
class Mbv3FLOPsModel(BaseEfficiencyModel):
def get_efficiency(self, arch_dict):
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
return MBv3LatencyTable.count_flops_given_config(active_net_config, image_size[0])
class ResNet50FLOPsModel(BaseEfficiencyModel):
def get_efficiency(self, arch_dict):
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
return ResNet50LatencyTable.count_flops_given_config(
active_net_config, image_size
)
class ProxylessNASLatencyModel(BaseEfficiencyModel):
def __init__(self, dyn_net, lookup_table_path_dict):
super(ProxylessNASLatencyModel, self).__init__(dyn_net)
self.latency_tables = {}
for image_size, path in lookup_table_path_dict.items():
self.latency_tables[image_size] = ProxylessNASLatencyTable(
local_dir="/tmp/.dyn_latency_tools/",
url=os.path.join(path, "%d_lookup_table.yaml" % image_size),
)
def get_efficiency(self, arch_dict):
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
return self.latency_tables[image_size].predict_network_latency_given_config(
active_net_config, image_size
)
class Mbv3LatencyModel(BaseEfficiencyModel):
def __init__(self, dyn_net, lookup_table_path_dict):
super(Mbv3LatencyModel, self).__init__(dyn_net)
self.latency_tables = {}
for image_size, path in lookup_table_path_dict.items():
self.latency_tables[image_size] = MBv3LatencyTable(
local_dir="/tmp/.dyn_latency_tools/",
url=os.path.join(path, "%d_lookup_table.yaml" % image_size),
)
def get_efficiency(self, arch_dict):
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
return self.latency_tables[image_size].predict_network_latency_given_config(
active_net_config, image_size
)