|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import copy |
|
|
from .latency_lookup_table import * |
|
|
|
|
|
|
|
|
class BaseEfficiencyModel: |
|
|
def __init__(self, dyn_net): |
|
|
self.dyn_net = dyn_net |
|
|
|
|
|
def get_active_subnet_config(self, arch_dict): |
|
|
arch_dict = copy.deepcopy(arch_dict) |
|
|
image_size = arch_dict.pop("image_size") |
|
|
self.dyn_net.set_active_subnet(**arch_dict) |
|
|
active_net_config = self.dyn_net.get_active_net_config() |
|
|
return active_net_config, image_size |
|
|
|
|
|
def get_efficiency(self, arch_dict): |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
class ProxylessNASFLOPsModel(BaseEfficiencyModel): |
|
|
def get_efficiency(self, arch_dict): |
|
|
active_net_config, image_size = self.get_active_subnet_config(arch_dict) |
|
|
return ProxylessNASLatencyTable.count_flops_given_config( |
|
|
active_net_config, image_size |
|
|
) |
|
|
|
|
|
|
|
|
class Mbv3FLOPsModel(BaseEfficiencyModel): |
|
|
def get_efficiency(self, arch_dict): |
|
|
active_net_config, image_size = self.get_active_subnet_config(arch_dict) |
|
|
return MBv3LatencyTable.count_flops_given_config(active_net_config, image_size[0]) |
|
|
|
|
|
|
|
|
class ResNet50FLOPsModel(BaseEfficiencyModel): |
|
|
def get_efficiency(self, arch_dict): |
|
|
active_net_config, image_size = self.get_active_subnet_config(arch_dict) |
|
|
return ResNet50LatencyTable.count_flops_given_config( |
|
|
active_net_config, image_size |
|
|
) |
|
|
|
|
|
|
|
|
class ProxylessNASLatencyModel(BaseEfficiencyModel): |
|
|
def __init__(self, dyn_net, lookup_table_path_dict): |
|
|
super(ProxylessNASLatencyModel, self).__init__(dyn_net) |
|
|
self.latency_tables = {} |
|
|
for image_size, path in lookup_table_path_dict.items(): |
|
|
self.latency_tables[image_size] = ProxylessNASLatencyTable( |
|
|
local_dir="/tmp/.dyn_latency_tools/", |
|
|
url=os.path.join(path, "%d_lookup_table.yaml" % image_size), |
|
|
) |
|
|
|
|
|
def get_efficiency(self, arch_dict): |
|
|
active_net_config, image_size = self.get_active_subnet_config(arch_dict) |
|
|
return self.latency_tables[image_size].predict_network_latency_given_config( |
|
|
active_net_config, image_size |
|
|
) |
|
|
|
|
|
|
|
|
class Mbv3LatencyModel(BaseEfficiencyModel): |
|
|
def __init__(self, dyn_net, lookup_table_path_dict): |
|
|
super(Mbv3LatencyModel, self).__init__(dyn_net) |
|
|
self.latency_tables = {} |
|
|
for image_size, path in lookup_table_path_dict.items(): |
|
|
self.latency_tables[image_size] = MBv3LatencyTable( |
|
|
local_dir="/tmp/.dyn_latency_tools/", |
|
|
url=os.path.join(path, "%d_lookup_table.yaml" % image_size), |
|
|
) |
|
|
|
|
|
def get_efficiency(self, arch_dict): |
|
|
active_net_config, image_size = self.get_active_subnet_config(arch_dict) |
|
|
return self.latency_tables[image_size].predict_network_latency_given_config( |
|
|
active_net_config, image_size |
|
|
) |
|
|
|