|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from ppdet.core.workspace import create |
|
|
from ppdet.utils.logger import setup_logger |
|
|
logger = setup_logger('ppdet.engine') |
|
|
|
|
|
from . import Trainer |
|
|
__all__ = ['TrainerCot'] |
|
|
|
|
|
class TrainerCot(Trainer): |
|
|
""" |
|
|
Trainer for label-cotuning |
|
|
calculate the relationship between base_classes and novel_classes |
|
|
""" |
|
|
def __init__(self, cfg, mode='train'): |
|
|
super(TrainerCot, self).__init__(cfg, mode) |
|
|
self.cotuning_init() |
|
|
|
|
|
def cotuning_init(self): |
|
|
num_classes_novel = self.cfg['num_classes'] |
|
|
|
|
|
self.load_weights(self.cfg.pretrain_weights) |
|
|
|
|
|
self.model.eval() |
|
|
relationship = self.model.relationship_learning(self.loader, num_classes_novel) |
|
|
|
|
|
self.model.init_cot_head(relationship) |
|
|
self.optimizer = create('OptimizerBuilder')(self.lr, self.model) |
|
|
|
|
|
|
|
|
|