| from typing import Any, Dict, List |
|
|
| import borzoi_pytorch |
| import torch |
| import torch.nn as nn |
| from einops import rearrange |
| from torch import einsum |
| from transformers import PretrainedConfig, PreTrainedModel |
|
|
| from genomics_research.segmentnt.porting_to_pytorch.layers.segmentation_head import ( |
| TorchUNetHead, |
| ) |
|
|
| FEATURES = [ |
| "protein_coding_gene", |
| "lncRNA", |
| "exon", |
| "intron", |
| "splice_donor", |
| "splice_acceptor", |
| "5UTR", |
| "3UTR", |
| "CTCF-bound", |
| "polyA_signal", |
| "enhancer_Tissue_specific", |
| "enhancer_Tissue_invariant", |
| "promoter_Tissue_specific", |
| "promoter_Tissue_invariant", |
| ] |
|
|
|
|
| class SegmentBorzoiConfig(PretrainedConfig): |
| model_type = "segment_borzoi" |
|
|
| def __init__( |
| self, |
| features: List[str] = FEATURES, |
| embed_dim: int = 1536, |
| dim_divisible_by: int = 32, |
| attention_dim_key: int = 64, |
| num_attention_heads: int = 8, |
| num_rel_pos_features: int = 32, |
| **kwargs: Dict[str, Any], |
| ): |
| self.features = features |
| self.embed_dim = embed_dim |
| self.dim_divisible_by = dim_divisible_by |
| self.attention_dim_key = attention_dim_key |
| self.num_attention_heads = num_attention_heads |
| self.num_rel_pos_features = num_rel_pos_features |
|
|
| super().__init__(**kwargs) |
|
|
|
|
| class SegmentBorzoi(PreTrainedModel): |
| config_class = SegmentBorzoiConfig |
|
|
| def __init__(self, config: SegmentBorzoiConfig): |
| super().__init__(config=config) |
| borzoi = borzoi_pytorch.Borzoi.from_pretrained("johahi/borzoi-replicate-0") |
|
|
| |
| self.stem = borzoi.conv_dna |
|
|
| |
| self.res_tower = borzoi.res_tower |
| self.unet1 = borzoi.unet1 |
| self._max_pool = borzoi._max_pool |
|
|
| |
| self.transformer = borzoi.transformer |
|
|
| |
| self.horizontal_conv1 = borzoi.horizontal_conv1 |
| self.horizontal_conv0 = borzoi.horizontal_conv0 |
| self.upsampling_unet1 = borzoi.upsampling_unet1 |
| self.upsampling_unet0 = borzoi.upsampling_unet0 |
| self.separable1 = borzoi.separable1 |
| self.separable0 = borzoi.separable0 |
|
|
| |
| self.crop = borzoi.crop |
|
|
| |
| self.final_joined_convs = borzoi.final_joined_convs |
|
|
| self.unet_head = TorchUNetHead( |
| features=config.features, |
| embed_dimension=config.embed_dim, |
| nucl_per_token=config.dim_divisible_by, |
| remove_cls_token=False, |
| ) |
|
|
| |
| for layer in self.transformer: |
| layer[0].fn[1] = BorzoiAttentionLayer( |
| config.embed_dim, |
| heads=config.num_attention_heads, |
| dim_key=config.attention_dim_key, |
| dim_value=config.embed_dim // config.num_attention_heads, |
| dropout=0.05, |
| pos_dropout=0.01, |
| num_rel_pos_features=config.num_rel_pos_features, |
| ) |
|
|
| |
| self.unet_head.unet.downsample_blocks[0].conv_layers[0] = nn.Conv1d( |
| in_channels=1920, out_channels=1536, kernel_size=3, stride=1, padding=1 |
| ) |
|
|
| |
| self.separable1.conv_layer[1].bias = None |
| self.separable0.conv_layer[1].bias = None |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| x = x.transpose(1, 2) |
| x = self.stem(x) |
|
|
| |
| x_unet0 = self.res_tower(x) |
| x_unet1 = self.unet1(x_unet0) |
| x = self._max_pool(x_unet1) |
|
|
| |
| x = x.permute(0, 2, 1) |
| x = self.transformer(x) |
| x = x.permute(0, 2, 1) |
|
|
| |
| x_unet1 = self.horizontal_conv1(x_unet1) |
| x_unet0 = self.horizontal_conv0(x_unet0) |
|
|
| |
| x = self.upsampling_unet1(x) |
| x += x_unet1 |
| x = self.separable1(x) |
| x = self.upsampling_unet0(x) |
| x += x_unet0 |
| x = self.separable0(x) |
|
|
| |
| x = self.crop(x.permute(0, 2, 1)) |
| x = x.permute(0, 2, 1) |
|
|
| |
| x = self.final_joined_convs(x) |
|
|
| x = self.unet_head(x) |
|
|
| return x |
|
|
|
|
| |
| |
| def _prepend_dims(tensor: torch.Tensor, num_dims: int) -> torch.Tensor: |
| """Prepends dimensions to match the required shape.""" |
| for _ in range(num_dims - tensor.dim()): |
| tensor = tensor.unsqueeze(0) |
| return tensor |
|
|
|
|
| def get_positional_features_central_mask_borzoi( |
| positions: torch.Tensor, feature_size: int, seq_length: int |
| ) -> torch.Tensor: |
| """Positional features using a central mask (allow only central features).""" |
| pow_rate = torch.exp(torch.log(torch.tensor(seq_length + 1.0)) / feature_size) |
| center_widths = torch.pow(pow_rate, torch.arange(1, feature_size + 1).float()) - 1 |
| center_widths = _prepend_dims(center_widths, positions.ndim) |
| outputs = (center_widths > torch.abs(positions).unsqueeze(-1)).float() |
| return outputs |
|
|
|
|
| def get_positional_embed_borzoi(seq_len: int, feature_size: int) -> torch.Tensor: |
| """ |
| Compute positional embedding for Borzoi. Note that it is different than the one |
| used in Enformer. |
| """ |
| distances = torch.arange(-seq_len + 1, seq_len) |
|
|
| num_components = 2 |
|
|
| if (feature_size % num_components) != 0: |
| raise ValueError( |
| f"feature size is not divisible by number of components ({num_components})" |
| ) |
|
|
| num_basis_per_class = feature_size // num_components |
|
|
| embeddings = [] |
|
|
| embeddings.append( |
| get_positional_features_central_mask_borzoi( |
| distances, num_basis_per_class, seq_len |
| ) |
| ) |
|
|
| embeddings = torch.cat(embeddings, dim=-1) |
| embeddings = torch.cat( |
| (embeddings, torch.sign(distances).unsqueeze(-1) * embeddings), dim=-1 |
| ) |
| return embeddings |
|
|
|
|
| def relative_shift(x: torch.Tensor) -> torch.Tensor: |
| to_pad = torch.zeros_like(x[..., :1]) |
| x = torch.cat((to_pad, x), dim=-1) |
| _, h, t1, t2 = x.shape |
| x = x.reshape(-1, h, t2, t1) |
| x = x[:, :, 1:, :] |
| x = x.reshape(-1, h, t1, t2 - 1) |
| return x[..., : ((t2 + 1) // 2)] |
|
|
|
|
| class BorzoiAttentionLayer(nn.Module): |
| def __init__( |
| self, |
| dim, |
| *, |
| num_rel_pos_features, |
| heads=8, |
| dim_key=64, |
| dim_value=64, |
| dropout=0.0, |
| pos_dropout=0.0, |
| ) -> None: |
| super().__init__() |
| self.scale = dim_key**-0.5 |
| self.heads = heads |
|
|
| self.to_q = nn.Linear(dim, dim_key * heads, bias=False) |
| self.to_k = nn.Linear(dim, dim_key * heads, bias=False) |
| self.to_v = nn.Linear(dim, dim_value * heads, bias=False) |
|
|
| self.to_out = nn.Linear(dim_value * heads, dim) |
| nn.init.zeros_(self.to_out.weight) |
| nn.init.zeros_(self.to_out.bias) |
|
|
| self.num_rel_pos_features = num_rel_pos_features |
|
|
| self.to_rel_k = nn.Linear(num_rel_pos_features, dim_key * heads, bias=False) |
| self.rel_content_bias = nn.Parameter( |
| torch.randn(1, heads, 1, dim_key) |
| ) |
| self.rel_pos_bias = nn.Parameter( |
| torch.randn(1, heads, 1, dim_key) |
| ) |
|
|
| |
|
|
| self.pos_dropout = nn.Dropout(pos_dropout) |
| self.attn_dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| n, h = x.shape[-2], self.heads |
|
|
| q = self.to_q(x) |
| k = self.to_k(x) |
| v = self.to_v(x) |
|
|
| q, k, v = map( |
| lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), |
| (q, k, v), |
| ) |
|
|
| q = q * self.scale |
|
|
| content_logits = einsum( |
| "b h i d, b h j d -> b h i j", q + self.rel_content_bias, k |
| ) |
|
|
| positions = get_positional_embed_borzoi(n, self.num_rel_pos_features) |
| positions = self.pos_dropout(positions) |
| rel_k = self.to_rel_k(positions) |
|
|
| rel_k = rearrange(rel_k, "n (h d) -> h n d", h=h) |
| rel_logits = einsum("b h i d, h j d -> b h i j", q + self.rel_pos_bias, rel_k) |
| rel_logits = relative_shift(rel_logits) |
|
|
| logits = content_logits + rel_logits |
| attn = logits.softmax(dim=-1) |
| attn = self.attn_dropout(attn) |
|
|
| out = einsum("b h i j, b h j d -> b h i d", attn, v) |
| out = rearrange(out, "b h n d -> b n (h d)") |
| return self.to_out(out) |
|
|