File size: 3,054 Bytes
188f311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.

import os
import copy
from .latency_lookup_table import *


class BaseEfficiencyModel:
    def __init__(self, dyn_net):
        self.dyn_net = dyn_net

    def get_active_subnet_config(self, arch_dict):
        arch_dict = copy.deepcopy(arch_dict)
        image_size = arch_dict.pop("image_size")
        self.dyn_net.set_active_subnet(**arch_dict)
        active_net_config = self.dyn_net.get_active_net_config()
        return active_net_config, image_size

    def get_efficiency(self, arch_dict):
        raise NotImplementedError


class ProxylessNASFLOPsModel(BaseEfficiencyModel):
    def get_efficiency(self, arch_dict):
        active_net_config, image_size = self.get_active_subnet_config(arch_dict)
        return ProxylessNASLatencyTable.count_flops_given_config(
            active_net_config, image_size
        )


class Mbv3FLOPsModel(BaseEfficiencyModel):
    def get_efficiency(self, arch_dict):
        active_net_config, image_size = self.get_active_subnet_config(arch_dict)
        return MBv3LatencyTable.count_flops_given_config(active_net_config, image_size[0])


class ResNet50FLOPsModel(BaseEfficiencyModel):
    def get_efficiency(self, arch_dict):
        active_net_config, image_size = self.get_active_subnet_config(arch_dict)
        return ResNet50LatencyTable.count_flops_given_config(
            active_net_config, image_size
        )


class ProxylessNASLatencyModel(BaseEfficiencyModel):
    def __init__(self, dyn_net, lookup_table_path_dict):
        super(ProxylessNASLatencyModel, self).__init__(dyn_net)
        self.latency_tables = {}
        for image_size, path in lookup_table_path_dict.items():
            self.latency_tables[image_size] = ProxylessNASLatencyTable(
                local_dir="/tmp/.dyn_latency_tools/",
                url=os.path.join(path, "%d_lookup_table.yaml" % image_size),
            )

    def get_efficiency(self, arch_dict):
        active_net_config, image_size = self.get_active_subnet_config(arch_dict)
        return self.latency_tables[image_size].predict_network_latency_given_config(
            active_net_config, image_size
        )


class Mbv3LatencyModel(BaseEfficiencyModel):
    def __init__(self, dyn_net, lookup_table_path_dict):
        super(Mbv3LatencyModel, self).__init__(dyn_net)
        self.latency_tables = {}
        for image_size, path in lookup_table_path_dict.items():
            self.latency_tables[image_size] = MBv3LatencyTable(
                local_dir="/tmp/.dyn_latency_tools/",
                url=os.path.join(path, "%d_lookup_table.yaml" % image_size),
            )

    def get_efficiency(self, arch_dict):
        active_net_config, image_size = self.get_active_subnet_config(arch_dict)
        return self.latency_tables[image_size].predict_network_latency_given_config(
            active_net_config, image_size
        )