Yuto2007 commited on
Commit
bc12f7a
·
verified ·
1 Parent(s): e093a4b

Delete modeling_tx_standalone.py

Browse files
Files changed (1) hide show
  1. modeling_tx_standalone.py +0 -157
modeling_tx_standalone.py DELETED
@@ -1,157 +0,0 @@
1
- # Copyright (C) Tahoe Therapeutics 2025. All rights reserved.
2
- """
3
- HuggingFace-compatible wrapper for TXModel (Standalone version)
4
- Only requires: transformers, torch, safetensors
5
- """
6
-
7
- from typing import Optional, Union, Tuple
8
- import torch
9
- from transformers import PreTrainedModel
10
- from transformers.modeling_outputs import BaseModelOutput
11
-
12
- from configuration_tx import TXConfig
13
- from model_standalone import TXModel
14
-
15
-
16
- class TXPreTrainedModel(PreTrainedModel):
17
- """
18
- Base class for TXModel with HuggingFace integration
19
- """
20
- config_class = TXConfig
21
- base_model_prefix = "tx_model"
22
- supports_gradient_checkpointing = False
23
- _no_split_modules = ["TXBlock"]
24
-
25
- def _init_weights(self, module):
26
- """Initialize weights"""
27
- if isinstance(module, torch.nn.Linear):
28
- module.weight.data.normal_(mean=0.0, std=0.02)
29
- if module.bias is not None:
30
- module.bias.data.zero_()
31
- elif isinstance(module, torch.nn.Embedding):
32
- module.weight.data.normal_(mean=0.0, std=0.02)
33
- if module.padding_idx is not None:
34
- module.weight.data[module.padding_idx].zero_()
35
- elif isinstance(module, torch.nn.LayerNorm):
36
- module.bias.data.zero_()
37
- module.weight.data.fill_(1.0)
38
-
39
-
40
- class TXModelForHF(TXPreTrainedModel):
41
- """
42
- HuggingFace-compatible TXModel
43
-
44
- This model can be used directly with HuggingFace's from_pretrained()
45
- and requires only: transformers, torch, safetensors
46
-
47
- No dependencies on llmfoundry, composer, or other external libraries.
48
- """
49
-
50
- def __init__(self, config: TXConfig):
51
- super().__init__(config)
52
-
53
- # Initialize standalone model
54
- self.tx_model = TXModel(
55
- vocab_size=config.vocab_size,
56
- d_model=config.d_model,
57
- n_layers=config.n_layers,
58
- n_heads=config.n_heads,
59
- expansion_ratio=config.expansion_ratio,
60
- pad_token_id=config.pad_token_id,
61
- pad_value=config.pad_value,
62
- num_bins=config.num_bins,
63
- norm_scheme=config.norm_scheme,
64
- transformer_activation=config.transformer_activation,
65
- cell_emb_style=config.cell_emb_style,
66
- use_chem_token=config.use_chem_token,
67
- attn_config=config.attn_config,
68
- norm_config=config.norm_config,
69
- gene_encoder_config=config.gene_encoder_config,
70
- expression_encoder_config=config.expression_encoder_config,
71
- expression_decoder_config=config.expression_decoder_config,
72
- mvc_config=config.mvc_config,
73
- chemical_encoder_config=config.chemical_encoder_config,
74
- use_glu=config.use_glu,
75
- return_gene_embeddings=config.return_gene_embeddings,
76
- keep_first_n_tokens=config.keep_first_n_tokens,
77
- )
78
-
79
- # Post init
80
- self.post_init()
81
-
82
- def forward(
83
- self,
84
- genes: torch.Tensor,
85
- values: torch.Tensor,
86
- gen_masks: torch.Tensor,
87
- key_padding_mask: Optional[torch.Tensor] = None,
88
- drug_ids: Optional[torch.Tensor] = None,
89
- skip_decoders: bool = False,
90
- output_hidden_states: bool = False,
91
- return_dict: bool = True,
92
- ) -> Union[Tuple, BaseModelOutput]:
93
- """
94
- Forward pass through the model.
95
-
96
- Args:
97
- genes: Gene token IDs [batch_size, seq_len]
98
- values: Expression values [batch_size, seq_len]
99
- gen_masks: Generation masks [batch_size, seq_len]
100
- key_padding_mask: Padding mask [batch_size, seq_len]
101
- drug_ids: Drug IDs [batch_size] (optional)
102
- skip_decoders: Whether to skip decoder computation
103
- output_hidden_states: Whether to return hidden states
104
- return_dict: Whether to return a dict or tuple
105
-
106
- Returns:
107
- Model outputs
108
- """
109
-
110
- if key_padding_mask is None:
111
- key_padding_mask = ~genes.eq(self.config.pad_token_id)
112
-
113
- outputs = self.tx_model(
114
- genes=genes,
115
- values=values,
116
- gen_masks=gen_masks,
117
- key_padding_mask=key_padding_mask,
118
- drug_ids=drug_ids,
119
- skip_decoders=skip_decoders,
120
- output_hidden_states=output_hidden_states,
121
- )
122
-
123
- if not return_dict:
124
- return tuple(v for v in outputs.values())
125
-
126
- # Convert to HuggingFace output format
127
- return BaseModelOutput(
128
- last_hidden_state=outputs.get("cell_emb"),
129
- hidden_states=outputs.get("hidden_states") if output_hidden_states else None,
130
- )
131
-
132
- def get_input_embeddings(self):
133
- """Get input embeddings"""
134
- return self.tx_model.gene_encoder.embedding
135
-
136
- def set_input_embeddings(self, value):
137
- """Set input embeddings"""
138
- self.tx_model.gene_encoder.embedding = value
139
-
140
- def get_output_embeddings(self):
141
- """Get output embeddings (not applicable)"""
142
- return None
143
-
144
- @classmethod
145
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
146
- """
147
- Load model from pretrained weights.
148
-
149
- Works with both local paths and HuggingFace Hub.
150
- Requires only: transformers, torch, safetensors
151
- """
152
- # Let parent class handle config and weight loading
153
- return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
154
-
155
-
156
- # Alias for easier importing
157
- TXForCausalLM = TXModelForHF