pantoniadis commited on
Commit
beb1052
·
verified ·
1 Parent(s): f26d58d

Upload folder using huggingface_hub

Browse files
config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ model_type": "hourglass_transformer",
3
+ "auto_map": {
4
+ "AutoConfig": "hourglass_transformer.HourglassTransformerConfig",
5
+ "AutoModel": "hourglass_transformer.HourglassTransformerForMaskedLM"
6
+ },
7
+ "activation_function": "gelu",
8
+ "architectures": [
9
+ "HourglassTransformerForMaskedLM"
10
+ ],
11
+ "attn_resampling": false,
12
+ "bias": false,
13
+ "depth": [
14
+ 4,
15
+ [
16
+ 4,
17
+ 4,
18
+ 4
19
+ ],
20
+ 4
21
+ ],
22
+ "dim": 768,
23
+ "dim_head": 64,
24
+ "heads": 8,
25
+ "inference": false,
26
+ "metadata_dim": 3072,
27
+ "model_type": "hourglass_transformer",
28
+ "norm_out": false,
29
+ "predict_expression_mode": false,
30
+ "predict_seq": true,
31
+ "predict_taxonomy": false,
32
+ "predict_tracks": true,
33
+ "rotary_emb_dim": 32,
34
+ "seq_vocab_size": 11,
35
+ "shorten_factor": [
36
+ 8,
37
+ 8
38
+ ],
39
+ "sliding_window": [
40
+ 512,
41
+ 512,
42
+ -1
43
+ ],
44
+ "taxonomy_vocab_size": 2604,
45
+ "torch_dtype": "float32",
46
+ "track_activation_fn": null,
47
+ "track_output_dim": 4,
48
+ "transformers_version": "4.44.2",
49
+ "updown_sample_type": "linear",
50
+ "use_metadata": true,
51
+ "use_taxonomy": false
52
+ }
hourglass_transformer.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace model wrapper for HourglassTransformerLM.
3
+ This allows the model to be saved and loaded in HuggingFace format.
4
+ """
5
+
6
+ from typing import Optional, Union
7
+ from dataclasses import dataclass
8
+ import torch
9
+ from transformers import PreTrainedModel, PretrainedConfig
10
+ from transformers.modeling_outputs import MaskedLMOutput
11
+ from rnalm.utils.hydra_utils import to_tuple_recursive
12
+ from rnalm.models.networks.hourglass_transformer import HourglassTransformerLM
13
+
14
+
15
+ @dataclass
16
+ class HourglassTransformerOutput(MaskedLMOutput):
17
+ # Standard MaskedLMOutput fields (inherited)
18
+ # loss: Optional[torch.FloatTensor] = None
19
+ # logits: torch.FloatTensor = None
20
+ # hidden_states: Optional[tuple] = None
21
+ # attentions: Optional[tuple] = None
22
+
23
+ # Custom multi-task fields
24
+ seq_logits: Optional[torch.FloatTensor] = None
25
+ tax_logits: Optional[torch.FloatTensor] = None
26
+ track_yhat: Optional[torch.FloatTensor] = None
27
+ expression_mode: Optional[torch.FloatTensor] = None
28
+ last_hidden_state: Optional[torch.FloatTensor] = None
29
+ last_hidden_state_track: Optional[torch.FloatTensor] = None
30
+
31
+ def __post_init__(self):
32
+ """Sync standard and custom field names for compatibility."""
33
+ # Call parent __post_init__ if it exists
34
+ if hasattr(super(), "__post_init__"):
35
+ super().__post_init__()
36
+
37
+ # Map seq_logits to logits if logits is None
38
+ if self.logits is None and self.seq_logits is not None:
39
+ object.__setattr__(self, "logits", self.seq_logits)
40
+
41
+
42
+ class HourglassTransformerConfig(PretrainedConfig):
43
+ model_type = "hourglass_transformer"
44
+
45
+ def __init__(
46
+ self,
47
+ seq_vocab_size: int = 11,
48
+ taxonomy_vocab_size: int = 2604,
49
+ dim: int = 128,
50
+ depth: tuple = (2, 2, 2),
51
+ shorten_factor: Union[int, tuple] = 4,
52
+ sliding_window: tuple = (512, 512),
53
+ attn_resampling: bool = False,
54
+ updown_sample_type: str = "linear",
55
+ heads: int = 8,
56
+ dim_head: int = 64,
57
+ norm_out: bool = False,
58
+ bias: bool = True,
59
+ activation_function: str = "gelu",
60
+ rotary_emb_dim: int = 32,
61
+ use_taxonomy: bool = False,
62
+ use_metadata: bool = False,
63
+ predict_taxonomy: bool = False,
64
+ predict_tracks: bool = False,
65
+ predict_seq: bool = True,
66
+ track_activation_fn: Optional[str] = None,
67
+ track_output_dim: int = 4,
68
+ predict_expression_mode: bool = False,
69
+ inference: bool = False,
70
+ metadata_dim: int = 3072,
71
+ **kwargs,
72
+ ):
73
+ super().__init__(**kwargs)
74
+ self.seq_vocab_size = seq_vocab_size
75
+ self.taxonomy_vocab_size = taxonomy_vocab_size
76
+ self.dim = dim
77
+ if isinstance(depth, tuple):
78
+ self.depth = depth
79
+ elif isinstance(depth, list):
80
+ self.depth = tuple(depth)
81
+ else:
82
+ self.depth = depth
83
+ if isinstance(sliding_window, tuple):
84
+ self.sliding_window = sliding_window
85
+ elif isinstance(sliding_window, list):
86
+ self.sliding_window = tuple(sliding_window)
87
+ else:
88
+ self.sliding_window = sliding_window
89
+ self.rotary_emb_dim = rotary_emb_dim
90
+ self.shorten_factor = shorten_factor
91
+ self.attn_resampling = attn_resampling
92
+ self.updown_sample_type = updown_sample_type
93
+ self.heads = heads
94
+ self.dim_head = dim_head
95
+ self.norm_out = norm_out
96
+ self.bias = bias
97
+ self.activation_function = activation_function
98
+ self.use_taxonomy = use_taxonomy
99
+ self.use_metadata = use_metadata
100
+ self.metadata_dim = metadata_dim
101
+ self.predict_taxonomy = predict_taxonomy
102
+ self.predict_tracks = predict_tracks
103
+ self.predict_seq = predict_seq
104
+ self.track_activation_fn = track_activation_fn
105
+ self.track_output_dim = track_output_dim
106
+ self.predict_expression_mode = predict_expression_mode
107
+ self.inference = inference
108
+
109
+
110
+ class HourglassTransformerForMaskedLM(PreTrainedModel):
111
+
112
+ config_class = HourglassTransformerConfig
113
+
114
+ def __init__(self, config: HourglassTransformerConfig):
115
+ super().__init__(config)
116
+
117
+ # Convert config to dict for model initialization
118
+ model_kwargs = {
119
+ "seq_vocab_size": config.seq_vocab_size,
120
+ "taxonomy_vocab_size": config.taxonomy_vocab_size,
121
+ "dim": config.dim,
122
+ "depth": to_tuple_recursive(config.depth),
123
+ "sliding_window": config.sliding_window,
124
+ "rotary_emb_dim": config.rotary_emb_dim,
125
+ "shorten_factor": config.shorten_factor,
126
+ "attn_resampling": config.attn_resampling,
127
+ "updown_sample_type": config.updown_sample_type,
128
+ "heads": config.heads,
129
+ "dim_head": config.dim_head,
130
+ "norm_out": config.norm_out,
131
+ "bias": config.bias,
132
+ "activation_function": config.activation_function,
133
+ "use_taxonomy": config.use_taxonomy,
134
+ "use_metadata": config.use_metadata,
135
+ "metadata_dim": config.metadata_dim,
136
+ "predict_taxonomy": config.predict_taxonomy,
137
+ "predict_tracks": config.predict_tracks,
138
+ "predict_seq": config.predict_seq,
139
+ "track_activation_fn": config.track_activation_fn,
140
+ "track_output_dim": config.track_output_dim,
141
+ "predict_expression_mode": config.predict_expression_mode,
142
+ "inference": config.inference,
143
+ }
144
+
145
+ self.model = HourglassTransformerLM(**model_kwargs)
146
+
147
+ def forward(
148
+ self,
149
+ input_ids: Optional[torch.Tensor] = None,
150
+ masked_taxonomy: Optional[torch.Tensor] = None,
151
+ metadata: Optional[torch.Tensor] = None,
152
+ attention_mask: Optional[torch.Tensor] = None,
153
+ labels: Optional[torch.Tensor] = None,
154
+ output_attentions: Optional[bool] = None,
155
+ output_hidden_states: Optional[bool] = None,
156
+ **kwargs,
157
+ ) -> HourglassTransformerOutput:
158
+ """
159
+ Forward pass of the model.
160
+
161
+ Args:
162
+ input_ids: Tokenized input sequences (batch_size, seq_len)
163
+ masked_taxonomy: Optional taxonomy tokens (batch_size, 8)
164
+ metadata: Optional metadata embeddings
165
+ attention_mask: Optional attention mask (batch_size, seq_len)
166
+ labels: Optional labels for computing loss (batch_size, seq_len)
167
+ output_attentions: Whether to return attentions (not supported)
168
+ output_hidden_states: Whether to return hidden states
169
+
170
+ Returns:
171
+ HourglassTransformerOutput containing all model outputs
172
+ """
173
+ # Get the base model output
174
+ outputs = self.model(
175
+ masked_seq=input_ids,
176
+ masked_taxonomy=masked_taxonomy,
177
+ metadata=metadata,
178
+ mask=attention_mask,
179
+ )
180
+
181
+ # Convert to HourglassTransformerOutput
182
+ # This extends MaskedLMOutput for HuggingFace compatibility
183
+ hf_output = HourglassTransformerOutput(
184
+ loss=None, # Loss should be computed externally if labels provided
185
+ logits=outputs.seq_logits, # Standard HuggingFace field
186
+ hidden_states=(
187
+ (outputs.last_hidden_state,)
188
+ if (output_hidden_states and outputs.last_hidden_state is not None)
189
+ else None
190
+ ),
191
+ attentions=None, # Not currently supported
192
+ # Custom fields
193
+ seq_logits=outputs.seq_logits,
194
+ tax_logits=outputs.tax_logits,
195
+ track_yhat=outputs.track_yhat,
196
+ expression_mode=outputs.expression_mode,
197
+ last_hidden_state=outputs.last_hidden_state,
198
+ last_hidden_state_track=outputs.last_hidden_state_track,
199
+ )
200
+
201
+ return hf_output
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:423ac098fcb8ee69c96e99c310f19eb55d225804478abf3cd650e66471043301
3
+ size 593330620
special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "[BOS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "[CLS]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "[EOS]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "[MASK]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "[PAD]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "[SEP]",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "[UNK]",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
tokenizer.json ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "1.0",
3
+ "truncation": null,
4
+ "padding": null,
5
+ "added_tokens": [
6
+ {
7
+ "id": 0,
8
+ "content": "[PAD]",
9
+ "single_word": false,
10
+ "lstrip": false,
11
+ "rstrip": false,
12
+ "normalized": false,
13
+ "special": true
14
+ },
15
+ {
16
+ "id": 1,
17
+ "content": "[MASK]",
18
+ "single_word": false,
19
+ "lstrip": false,
20
+ "rstrip": false,
21
+ "normalized": false,
22
+ "special": true
23
+ },
24
+ {
25
+ "id": 2,
26
+ "content": "[UNK]",
27
+ "single_word": false,
28
+ "lstrip": false,
29
+ "rstrip": false,
30
+ "normalized": false,
31
+ "special": true
32
+ },
33
+ {
34
+ "id": 3,
35
+ "content": "[SEP]",
36
+ "single_word": false,
37
+ "lstrip": false,
38
+ "rstrip": false,
39
+ "normalized": false,
40
+ "special": true
41
+ },
42
+ {
43
+ "id": 4,
44
+ "content": "[BOS]",
45
+ "single_word": false,
46
+ "lstrip": false,
47
+ "rstrip": false,
48
+ "normalized": false,
49
+ "special": true
50
+ },
51
+ {
52
+ "id": 5,
53
+ "content": "[EOS]",
54
+ "single_word": false,
55
+ "lstrip": false,
56
+ "rstrip": false,
57
+ "normalized": false,
58
+ "special": true
59
+ },
60
+ {
61
+ "id": 6,
62
+ "content": "[CLS]",
63
+ "single_word": false,
64
+ "lstrip": false,
65
+ "rstrip": false,
66
+ "normalized": false,
67
+ "special": true
68
+ }
69
+ ],
70
+ "normalizer": {
71
+ "type": "Lowercase"
72
+ },
73
+ "pre_tokenizer": null,
74
+ "post_processor": null,
75
+ "decoder": null,
76
+ "model": {
77
+ "type": "BPE",
78
+ "dropout": null,
79
+ "unk_token": "[UNK]",
80
+ "continuing_subword_prefix": null,
81
+ "end_of_word_suffix": null,
82
+ "fuse_unk": false,
83
+ "byte_fallback": false,
84
+ "ignore_merges": false,
85
+ "vocab": {
86
+ "[PAD]": 0,
87
+ "[MASK]": 1,
88
+ "[UNK]": 2,
89
+ "[SEP]": 3,
90
+ "[BOS]": 4,
91
+ "[EOS]": 5,
92
+ "[CLS]": 6,
93
+ "a": 7,
94
+ "c": 8,
95
+ "g": 9,
96
+ "t": 10
97
+ },
98
+ "merges": []
99
+ }
100
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "[MASK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "[UNK]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "4": {
36
+ "content": "[BOS]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "5": {
44
+ "content": "[EOS]",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ },
51
+ "6": {
52
+ "content": "[CLS]",
53
+ "lstrip": false,
54
+ "normalized": false,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": true
58
+ }
59
+ },
60
+ "bos_token": "[BOS]",
61
+ "clean_up_tokenization_spaces": true,
62
+ "cls_token": "[CLS]",
63
+ "eos_token": "[EOS]",
64
+ "mask_token": "[MASK]",
65
+ "model_max_length": 1000000000000000019884624838656,
66
+ "pad_token": "[PAD]",
67
+ "sep_token": "[SEP]",
68
+ "split_special_tokens": false,
69
+ "tokenizer_class": "PreTrainedTokenizerFast",
70
+ "unk_token": "[UNK]"
71
+ }