pantoniadis commited on
Commit
44aac51
·
verified ·
1 Parent(s): fd77b62

Update model files

Browse files
Files changed (3) hide show
  1. config.json +51 -0
  2. hourglass_transformer.py +201 -0
  3. model.safetensors +3 -0
config.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"model_type": "hourglass_transformer",
2
+ "auto_map": {
3
+ "AutoConfig": "hourglass_transformer.HourglassTransformerConfig",
4
+ "AutoModel": "hourglass_transformer.HourglassTransformerForMaskedLM"
5
+ },
6
+ "activation_function": "gelu",
7
+ "architectures": [
8
+ "HourglassTransformerForMaskedLM"
9
+ ],
10
+ "attn_resampling": false,
11
+ "bias": false,
12
+ "depth": [
13
+ 4,
14
+ [
15
+ 4,
16
+ 4,
17
+ 4
18
+ ],
19
+ 4
20
+ ],
21
+ "dim": 768,
22
+ "dim_head": 64,
23
+ "heads": 8,
24
+ "inference": false,
25
+ "metadata_dim": 3072,
26
+ "model_type": "hourglass_transformer",
27
+ "norm_out": false,
28
+ "predict_expression_mode": false,
29
+ "predict_seq": true,
30
+ "predict_taxonomy": false,
31
+ "predict_tracks": true,
32
+ "rotary_emb_dim": 32,
33
+ "seq_vocab_size": 11,
34
+ "shorten_factor": [
35
+ 8,
36
+ 8
37
+ ],
38
+ "sliding_window": [
39
+ 512,
40
+ 512,
41
+ -1
42
+ ],
43
+ "taxonomy_vocab_size": 2604,
44
+ "torch_dtype": "float32",
45
+ "track_activation_fn": null,
46
+ "track_output_dim": 4,
47
+ "transformers_version": "4.44.2",
48
+ "updown_sample_type": "linear",
49
+ "use_metadata": true,
50
+ "use_taxonomy": false
51
+ }
hourglass_transformer.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace model wrapper for HourglassTransformerLM.
3
+ This allows the model to be saved and loaded in HuggingFace format.
4
+ """
5
+
6
+ from typing import Optional, Union
7
+ from dataclasses import dataclass
8
+ import torch
9
+ from transformers import PreTrainedModel, PretrainedConfig
10
+ from transformers.modeling_outputs import MaskedLMOutput
11
+ from rnalm.utils.hydra_utils import to_tuple_recursive
12
+ from rnalm.models.networks.hourglass_transformer import HourglassTransformerLM
13
+
14
+
15
+ @dataclass
16
+ class HourglassTransformerOutput(MaskedLMOutput):
17
+ # Standard MaskedLMOutput fields (inherited)
18
+ # loss: Optional[torch.FloatTensor] = None
19
+ # logits: torch.FloatTensor = None
20
+ # hidden_states: Optional[tuple] = None
21
+ # attentions: Optional[tuple] = None
22
+
23
+ # Custom multi-task fields
24
+ seq_logits: Optional[torch.FloatTensor] = None
25
+ tax_logits: Optional[torch.FloatTensor] = None
26
+ track_yhat: Optional[torch.FloatTensor] = None
27
+ expression_mode: Optional[torch.FloatTensor] = None
28
+ last_hidden_state: Optional[torch.FloatTensor] = None
29
+ last_hidden_state_track: Optional[torch.FloatTensor] = None
30
+
31
+ def __post_init__(self):
32
+ """Sync standard and custom field names for compatibility."""
33
+ # Call parent __post_init__ if it exists
34
+ if hasattr(super(), "__post_init__"):
35
+ super().__post_init__()
36
+
37
+ # Map seq_logits to logits if logits is None
38
+ if self.logits is None and self.seq_logits is not None:
39
+ object.__setattr__(self, "logits", self.seq_logits)
40
+
41
+
42
+ class HourglassTransformerConfig(PretrainedConfig):
43
+ model_type = "hourglass_transformer"
44
+
45
+ def __init__(
46
+ self,
47
+ seq_vocab_size: int = 11,
48
+ taxonomy_vocab_size: int = 2604,
49
+ dim: int = 128,
50
+ depth: tuple = (2, 2, 2),
51
+ shorten_factor: Union[int, tuple] = 4,
52
+ sliding_window: tuple = (512, 512),
53
+ attn_resampling: bool = False,
54
+ updown_sample_type: str = "linear",
55
+ heads: int = 8,
56
+ dim_head: int = 64,
57
+ norm_out: bool = False,
58
+ bias: bool = True,
59
+ activation_function: str = "gelu",
60
+ rotary_emb_dim: int = 32,
61
+ use_taxonomy: bool = False,
62
+ use_metadata: bool = False,
63
+ predict_taxonomy: bool = False,
64
+ predict_tracks: bool = False,
65
+ predict_seq: bool = True,
66
+ track_activation_fn: Optional[str] = None,
67
+ track_output_dim: int = 4,
68
+ predict_expression_mode: bool = False,
69
+ inference: bool = False,
70
+ metadata_dim: int = 3072,
71
+ **kwargs,
72
+ ):
73
+ super().__init__(**kwargs)
74
+ self.seq_vocab_size = seq_vocab_size
75
+ self.taxonomy_vocab_size = taxonomy_vocab_size
76
+ self.dim = dim
77
+ if isinstance(depth, tuple):
78
+ self.depth = depth
79
+ elif isinstance(depth, list):
80
+ self.depth = tuple(depth)
81
+ else:
82
+ self.depth = depth
83
+ if isinstance(sliding_window, tuple):
84
+ self.sliding_window = sliding_window
85
+ elif isinstance(sliding_window, list):
86
+ self.sliding_window = tuple(sliding_window)
87
+ else:
88
+ self.sliding_window = sliding_window
89
+ self.rotary_emb_dim = rotary_emb_dim
90
+ self.shorten_factor = shorten_factor
91
+ self.attn_resampling = attn_resampling
92
+ self.updown_sample_type = updown_sample_type
93
+ self.heads = heads
94
+ self.dim_head = dim_head
95
+ self.norm_out = norm_out
96
+ self.bias = bias
97
+ self.activation_function = activation_function
98
+ self.use_taxonomy = use_taxonomy
99
+ self.use_metadata = use_metadata
100
+ self.metadata_dim = metadata_dim
101
+ self.predict_taxonomy = predict_taxonomy
102
+ self.predict_tracks = predict_tracks
103
+ self.predict_seq = predict_seq
104
+ self.track_activation_fn = track_activation_fn
105
+ self.track_output_dim = track_output_dim
106
+ self.predict_expression_mode = predict_expression_mode
107
+ self.inference = inference
108
+
109
+
110
+ class HourglassTransformerForMaskedLM(PreTrainedModel):
111
+
112
+ config_class = HourglassTransformerConfig
113
+
114
+ def __init__(self, config: HourglassTransformerConfig):
115
+ super().__init__(config)
116
+
117
+ # Convert config to dict for model initialization
118
+ model_kwargs = {
119
+ "seq_vocab_size": config.seq_vocab_size,
120
+ "taxonomy_vocab_size": config.taxonomy_vocab_size,
121
+ "dim": config.dim,
122
+ "depth": to_tuple_recursive(config.depth),
123
+ "sliding_window": config.sliding_window,
124
+ "rotary_emb_dim": config.rotary_emb_dim,
125
+ "shorten_factor": config.shorten_factor,
126
+ "attn_resampling": config.attn_resampling,
127
+ "updown_sample_type": config.updown_sample_type,
128
+ "heads": config.heads,
129
+ "dim_head": config.dim_head,
130
+ "norm_out": config.norm_out,
131
+ "bias": config.bias,
132
+ "activation_function": config.activation_function,
133
+ "use_taxonomy": config.use_taxonomy,
134
+ "use_metadata": config.use_metadata,
135
+ "metadata_dim": config.metadata_dim,
136
+ "predict_taxonomy": config.predict_taxonomy,
137
+ "predict_tracks": config.predict_tracks,
138
+ "predict_seq": config.predict_seq,
139
+ "track_activation_fn": config.track_activation_fn,
140
+ "track_output_dim": config.track_output_dim,
141
+ "predict_expression_mode": config.predict_expression_mode,
142
+ "inference": config.inference,
143
+ }
144
+
145
+ self.model = HourglassTransformerLM(**model_kwargs)
146
+
147
+ def forward(
148
+ self,
149
+ input_ids: Optional[torch.Tensor] = None,
150
+ masked_taxonomy: Optional[torch.Tensor] = None,
151
+ metadata: Optional[torch.Tensor] = None,
152
+ attention_mask: Optional[torch.Tensor] = None,
153
+ labels: Optional[torch.Tensor] = None,
154
+ output_attentions: Optional[bool] = None,
155
+ output_hidden_states: Optional[bool] = None,
156
+ **kwargs,
157
+ ) -> HourglassTransformerOutput:
158
+ """
159
+ Forward pass of the model.
160
+
161
+ Args:
162
+ input_ids: Tokenized input sequences (batch_size, seq_len)
163
+ masked_taxonomy: Optional taxonomy tokens (batch_size, 8)
164
+ metadata: Optional metadata embeddings
165
+ attention_mask: Optional attention mask (batch_size, seq_len)
166
+ labels: Optional labels for computing loss (batch_size, seq_len)
167
+ output_attentions: Whether to return attentions (not supported)
168
+ output_hidden_states: Whether to return hidden states
169
+
170
+ Returns:
171
+ HourglassTransformerOutput containing all model outputs
172
+ """
173
+ # Get the base model output
174
+ outputs = self.model(
175
+ masked_seq=input_ids,
176
+ masked_taxonomy=masked_taxonomy,
177
+ metadata=metadata,
178
+ mask=attention_mask,
179
+ )
180
+
181
+ # Convert to HourglassTransformerOutput
182
+ # This extends MaskedLMOutput for HuggingFace compatibility
183
+ hf_output = HourglassTransformerOutput(
184
+ loss=None, # Loss should be computed externally if labels provided
185
+ logits=outputs.seq_logits, # Standard HuggingFace field
186
+ hidden_states=(
187
+ (outputs.last_hidden_state,)
188
+ if (output_hidden_states and outputs.last_hidden_state is not None)
189
+ else None
190
+ ),
191
+ attentions=None, # Not currently supported
192
+ # Custom fields
193
+ seq_logits=outputs.seq_logits,
194
+ tax_logits=outputs.tax_logits,
195
+ track_yhat=outputs.track_yhat,
196
+ expression_mode=outputs.expression_mode,
197
+ last_hidden_state=outputs.last_hidden_state,
198
+ last_hidden_state_track=outputs.last_hidden_state_track,
199
+ )
200
+
201
+ return hf_output
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:423ac098fcb8ee69c96e99c310f19eb55d225804478abf3cd650e66471043301
3
+ size 593330620