import torch from torch import nn class Finetune(nn.Module): def __init__(self, backbone, feat_dim, num_class, **kwargs): super().__init__() self.backbone = backbone self.feat_dim = feat_dim self.num_class = num_class self.classifier = nn.Linear(feat_dim, num_class) self.loss_fn = nn.CrossEntropyLoss(reduction='mean') self.device = kwargs['device'] self.kwargs = kwargs def observe(self, data): x, y = data['image'], data['label'] x = x.to(self.device) y = y.to(self.device) logit = self.classifier(self.backbone(x)['features']) loss = self.loss_fn(logit, y) pred = torch.argmax(logit, dim=1) acc = torch.sum(pred == y).item() return pred, acc / x.size(0), loss def inference(self, data): x, y = data['image'], data['label'] x = x.to(self.device) y = y.to(self.device) logit = self.classifier(self.backbone(x)['features']) pred = torch.argmax(logit, dim=1) acc = torch.sum(pred == y).item() return pred, acc / x.size(0) def forward(self, x): return self.classifier(self.backbone(x)['features']) def before_task(self, task_idx, buffer, train_loader, test_loaders): pass def after_task(self, task_idx, buffer, train_loader, test_loaders): pass def get_parameters(self, config): train_parameters = [] train_parameters.append({"params": self.backbone.parameters()}) train_parameters.append({"params": self.classifier.parameters()}) return train_parameters