| | |
| | |
| |
|
| |
|
| |
|
| |
|
| |
|
| | import abc |
| |
|
| |
|
| | class SamplingTrainableMixin(metaclass=abc.ABCMeta): |
| |
|
| | def __init__(self, *args, **kwargs): |
| | super(SamplingTrainableMixin, self).__init__(*args, **kwargs) |
| | self._train_param_blobs = None |
| | self._train_param_blobs_frozen = False |
| |
|
| | @property |
| | @abc.abstractmethod |
| | def param_blobs(self): |
| | """ |
| | List of parameter blobs for prediction net |
| | """ |
| | pass |
| |
|
| | @property |
| | def train_param_blobs(self): |
| | """ |
| | If train_param_blobs is not set before used, default to param_blobs |
| | """ |
| | if self._train_param_blobs is None: |
| | self.train_param_blobs = self.param_blobs |
| | return self._train_param_blobs |
| |
|
| | @train_param_blobs.setter |
| | def train_param_blobs(self, blobs): |
| | assert not self._train_param_blobs_frozen |
| | assert blobs is not None |
| | self._train_param_blobs_frozen = True |
| | self._train_param_blobs = blobs |
| |
|
| | @abc.abstractmethod |
| | def _add_ops(self, net, param_blobs): |
| | """ |
| | Add ops to the given net, using the given param_blobs |
| | """ |
| | pass |
| |
|
| | def add_ops(self, net): |
| | self._add_ops(net, self.param_blobs) |
| |
|
| | def add_train_ops(self, net): |
| | self._add_ops(net, self.train_param_blobs) |
| |
|