Yuto2007 commited on
Commit
e9d18f4
·
verified ·
1 Parent(s): bc12f7a

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_tx_standalone.py +157 -0
modeling_tx_standalone.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) Tahoe Therapeutics 2025. All rights reserved.
2
+ """
3
+ HuggingFace-compatible wrapper for TXModel (Standalone version)
4
+ Only requires: transformers, torch, safetensors
5
+ """
6
+
7
+ from typing import Optional, Union, Tuple
8
+ import torch
9
+ from transformers import PreTrainedModel
10
+ from transformers.modeling_outputs import BaseModelOutput
11
+
12
+ from configuration_tx import TXConfig
13
+ from model_standalone import TXModel
14
+
15
+
16
+ class TXPreTrainedModel(PreTrainedModel):
17
+ """
18
+ Base class for TXModel with HuggingFace integration
19
+ """
20
+ config_class = TXConfig
21
+ base_model_prefix = "tx_model"
22
+ supports_gradient_checkpointing = False
23
+ _no_split_modules = ["TXBlock"]
24
+
25
+ def _init_weights(self, module):
26
+ """Initialize weights"""
27
+ if isinstance(module, torch.nn.Linear):
28
+ module.weight.data.normal_(mean=0.0, std=0.02)
29
+ if module.bias is not None:
30
+ module.bias.data.zero_()
31
+ elif isinstance(module, torch.nn.Embedding):
32
+ module.weight.data.normal_(mean=0.0, std=0.02)
33
+ if module.padding_idx is not None:
34
+ module.weight.data[module.padding_idx].zero_()
35
+ elif isinstance(module, torch.nn.LayerNorm):
36
+ module.bias.data.zero_()
37
+ module.weight.data.fill_(1.0)
38
+
39
+
40
+ class TXModelForHF(TXPreTrainedModel):
41
+ """
42
+ HuggingFace-compatible TXModel
43
+
44
+ This model can be used directly with HuggingFace's from_pretrained()
45
+ and requires only: transformers, torch, safetensors
46
+
47
+ No dependencies on llmfoundry, composer, or other external libraries.
48
+ """
49
+
50
+ def __init__(self, config: TXConfig):
51
+ super().__init__(config)
52
+
53
+ # Initialize standalone model
54
+ self.tx_model = TXModel(
55
+ vocab_size=config.vocab_size,
56
+ d_model=config.d_model,
57
+ n_layers=config.n_layers,
58
+ n_heads=config.n_heads,
59
+ expansion_ratio=config.expansion_ratio,
60
+ pad_token_id=config.pad_token_id,
61
+ pad_value=config.pad_value,
62
+ num_bins=config.num_bins,
63
+ norm_scheme=config.norm_scheme,
64
+ transformer_activation=config.transformer_activation,
65
+ cell_emb_style=config.cell_emb_style,
66
+ use_chem_token=config.use_chem_token,
67
+ attn_config=config.attn_config,
68
+ norm_config=config.norm_config,
69
+ gene_encoder_config=config.gene_encoder_config,
70
+ expression_encoder_config=config.expression_encoder_config,
71
+ expression_decoder_config=config.expression_decoder_config,
72
+ mvc_config=config.mvc_config,
73
+ chemical_encoder_config=config.chemical_encoder_config,
74
+ use_glu=config.use_glu,
75
+ return_gene_embeddings=config.return_gene_embeddings,
76
+ keep_first_n_tokens=config.keep_first_n_tokens,
77
+ )
78
+
79
+ # Post init
80
+ self.post_init()
81
+
82
+ def forward(
83
+ self,
84
+ genes: torch.Tensor,
85
+ values: torch.Tensor,
86
+ gen_masks: torch.Tensor,
87
+ key_padding_mask: Optional[torch.Tensor] = None,
88
+ drug_ids: Optional[torch.Tensor] = None,
89
+ skip_decoders: bool = False,
90
+ output_hidden_states: bool = False,
91
+ return_dict: bool = True,
92
+ ) -> Union[Tuple, BaseModelOutput]:
93
+ """
94
+ Forward pass through the model.
95
+
96
+ Args:
97
+ genes: Gene token IDs [batch_size, seq_len]
98
+ values: Expression values [batch_size, seq_len]
99
+ gen_masks: Generation masks [batch_size, seq_len]
100
+ key_padding_mask: Padding mask [batch_size, seq_len]
101
+ drug_ids: Drug IDs [batch_size] (optional)
102
+ skip_decoders: Whether to skip decoder computation
103
+ output_hidden_states: Whether to return hidden states
104
+ return_dict: Whether to return a dict or tuple
105
+
106
+ Returns:
107
+ Model outputs
108
+ """
109
+
110
+ if key_padding_mask is None:
111
+ key_padding_mask = ~genes.eq(self.config.pad_token_id)
112
+
113
+ outputs = self.tx_model(
114
+ genes=genes,
115
+ values=values,
116
+ gen_masks=gen_masks,
117
+ key_padding_mask=key_padding_mask,
118
+ drug_ids=drug_ids,
119
+ skip_decoders=skip_decoders,
120
+ output_hidden_states=output_hidden_states,
121
+ )
122
+
123
+ if not return_dict:
124
+ return tuple(v for v in outputs.values())
125
+
126
+ # Convert to HuggingFace output format
127
+ return BaseModelOutput(
128
+ last_hidden_state=outputs.get("cell_emb"),
129
+ hidden_states=outputs.get("hidden_states") if output_hidden_states else None,
130
+ )
131
+
132
+ def get_input_embeddings(self):
133
+ """Get input embeddings"""
134
+ return self.tx_model.gene_encoder.embedding
135
+
136
+ def set_input_embeddings(self, value):
137
+ """Set input embeddings"""
138
+ self.tx_model.gene_encoder.embedding = value
139
+
140
+ def get_output_embeddings(self):
141
+ """Get output embeddings (not applicable)"""
142
+ return None
143
+
144
+ @classmethod
145
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
146
+ """
147
+ Load model from pretrained weights.
148
+
149
+ Works with both local paths and HuggingFace Hub.
150
+ Requires only: transformers, torch, safetensors
151
+ """
152
+ # Let parent class handle config and weight loading
153
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
154
+
155
+
156
+ # Alias for easier importing
157
+ TXForCausalLM = TXModelForHF