PlantGenoANN / modeling_segment_caduceus.py
qzzhang's picture
Rename PlantGenoANN/modeling_segment_caduceus.py to modeling_segment_caduceus.py
d591c3d verified
import os
from typing import Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
from .modeling_caduceus_moe import CaduceusPh, CaduceusPhPreTrainedModel
from .configuration_caduceus_ph import CaduceusPhConfig
class DilatedConvLayer(nn.Module):
def __init__(self, in_channels, out_channels, padding, dilation, groups=1):
super().__init__()
self.dilated_conv = nn.Sequential(
nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
padding=padding,
dilation=dilation,
groups=groups,
),
nn.SiLU(),
nn.Dropout1d(p=0.25),
)
def forward(self, x: torch.Tensor):
return self.dilated_conv(x)
class NormLayer(nn.Module):
def __init__(self, norm_shape):
super().__init__()
self.layer_norm = nn.LayerNorm(normalized_shape=norm_shape)
def forward(self, x: torch.Tensor):
x = self.layer_norm(x.transpose(1,2))
return x.transpose(1,2)
class DilatedConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, padding, dilation):
super().__init__()
self.dilated_layer = DilatedConvLayer(
in_channels=in_channels,
out_channels=out_channels,
padding=padding,
dilation=dilation,
groups=1,
)
self.norm = NormLayer(out_channels)
def forward(self, x: torch.Tensor):
y = self.dilated_layer(x)
y = self.norm(y)
return y
class ShiftedConvBlock(nn.Module):
def __init__(self, n_layer, in_channels, out_channels, padding, dilation, shift_steps=2, shift_stride=1):
super().__init__()
self.shift_steps = shift_steps
self.hf_shift_steps = shift_steps // 2
self.shift_stride = shift_stride
self.pad = (shift_steps*shift_stride) // 2
first_layer = [
DilatedConvLayer(
in_channels=in_channels,
out_channels=out_channels,
padding=padding,
dilation=dilation,
groups=shift_steps
)
]
next_layer = [
DilatedConvLayer(
in_channels=out_channels,
out_channels=out_channels,
padding=padding,
dilation=dilation,
groups=shift_steps,
)
for i in range(n_layer-1)
]
self.shifted_layers = nn.ModuleList(first_layer+next_layer)
self.norm = NormLayer(out_channels)
def shift(self, x: torch.Tensor):
bsz, d, L = x.shape
xn = F.pad(x, (self.pad ,self.pad), "constant", 0)
xs = torch.chunk(xn, self.shift_steps, 1)
x_shift = [torch.roll(x_c, shift*self.shift_stride, 2) for x_c, shift in zip(xs, range(-self.hf_shift_steps, self.hf_shift_steps+1))]
x_cat = torch.cat(x_shift, 1)
x_cat = torch.narrow(x_cat, 2, self.pad, L)
return x_cat
def forward(self, x: torch.Tensor):
y = self.shift(x)
for layer in self.shifted_layers:
y = layer(y)
y = self.norm(y)
return y
class ShiftedUNetHead(nn.Module):
def __init__(
self,
embd_dim=[512,1024,1536,2560,4096],
dilation=[1,4,8,16],
shift_steps=[2,4,4]
):
super().__init__()
self.down_conv1 = DilatedConvBlock(embd_dim[0], embd_dim[1], padding=dilation[0], dilation=dilation[0])
self.down_conv2 = ShiftedConvBlock(
n_layer=2, in_channels=embd_dim[1], out_channels=embd_dim[2],
padding=dilation[1], dilation=dilation[1], shift_steps=shift_steps[0]
)
self.down_conv3 = ShiftedConvBlock(
n_layer=2, in_channels=embd_dim[2], out_channels=embd_dim[3],
padding=dilation[2], dilation=dilation[2], shift_steps=shift_steps[1]
)
self.down_conv4 = ShiftedConvBlock(
n_layer=3, in_channels=embd_dim[3], out_channels=embd_dim[4],
padding=dilation[3], dilation=dilation[3], shift_steps=shift_steps[2]
)
self.up_trans1 = nn.ConvTranspose1d(embd_dim[4], embd_dim[3], kernel_size=2, stride=2, groups=128)
self.up_trans2 = nn.ConvTranspose1d(embd_dim[3], embd_dim[2], kernel_size=2, stride=2, groups=128)
self.up_trans3 = nn.ConvTranspose1d(embd_dim[2], embd_dim[1], kernel_size=2, stride=2, groups=128)
self.up_conv1 = ShiftedConvBlock(
n_layer=2, in_channels=embd_dim[3], out_channels=embd_dim[3],
padding=dilation[2], dilation=dilation[2], shift_steps=shift_steps[1]
)
self.up_conv2 = ShiftedConvBlock(
n_layer=2, in_channels=embd_dim[2], out_channels=embd_dim[2],
padding=dilation[1], dilation=dilation[1], shift_steps=shift_steps[0]
)
self.up_conv3 = DilatedConvBlock(embd_dim[1], embd_dim[1], padding=dilation[0], dilation=dilation[0])
self.norm_f = NormLayer(embd_dim[1])
def forward(self, x: torch.Tensor):
x = self.down_conv1(x)
t1 = x
x = F.avg_pool1d(x, kernel_size=2, stride=2)
x = self.down_conv2(x)
t3 = x
x = F.avg_pool1d(x, kernel_size=2, stride=2)
x = self.down_conv3(x)
t5 = x
x = F.avg_pool1d(x, kernel_size=2, stride=2)
x = self.down_conv4(x)
x = self.up_trans1(x)
x = torch.add(x, t5)
x = self.up_conv1(x)
x = self.up_trans2(x)
x = torch.add(x, t3)
x = self.up_conv2(x)
x = self.up_trans3(x)
x = torch.add(x, t1)
x = self.up_conv3(x)
return self.norm_f(x)
class SegmentCaduceus(CaduceusPhPreTrainedModel):
"""SegmentCaduceusModel for sequence segmentation"""
def __init__(self, config: CaduceusPhConfig, device=None, dtype=None, **kwargs):
super().__init__(config, **kwargs)
self.config = config
self.num_features = config.num_features
self.training_features = None
factory_kwargs = {"device": device, "dtype": dtype}
self.caduceus_ph = CaduceusPh(config, **factory_kwargs, **kwargs)
self.shift_unet_head = ShiftedUNetHead()
self.final_head = nn.Sequential(
nn.Conv1d(2 * config.d_model, config.d_model, kernel_size=1, padding=0),
nn.SiLU(),
nn.Conv1d(config.d_model, self.num_features, kernel_size=1, padding=0),
)
self.post_init()
def get_feature_logits(self, feature: str, strand: str, logits: torch.FloatTensor):
if feature not in {"gene", "CDS", "exon"}:
raise ValueError("Input features must be in 'gene', 'CDS' or 'exon'.")
if strand not in {"+", "-"}:
raise ValueError("Input strand must be '+' or '-'.")
feature_dict = {
"gene+": 0,
"gene-":1,
"exon+": 6,
"exon-":7,
"CDS+": 8,
"CDS-":9,
}
feature_index = feature_dict[feature+strand]
return logits[...,feature_index]
def forward(
self,
input_ids: torch.LongTensor = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
# **kwargs
):
if not ((input_ids.shape[1] - 2) % 8 == 0):
raise ValueError("Input sequence length must be divisible by the 8.")
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.caduceus_ph(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = self.shift_unet_head(outputs[0][:,1:-1,:].transpose(1,2))
logits = self.final_head(hidden_states)
logits = torch.sigmoid(torch.transpose(logits, 1, 2))
return SequenceClassifierOutput(
loss=None,
logits=logits,
hidden_states=torch.transpose(hidden_states, 1, 2),
)