File size: 555 Bytes
a47307e
 
 
 
 
 
0ddbb98
 
 
a47307e
0ddbb98
 
a47307e
 
 
0ddbb98
a47307e
 
 
0ddbb98
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T
from PIL import Image

class LinearClassifier(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearClassifier, self).__init__()

        self.linear_head = nn.Sequential(
            nn.Linear(input_dim, 512, bias=True), 
            nn.ReLU(),
            nn.Linear(512, 256, bias=True),
            nn.ReLU(),
            nn.Linear(256, output_dim, bias=True)
        )

    def forward(self, x):
        return self.linear_head(x)