| import os,sys | |
| import math | |
| from pretrain.track.model import build_track_model | |
| import torch.nn as nn | |
| import torch | |
| import torch.nn.functional as F | |
| from einops.layers.torch import Rearrange | |
| from einops import rearrange | |
| import numpy as np | |
| class Convblock(nn.Module): | |
| def __init__(self,in_channel,kernel_size,dilate_size,dropout=0.1): | |
| super().__init__() | |
| self.conv=nn.Sequential( | |
| nn.Conv2d( | |
| in_channel, in_channel, | |
| kernel_size, padding=self.pad(kernel_size,1)), | |
| nn.GroupNorm(16, in_channel), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Conv2d( | |
| in_channel, in_channel, | |
| kernel_size, padding=self.pad(kernel_size, dilate_size), | |
| dilation=dilate_size), | |
| ) | |
| def pad(self,kernelsize, dialte_size): | |
| return (kernelsize - 1) * dialte_size // 2 | |
| def symmetric(self,x): | |
| return (x + x.permute(0,1,3,2)) / 2 | |
| def forward(self,x): | |
| identity=x | |
| out=self.conv(x) | |
| x=out+identity | |
| x=self.symmetric(x) | |
| return F.relu(x) | |
| class dilated_tower(nn.Module): | |
| def __init__(self,embed_dim,in_channel=64,kernel_size=7,dilate_rate=5): | |
| super().__init__() | |
| dilate_convs=[] | |
| for i in range(dilate_rate+1): | |
| dilate_convs.append( | |
| Convblock(in_channel,kernel_size=kernel_size,dilate_size=2**i)) | |
| self.cnn=nn.Sequential( | |
| Rearrange('b l n d -> b d l n'), | |
| nn.Conv2d(embed_dim, in_channel, kernel_size=1), | |
| *dilate_convs, | |
| nn.Conv2d(in_channel, in_channel, kernel_size=1), | |
| Rearrange('b d l n -> b l n d'), | |
| ) | |
| def forward(self,x,crop): | |
| x=self.cnn(x) | |
| x=x[:,crop:-crop,crop:-crop,:] | |
| return x | |
| class Downstream_microc_model(nn.Module): | |
| def __init__( | |
| self, | |
| pretrain_model, | |
| embed_dim, | |
| hidden_dim=256, | |
| in_dim=64, | |
| crop=10, | |
| ): | |
| super().__init__() | |
| self.project = nn.Sequential( | |
| nn.Linear(embed_dim, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, hidden_dim), | |
| ) | |
| self.pretrain_model=pretrain_model | |
| self.dilate_tower = dilated_tower(embed_dim=hidden_dim, in_channel=in_dim,dilate_rate=5) | |
| self.prediction_head = nn.Linear(in_dim, 1) | |
| self.crop=crop | |
| def output_head(self, x): | |
| bins=x.shape[1] | |
| x1 = torch.tile(x.unsqueeze(1), (1, bins, 1, 1)) | |
| x2 = x1.permute(0, 2, 1, 3) | |
| mean_out = (x1 + x2) / 2 | |
| dot_out = (x1 * x2)/math.sqrt(x.shape[-1]) | |
| return mean_out + dot_out | |
| def upper_tri(self, x,bins): | |
| triu_tup = np.triu_indices(bins) | |
| d = np.array(list(triu_tup[1] + bins * triu_tup[0])) | |
| return x[:, d, :] | |
| def forward(self,x): | |
| x=self.pretrain_model(x) | |
| x=self.project(x) | |
| x = self.output_head(x) | |
| x = self.dilate_tower(x, self.crop) | |
| bins = x.shape[1] | |
| x = rearrange(x, 'b l n d -> b (l n) d') | |
| x = self.upper_tri(x,bins) | |
| x = self.prediction_head(x) | |
| return x | |
| def build_microc_model(args): | |
| pretrain_model=build_track_model(args) | |
| model=Downstream_microc_model( | |
| pretrain_model=pretrain_model, | |
| embed_dim=args.embed_dim, | |
| crop=args.crop | |
| ) | |
| return model | |