Spaces:
Runtime error
Runtime error
| import logging | |
| from functools import partial | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from models import register | |
| from .mmseg.models.sam import ImageEncoderViT, MaskDecoder, TwoWayTransformer | |
| logger = logging.getLogger(__name__) | |
| from .iou_loss import IOU | |
| from typing import Any, Optional, Tuple | |
| from .mmseg.models.sam import PromptEncoder | |
| def init_weights(layer): | |
| if type(layer) == nn.Conv2d: | |
| nn.init.normal_(layer.weight, mean=0.0, std=0.02) | |
| nn.init.constant_(layer.bias, 0.0) | |
| elif type(layer) == nn.Linear: | |
| nn.init.normal_(layer.weight, mean=0.0, std=0.02) | |
| nn.init.constant_(layer.bias, 0.0) | |
| elif type(layer) == nn.BatchNorm2d: | |
| # print(layer) | |
| nn.init.normal_(layer.weight, mean=1.0, std=0.02) | |
| nn.init.constant_(layer.bias, 0.0) | |
| class BBCEWithLogitLoss(nn.Module): | |
| ''' | |
| Balanced BCEWithLogitLoss | |
| ''' | |
| def __init__(self): | |
| super(BBCEWithLogitLoss, self).__init__() | |
| def forward(self, pred, gt): | |
| eps = 1e-10 | |
| count_pos = torch.sum(gt) + eps | |
| count_neg = torch.sum(1. - gt) | |
| ratio = count_neg / count_pos | |
| w_neg = count_pos / (count_pos + count_neg) | |
| bce1 = nn.BCEWithLogitsLoss(pos_weight=ratio) | |
| loss = w_neg * bce1(pred, gt) | |
| return loss | |
| def _iou_loss(pred, target): | |
| print('*****&&&', pred.shape, target.shape) | |
| pred = torch.sigmoid(pred) | |
| inter = (pred * target).sum(dim=(2, 3)) | |
| union = (pred + target).sum(dim=(2, 3)) - inter | |
| iou = 1 - (inter / union) | |
| return iou.mean() | |
| class PositionEmbeddingRandom(nn.Module): | |
| """ | |
| Positional encoding using random spatial frequencies. | |
| """ | |
| def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: | |
| super().__init__() | |
| if scale is None or scale <= 0.0: | |
| scale = 1.0 | |
| self.register_buffer( | |
| "positional_encoding_gaussian_matrix", | |
| scale * torch.randn((2, num_pos_feats)), | |
| ) | |
| def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: | |
| """Positionally encode points that are normalized to [0,1].""" | |
| # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape | |
| coords = 2 * coords - 1 | |
| coords = coords @ self.positional_encoding_gaussian_matrix | |
| coords = 2 * np.pi * coords | |
| # outputs d_1 x ... x d_n x C shape | |
| return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) | |
| def forward(self, size: int) -> torch.Tensor: | |
| """Generate positional encoding for a grid of the specified size.""" | |
| h, w = size, size | |
| device: Any = self.positional_encoding_gaussian_matrix.device | |
| grid = torch.ones((h, w), device=device, dtype=torch.float32) | |
| y_embed = grid.cumsum(dim=0) - 0.5 | |
| x_embed = grid.cumsum(dim=1) - 0.5 | |
| y_embed = y_embed / h | |
| x_embed = x_embed / w | |
| pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) | |
| return pe.permute(2, 0, 1) # C x H x W | |
| class SAM(nn.Module): | |
| def __init__(self, inp_size=None, encoder_mode=None, loss=None): | |
| super().__init__() | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.embed_dim = encoder_mode['embed_dim'] | |
| self.image_encoder = ImageEncoderViT( | |
| img_size=inp_size, | |
| patch_size=encoder_mode['patch_size'], | |
| in_chans=3, | |
| embed_dim=encoder_mode['embed_dim'], | |
| depth=encoder_mode['depth'], | |
| num_heads=encoder_mode['num_heads'], | |
| mlp_ratio=encoder_mode['mlp_ratio'], | |
| out_chans=encoder_mode['out_chans'], | |
| qkv_bias=encoder_mode['qkv_bias'], | |
| norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), | |
| act_layer=nn.GELU, | |
| use_rel_pos=encoder_mode['use_rel_pos'], | |
| rel_pos_zero_init=True, | |
| window_size=encoder_mode['window_size'], | |
| global_attn_indexes=encoder_mode['global_attn_indexes'], | |
| ) | |
| self.prompt_embed_dim = encoder_mode['prompt_embed_dim']#256 | |
| prompt_embed_dim = 256 | |
| image_embedding_size = inp_size / 16 | |
| self.prompt_encoder = PromptEncoder( | |
| embed_dim=prompt_embed_dim, | |
| image_embedding_size=(int(image_embedding_size), int(image_embedding_size)), | |
| input_image_size=(inp_size, inp_size), | |
| mask_in_chans=16, | |
| ) | |
| self.mask_decoder = MaskDecoder( | |
| # num_multimask_outputs=3, | |
| # num_multimask_outputs=15,#iasid | |
| # num_multimask_outputs=5, | |
| # num_multimask_outputs=25, | |
| num_multimask_outputs=14, | |
| # num_multimask_outputs=26, | |
| transformer=TwoWayTransformer( | |
| depth=2, | |
| embedding_dim=self.prompt_embed_dim, | |
| mlp_dim=2048, | |
| num_heads=8, | |
| ), | |
| transformer_dim=self.prompt_embed_dim, | |
| iou_head_depth=3, | |
| iou_head_hidden_dim=256, | |
| ) | |
| self.mask_decoder_diwu = MaskDecoder( | |
| # num_multimask_outputs=3, | |
| # num_multimask_outputs=15,#iasid | |
| # num_multimask_outputs=5, | |
| # num_multimask_outputs=25, | |
| # num_multimask_outputs=12, | |
| num_multimask_outputs=12, | |
| transformer=TwoWayTransformer( | |
| depth=2, | |
| embedding_dim=self.prompt_embed_dim, | |
| mlp_dim=2048, | |
| num_heads=8, | |
| ), | |
| transformer_dim=self.prompt_embed_dim, | |
| iou_head_depth=3, | |
| iou_head_hidden_dim=256, | |
| ) | |
| if 'evp' in encoder_mode['name']: | |
| for k, p in self.encoder.named_parameters(): | |
| if "prompt" not in k and "mask_decoder" not in k and "prompt_encoder" not in k: | |
| p.requires_grad = False | |
| self.loss_mode = loss | |
| if self.loss_mode == 'bce': | |
| self.criterionBCE = torch.nn.BCEWithLogitsLoss() | |
| elif self.loss_mode == 'bbce': | |
| self.criterionBCE = BBCEWithLogitLoss() | |
| elif self.loss_mode == 'iou': | |
| self.criterionBCE = torch.nn.BCEWithLogitsLoss() | |
| self.criterionIOU = IOU() | |
| elif self.loss_mode == 'cr': | |
| # self.criterionCR = torch.nn.CrossEntropyLoss(ignore_index=255, reduction='mean') | |
| self.criterionCR = torch.nn.CrossEntropyLoss(ignore_index=25, reduction='mean') | |
| # 鑳屾櫙绫讳笉鍙備笌璁$畻loss | |
| self.criterionIOU = IOU() | |
| self.pe_layer = PositionEmbeddingRandom(encoder_mode['prompt_embed_dim'] // 2) | |
| self.inp_size = inp_size | |
| self.image_embedding_size = inp_size // encoder_mode['patch_size']#1024/16 | |
| self.no_mask_embed = nn.Embedding(1, encoder_mode['prompt_embed_dim'])#256 | |
| def set_input(self, input, gt_mask): | |
| self.input = input.to(self.device) | |
| self.gt_mask = gt_mask.to(self.device) | |
| def get_dense_pe(self) -> torch.Tensor: | |
| """ | |
| Returns the positional encoding used to encode point prompts, | |
| applied to a dense set of points the shape of the image encoding. | |
| Returns: | |
| torch.Tensor: Positional encoding with shape | |
| 1x(embed_dim)x(embedding_h)x(embedding_w) | |
| """ | |
| return self.pe_layer(self.image_embedding_size).unsqueeze(0) | |
| def forward(self): | |
| bs = 1 | |
| # Embed prompts | |
| sparse_embeddings = torch.empty((bs, 0, self.prompt_embed_dim), device=self.input.device)#绌簍ensor | |
| dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( | |
| bs, -1, self.image_embedding_size, self.image_embedding_size | |
| ) | |
| #鎻愬彇 image embedding | |
| # print('-----input-----',self.input.shape) | |
| self.features = self.image_encoder(self.input) #鏈€鍚庝竴灞傝緭鍑? # print('-----image emded-----', self.features.shape) | |
| # Predict masks | |
| low_res_masks, iou_predictions = self.mask_decoder( | |
| image_embeddings=self.features, | |
| image_pe=self.get_dense_pe(), | |
| sparse_prompt_embeddings=sparse_embeddings, | |
| dense_prompt_embeddings=dense_embeddings, | |
| # multimask_output=False, | |
| multimask_output=True, | |
| )#B*C+1*H*W | |
| low_res_masks_2, iou_predictions_2 = self.mask_decoder_diwu( | |
| image_embeddings=self.features, | |
| image_pe=self.get_dense_pe(), | |
| sparse_prompt_embeddings=sparse_embeddings, | |
| dense_prompt_embeddings=dense_embeddings, | |
| # multimask_output=False, | |
| multimask_output=True, | |
| )#B*C+1*H*W | |
| # print('----before cat',low_res_masks.shape, low_res_masks_2.shape) | |
| low_res_masks = torch.cat((low_res_masks, low_res_masks_2), 1) | |
| # print('----behind cat',low_res_masks.shape) | |
| # Upscale the masks to the original image resolution | |
| masks = self.postprocess_masks(low_res_masks, self.inp_size, self.inp_size) | |
| self.pred_mask = masks | |
| def infer(self, input): | |
| bs = 1 | |
| # Embed prompts | |
| sparse_embeddings = torch.empty((bs, 0, self.prompt_embed_dim), device=input.device) | |
| dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( | |
| bs, -1, self.image_embedding_size, self.image_embedding_size | |
| ) | |
| self.features = self.image_encoder(input) | |
| # Predict masks | |
| low_res_masks, iou_predictions = self.mask_decoder( | |
| image_embeddings=self.features, | |
| image_pe=self.get_dense_pe(), | |
| sparse_prompt_embeddings=sparse_embeddings, | |
| dense_prompt_embeddings=dense_embeddings, | |
| # multimask_output=False, | |
| multimask_output=True, | |
| )#b*1*256*256 | |
| low_res_masks_2, iou_predictions_2 = self.mask_decoder_diwu( | |
| image_embeddings=self.features, | |
| image_pe=self.get_dense_pe(), | |
| sparse_prompt_embeddings=sparse_embeddings, | |
| dense_prompt_embeddings=dense_embeddings, | |
| # multimask_output=False, | |
| multimask_output=True, | |
| ) # B*C+1*H*W | |
| # print('----before cat',low_res_masks.shape, low_res_masks_2.shape) | |
| low_res_masks = torch.cat((low_res_masks, low_res_masks_2), 1) | |
| # Upscale the masks to the original image resolution | |
| #b*1*1024*1024 | |
| masks = self.postprocess_masks(low_res_masks, self.inp_size, self.inp_size)#涓婇噰鏍疯嚦鍘熷浘澶у皬 | |
| # masks = masks.sigmoid() | |
| return masks | |
| def postprocess_masks( | |
| self, | |
| masks: torch.Tensor, | |
| input_size: Tuple[int, ...], | |
| original_size: Tuple[int, ...], | |
| ) -> torch.Tensor: | |
| """ | |
| Remove padding and upscale masks to the original image size. | |
| Arguments: | |
| masks (torch.Tensor): Batched masks from the mask_decoder, | |
| in BxCxHxW format. | |
| input_size (tuple(int, int)): The size of the image input to the | |
| model, in (H, W) format. Used to remove padding. | |
| original_size (tuple(int, int)): The original size of the image | |
| before resizing for input to the model, in (H, W) format. | |
| Returns: | |
| (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) | |
| is given by original_size. | |
| """ | |
| masks = F.interpolate( | |
| masks, | |
| (self.image_encoder.img_size, self.image_encoder.img_size), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| masks = masks[..., : input_size, : input_size] | |
| masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) | |
| return masks | |
| def backward_G(self): | |
| """Calculate GAN and L1 loss for the generator""" | |
| # self.loss_G = self.criterionBCE(self.pred_mask, self.gt_mask) | |
| # if self.loss_mode == 'iou': | |
| # self.loss_G += _iou_loss(self.pred_mask, self.gt_mask) | |
| # print('^&&&*###',self.pred_mask.shape, self.gt_mask.shape) | |
| # print(torch.unique(self.gt_mask)) | |
| self.loss_G = self.criterionCR(self.pred_mask, self.gt_mask.squeeze(1).long()) | |
| # if self.loss_mode == 'cr': | |
| # self.loss_G += _iou_loss(self.pred_mask, self.gt_mask) | |
| # print('***selg gt masks',torch.unique(self.gt_mask)) | |
| # print('####', self.loss_G) | |
| self.loss_G.backward() | |
| def _backward_(self, pred_mask, gt_mask): | |
| self.loss_G = self.criterionCR(pred_mask, gt_mask.squeeze(1).long()) | |
| self.loss_G.backward() | |
| def optimize_parameters(self): | |
| self.forward() | |
| self.optimizer.zero_grad() # set G's gradients to zero | |
| self.backward_G() # calculate graidents for G | |
| self.optimizer.step() # udpate G's weights | |
| def preprocess(self, x: torch.Tensor) -> torch.Tensor: | |
| """Normalize pixel values and pad to a square input.""" | |
| # Normalize colors | |
| x = (x - self.pixel_mean) / self.pixel_std | |
| # Pad | |
| h, w = x.shape[-2:] | |
| padh = self.image_encoder.img_size - h | |
| padw = self.image_encoder.img_size - w | |
| x = F.pad(x, (0, padw, 0, padh)) | |
| return x | |
| def set_requires_grad(self, nets, requires_grad=False): | |
| """Set requies_grad=Fasle for all the networks to avoid unnecessary computations | |
| Parameters: | |
| nets (network list) -- a list of networks | |
| requires_grad (bool) -- whether the networks require gradients or not | |
| """ | |
| if not isinstance(nets, list): | |
| nets = [nets] | |
| for net in nets: | |
| if net is not None: | |
| for param in net.parameters(): | |
| param.requires_grad = requires_grad | |