AI-Cyber's picture
Upload 123 files
8d7921b
raw
history blame
13.9 kB
import logging
from functools import partial
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from models import register
from .mmseg.models.sam import ImageEncoderViT, MaskDecoder, TwoWayTransformer
logger = logging.getLogger(__name__)
from .iou_loss import IOU
from typing import Any, Optional, Tuple
from .mmseg.models.sam import PromptEncoder
def init_weights(layer):
if type(layer) == nn.Conv2d:
nn.init.normal_(layer.weight, mean=0.0, std=0.02)
nn.init.constant_(layer.bias, 0.0)
elif type(layer) == nn.Linear:
nn.init.normal_(layer.weight, mean=0.0, std=0.02)
nn.init.constant_(layer.bias, 0.0)
elif type(layer) == nn.BatchNorm2d:
# print(layer)
nn.init.normal_(layer.weight, mean=1.0, std=0.02)
nn.init.constant_(layer.bias, 0.0)
class BBCEWithLogitLoss(nn.Module):
'''
Balanced BCEWithLogitLoss
'''
def __init__(self):
super(BBCEWithLogitLoss, self).__init__()
def forward(self, pred, gt):
eps = 1e-10
count_pos = torch.sum(gt) + eps
count_neg = torch.sum(1. - gt)
ratio = count_neg / count_pos
w_neg = count_pos / (count_pos + count_neg)
bce1 = nn.BCEWithLogitsLoss(pos_weight=ratio)
loss = w_neg * bce1(pred, gt)
return loss
def _iou_loss(pred, target):
print('*****&&&', pred.shape, target.shape)
pred = torch.sigmoid(pred)
inter = (pred * target).sum(dim=(2, 3))
union = (pred + target).sum(dim=(2, 3)) - inter
iou = 1 - (inter / union)
return iou.mean()
class PositionEmbeddingRandom(nn.Module):
"""
Positional encoding using random spatial frequencies.
"""
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
super().__init__()
if scale is None or scale <= 0.0:
scale = 1.0
self.register_buffer(
"positional_encoding_gaussian_matrix",
scale * torch.randn((2, num_pos_feats)),
)
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
"""Positionally encode points that are normalized to [0,1]."""
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
coords = 2 * coords - 1
coords = coords @ self.positional_encoding_gaussian_matrix
coords = 2 * np.pi * coords
# outputs d_1 x ... x d_n x C shape
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
def forward(self, size: int) -> torch.Tensor:
"""Generate positional encoding for a grid of the specified size."""
h, w = size, size
device: Any = self.positional_encoding_gaussian_matrix.device
grid = torch.ones((h, w), device=device, dtype=torch.float32)
y_embed = grid.cumsum(dim=0) - 0.5
x_embed = grid.cumsum(dim=1) - 0.5
y_embed = y_embed / h
x_embed = x_embed / w
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
return pe.permute(2, 0, 1) # C x H x W
@register('sam')
class SAM(nn.Module):
def __init__(self, inp_size=None, encoder_mode=None, loss=None):
super().__init__()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.embed_dim = encoder_mode['embed_dim']
self.image_encoder = ImageEncoderViT(
img_size=inp_size,
patch_size=encoder_mode['patch_size'],
in_chans=3,
embed_dim=encoder_mode['embed_dim'],
depth=encoder_mode['depth'],
num_heads=encoder_mode['num_heads'],
mlp_ratio=encoder_mode['mlp_ratio'],
out_chans=encoder_mode['out_chans'],
qkv_bias=encoder_mode['qkv_bias'],
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
use_rel_pos=encoder_mode['use_rel_pos'],
rel_pos_zero_init=True,
window_size=encoder_mode['window_size'],
global_attn_indexes=encoder_mode['global_attn_indexes'],
)
self.prompt_embed_dim = encoder_mode['prompt_embed_dim']#256
prompt_embed_dim = 256
image_embedding_size = inp_size / 16
self.prompt_encoder = PromptEncoder(
embed_dim=prompt_embed_dim,
image_embedding_size=(int(image_embedding_size), int(image_embedding_size)),
input_image_size=(inp_size, inp_size),
mask_in_chans=16,
)
self.mask_decoder = MaskDecoder(
# num_multimask_outputs=3,
# num_multimask_outputs=15,#iasid
# num_multimask_outputs=5,
# num_multimask_outputs=25,
num_multimask_outputs=14,
# num_multimask_outputs=26,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=self.prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=self.prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
)
self.mask_decoder_diwu = MaskDecoder(
# num_multimask_outputs=3,
# num_multimask_outputs=15,#iasid
# num_multimask_outputs=5,
# num_multimask_outputs=25,
# num_multimask_outputs=12,
num_multimask_outputs=12,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=self.prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=self.prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
)
if 'evp' in encoder_mode['name']:
for k, p in self.encoder.named_parameters():
if "prompt" not in k and "mask_decoder" not in k and "prompt_encoder" not in k:
p.requires_grad = False
self.loss_mode = loss
if self.loss_mode == 'bce':
self.criterionBCE = torch.nn.BCEWithLogitsLoss()
elif self.loss_mode == 'bbce':
self.criterionBCE = BBCEWithLogitLoss()
elif self.loss_mode == 'iou':
self.criterionBCE = torch.nn.BCEWithLogitsLoss()
self.criterionIOU = IOU()
elif self.loss_mode == 'cr':
# self.criterionCR = torch.nn.CrossEntropyLoss(ignore_index=255, reduction='mean')
self.criterionCR = torch.nn.CrossEntropyLoss(ignore_index=25, reduction='mean')
# 鑳屾櫙绫讳笉鍙備笌璁$畻loss
self.criterionIOU = IOU()
self.pe_layer = PositionEmbeddingRandom(encoder_mode['prompt_embed_dim'] // 2)
self.inp_size = inp_size
self.image_embedding_size = inp_size // encoder_mode['patch_size']#1024/16
self.no_mask_embed = nn.Embedding(1, encoder_mode['prompt_embed_dim'])#256
def set_input(self, input, gt_mask):
self.input = input.to(self.device)
self.gt_mask = gt_mask.to(self.device)
def get_dense_pe(self) -> torch.Tensor:
"""
Returns the positional encoding used to encode point prompts,
applied to a dense set of points the shape of the image encoding.
Returns:
torch.Tensor: Positional encoding with shape
1x(embed_dim)x(embedding_h)x(embedding_w)
"""
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
def forward(self):
bs = 1
# Embed prompts
sparse_embeddings = torch.empty((bs, 0, self.prompt_embed_dim), device=self.input.device)#绌簍ensor
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
bs, -1, self.image_embedding_size, self.image_embedding_size
)
#鎻愬彇 image embedding
# print('-----input-----',self.input.shape)
self.features = self.image_encoder(self.input) #鏈€鍚庝竴灞傝緭鍑? # print('-----image emded-----', self.features.shape)
# Predict masks
low_res_masks, iou_predictions = self.mask_decoder(
image_embeddings=self.features,
image_pe=self.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
# multimask_output=False,
multimask_output=True,
)#B*C+1*H*W
low_res_masks_2, iou_predictions_2 = self.mask_decoder_diwu(
image_embeddings=self.features,
image_pe=self.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
# multimask_output=False,
multimask_output=True,
)#B*C+1*H*W
# print('----before cat',low_res_masks.shape, low_res_masks_2.shape)
low_res_masks = torch.cat((low_res_masks, low_res_masks_2), 1)
# print('----behind cat',low_res_masks.shape)
# Upscale the masks to the original image resolution
masks = self.postprocess_masks(low_res_masks, self.inp_size, self.inp_size)
self.pred_mask = masks
def infer(self, input):
bs = 1
# Embed prompts
sparse_embeddings = torch.empty((bs, 0, self.prompt_embed_dim), device=input.device)
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
bs, -1, self.image_embedding_size, self.image_embedding_size
)
self.features = self.image_encoder(input)
# Predict masks
low_res_masks, iou_predictions = self.mask_decoder(
image_embeddings=self.features,
image_pe=self.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
# multimask_output=False,
multimask_output=True,
)#b*1*256*256
low_res_masks_2, iou_predictions_2 = self.mask_decoder_diwu(
image_embeddings=self.features,
image_pe=self.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
# multimask_output=False,
multimask_output=True,
) # B*C+1*H*W
# print('----before cat',low_res_masks.shape, low_res_masks_2.shape)
low_res_masks = torch.cat((low_res_masks, low_res_masks_2), 1)
# Upscale the masks to the original image resolution
#b*1*1024*1024
masks = self.postprocess_masks(low_res_masks, self.inp_size, self.inp_size)#涓婇噰鏍疯嚦鍘熷浘澶у皬
# masks = masks.sigmoid()
return masks
def postprocess_masks(
self,
masks: torch.Tensor,
input_size: Tuple[int, ...],
original_size: Tuple[int, ...],
) -> torch.Tensor:
"""
Remove padding and upscale masks to the original image size.
Arguments:
masks (torch.Tensor): Batched masks from the mask_decoder,
in BxCxHxW format.
input_size (tuple(int, int)): The size of the image input to the
model, in (H, W) format. Used to remove padding.
original_size (tuple(int, int)): The original size of the image
before resizing for input to the model, in (H, W) format.
Returns:
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
is given by original_size.
"""
masks = F.interpolate(
masks,
(self.image_encoder.img_size, self.image_encoder.img_size),
mode="bilinear",
align_corners=False,
)
masks = masks[..., : input_size, : input_size]
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
return masks
def backward_G(self):
"""Calculate GAN and L1 loss for the generator"""
# self.loss_G = self.criterionBCE(self.pred_mask, self.gt_mask)
# if self.loss_mode == 'iou':
# self.loss_G += _iou_loss(self.pred_mask, self.gt_mask)
# print('^&&&*###',self.pred_mask.shape, self.gt_mask.shape)
# print(torch.unique(self.gt_mask))
self.loss_G = self.criterionCR(self.pred_mask, self.gt_mask.squeeze(1).long())
# if self.loss_mode == 'cr':
# self.loss_G += _iou_loss(self.pred_mask, self.gt_mask)
# print('***selg gt masks',torch.unique(self.gt_mask))
# print('####', self.loss_G)
self.loss_G.backward()
def _backward_(self, pred_mask, gt_mask):
self.loss_G = self.criterionCR(pred_mask, gt_mask.squeeze(1).long())
self.loss_G.backward()
def optimize_parameters(self):
self.forward()
self.optimizer.zero_grad() # set G's gradients to zero
self.backward_G() # calculate graidents for G
self.optimizer.step() # udpate G's weights
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
"""Normalize pixel values and pad to a square input."""
# Normalize colors
x = (x - self.pixel_mean) / self.pixel_std
# Pad
h, w = x.shape[-2:]
padh = self.image_encoder.img_size - h
padw = self.image_encoder.img_size - w
x = F.pad(x, (0, padw, 0, padh))
return x
def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters:
nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad