Upload Flowformer
Browse files- configuration_flowformer.py +33 -1
- model_flowformer.py +37 -11
configuration_flowformer.py
CHANGED
|
@@ -1,8 +1,40 @@
|
|
| 1 |
from transformers import PretrainedConfig
|
| 2 |
|
| 3 |
class FlowformerConfig(PretrainedConfig):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
def __init__(self,
|
| 5 |
-
dim_hidden: int=32,
|
| 6 |
num_heads: int=4,
|
| 7 |
num_inds: int=16,
|
| 8 |
hidden_layers: int=3,
|
|
|
|
| 1 |
from transformers import PretrainedConfig
|
| 2 |
|
| 3 |
class FlowformerConfig(PretrainedConfig):
|
| 4 |
+
r"""
|
| 5 |
+
This is the configuration class to store the configuration of a [`Flowformer`]. It is used to instantiate an
|
| 6 |
+
Flowformer model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 7 |
+
with the defaults will yield a similar configuration to that of out model for ALL data (https://arxiv.org/abs/2108.10072).
|
| 8 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 9 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 10 |
+
Args:
|
| 11 |
+
dim_hidden (`int`, *optional*, defaults to 32):
|
| 12 |
+
The dimensionality of the hidden states. dim_hidden must be divisible by num_heads i.e. dim_hidden%num_heads = 0.
|
| 13 |
+
num_heads (`int`, *optional*, defaults to 4):
|
| 14 |
+
The number of attention heads.
|
| 15 |
+
num_inds (`int`, *optional*, defaults to 32):
|
| 16 |
+
The number of inducing points.
|
| 17 |
+
hidden_layers (`int`, *optional*, defaults to 3):
|
| 18 |
+
The number of hidden layers.
|
| 19 |
+
layer_norm (`bool`, *optional*, defaults to True):
|
| 20 |
+
Whether to apply layer normalization.
|
| 21 |
+
dim_input (`int`, *optional*, defaults to 11):
|
| 22 |
+
The dimensionality of the input.
|
| 23 |
+
markers (`list`, *optional*, defaults to ["TIME", "FSC-A", "FSC-W", "SSC-A", "CD20", "CD10", "CD45", "CD34", "CD19", "CD38", "SY41"]):
|
| 24 |
+
The list of markers.
|
| 25 |
+
Example:
|
| 26 |
+
```python
|
| 27 |
+
>>> from transformers import FlowformerConfig, FlowformerModel
|
| 28 |
+
>>> # Initializing a Flowformer configuration
|
| 29 |
+
>>> configuration = FlowformerConfig()
|
| 30 |
+
>>> # Initializing a model (with random weights) from the Flowformer configuration
|
| 31 |
+
>>> model = FlowformerModel(configuration)
|
| 32 |
+
>>> # Accessing the model configuration
|
| 33 |
+
>>> configuration = model.config
|
| 34 |
+
```
|
| 35 |
+
"""
|
| 36 |
def __init__(self,
|
| 37 |
+
dim_hidden: int=32,
|
| 38 |
num_heads: int=4,
|
| 39 |
num_inds: int=16,
|
| 40 |
hidden_layers: int=3,
|
model_flowformer.py
CHANGED
|
@@ -4,6 +4,7 @@ import torch.nn.functional as F
|
|
| 4 |
from torch.nn.functional import binary_cross_entropy_with_logits
|
| 5 |
import math
|
| 6 |
from transformers import PreTrainedModel
|
|
|
|
| 7 |
from .configuration_flowformer import FlowformerConfig
|
| 8 |
|
| 9 |
|
|
@@ -11,7 +12,7 @@ class MAB(nn.Module):
|
|
| 11 |
"""
|
| 12 |
Multihead attention Block (MAB) from https://arxiv.org/abs/1810.00825.
|
| 13 |
"""
|
| 14 |
-
def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
|
| 15 |
super(MAB, self).__init__()
|
| 16 |
|
| 17 |
self.dim_V = dim_V
|
|
@@ -47,7 +48,7 @@ class ISAB(nn.Module):
|
|
| 47 |
"""
|
| 48 |
The Induced Set Attention Block (ISAB) from https://arxiv.org/abs/1810.00825.
|
| 49 |
"""
|
| 50 |
-
def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
|
| 51 |
super(ISAB, self).__init__()
|
| 52 |
|
| 53 |
self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
|
|
@@ -60,8 +61,30 @@ class ISAB(nn.Module):
|
|
| 60 |
|
| 61 |
return self.mab1(X, H)
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
class Flowformer(PreTrainedModel):
|
| 64 |
-
def __init__(self, config):
|
| 65 |
super().__init__(config)
|
| 66 |
|
| 67 |
# Load config
|
|
@@ -72,7 +95,7 @@ class Flowformer(PreTrainedModel):
|
|
| 72 |
hidden_layers = config.hidden_layers
|
| 73 |
layer_norm = config.layer_norm
|
| 74 |
dim_output = 1
|
| 75 |
-
self.
|
| 76 |
|
| 77 |
# Define encoder
|
| 78 |
enc_layers = [ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=layer_norm)]
|
|
@@ -85,17 +108,18 @@ class Flowformer(PreTrainedModel):
|
|
| 85 |
dec_layers = [nn.Linear(dim_input, dim_output)]
|
| 86 |
self.dec = nn.Sequential(*dec_layers)
|
| 87 |
|
| 88 |
-
def
|
| 89 |
return self._pretrained_markers
|
| 90 |
|
| 91 |
-
|
|
|
|
| 92 |
B, L, M = tensor.shape
|
| 93 |
if markers is not None:
|
| 94 |
assert len(markers) == M, "Number of markers in x and markers must be identical"
|
| 95 |
|
| 96 |
-
zeros = torch.zeros((B, L, len(self.
|
| 97 |
-
valid_markers = [m for m in markers if m in set(self.
|
| 98 |
-
idx = [self.
|
| 99 |
zeros[:, :, idx] = tensor # select only the markers that are in the pretrained model
|
| 100 |
tensor = zeros
|
| 101 |
|
|
@@ -105,10 +129,12 @@ class Flowformer(PreTrainedModel):
|
|
| 105 |
if labels is not None:
|
| 106 |
return {
|
| 107 |
'loss': binary_cross_entropy_with_logits(output, labels),
|
| 108 |
-
'logits': output
|
|
|
|
| 109 |
}
|
| 110 |
else:
|
| 111 |
return {
|
| 112 |
-
'logits': output
|
|
|
|
| 113 |
}
|
| 114 |
|
|
|
|
| 4 |
from torch.nn.functional import binary_cross_entropy_with_logits
|
| 5 |
import math
|
| 6 |
from transformers import PreTrainedModel
|
| 7 |
+
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
| 8 |
from .configuration_flowformer import FlowformerConfig
|
| 9 |
|
| 10 |
|
|
|
|
| 12 |
"""
|
| 13 |
Multihead attention Block (MAB) from https://arxiv.org/abs/1810.00825.
|
| 14 |
"""
|
| 15 |
+
def __init__(self, dim_Q: int, dim_K: int, dim_V: int, num_heads: int, ln: int=False):
|
| 16 |
super(MAB, self).__init__()
|
| 17 |
|
| 18 |
self.dim_V = dim_V
|
|
|
|
| 48 |
"""
|
| 49 |
The Induced Set Attention Block (ISAB) from https://arxiv.org/abs/1810.00825.
|
| 50 |
"""
|
| 51 |
+
def __init__(self, dim_in: int, dim_out: int, num_heads: int, num_inds: int, ln: bool=False):
|
| 52 |
super(ISAB, self).__init__()
|
| 53 |
|
| 54 |
self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
|
|
|
|
| 61 |
|
| 62 |
return self.mab1(X, H)
|
| 63 |
|
| 64 |
+
FLOWFORMER_START_DOCSTRING = r"""
|
| 65 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
| 66 |
+
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 67 |
+
behavior.
|
| 68 |
+
Parameters:
|
| 69 |
+
config ([`FlowformerConfig`]): Model configuration class with all the parameters of the model.
|
| 70 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 71 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
FLOWFORMER_INPUTS_DOCSTRING = r"""
|
| 75 |
+
Args:
|
| 76 |
+
tensor (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_markers)`):
|
| 77 |
+
The sample used as a basis for the prediction.
|
| 78 |
+
labels (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 79 |
+
Optional ground truth lables for computing the loss.
|
| 80 |
+
markers (`list` of length `num_markers`):
|
| 81 |
+
The list of markers in the same order as the last dimension of the input tensor.
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@add_start_docstrings(FLOWFORMER_START_DOCSTRING)
|
| 86 |
class Flowformer(PreTrainedModel):
|
| 87 |
+
def __init__(self, config: FlowformerConfig):
|
| 88 |
super().__init__(config)
|
| 89 |
|
| 90 |
# Load config
|
|
|
|
| 95 |
hidden_layers = config.hidden_layers
|
| 96 |
layer_norm = config.layer_norm
|
| 97 |
dim_output = 1
|
| 98 |
+
self._markers = config.markers
|
| 99 |
|
| 100 |
# Define encoder
|
| 101 |
enc_layers = [ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=layer_norm)]
|
|
|
|
| 108 |
dec_layers = [nn.Linear(dim_input, dim_output)]
|
| 109 |
self.dec = nn.Sequential(*dec_layers)
|
| 110 |
|
| 111 |
+
def markers(self):
|
| 112 |
return self._pretrained_markers
|
| 113 |
|
| 114 |
+
@add_start_docstrings_to_model_forward(FLOWFORMER_INPUTS_DOCSTRING)
|
| 115 |
+
def forward(self, tensor: torch.Tensor, labels: torch.Tensor=None, markers: list=None):
|
| 116 |
B, L, M = tensor.shape
|
| 117 |
if markers is not None:
|
| 118 |
assert len(markers) == M, "Number of markers in x and markers must be identical"
|
| 119 |
|
| 120 |
+
zeros = torch.zeros((B, L, len(self.markers())), device=tensor.device)
|
| 121 |
+
valid_markers = [m for m in markers if m in set(self.markers()).intersection(markers)]
|
| 122 |
+
idx = [self.markers().index(m) for m in valid_markers]
|
| 123 |
zeros[:, :, idx] = tensor # select only the markers that are in the pretrained model
|
| 124 |
tensor = zeros
|
| 125 |
|
|
|
|
| 129 |
if labels is not None:
|
| 130 |
return {
|
| 131 |
'loss': binary_cross_entropy_with_logits(output, labels),
|
| 132 |
+
'logits': output,
|
| 133 |
+
'prediction': torch.where(output > 0, 1, 0)
|
| 134 |
}
|
| 135 |
else:
|
| 136 |
return {
|
| 137 |
+
'logits': output,
|
| 138 |
+
'prediction': torch.where(output > 0, 1, 0)
|
| 139 |
}
|
| 140 |
|