Upload Flowformer
Browse files- config.json +30 -0
- configuration_flowformer.py +23 -0
- model_flowformer.py +114 -0
- pytorch_model.bin +3 -0
config.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"Flowformer"
|
| 4 |
+
],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "configuration_flowformer.FlowformerConfig",
|
| 7 |
+
"AutoModel": "model_flowformer.Flowformer"
|
| 8 |
+
},
|
| 9 |
+
"dim_hidden": 32,
|
| 10 |
+
"dim_input": 11,
|
| 11 |
+
"hidden_layers": 3,
|
| 12 |
+
"layer_norm": true,
|
| 13 |
+
"markers": [
|
| 14 |
+
"TIME",
|
| 15 |
+
"FSC-A",
|
| 16 |
+
"FSC-W",
|
| 17 |
+
"SSC-A",
|
| 18 |
+
"CD20",
|
| 19 |
+
"CD10",
|
| 20 |
+
"CD45",
|
| 21 |
+
"CD34",
|
| 22 |
+
"CD19",
|
| 23 |
+
"CD38",
|
| 24 |
+
"SY41"
|
| 25 |
+
],
|
| 26 |
+
"num_heads": 4,
|
| 27 |
+
"num_inds": 16,
|
| 28 |
+
"torch_dtype": "float32",
|
| 29 |
+
"transformers_version": "4.28.1"
|
| 30 |
+
}
|
configuration_flowformer.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
|
| 3 |
+
class FlowformerConfig(PretrainedConfig):
|
| 4 |
+
def __init__(self,
|
| 5 |
+
dim_hidden: int=32, # dim_hidden must be divisible by num_heads i.e. dim_hidden%num_heads = 0
|
| 6 |
+
num_heads: int=4,
|
| 7 |
+
num_inds: int=16,
|
| 8 |
+
hidden_layers: int=3,
|
| 9 |
+
layer_norm: bool=True,
|
| 10 |
+
dim_input: int=11,
|
| 11 |
+
markers: list=["TIME", "FSC-A", "FSC-W", "SSC-A", "CD20", "CD10", "CD45", "CD34", "CD19", "CD38", "SY41"],
|
| 12 |
+
**kwargs
|
| 13 |
+
):
|
| 14 |
+
assert dim_input == len(markers), "dim_input must be equal to the number of markers"
|
| 15 |
+
|
| 16 |
+
self.dim_hidden = dim_hidden
|
| 17 |
+
self.num_heads = num_heads
|
| 18 |
+
self.num_inds = num_inds
|
| 19 |
+
self.hidden_layers = hidden_layers
|
| 20 |
+
self.layer_norm = layer_norm
|
| 21 |
+
self.dim_input = dim_input
|
| 22 |
+
self.markers = markers
|
| 23 |
+
super().__init__(**kwargs)
|
model_flowformer.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.nn.functional import binary_cross_entropy_with_logits
|
| 5 |
+
import math
|
| 6 |
+
from transformers import PreTrainedModel
|
| 7 |
+
from .configuration_flowformer import FlowformerConfig
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class MAB(nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
Multihead attention Block (MAB) from https://arxiv.org/abs/1810.00825.
|
| 13 |
+
"""
|
| 14 |
+
def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
|
| 15 |
+
super(MAB, self).__init__()
|
| 16 |
+
|
| 17 |
+
self.dim_V = dim_V
|
| 18 |
+
self.num_heads = num_heads
|
| 19 |
+
self.fc_q = nn.Linear(dim_Q, dim_V)
|
| 20 |
+
self.fc_k = nn.Linear(dim_K, dim_V)
|
| 21 |
+
self.fc_v = nn.Linear(dim_K, dim_V)
|
| 22 |
+
|
| 23 |
+
if ln:
|
| 24 |
+
self.ln0 = nn.LayerNorm(dim_V)
|
| 25 |
+
self.ln1 = nn.LayerNorm(dim_V)
|
| 26 |
+
self.fc_o = nn.Linear(dim_V, dim_V)
|
| 27 |
+
|
| 28 |
+
def forward(self, Q, K):
|
| 29 |
+
Q = self.fc_q(Q)
|
| 30 |
+
K, V = self.fc_k(K), self.fc_v(K)
|
| 31 |
+
|
| 32 |
+
dim_split = self.dim_V // self.num_heads
|
| 33 |
+
Q_ = torch.cat(Q.split(dim_split, 2), dim=0)
|
| 34 |
+
K_ = torch.cat(K.split(dim_split, 2), dim=0)
|
| 35 |
+
V_ = torch.cat(V.split(dim_split, 2), dim=0)
|
| 36 |
+
|
| 37 |
+
A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
|
| 38 |
+
O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
|
| 39 |
+
O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
|
| 40 |
+
O = O + F.relu(self.fc_o(O))
|
| 41 |
+
O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
|
| 42 |
+
|
| 43 |
+
return O
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class ISAB(nn.Module):
|
| 47 |
+
"""
|
| 48 |
+
The Induced Set Attention Block (ISAB) from https://arxiv.org/abs/1810.00825.
|
| 49 |
+
"""
|
| 50 |
+
def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
|
| 51 |
+
super(ISAB, self).__init__()
|
| 52 |
+
|
| 53 |
+
self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
|
| 54 |
+
nn.init.xavier_uniform_(self.I)
|
| 55 |
+
self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln)
|
| 56 |
+
self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln)
|
| 57 |
+
|
| 58 |
+
def forward(self, X):
|
| 59 |
+
H = self.mab0(self.I.repeat(X.size(0), 1, 1), X)
|
| 60 |
+
|
| 61 |
+
return self.mab1(X, H)
|
| 62 |
+
|
| 63 |
+
class Flowformer(PreTrainedModel):
|
| 64 |
+
def __init__(self, config):
|
| 65 |
+
super().__init__(config)
|
| 66 |
+
|
| 67 |
+
# Load config
|
| 68 |
+
dim_input = config.dim_input
|
| 69 |
+
dim_hidden = config.dim_hidden
|
| 70 |
+
num_heads = config.num_heads
|
| 71 |
+
num_inds = config.num_inds
|
| 72 |
+
hidden_layers = config.hidden_layers
|
| 73 |
+
layer_norm = config.layer_norm
|
| 74 |
+
dim_output = 1
|
| 75 |
+
self._pretrained_markers = config.markers or ["TIME", "FSC-A", "FSC-W", "SSC-A", "CD20", "CD10", "CD45", "CD34", "CD19", "CD38", "SY41"]
|
| 76 |
+
|
| 77 |
+
# Define encoder
|
| 78 |
+
enc_layers = [ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=layer_norm)]
|
| 79 |
+
for _ in range(1, hidden_layers):
|
| 80 |
+
enc_layers.append(ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=layer_norm))
|
| 81 |
+
enc_layers.append(ISAB(dim_hidden, dim_input, 1, num_inds, ln=layer_norm)) # num_heads == 1 because dim_input can be a prime number
|
| 82 |
+
self.enc = nn.Sequential(*enc_layers)
|
| 83 |
+
|
| 84 |
+
# Define decoder
|
| 85 |
+
dec_layers = [nn.Linear(dim_input, dim_output)]
|
| 86 |
+
self.dec = nn.Sequential(*dec_layers)
|
| 87 |
+
|
| 88 |
+
def pretrained_markers(self):
|
| 89 |
+
return self._pretrained_markers
|
| 90 |
+
|
| 91 |
+
def forward(self, tensor, labels=None, markers: list=None):
|
| 92 |
+
B, L, M = tensor.shape
|
| 93 |
+
if markers is not None:
|
| 94 |
+
assert len(markers) == M, "Number of markers in x and markers must be identical"
|
| 95 |
+
|
| 96 |
+
zeros = torch.zeros((B, L, len(self._pretrained_markers)), device=tensor.device)
|
| 97 |
+
valid_markers = [m for m in markers if m in set(self._pretrained_markers).intersection(markers)]
|
| 98 |
+
idx = [self._pretrained_markers.index(m) for m in valid_markers]
|
| 99 |
+
zeros[:, :, idx] = tensor # select only the markers that are in the pretrained model
|
| 100 |
+
tensor = zeros
|
| 101 |
+
|
| 102 |
+
enc_out = self.enc(tensor)
|
| 103 |
+
output = self.dec(enc_out)[:,:,0]
|
| 104 |
+
|
| 105 |
+
if labels is not None:
|
| 106 |
+
return {
|
| 107 |
+
'loss': binary_cross_entropy_with_logits(output, labels),
|
| 108 |
+
'logits': output
|
| 109 |
+
}
|
| 110 |
+
else:
|
| 111 |
+
return {
|
| 112 |
+
'logits': output
|
| 113 |
+
}
|
| 114 |
+
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:055b27977924a2b82a5842c34673de48fa8478eb110374b6066508469b2c9c35
|
| 3 |
+
size 139813
|