| import os
|
| from typing import Optional, Tuple, Union
|
| import torch
|
| from torch import nn
|
| from torch.nn import functional as F
|
| from transformers import PreTrainedModel
|
| from transformers.modeling_outputs import SequenceClassifierOutput
|
| from .modeling_caduceus_moe import CaduceusPh, CaduceusPhPreTrainedModel
|
| from .configuration_caduceus_ph import CaduceusPhConfig
|
|
|
| class DilatedConvLayer(nn.Module):
|
| def __init__(self, in_channels, out_channels, padding, dilation, groups=1):
|
| super().__init__()
|
| self.dilated_conv = nn.Sequential(
|
| nn.Conv1d(
|
| in_channels=in_channels,
|
| out_channels=out_channels,
|
| kernel_size=3,
|
| padding=padding,
|
| dilation=dilation,
|
| groups=groups,
|
| ),
|
| nn.SiLU(),
|
| nn.Dropout1d(p=0.25),
|
| )
|
|
|
| def forward(self, x: torch.Tensor):
|
|
|
| return self.dilated_conv(x)
|
|
|
| class NormLayer(nn.Module):
|
| def __init__(self, norm_shape):
|
| super().__init__()
|
| self.layer_norm = nn.LayerNorm(normalized_shape=norm_shape)
|
|
|
| def forward(self, x: torch.Tensor):
|
| x = self.layer_norm(x.transpose(1,2))
|
|
|
| return x.transpose(1,2)
|
|
|
| class DilatedConvBlock(nn.Module):
|
| def __init__(self, in_channels, out_channels, padding, dilation):
|
| super().__init__()
|
| self.dilated_layer = DilatedConvLayer(
|
| in_channels=in_channels,
|
| out_channels=out_channels,
|
| padding=padding,
|
| dilation=dilation,
|
| groups=1,
|
| )
|
| self.norm = NormLayer(out_channels)
|
|
|
| def forward(self, x: torch.Tensor):
|
| y = self.dilated_layer(x)
|
| y = self.norm(y)
|
|
|
| return y
|
|
|
| class ShiftedConvBlock(nn.Module):
|
| def __init__(self, n_layer, in_channels, out_channels, padding, dilation, shift_steps=2, shift_stride=1):
|
| super().__init__()
|
| self.shift_steps = shift_steps
|
| self.hf_shift_steps = shift_steps // 2
|
| self.shift_stride = shift_stride
|
| self.pad = (shift_steps*shift_stride) // 2
|
|
|
| first_layer = [
|
| DilatedConvLayer(
|
| in_channels=in_channels,
|
| out_channels=out_channels,
|
| padding=padding,
|
| dilation=dilation,
|
| groups=shift_steps
|
| )
|
| ]
|
| next_layer = [
|
| DilatedConvLayer(
|
| in_channels=out_channels,
|
| out_channels=out_channels,
|
| padding=padding,
|
| dilation=dilation,
|
| groups=shift_steps,
|
| )
|
| for i in range(n_layer-1)
|
| ]
|
|
|
| self.shifted_layers = nn.ModuleList(first_layer+next_layer)
|
| self.norm = NormLayer(out_channels)
|
|
|
| def shift(self, x: torch.Tensor):
|
| bsz, d, L = x.shape
|
| xn = F.pad(x, (self.pad ,self.pad), "constant", 0)
|
| xs = torch.chunk(xn, self.shift_steps, 1)
|
| x_shift = [torch.roll(x_c, shift*self.shift_stride, 2) for x_c, shift in zip(xs, range(-self.hf_shift_steps, self.hf_shift_steps+1))]
|
| x_cat = torch.cat(x_shift, 1)
|
| x_cat = torch.narrow(x_cat, 2, self.pad, L)
|
|
|
| return x_cat
|
|
|
| def forward(self, x: torch.Tensor):
|
| y = self.shift(x)
|
| for layer in self.shifted_layers:
|
| y = layer(y)
|
| y = self.norm(y)
|
|
|
| return y
|
|
|
| class ShiftedUNetHead(nn.Module):
|
| def __init__(
|
| self,
|
| embd_dim=[512,1024,1536,2560,4096],
|
| dilation=[1,4,8,16],
|
| shift_steps=[2,4,4]
|
| ):
|
| super().__init__()
|
| self.down_conv1 = DilatedConvBlock(embd_dim[0], embd_dim[1], padding=dilation[0], dilation=dilation[0])
|
|
|
| self.down_conv2 = ShiftedConvBlock(
|
| n_layer=2, in_channels=embd_dim[1], out_channels=embd_dim[2],
|
| padding=dilation[1], dilation=dilation[1], shift_steps=shift_steps[0]
|
| )
|
| self.down_conv3 = ShiftedConvBlock(
|
| n_layer=2, in_channels=embd_dim[2], out_channels=embd_dim[3],
|
| padding=dilation[2], dilation=dilation[2], shift_steps=shift_steps[1]
|
| )
|
| self.down_conv4 = ShiftedConvBlock(
|
| n_layer=3, in_channels=embd_dim[3], out_channels=embd_dim[4],
|
| padding=dilation[3], dilation=dilation[3], shift_steps=shift_steps[2]
|
| )
|
|
|
| self.up_trans1 = nn.ConvTranspose1d(embd_dim[4], embd_dim[3], kernel_size=2, stride=2, groups=128)
|
| self.up_trans2 = nn.ConvTranspose1d(embd_dim[3], embd_dim[2], kernel_size=2, stride=2, groups=128)
|
| self.up_trans3 = nn.ConvTranspose1d(embd_dim[2], embd_dim[1], kernel_size=2, stride=2, groups=128)
|
|
|
| self.up_conv1 = ShiftedConvBlock(
|
| n_layer=2, in_channels=embd_dim[3], out_channels=embd_dim[3],
|
| padding=dilation[2], dilation=dilation[2], shift_steps=shift_steps[1]
|
| )
|
| self.up_conv2 = ShiftedConvBlock(
|
| n_layer=2, in_channels=embd_dim[2], out_channels=embd_dim[2],
|
| padding=dilation[1], dilation=dilation[1], shift_steps=shift_steps[0]
|
| )
|
| self.up_conv3 = DilatedConvBlock(embd_dim[1], embd_dim[1], padding=dilation[0], dilation=dilation[0])
|
|
|
| self.norm_f = NormLayer(embd_dim[1])
|
|
|
| def forward(self, x: torch.Tensor):
|
| x = self.down_conv1(x)
|
| t1 = x
|
|
|
| x = F.avg_pool1d(x, kernel_size=2, stride=2)
|
| x = self.down_conv2(x)
|
| t3 = x
|
|
|
| x = F.avg_pool1d(x, kernel_size=2, stride=2)
|
| x = self.down_conv3(x)
|
| t5 = x
|
|
|
| x = F.avg_pool1d(x, kernel_size=2, stride=2)
|
| x = self.down_conv4(x)
|
|
|
| x = self.up_trans1(x)
|
| x = torch.add(x, t5)
|
| x = self.up_conv1(x)
|
|
|
| x = self.up_trans2(x)
|
| x = torch.add(x, t3)
|
| x = self.up_conv2(x)
|
|
|
| x = self.up_trans3(x)
|
| x = torch.add(x, t1)
|
| x = self.up_conv3(x)
|
|
|
| return self.norm_f(x)
|
|
|
| class SegmentCaduceus(CaduceusPhPreTrainedModel):
|
| """SegmentCaduceusModel for sequence segmentation"""
|
|
|
| def __init__(self, config: CaduceusPhConfig, device=None, dtype=None, **kwargs):
|
| super().__init__(config, **kwargs)
|
| self.config = config
|
| self.num_features = config.num_features
|
| self.training_features = None
|
| factory_kwargs = {"device": device, "dtype": dtype}
|
| self.caduceus_ph = CaduceusPh(config, **factory_kwargs, **kwargs)
|
|
|
| self.shift_unet_head = ShiftedUNetHead()
|
| self.final_head = nn.Sequential(
|
| nn.Conv1d(2 * config.d_model, config.d_model, kernel_size=1, padding=0),
|
| nn.SiLU(),
|
| nn.Conv1d(config.d_model, self.num_features, kernel_size=1, padding=0),
|
| )
|
|
|
| self.post_init()
|
|
|
| def get_feature_logits(self, feature: str, strand: str, logits: torch.FloatTensor):
|
| if feature not in {"gene", "CDS", "exon"}:
|
| raise ValueError("Input features must be in 'gene', 'CDS' or 'exon'.")
|
| if strand not in {"+", "-"}:
|
| raise ValueError("Input strand must be '+' or '-'.")
|
|
|
| feature_dict = {
|
| "gene+": 0,
|
| "gene-":1,
|
| "exon+": 6,
|
| "exon-":7,
|
| "CDS+": 8,
|
| "CDS-":9,
|
| }
|
|
|
| feature_index = feature_dict[feature+strand]
|
|
|
| return logits[...,feature_index]
|
|
|
| def forward(
|
| self,
|
| input_ids: torch.LongTensor = None,
|
| inputs_embeds: Optional[torch.FloatTensor] = None,
|
| labels: Optional[torch.LongTensor] = None,
|
| output_hidden_states: Optional[bool] = None,
|
| return_dict: Optional[bool] = None,
|
|
|
| ):
|
| if not ((input_ids.shape[1] - 2) % 8 == 0):
|
| raise ValueError("Input sequence length must be divisible by the 8.")
|
|
|
| output_hidden_states = (
|
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| )
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
| outputs = self.caduceus_ph(
|
| input_ids=input_ids,
|
| inputs_embeds=inputs_embeds,
|
| output_hidden_states=output_hidden_states,
|
| return_dict=return_dict,
|
| )
|
|
|
| hidden_states = self.shift_unet_head(outputs[0][:,1:-1,:].transpose(1,2))
|
| logits = self.final_head(hidden_states)
|
| logits = torch.sigmoid(torch.transpose(logits, 1, 2))
|
|
|
| return SequenceClassifierOutput(
|
| loss=None,
|
| logits=logits,
|
| hidden_states=torch.transpose(hidden_states, 1, 2),
|
| )
|
|
|