File size: 1,947 Bytes
a65c9ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig, AutoModel


class SkinCancerConfig(PretrainedConfig):
    model_type = "vit_tabular_skin_cancer"
    
    def __init__(self,
                 vision_model_checkpoint="google/vit-base-patch16-224-in21k",
                 tabular_dim=0,
                 num_labels=7,
                 id2label=None, 
                 label2id=None,
                 age_min=0.0, 
                 age_max=100.0, 
                 age_mean=50.0,
                 **kwargs):
        super().__init__(**kwargs)
        self.vision_model_checkpoint = vision_model_checkpoint
        self.tabular_dim = tabular_dim
        self.num_labels = num_labels
        self.id2label = id2label
        self.label2id = label2id
        self.age_min = age_min
        self.age_max = age_max
        self.age_mean = age_mean


class SkinCancerViT(PreTrainedModel):
    config_class = SkinCancerConfig
    
    def __init__(self, config):
        super().__init__(config)
        self.vision = AutoModel.from_pretrained(config.vision_model_checkpoint)
        hdim = self.vision.config.hidden_size

        self.tabular = nn.Sequential(
            nn.Linear(config.tabular_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 64),
            nn.ReLU()
        )
        self.classifier = nn.Linear(hdim + 64, config.num_labels)
        self.post_init()

    def forward(self, pixel_values, tabular_features):
        vout = self.vision(pixel_values=pixel_values, output_hidden_states=False, return_dict=True)
        if getattr(vout, "pooler_output", None) is not None:
            vfeat = vout.pooler_output
        else:
            vfeat = vout.last_hidden_state[:, 0, :]  # CLS
        tfeat = self.tabular(tabular_features.float())
        feats = torch.cat([vfeat, tfeat], dim=-1)
        logits = self.classifier(feats)
        return logits