Upload folder using huggingface_hub
Browse files- modeling_tx_standalone.py +157 -0
modeling_tx_standalone.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) Tahoe Therapeutics 2025. All rights reserved.
|
| 2 |
+
"""
|
| 3 |
+
HuggingFace-compatible wrapper for TXModel (Standalone version)
|
| 4 |
+
Only requires: transformers, torch, safetensors
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import Optional, Union, Tuple
|
| 8 |
+
import torch
|
| 9 |
+
from transformers import PreTrainedModel
|
| 10 |
+
from transformers.modeling_outputs import BaseModelOutput
|
| 11 |
+
|
| 12 |
+
from configuration_tx import TXConfig
|
| 13 |
+
from model_standalone import TXModel
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TXPreTrainedModel(PreTrainedModel):
|
| 17 |
+
"""
|
| 18 |
+
Base class for TXModel with HuggingFace integration
|
| 19 |
+
"""
|
| 20 |
+
config_class = TXConfig
|
| 21 |
+
base_model_prefix = "tx_model"
|
| 22 |
+
supports_gradient_checkpointing = False
|
| 23 |
+
_no_split_modules = ["TXBlock"]
|
| 24 |
+
|
| 25 |
+
def _init_weights(self, module):
|
| 26 |
+
"""Initialize weights"""
|
| 27 |
+
if isinstance(module, torch.nn.Linear):
|
| 28 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
| 29 |
+
if module.bias is not None:
|
| 30 |
+
module.bias.data.zero_()
|
| 31 |
+
elif isinstance(module, torch.nn.Embedding):
|
| 32 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
| 33 |
+
if module.padding_idx is not None:
|
| 34 |
+
module.weight.data[module.padding_idx].zero_()
|
| 35 |
+
elif isinstance(module, torch.nn.LayerNorm):
|
| 36 |
+
module.bias.data.zero_()
|
| 37 |
+
module.weight.data.fill_(1.0)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class TXModelForHF(TXPreTrainedModel):
|
| 41 |
+
"""
|
| 42 |
+
HuggingFace-compatible TXModel
|
| 43 |
+
|
| 44 |
+
This model can be used directly with HuggingFace's from_pretrained()
|
| 45 |
+
and requires only: transformers, torch, safetensors
|
| 46 |
+
|
| 47 |
+
No dependencies on llmfoundry, composer, or other external libraries.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(self, config: TXConfig):
|
| 51 |
+
super().__init__(config)
|
| 52 |
+
|
| 53 |
+
# Initialize standalone model
|
| 54 |
+
self.tx_model = TXModel(
|
| 55 |
+
vocab_size=config.vocab_size,
|
| 56 |
+
d_model=config.d_model,
|
| 57 |
+
n_layers=config.n_layers,
|
| 58 |
+
n_heads=config.n_heads,
|
| 59 |
+
expansion_ratio=config.expansion_ratio,
|
| 60 |
+
pad_token_id=config.pad_token_id,
|
| 61 |
+
pad_value=config.pad_value,
|
| 62 |
+
num_bins=config.num_bins,
|
| 63 |
+
norm_scheme=config.norm_scheme,
|
| 64 |
+
transformer_activation=config.transformer_activation,
|
| 65 |
+
cell_emb_style=config.cell_emb_style,
|
| 66 |
+
use_chem_token=config.use_chem_token,
|
| 67 |
+
attn_config=config.attn_config,
|
| 68 |
+
norm_config=config.norm_config,
|
| 69 |
+
gene_encoder_config=config.gene_encoder_config,
|
| 70 |
+
expression_encoder_config=config.expression_encoder_config,
|
| 71 |
+
expression_decoder_config=config.expression_decoder_config,
|
| 72 |
+
mvc_config=config.mvc_config,
|
| 73 |
+
chemical_encoder_config=config.chemical_encoder_config,
|
| 74 |
+
use_glu=config.use_glu,
|
| 75 |
+
return_gene_embeddings=config.return_gene_embeddings,
|
| 76 |
+
keep_first_n_tokens=config.keep_first_n_tokens,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# Post init
|
| 80 |
+
self.post_init()
|
| 81 |
+
|
| 82 |
+
def forward(
|
| 83 |
+
self,
|
| 84 |
+
genes: torch.Tensor,
|
| 85 |
+
values: torch.Tensor,
|
| 86 |
+
gen_masks: torch.Tensor,
|
| 87 |
+
key_padding_mask: Optional[torch.Tensor] = None,
|
| 88 |
+
drug_ids: Optional[torch.Tensor] = None,
|
| 89 |
+
skip_decoders: bool = False,
|
| 90 |
+
output_hidden_states: bool = False,
|
| 91 |
+
return_dict: bool = True,
|
| 92 |
+
) -> Union[Tuple, BaseModelOutput]:
|
| 93 |
+
"""
|
| 94 |
+
Forward pass through the model.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
genes: Gene token IDs [batch_size, seq_len]
|
| 98 |
+
values: Expression values [batch_size, seq_len]
|
| 99 |
+
gen_masks: Generation masks [batch_size, seq_len]
|
| 100 |
+
key_padding_mask: Padding mask [batch_size, seq_len]
|
| 101 |
+
drug_ids: Drug IDs [batch_size] (optional)
|
| 102 |
+
skip_decoders: Whether to skip decoder computation
|
| 103 |
+
output_hidden_states: Whether to return hidden states
|
| 104 |
+
return_dict: Whether to return a dict or tuple
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
Model outputs
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
if key_padding_mask is None:
|
| 111 |
+
key_padding_mask = ~genes.eq(self.config.pad_token_id)
|
| 112 |
+
|
| 113 |
+
outputs = self.tx_model(
|
| 114 |
+
genes=genes,
|
| 115 |
+
values=values,
|
| 116 |
+
gen_masks=gen_masks,
|
| 117 |
+
key_padding_mask=key_padding_mask,
|
| 118 |
+
drug_ids=drug_ids,
|
| 119 |
+
skip_decoders=skip_decoders,
|
| 120 |
+
output_hidden_states=output_hidden_states,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
if not return_dict:
|
| 124 |
+
return tuple(v for v in outputs.values())
|
| 125 |
+
|
| 126 |
+
# Convert to HuggingFace output format
|
| 127 |
+
return BaseModelOutput(
|
| 128 |
+
last_hidden_state=outputs.get("cell_emb"),
|
| 129 |
+
hidden_states=outputs.get("hidden_states") if output_hidden_states else None,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
def get_input_embeddings(self):
|
| 133 |
+
"""Get input embeddings"""
|
| 134 |
+
return self.tx_model.gene_encoder.embedding
|
| 135 |
+
|
| 136 |
+
def set_input_embeddings(self, value):
|
| 137 |
+
"""Set input embeddings"""
|
| 138 |
+
self.tx_model.gene_encoder.embedding = value
|
| 139 |
+
|
| 140 |
+
def get_output_embeddings(self):
|
| 141 |
+
"""Get output embeddings (not applicable)"""
|
| 142 |
+
return None
|
| 143 |
+
|
| 144 |
+
@classmethod
|
| 145 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 146 |
+
"""
|
| 147 |
+
Load model from pretrained weights.
|
| 148 |
+
|
| 149 |
+
Works with both local paths and HuggingFace Hub.
|
| 150 |
+
Requires only: transformers, torch, safetensors
|
| 151 |
+
"""
|
| 152 |
+
# Let parent class handle config and weight loading
|
| 153 |
+
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
# Alias for easier importing
|
| 157 |
+
TXForCausalLM = TXModelForHF
|