import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig # 1. СНАЧАЛА ОПРЕДЕЛЯЕМ КОНФИГ class AlphaDepthConfig(PretrainedConfig): model_type = "alpha-depth" def __init__(self, input_size=[3, 128, 128], **kwargs): self.input_size = input_size super().__init__(**kwargs) # 2. ВСПОМОГАТЕЛЬНЫЕ БЛОКИ class ConvBlock(nn.Module): def __init__(self, in_c, out_c): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_c, out_c, 3, padding=1), nn.BatchNorm2d(out_c), nn.ReLU(), nn.Conv2d(out_c, out_c, 3, padding=1), nn.BatchNorm2d(out_c), nn.ReLU() ) def forward(self, x): return self.conv(x) # 3. САМА МОДЕЛЬ (Наследуемся от PreTrainedModel!) class AlphaUNet(PreTrainedModel): config_class = AlphaDepthConfig def __init__(self, config): super().__init__(config) # Encoder self.enc1 = ConvBlock(3, 32) self.pool = nn.MaxPool2d(2) self.enc2 = ConvBlock(32, 64) self.enc3 = ConvBlock(64, 128) # Decoder self.up2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.dec2 = ConvBlock(128 + 64, 64) self.up1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.dec1 = ConvBlock(64 + 32, 32) self.final = nn.Conv2d(32, 1, 1) self.sigmoid = nn.Sigmoid() def forward(self, x): e1 = self.enc1(x) e2 = self.enc2(self.pool(e1)) e3 = self.enc3(self.pool(e2)) d2 = self.up2(e3) d2 = torch.cat([d2, e2], dim=1) d2 = self.dec2(d2) d1 = self.up1(d2) d1 = torch.cat([d1, e1], dim=1) d1 = self.dec1(d1) return self.sigmoid(self.final(d1))