File size: 1,662 Bytes
5fee096 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 | import torch
from torch import nn
class Finetune(nn.Module):
def __init__(self, backbone, feat_dim, num_class, **kwargs):
super().__init__()
self.backbone = backbone
self.feat_dim = feat_dim
self.num_class = num_class
self.classifier = nn.Linear(feat_dim, num_class)
self.loss_fn = nn.CrossEntropyLoss(reduction='mean')
self.device = kwargs['device']
self.kwargs = kwargs
def observe(self, data):
x, y = data['image'], data['label']
x = x.to(self.device)
y = y.to(self.device)
logit = self.classifier(self.backbone(x)['features'])
loss = self.loss_fn(logit, y)
pred = torch.argmax(logit, dim=1)
acc = torch.sum(pred == y).item()
return pred, acc / x.size(0), loss
def inference(self, data):
x, y = data['image'], data['label']
x = x.to(self.device)
y = y.to(self.device)
logit = self.classifier(self.backbone(x)['features'])
pred = torch.argmax(logit, dim=1)
acc = torch.sum(pred == y).item()
return pred, acc / x.size(0)
def forward(self, x):
return self.classifier(self.backbone(x)['features'])
def before_task(self, task_idx, buffer, train_loader, test_loaders):
pass
def after_task(self, task_idx, buffer, train_loader, test_loaders):
pass
def get_parameters(self, config):
train_parameters = []
train_parameters.append({"params": self.backbone.parameters()})
train_parameters.append({"params": self.classifier.parameters()})
return train_parameters
|