| from typing import List |
|
|
| from einops import rearrange |
| from enformer_pytorch import Enformer |
| from transformers import PretrainedConfig, PreTrainedModel |
|
|
| from genomics_research.segmentnt.layers.torch.segmentation_head import TorchUNetHead |
|
|
| FEATURES = [ |
| "protein_coding_gene", |
| "lncRNA", |
| "exon", |
| "intron", |
| "splice_donor", |
| "splice_acceptor", |
| "5UTR", |
| "3UTR", |
| "CTCF-bound", |
| "polyA_signal", |
| "enhancer_Tissue_specific", |
| "enhancer_Tissue_invariant", |
| "promoter_Tissue_specific", |
| "promoter_Tissue_invariant", |
| ] |
|
|
|
|
| class SegmentEnformerConfig(PretrainedConfig): |
| model_type = "segment_enformer" |
|
|
| def __init__( |
| self, |
| features: List[str] = FEATURES, |
| embed_dim: int = 1536, |
| dim_divisible_by: int = 128, |
| **kwargs |
| ): |
| self.features = features |
| self.embed_dim = embed_dim |
| self.dim_divisible_by = dim_divisible_by |
|
|
| super().__init__(**kwargs) |
|
|
|
|
| class SegmentEnformer(PreTrainedModel): |
| config_class = SegmentEnformerConfig |
|
|
| def __init__(self, config: SegmentEnformerConfig): |
| super().__init__(config=config) |
|
|
| enformer = Enformer.from_pretrained("EleutherAI/enformer-official-rough") |
|
|
| self.stem = enformer.stem |
| self.conv_tower = enformer.conv_tower |
| self.transformer = enformer.transformer |
|
|
| self.unet_head = TorchUNetHead( |
| features=config.features, |
| embed_dimension=config.embed_dim, |
| nucl_per_token=config.dim_divisible_by, |
| remove_cls_token=False, |
| ) |
|
|
| def __call__(self, x): |
| x = rearrange(x, "b n d -> b d n") |
| x = self.stem(x) |
|
|
| x = self.conv_tower(x) |
|
|
| x = rearrange(x, "b d n -> b n d") |
| x = self.transformer(x) |
|
|
| x = rearrange(x, "b n d -> b d n") |
| x = self.unet_head(x) |
|
|
| return x |
|
|