|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Definitions for AssembleNet/++ structures.
|
|
|
| This structure is a `list` corresponding to a graph representation of the
|
| network, where a node is a convolutional block and an edge specifies a
|
| connection from one block to another.
|
|
|
| Each node itself (in the structure list) is a list with the following format:
|
| [block_level, [list_of_input_blocks], number_filter, temporal_dilation,
|
| spatial_stride]. [list_of_input_blocks] should be the list of node indexes whose
|
| values are less than the index of the node itself. The 'stems' of the network
|
| directly taking raw inputs follow a different node format:
|
| [stem_type, temporal_dilation]. The stem_type is -1 for RGB stem and is -2 for
|
| optical flow stem. The stem_type -3 is reserved for the object segmentation
|
| input.
|
|
|
| In AssembleNet++lite, instead of passing a single `int` for number_filter, we
|
| pass a list/tuple of three `int`s. They specify the number of channels to be
|
| used for each layer in the inverted bottleneck modules.
|
|
|
| The structure_weights specify the learned connection weights.
|
| """
|
| import dataclasses
|
| from typing import List, Optional, Tuple
|
|
|
| from official.core import config_definitions as cfg
|
| from official.core import exp_factory
|
| from official.modeling import hyperparams
|
| from official.vision.configs import backbones_3d
|
| from official.vision.configs import common
|
| from official.vision.configs import video_classification
|
|
|
|
|
| @dataclasses.dataclass
|
| class BlockSpec(hyperparams.Config):
|
| level: int = -1
|
| input_blocks: Tuple[int, ...] = tuple()
|
| num_filters: int = -1
|
| temporal_dilation: int = 1
|
| spatial_stride: int = 1
|
| input_block_weight: Tuple[float, ...] = tuple()
|
|
|
|
|
| def flat_lists_to_blocks(model_structures, model_edge_weights):
|
| """Transforms the raw list structure configs to BlockSpec tuple."""
|
| blocks = []
|
| for node, edge_weights in zip(model_structures, model_edge_weights):
|
| if node[0] < 0:
|
| block = BlockSpec(level=node[0], temporal_dilation=node[1])
|
| else:
|
| block = BlockSpec(
|
| level=node[0],
|
| input_blocks=node[1],
|
| num_filters=node[2],
|
| temporal_dilation=node[3],
|
| spatial_stride=node[4])
|
| if edge_weights:
|
| assert len(edge_weights[0]) == len(block.input_blocks), (
|
| f'{len(edge_weights[0])} != {len(block.input_blocks)} at block '
|
| f'{block} weight {edge_weights}')
|
| block.input_block_weight = tuple(edge_weights[0])
|
| blocks.append(block)
|
| return tuple(blocks)
|
|
|
|
|
| def blocks_to_flat_lists(blocks: List[BlockSpec]):
|
| """Transforms BlockSpec tuple to the raw list structure configs."""
|
|
|
|
|
| model_structure = [[
|
| b.level,
|
| list(b.input_blocks), b.num_filters, b.temporal_dilation,
|
| b.spatial_stride, 0
|
| ] if b.level >= 0 else [b.level, b.temporal_dilation] for b in blocks]
|
| model_edge_weights = [
|
| [list(b.input_block_weight)] if b.input_block_weight else []
|
| for b in blocks
|
| ]
|
| return model_structure, model_edge_weights
|
|
|
|
|
|
|
|
|
|
|
| asn50_structure = [[-1, 4], [-1, 4], [-2, 1], [-2, 1], [0, [1], 32, 1, 1, 0],
|
| [0, [0], 32, 4, 1, 0], [0, [0, 1, 2, 3], 32, 1, 1, 0],
|
| [0, [2, 3], 32, 2, 1, 0], [1, [0, 4, 5, 6, 7], 64, 2, 2, 0],
|
| [1, [0, 2, 4, 7], 64, 1, 2, 0], [1, [0, 5, 7], 64, 4, 2, 0],
|
| [1, [0, 5], 64, 1, 2, 0], [2, [4, 8, 10, 11], 256, 1, 2, 0],
|
| [2, [8, 9], 256, 4, 2, 0], [3, [12, 13], 512, 2, 2, 0]]
|
| asn101_structure = [[-1, 4], [-1, 4], [-2, 1], [-2, 1], [0, [1], 32, 1, 1, 0],
|
| [0, [0], 32, 4, 1, 0], [0, [0, 1, 2, 3], 32, 1, 1, 0],
|
| [0, [2, 3], 32, 2, 1, 0], [1, [0, 4, 5, 6, 7], 64, 2, 2, 0],
|
| [1, [0, 2, 4, 7], 64, 1, 2, 0], [1, [0, 5, 7], 64, 4, 2, 0],
|
| [1, [0, 5], 64, 1, 2, 0], [2, [4, 8, 10, 11], 192, 1, 2, 0],
|
| [2, [8, 9], 192, 4, 2, 0], [3, [12, 13], 256, 2, 2, 0]]
|
| asn_structure_weights = [
|
| [], [], [], [], [], [],
|
| [[
|
| 0.13810564577579498, 0.8465337157249451, 0.3072969317436218,
|
| 0.2867436408996582
|
| ]], [[0.5846117734909058, 0.6066334843635559]],
|
| [[
|
| 0.16382087767124176, 0.8852924704551697, 0.4039595425128937,
|
| 0.6823437809944153, 0.5331538319587708
|
| ]],
|
| [[
|
| 0.028569204732775688, 0.10333596915006638, 0.7517264485359192,
|
| 0.9260114431381226
|
| ]], [[0.28832191228866577, 0.7627848982810974, 0.404977947473526]],
|
| [[0.23474831879138947, 0.7841425538063049]],
|
| [[
|
| 0.27616503834724426, 0.9514784812927246, 0.6568767428398132,
|
| 0.9547983407974243
|
| ]], [[0.5047007203102112, 0.8876819610595703]],
|
| [[0.9892204403877258, 0.8454614877700806]]
|
| ]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| full_asnp50_structure = [[-1, 2], [-1, 4], [-2, 2], [-2, 1], [-3, 4],
|
| [0, [0, 1, 2, 3, 4], 32, 1, 1, 0],
|
| [0, [0, 1, 4], 32, 4, 1, 0],
|
| [0, [2, 3, 4], 32, 8, 1, 0],
|
| [0, [2, 3, 4], 32, 1, 1, 0],
|
| [1, [0, 1, 2, 4, 5, 6, 7, 8], 64, 4, 2, 0],
|
| [1, [2, 3, 4, 7, 8], 64, 1, 2, 0],
|
| [1, [0, 4, 5, 6, 7], 128, 8, 2, 0],
|
| [2, [4, 11], 256, 8, 2, 0],
|
| [2, [2, 3, 4, 5, 6, 7, 8, 10, 11], 256, 4, 2, 0],
|
| [3, [12, 13], 512, 2, 2, 0]]
|
| full_asnp_structure_weights = [[], [], [], [], [], [[0.6143830418586731, 0.7111759185791016, 0.19351491332054138, 0.1701001077890396, 0.7178536653518677]], [[0.5755624771118164, 0.5644599795341492, 0.7128658294677734]], [[0.26563042402267456, 0.3033692538738251, 0.8244096636772156]], [[0.07013848423957825, 0.07905343919992447, 0.8767927885055542]], [[0.5008697509765625, 0.5020178556442261, 0.49819135665893555, 0.5015180706977844, 0.4987695813179016, 0.4990265369415283, 0.499239057302475, 0.4974501430988312]], [[0.47034338116645813, 0.4694305658340454, 0.767791748046875, 0.5539310574531555, 0.4520096182823181]], [[0.2769702076911926, 0.8116549253463745, 0.597356915473938, 0.6585626602172852, 0.5915306210517883]], [[0.501274824142456, 0.5016682147979736]], [[0.0866393893957138, 0.08469288796186447, 0.9739039540290833, 0.058271341025829315, 0.08397126197814941, 0.10285478830337524, 0.18506969511508942, 0.23874442279338837, 0.9188644886016846]], [[0.4174623489379883, 0.5844835638999939]]]
|
|
|
|
|
|
|
|
|
|
|
| asnp_lite_structure = [[-1, 1], [-2, 1],
|
| [0, [0, 1], [27, 27, 12], 1, 2, 0],
|
| [0, [0, 1], [27, 27, 12], 4, 2, 0],
|
| [1, [0, 1, 2, 3], [54, 54, 24], 2, 2, 0],
|
| [1, [0, 1, 2, 3], [54, 54, 24], 1, 2, 0],
|
| [1, [0, 1, 2, 3], [54, 54, 24], 4, 2, 0],
|
| [1, [0, 1, 2, 3], [54, 54, 24], 1, 2, 0],
|
| [2, [0, 1, 2, 3, 4, 5, 6, 7], [152, 152, 68], 1, 2, 0],
|
| [2, [0, 1, 2, 3, 4, 5, 6, 7], [152, 152, 68], 4, 2, 0],
|
| [3, [2, 3, 4, 5, 6, 7, 8, 9], [432, 432, 192], 2, 2, 0]]
|
| asnp_lite_structure_weights = [[], [], [[0.19914183020591736, 0.9278576374053955]], [[0.010816320776939392, 0.888792097568512]], [[0.9473835825920105, 0.6303419470787048, 0.1704932451248169, 0.05950307101011276]], [[0.9560931324958801, 0.7898273468017578, 0.36138781905174255, 0.07344610244035721]], [[0.9213919043540955, 0.13418640196323395, 0.8371981978416443, 0.07936054468154907]], [[0.9441559910774231, 0.9435100555419922, 0.7253988981246948, 0.13498817384243011]], [[0.9964852333068848, 0.8427878618240356, 0.8895476460456848, 0.11014710366725922, 0.6270533204078674, 0.44782018661499023, 0.61344975233078, 0.44898226857185364]], [[0.9970942735671997, 0.7105681896209717, 0.5078442096710205, 0.0951600968837738, 0.624282717704773, 0.8527252674102783, 0.8105692863464355, 0.7857823967933655]], [[0.6180334091186523, 0.11882413923740387, 0.06102970987558365, 0.04484326392412186, 0.05602221190929413, 0.052324872463941574, 0.9969874024391174, 0.9987731575965881]]]
|
|
|
|
|
| @dataclasses.dataclass
|
| class AssembleNet(hyperparams.Config):
|
| model_id: str = '50'
|
| num_frames: int = 0
|
| combine_method: str = 'sigmoid'
|
| blocks: Tuple[BlockSpec, ...] = tuple()
|
|
|
|
|
| @dataclasses.dataclass
|
| class AssembleNetPlus(hyperparams.Config):
|
| model_id: str = '50'
|
| num_frames: int = 0
|
| attention_mode: str = 'None'
|
| blocks: Tuple[BlockSpec, ...] = tuple()
|
| use_object_input: bool = False
|
|
|
|
|
| @dataclasses.dataclass
|
| class Backbone3D(backbones_3d.Backbone3D):
|
| """Configuration for backbones.
|
|
|
| Attributes:
|
| type: 'str', type of backbone be used, on the of fields below.
|
| assemblenet: AssembleNet backbone config.
|
| assemblenet_plus : AssembleNetPlus backbone config.
|
| """
|
| type: Optional[str] = None
|
| assemblenet: AssembleNet = dataclasses.field(default_factory=AssembleNet)
|
| assemblenet_plus: AssembleNetPlus = dataclasses.field(
|
| default_factory=AssembleNetPlus
|
| )
|
|
|
|
|
| @dataclasses.dataclass
|
| class AssembleNetModel(video_classification.VideoClassificationModel):
|
| """The AssembleNet model config."""
|
| model_type: str = 'assemblenet'
|
| backbone: Backbone3D = dataclasses.field(
|
| default_factory=lambda: Backbone3D(type='assemblenet')
|
| )
|
| norm_activation: common.NormActivation = dataclasses.field(
|
| default_factory=lambda: common.NormActivation(
|
| norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=True
|
| )
|
| )
|
| max_pool_predictions: bool = False
|
|
|
|
|
| @dataclasses.dataclass
|
| class AssembleNetPlusModel(video_classification.VideoClassificationModel):
|
| """The AssembleNet model config."""
|
| model_type: str = 'assemblenet_plus'
|
| backbone: Backbone3D = dataclasses.field(
|
| default_factory=lambda: Backbone3D(type='assemblenet_plus')
|
| )
|
| norm_activation: common.NormActivation = dataclasses.field(
|
| default_factory=lambda: common.NormActivation(
|
| norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=True
|
| )
|
| )
|
| max_pool_predictions: bool = False
|
|
|
|
|
| @exp_factory.register_config_factory('assemblenet50_kinetics600')
|
| def assemblenet_kinetics600() -> cfg.ExperimentConfig:
|
| """Video classification on Videonet with assemblenet."""
|
| exp = video_classification.video_classification_kinetics600()
|
|
|
| feature_shape = (32, 224, 224, 3)
|
| exp.task.train_data.global_batch_size = 1024
|
| exp.task.validation_data.global_batch_size = 32
|
| exp.task.train_data.feature_shape = feature_shape
|
| exp.task.validation_data.feature_shape = (120, 224, 224, 3)
|
| exp.task.train_data.dtype = 'bfloat16'
|
| exp.task.validation_data.dtype = 'bfloat16'
|
|
|
| model = AssembleNetModel()
|
| model.backbone.assemblenet.model_id = '50'
|
| model.backbone.assemblenet.blocks = flat_lists_to_blocks(
|
| asn50_structure, asn_structure_weights)
|
| model.backbone.assemblenet.num_frames = feature_shape[0]
|
| exp.task.model = model
|
|
|
| assert exp.task.model.backbone.assemblenet.num_frames > 0, (
|
| f'backbone num_frames '
|
| f'{exp.task.model.backbone.assemblenet}')
|
|
|
| return exp
|
|
|
|
|
| @exp_factory.register_config_factory('assemblenet_ucf101')
|
| def assemblenet_ucf101() -> cfg.ExperimentConfig:
|
| """Video classification on Videonet with assemblenet."""
|
| exp = video_classification.video_classification_ucf101()
|
| exp.task.train_data.dtype = 'bfloat16'
|
| exp.task.validation_data.dtype = 'bfloat16'
|
| feature_shape = (32, 224, 224, 3)
|
| model = AssembleNetModel()
|
| model.backbone.assemblenet.blocks = flat_lists_to_blocks(
|
| asn50_structure, asn_structure_weights)
|
| model.backbone.assemblenet.num_frames = feature_shape[0]
|
| exp.task.model = model
|
|
|
| assert exp.task.model.backbone.assemblenet.num_frames > 0, (
|
| f'backbone num_frames '
|
| f'{exp.task.model.backbone.assemblenet}')
|
|
|
| return exp
|
|
|
|
|
| @exp_factory.register_config_factory('assemblenetplus_ucf101')
|
| def assemblenetplus_ucf101() -> cfg.ExperimentConfig:
|
| """Video classification on Videonet with assemblenet."""
|
| exp = video_classification.video_classification_ucf101()
|
| exp.task.train_data.dtype = 'bfloat16'
|
| exp.task.validation_data.dtype = 'bfloat16'
|
| feature_shape = (32, 224, 224, 3)
|
| model = AssembleNetPlusModel()
|
| model.backbone.assemblenet_plus.blocks = flat_lists_to_blocks(
|
| asn50_structure, asn_structure_weights)
|
| model.backbone.assemblenet_plus.num_frames = feature_shape[0]
|
| exp.task.model = model
|
|
|
| assert exp.task.model.backbone.assemblenet_plus.num_frames > 0, (
|
| f'backbone num_frames '
|
| f'{exp.task.model.backbone.assemblenet_plus}')
|
|
|
| return exp
|
|
|