| | |
| | |
| | """ |
| | Implement functions for controlling execution of nets and steps, including |
| | Do |
| | DoParallel |
| | For-loop |
| | While-loop |
| | Do-While-loop |
| | Switch |
| | If |
| | """ |
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | from caffe2.python import core |
| | from future.utils import viewitems |
| |
|
| |
|
| | |
| | |
| | _current_idx = 1 |
| | _used_step_names = set() |
| |
|
| |
|
| | def _get_next_step_name(control_name, base_name): |
| | global _current_idx, _used_step_names |
| | concat_name = '%s/%s' % (base_name, control_name) |
| | next_name = concat_name |
| | while next_name in _used_step_names: |
| | next_name = '%s_%d' % (concat_name, _current_idx) |
| | _current_idx += 1 |
| | _used_step_names.add(next_name) |
| | return next_name |
| |
|
| |
|
| | def _MakeList(input): |
| | """ input is a tuple. |
| | Example: |
| | (a, b, c) --> [a, b, c] |
| | (a) --> [a] |
| | ([a, b, c]) --> [a, b, c] |
| | """ |
| | if len(input) == 0: |
| | raise ValueError( |
| | 'input cannot be empty.') |
| | elif len(input) == 1: |
| | output = input[0] |
| | if not isinstance(output, list): |
| | output = [output] |
| | else: |
| | output = list(input) |
| | return output |
| |
|
| |
|
| | def _IsNets(nets_or_steps): |
| | if isinstance(nets_or_steps, list): |
| | return all(isinstance(n, core.Net) for n in nets_or_steps) |
| | else: |
| | return isinstance(nets_or_steps, core.Net) |
| |
|
| |
|
| | def _PrependNets(nets_or_steps, *nets): |
| | nets_or_steps = _MakeList((nets_or_steps,)) |
| | nets = _MakeList(nets) |
| | if _IsNets(nets_or_steps): |
| | return nets + nets_or_steps |
| | else: |
| | return [Do('prepend', nets)] + nets_or_steps |
| |
|
| |
|
| | def _AppendNets(nets_or_steps, *nets): |
| | nets_or_steps = _MakeList((nets_or_steps,)) |
| | nets = _MakeList(nets) |
| | if _IsNets(nets_or_steps): |
| | return nets_or_steps + nets |
| | else: |
| | return nets_or_steps + [Do('append', nets)] |
| |
|
| |
|
| | def GetConditionBlobFromNet(condition_net): |
| | """ |
| | The condition blob is the last external_output that must |
| | be a single bool |
| | """ |
| | assert len(condition_net.Proto().external_output) > 0, ( |
| | "Condition net %s must has at least one external output" % |
| | condition_net.Proto.name) |
| | |
| | |
| | |
| | return core.BlobReference(condition_net.Proto().external_output[-1]) |
| |
|
| |
|
| | def BoolNet(*blobs_with_bool_value): |
| | """A net assigning constant bool values to blobs. It is mainly used for |
| | initializing condition blobs, for example, in multi-task learning, we |
| | need to access reader_done blobs before reader_net run. In that case, |
| | the reader_done blobs must be initialized. |
| | |
| | Args: |
| | blobs_with_bool_value: one or more (blob, bool_value) pairs. The net will |
| | assign each bool_value to the corresponding blob. |
| | |
| | returns |
| | bool_net: A net assigning constant bool values to blobs. |
| | |
| | Examples: |
| | - BoolNet((blob_1, bool_value_1), ..., (blob_n, bool_value_n)) |
| | - BoolNet([(blob_1, net1), ..., (blob_n, bool_value_n)]) |
| | - BoolNet((cond_1, bool_value_1)) |
| | """ |
| | blobs_with_bool_value = _MakeList(blobs_with_bool_value) |
| | bool_net = core.Net('bool_net') |
| | for blob, bool_value in blobs_with_bool_value: |
| | out_blob = bool_net.ConstantFill( |
| | [], |
| | [blob], |
| | shape=[], |
| | value=bool_value, |
| | dtype=core.DataType.BOOL) |
| | bool_net.AddExternalOutput(out_blob) |
| |
|
| | return bool_net |
| |
|
| |
|
| | def NotNet(condition_blob_or_net): |
| | """Not of a condition blob or net |
| | |
| | Args: |
| | condition_blob_or_net can be either blob or net. If condition_blob_or_net |
| | is Net, the condition is its last external_output |
| | that must be a single bool. |
| | |
| | returns |
| | not_net: the net NOT the input |
| | out_blob: the output blob of the not_net |
| | """ |
| | if isinstance(condition_blob_or_net, core.Net): |
| | condition_blob = GetConditionBlobFromNet(condition_blob_or_net) |
| | else: |
| | condition_blob = condition_blob_or_net |
| |
|
| | not_net = core.Net('not_net') |
| | out_blob = not_net.Not(condition_blob) |
| | not_net.AddExternalOutput(out_blob) |
| |
|
| | return not_net, out_blob |
| |
|
| |
|
| | def _CopyConditionBlobNet(condition_blob): |
| | """Make a condition net that copies the condition_blob |
| | |
| | Args: |
| | condition_blob is a single bool. |
| | |
| | returns |
| | not_net: the net NOT the input |
| | out_blob: the output blob of the not_net |
| | """ |
| | condition_net = core.Net('copy_condition_blob_net') |
| | out_blob = condition_net.Copy(condition_blob) |
| | condition_net.AddExternalOutput(out_blob) |
| |
|
| | return condition_net, out_blob |
| |
|
| |
|
| | def MergeConditionNets(name, condition_nets, relation): |
| | """ |
| | Merge multi condition nets into a single condition nets. |
| | |
| | Args: |
| | name: name of the new condition net. |
| | condition_nets: a list of condition nets. The last external_output |
| | of each condition net must be single bool value. |
| | relation: can be 'And' or 'Or'. |
| | |
| | Returns: |
| | - A new condition net. Its last external output is relation of all |
| | condition_nets. |
| | """ |
| | if not isinstance(condition_nets, list): |
| | return condition_nets |
| | if len(condition_nets) <= 1: |
| | return condition_nets[0] if condition_nets else None |
| |
|
| | merged_net = core.Net(name) |
| | for i in range(len(condition_nets)): |
| | net_proto = condition_nets[i].Proto() |
| | assert net_proto.device_option == merged_net.Proto().device_option |
| | assert net_proto.type == merged_net.Proto().type |
| | merged_net.Proto().op.extend(net_proto.op) |
| | merged_net.Proto().external_input.extend(net_proto.external_input) |
| | |
| | curr_cond = GetConditionBlobFromNet(condition_nets[i]) |
| | if i == 0: |
| | last_cond = curr_cond |
| | else: |
| | last_cond = merged_net.__getattr__(relation)([last_cond, curr_cond]) |
| | |
| | for k, v in viewitems(condition_nets[i]._attr_dict): |
| | merged_net._attr_dict[k] += v |
| |
|
| | merged_net.AddExternalOutput(last_cond) |
| |
|
| | return merged_net |
| |
|
| |
|
| | def CombineConditions(name, condition_nets, relation): |
| | """ |
| | Combine conditions of multi nets into a single condition nets. Unlike |
| | MergeConditionNets, the actual body of condition_nets is not copied into |
| | the combine condition net. |
| | |
| | One example is about multi readers. Each reader net has a reader_done |
| | condition. When we want to check whether all readers are done, we can |
| | use this function to build a new net. |
| | |
| | Args: |
| | name: name of the new condition net. |
| | condition_nets: a list of condition nets. The last external_output |
| | of each condition net must be single bool value. |
| | relation: can be 'And' or 'Or'. |
| | |
| | Returns: |
| | - A new condition net. Its last external output is relation of all |
| | condition_nets. |
| | """ |
| | if not condition_nets: |
| | return None |
| | if not isinstance(condition_nets, list): |
| | raise ValueError('condition_nets must be a list of nets.') |
| |
|
| | if len(condition_nets) == 1: |
| | condition_blob = GetConditionBlobFromNet(condition_nets[0]) |
| | condition_net, _ = _CopyConditionBlobNet(condition_blob) |
| | return condition_net |
| |
|
| | combined_net = core.Net(name) |
| | for i in range(len(condition_nets)): |
| | curr_cond = GetConditionBlobFromNet(condition_nets[i]) |
| | if i == 0: |
| | last_cond = curr_cond |
| | else: |
| | last_cond = combined_net.__getattr__(relation)( |
| | [last_cond, curr_cond]) |
| |
|
| | combined_net.AddExternalOutput(last_cond) |
| |
|
| | return combined_net |
| |
|
| |
|
| | def Do(name, *nets_or_steps): |
| | """ |
| | Execute the sequence of nets or steps once. |
| | |
| | Examples: |
| | - Do('myDo', net1, net2, ..., net_n) |
| | - Do('myDo', list_of_nets) |
| | - Do('myDo', step1, step2, ..., step_n) |
| | - Do('myDo', list_of_steps) |
| | """ |
| | nets_or_steps = _MakeList(nets_or_steps) |
| | if (len(nets_or_steps) == 1 and isinstance( |
| | nets_or_steps[0], core.ExecutionStep)): |
| | return nets_or_steps[0] |
| | else: |
| | return core.scoped_execution_step( |
| | _get_next_step_name('Do', name), nets_or_steps) |
| |
|
| |
|
| | def DoParallel(name, *nets_or_steps): |
| | """ |
| | Execute the nets or steps in parallel, waiting for all of them to finish |
| | |
| | Examples: |
| | - DoParallel('pDo', net1, net2, ..., net_n) |
| | - DoParallel('pDo', list_of_nets) |
| | - DoParallel('pDo', step1, step2, ..., step_n) |
| | - DoParallel('pDo', list_of_steps) |
| | """ |
| | nets_or_steps = _MakeList(nets_or_steps) |
| | if (len(nets_or_steps) == 1 and isinstance( |
| | nets_or_steps[0], core.ExecutionStep)): |
| | return nets_or_steps[0] |
| | else: |
| | return core.scoped_execution_step( |
| | _get_next_step_name('DoParallel', name), |
| | nets_or_steps, |
| | concurrent_substeps=True) |
| |
|
| |
|
| | def _RunOnceIf(name, condition_blob_or_net, nets_or_steps): |
| | """ |
| | Execute nets_or_steps once if condition_blob_or_net evaluates as true. |
| | |
| | If condition_blob_or_net is Net, the condition is its last external_output |
| | that must be a single bool. And this net will be executed before |
| | nets_or_steps so as to get the condition. |
| | """ |
| | condition_not_net, stop_blob = NotNet(condition_blob_or_net) |
| | if isinstance(condition_blob_or_net, core.Net): |
| | nets_or_steps = _PrependNets( |
| | nets_or_steps, condition_blob_or_net, condition_not_net) |
| | else: |
| | nets_or_steps = _PrependNets(nets_or_steps, condition_not_net) |
| |
|
| | def if_step(control_name): |
| | return core.scoped_execution_step( |
| | _get_next_step_name(control_name, name), |
| | nets_or_steps, |
| | should_stop_blob=stop_blob, |
| | only_once=True, |
| | ) |
| |
|
| | if _IsNets(nets_or_steps): |
| | bool_net = BoolNet((stop_blob, False)) |
| | return Do(name + '/_RunOnceIf', |
| | bool_net, if_step('_RunOnceIf-inner')) |
| | else: |
| | return if_step('_RunOnceIf') |
| |
|
| |
|
| | def _RunOnceIfNot(name, condition_blob_or_net, nets_or_steps): |
| | """ |
| | Similar to _RunOnceIf() but Execute nets_or_steps once if |
| | condition_blob_or_net evaluates as false. |
| | """ |
| | if isinstance(condition_blob_or_net, core.Net): |
| | condition_blob = GetConditionBlobFromNet(condition_blob_or_net) |
| | nets_or_steps = _PrependNets(nets_or_steps, condition_blob_or_net) |
| | else: |
| | copy_net, condition_blob = _CopyConditionBlobNet(condition_blob_or_net) |
| | nets_or_steps = _PrependNets(nets_or_steps, copy_net) |
| |
|
| | return core.scoped_execution_step( |
| | _get_next_step_name('_RunOnceIfNot', name), |
| | nets_or_steps, |
| | should_stop_blob=condition_blob, |
| | only_once=True, |
| | ) |
| |
|
| |
|
| | def For(name, nets_or_steps, iter_num): |
| | """ |
| | Execute nets_or_steps iter_num times. |
| | |
| | Args: |
| | nets_or_steps: a ExecutionStep or a Net or a list of ExecutionSteps or |
| | a list nets. |
| | iter_num: the number times to execute the nets_or_steps. |
| | |
| | Returns: |
| | A ExecutionStep instance. |
| | """ |
| | init_net = core.Net('init-net') |
| | iter_cnt = init_net.CreateCounter([], init_count=iter_num) |
| | iter_net = core.Net('For-iter') |
| | iter_done = iter_net.CountDown([iter_cnt]) |
| |
|
| | for_step = core.scoped_execution_step( |
| | _get_next_step_name('For-inner', name), |
| | _PrependNets(nets_or_steps, iter_net), |
| | should_stop_blob=iter_done) |
| | return Do(name + '/For', |
| | Do(name + '/For-init-net', init_net), |
| | for_step) |
| |
|
| |
|
| | def While(name, condition_blob_or_net, nets_or_steps): |
| | """ |
| | Execute nets_or_steps when condition_blob_or_net returns true. |
| | |
| | Args: |
| | condition_blob_or_net: If it is an instance of Net, its last |
| | external_output must be a single bool. |
| | nets_or_steps: a ExecutionStep or a Net or a list of ExecutionSteps or |
| | a list nets. |
| | |
| | Returns: |
| | A ExecutionStep instance. |
| | """ |
| | condition_not_net, stop_blob = NotNet(condition_blob_or_net) |
| | if isinstance(condition_blob_or_net, core.Net): |
| | nets_or_steps = _PrependNets( |
| | nets_or_steps, condition_blob_or_net, condition_not_net) |
| | else: |
| | nets_or_steps = _PrependNets(nets_or_steps, condition_not_net) |
| |
|
| | def while_step(control_name): |
| | return core.scoped_execution_step( |
| | _get_next_step_name(control_name, name), |
| | nets_or_steps, |
| | should_stop_blob=stop_blob, |
| | ) |
| |
|
| | if _IsNets(nets_or_steps): |
| | |
| | |
| | |
| | |
| | |
| | |
| | bool_net = BoolNet((stop_blob, False)) |
| | return Do(name + '/While', bool_net, while_step('While-inner')) |
| | else: |
| | return while_step('While') |
| |
|
| |
|
| | def Until(name, condition_blob_or_net, nets_or_steps): |
| | """ |
| | Similar to While() but execute nets_or_steps when |
| | condition_blob_or_net returns false |
| | """ |
| | if isinstance(condition_blob_or_net, core.Net): |
| | stop_blob = GetConditionBlobFromNet(condition_blob_or_net) |
| | nets_or_steps = _PrependNets(nets_or_steps, condition_blob_or_net) |
| | else: |
| | stop_blob = core.BlobReference(str(condition_blob_or_net)) |
| |
|
| | return core.scoped_execution_step( |
| | _get_next_step_name('Until', name), |
| | nets_or_steps, |
| | should_stop_blob=stop_blob) |
| |
|
| |
|
| | def DoWhile(name, condition_blob_or_net, nets_or_steps): |
| | """ |
| | Execute nets_or_steps when condition_blob_or_net returns true. It will |
| | execute nets_or_steps before evaluating condition_blob_or_net. |
| | |
| | Args: |
| | condition_blob_or_net: if it is an instance of Net, tts last external_output |
| | must be a single bool. |
| | nets_or_steps: a ExecutionStep or a Net or a list of ExecutionSteps or |
| | a list nets. |
| | |
| | Returns: |
| | A ExecutionStep instance. |
| | """ |
| | condition_not_net, stop_blob = NotNet(condition_blob_or_net) |
| | if isinstance(condition_blob_or_net, core.Net): |
| | nets_or_steps = _AppendNets( |
| | nets_or_steps, condition_blob_or_net, condition_not_net) |
| | else: |
| | nets_or_steps = _AppendNets(nets_or_steps, condition_not_net) |
| |
|
| | |
| | |
| | |
| | |
| | bool_net = BoolNet((stop_blob, False)) |
| | return Do(name + '/DoWhile', bool_net, core.scoped_execution_step( |
| | _get_next_step_name('DoWhile-inner', name), |
| | nets_or_steps, |
| | should_stop_blob=stop_blob, |
| | )) |
| |
|
| |
|
| | def DoUntil(name, condition_blob_or_net, nets_or_steps): |
| | """ |
| | Similar to DoWhile() but execute nets_or_steps when |
| | condition_blob_or_net returns false. It will execute |
| | nets_or_steps before evaluating condition_blob_or_net. |
| | |
| | Special case: if condition_blob_or_net is a blob and is pre-set to |
| | true, then only the first net/step of nets_or_steps will be executed and |
| | loop is exited. So you need to be careful about the initial value the |
| | condition blob when using DoUntil(), esp when DoUntil() is called twice. |
| | """ |
| | if not isinstance(condition_blob_or_net, core.Net): |
| | stop_blob = core.BlobReference(condition_blob_or_net) |
| | return core.scoped_execution_step( |
| | _get_next_step_name('DoUntil', name), |
| | nets_or_steps, |
| | should_stop_blob=stop_blob) |
| |
|
| | nets_or_steps = _AppendNets(nets_or_steps, condition_blob_or_net) |
| | stop_blob = GetConditionBlobFromNet(condition_blob_or_net) |
| |
|
| | |
| | |
| | |
| | |
| | bool_net = BoolNet((stop_blob, False)) |
| | return Do(name + '/DoUntil', bool_net, core.scoped_execution_step( |
| | _get_next_step_name('DoUntil-inner', name), |
| | nets_or_steps, |
| | should_stop_blob=stop_blob, |
| | )) |
| |
|
| |
|
| | def Switch(name, *conditions): |
| | """ |
| | Execute the steps for which the condition is true. |
| | Each condition is a tuple (condition_blob_or_net, nets_or_steps). |
| | Note: |
| | 1. Multi steps can be executed if their conditions are true. |
| | 2. The conditions_blob_or_net (if it is Net) of all steps will be |
| | executed once. |
| | |
| | Examples: |
| | - Switch('name', (cond_1, net_1), (cond_2, net_2), ..., (cond_n, net_n)) |
| | - Switch('name', [(cond_1, net1), (cond_2, net_2), ..., (cond_n, net_n)]) |
| | - Switch('name', (cond_1, net_1)) |
| | """ |
| | conditions = _MakeList(conditions) |
| | return core.scoped_execution_step( |
| | _get_next_step_name('Switch', name), |
| | [_RunOnceIf(name + '/Switch', cond, step) for cond, step in conditions]) |
| |
|
| |
|
| | def SwitchNot(name, *conditions): |
| | """ |
| | Similar to Switch() but execute the steps for which the condition is False. |
| | """ |
| | conditions = _MakeList(conditions) |
| | return core.scoped_execution_step( |
| | _get_next_step_name('SwitchNot', name), |
| | [_RunOnceIfNot(name + '/SwitchNot', cond, step) |
| | for cond, step in conditions]) |
| |
|
| |
|
| | def If(name, condition_blob_or_net, |
| | true_nets_or_steps, false_nets_or_steps=None): |
| | """ |
| | condition_blob_or_net is first evaluated or executed. If the condition is |
| | true, true_nets_or_steps is then executed, otherwise, false_nets_or_steps |
| | is executed. |
| | |
| | If condition_blob_or_net is Net, the condition is its last external_output |
| | that must be a single bool. And this Net will be executred before both |
| | true/false_nets_or_steps so as to get the condition. |
| | """ |
| | if not false_nets_or_steps: |
| | return _RunOnceIf(name + '/If', |
| | condition_blob_or_net, true_nets_or_steps) |
| |
|
| | if isinstance(condition_blob_or_net, core.Net): |
| | condition_blob = GetConditionBlobFromNet(condition_blob_or_net) |
| | else: |
| | condition_blob = condition_blob_or_net |
| |
|
| | return Do( |
| | name + '/If', |
| | _RunOnceIf(name + '/If-true', |
| | condition_blob_or_net, true_nets_or_steps), |
| | _RunOnceIfNot(name + '/If-false', condition_blob, false_nets_or_steps) |
| | ) |
| |
|
| |
|
| | def IfNot(name, condition_blob_or_net, |
| | true_nets_or_steps, false_nets_or_steps=None): |
| | """ |
| | If condition_blob_or_net returns false, executes true_nets_or_steps, |
| | otherwise executes false_nets_or_steps |
| | """ |
| | if not false_nets_or_steps: |
| | return _RunOnceIfNot(name + '/IfNot', |
| | condition_blob_or_net, true_nets_or_steps) |
| |
|
| | if isinstance(condition_blob_or_net, core.Net): |
| | condition_blob = GetConditionBlobFromNet(condition_blob_or_net) |
| | else: |
| | condition_blob = condition_blob_or_net |
| |
|
| | return Do( |
| | name + '/IfNot', |
| | _RunOnceIfNot(name + '/IfNot-true', |
| | condition_blob_or_net, true_nets_or_steps), |
| | _RunOnceIf(name + '/IfNot-false', condition_blob, false_nets_or_steps) |
| | ) |
| |
|