File size: 4,805 Bytes
1e9aaa3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from typing import List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig, PreTrainedModel


def _get_activation(activation: str) -> nn.Module:
    if activation == "prelu":
        return nn.PReLU()
    elif activation == "relu":
        return nn.ReLU()
    elif activation == "gelu":
        return nn.GELU()
    elif activation == "tanh":
        return nn.Tanh()
    raise ValueError(f"Unsupported activation: {activation!r}")


class MLP(nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int = 512,
        hidden_dim: Optional[List[int]] = None,
        dropout: float = 0.0,
        residual: bool = False,
        activation: str = "prelu",
    ):
        super().__init__()
        if hidden_dim is None:
            hidden_dim = [512, 512]
        self.latent_dim = output_dim
        self.residual = residual
        self.network = nn.ModuleList()

        if residual:
            assert len(set(hidden_dim)) == 1, "Residual connections require all hidden dims to be equal"

        for i in range(len(hidden_dim)):
            if i == 0:
                self.network.append(nn.Sequential(
                    nn.Linear(input_dim, hidden_dim[i]),
                    nn.BatchNorm1d(hidden_dim[i]),
                    _get_activation(activation),
                ))
            else:
                self.network.append(nn.Sequential(
                    nn.Dropout(p=dropout),
                    nn.Linear(hidden_dim[i - 1], hidden_dim[i]),
                    nn.BatchNorm1d(hidden_dim[i]),
                    _get_activation(activation),
                ))
        self.network.append(nn.Linear(hidden_dim[-1], output_dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for i, layer in enumerate(self.network):
            if self.residual and (0 < i < len(self.network) - 1):
                x = layer(x) + x
            else:
                x = layer(x)
        return x


class VirtualCellDistilConfig(PretrainedConfig):
    model_type = "virtual_cell_distil"

    def __init__(
        self,
        n_genes: int = 18301,
        output_dim: int = 512,
        hidden_dim: Optional[List[int]] = None,
        dropout: float = 0.0,
        residual: bool = False,
        activation: str = "prelu",
        num_labels: int = 2,
        classifier_dropout: float = 0.1,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.n_genes = n_genes
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim if hidden_dim is not None else [512, 512]
        self.dropout = dropout
        self.residual = residual
        self.activation = activation
        self.num_labels = num_labels
        self.classifier_dropout = classifier_dropout


class VirtualCellDistilModel(PreTrainedModel):
    """Pure encoder — returns 512-d patient embeddings from bulk expression."""
    config_class = VirtualCellDistilConfig

    def __init__(self, config: VirtualCellDistilConfig):
        super().__init__(config)
        self.encoder = MLP(
            input_dim=config.n_genes,
            output_dim=config.output_dim,
            hidden_dim=config.hidden_dim,
            dropout=config.dropout,
            residual=config.residual,
            activation=config.activation,
        )

    def forward(self, input_ids: torch.Tensor, **kwargs) -> dict:
        return {"embeddings": self.encoder(input_ids)}


class VirtualCellDistilForSequenceClassification(PreTrainedModel):
    """
    Encoder + linear classification head.

    The encoder is initialised from pretrained distilled weights.
    The classification head is randomly initialised and trained on your labels.
    Use ignore_mismatched_sizes=True when loading from the pretrained checkpoint.
    """
    config_class = VirtualCellDistilConfig

    def __init__(self, config: VirtualCellDistilConfig):
        super().__init__(config)
        self.encoder = MLP(
            input_dim=config.n_genes,
            output_dim=config.output_dim,
            hidden_dim=config.hidden_dim,
            dropout=config.dropout,
            residual=config.residual,
            activation=config.activation,
        )
        self.dropout = nn.Dropout(config.classifier_dropout)
        self.classifier = nn.Linear(config.output_dim, config.num_labels)

    def forward(
        self,
        input_ids: torch.Tensor,
        labels: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> dict:
        embeddings = self.encoder(input_ids)
        logits = self.classifier(self.dropout(embeddings))
        loss = None
        if labels is not None:
            loss = F.cross_entropy(logits, labels)
        return {"loss": loss, "logits": logits, "embeddings": embeddings}