import torch import torch.nn as nn import torch.nn.functional as F from xtuner.registry import BUILDER from mmengine.model import BaseModel from xtuner.model.utils import guess_load_checkpoint from .utils import compute_mask_IoU class FrozenLlava(BaseModel): def __init__(self, model, mask_head, merge='mean', loss_mask=None, loss_dice=None, pretrained=None, **kwargs): super().__init__() self.llava = BUILDER.build(model) self.llava.requires_grad_(False) in_channels = (self.llava.config.text_config.num_attention_heads * self.llava.config.text_config.num_hidden_layers) mask_head.update(in_channels=in_channels) self.mask_head = BUILDER.build(mask_head) self.patch_size = self.llava.config.vision_config.patch_size self.merge = merge assert merge in ['mean', 'max'] self.loss_mask = BUILDER.build(loss_mask) self.loss_dice = BUILDER.build(loss_dice) self.text_layer_weights = nn.Parameter( torch.ones(self.llava.config.text_config.num_hidden_layers)) if pretrained is not None: _ = self.load_state_dict( guess_load_checkpoint(pretrained), strict=False) def get_text_layer_weights(self): return torch.softmax(self.text_layer_weights, dim=0) def apply_merge(self, x, dim=1): if self.merge == 'mean': return x.mean(dim=dim) elif self.merge == 'max': return x.max(dim=dim).values else: raise NotImplementedError def init_weights(self): pass def train(self, mode=True): super().train(mode=mode) self.llava.train(mode=False) self.training = mode return self def forward(self, data, data_samples=None, mode='loss'): if mode == 'loss': return self.compute_loss(data) elif mode == 'predict': return self.predict(data) elif mode == 'tensor': return self._forward(data) else: raise NotImplementedError def _compute(self, pred_masks, gt_masks): mask_cnt = pred_masks.shape[0] loss_dice = self.loss_dice( pred_masks.view(mask_cnt, -1), gt_masks.view(mask_cnt, -1), avg_factor=mask_cnt) loss_mask = self.loss_mask( pred_masks.view(-1), gt_masks.view(-1), avg_factor=pred_masks.numel()) accuracy = torch.eq((pred_masks.detach().sigmoid() > 0.5).to(gt_masks), gt_masks).to(gt_masks).mean() aiou = compute_mask_IoU((pred_masks.detach().sigmoid() > 0.5).to(gt_masks).view(mask_cnt, -1), gt_masks.view(mask_cnt, -1)).mean() return loss_dice, loss_mask, accuracy, aiou class FrozenLlavaSAM(FrozenLlava): def __init__(self, sam, *args, **kwargs): pretrained = kwargs.pop('pretrained', None) super().__init__(*args, **kwargs) self.sam = BUILDER.build(sam) self.text_proj = nn.Linear(self.llava.config.text_config.hidden_size, self.sam.model.prompt_encoder.embed_dim) if pretrained is not None: _ = self.load_state_dict( guess_load_checkpoint(pretrained), strict=False) def _forward(self, data_sample): text_layer_weights = self.get_text_layer_weights() inputs = dict(input_ids=data_sample['input_ids'][None].to(self.llava.device), mask_ids=data_sample['mask_ids'][None].to( self.llava.device), pixel_values=data_sample['pixel_values'][None].to(device=self.llava.device, dtype=self.llava.dtype), labels=data_sample['labels'][None].to(self.llava.device) ) attention_mask = torch.ones(inputs['input_ids'].shape, device=self.llava.device, dtype=torch.bool) meta_data = data_sample['meta_data'] with torch.no_grad(): outputs = self.llava(**inputs, attention_mask=attention_mask, output_hidden_states=True, output_attentions=True) mask_ids = outputs['mask_ids'][0] attentions = [attn[0, ..., outputs['image_to_overwrite'][0]] for attn in outputs.attentions] hidden_states = outputs.hidden_states[-self.llava.config.text_config.num_hidden_layers:] labels = outputs.labels[0] # num_layers, seq_len, dim hidden_states = torch.stack([hs[0] for hs in hidden_states]) # seq_len, dim hidden_states = ( hidden_states * text_layer_weights.view(-1, 1, 1)).sum(0) del outputs padded_h, padded_w = meta_data['padded_shape']['height'], meta_data['padded_shape']['width'] llava_h, llava_w = padded_h // self.patch_size, padded_w // self.patch_size attentions = [attn.view(*attn.shape[:-1], llava_h, llava_w) for attn in attentions] masks = data_sample['masks'] mask_attentions = [] text_embeds = [] for mask_id in range(len(masks)): matched = mask_ids == mask_id assert matched.sum() > 0 mask_attentions.append(torch.cat( [self.apply_merge(attn[:, matched], dim=1) for attn in attentions])) text_embeds.append(self.text_proj(hidden_states[matched])) del attentions mask_attentions = torch.stack(mask_attentions).to(self.mask_head.dtype) # if self.training: # mask_attentions.requires_grad = True pred_masks = self.mask_head(mask_attentions)[:, 0] # todo: unpad pred_masks padded_mask_h, padded_mask_w = pred_masks.shape[-2:] before_height = int( meta_data['padding']['before_height'] * padded_mask_h / padded_h) before_width = int( meta_data['padding']['before_width'] * padded_mask_w / padded_w) mask_h = int(meta_data['image_shape']['height'] * padded_mask_h / padded_h + 0.5) mask_w = int(meta_data['image_shape']['width'] * padded_mask_w / padded_w + 0.5) pred_masks \ = pred_masks[:, before_height:before_height + mask_h, before_width:before_width + mask_w].contiguous() sam_pred_masks = self.sam( data_sample['image'], pred_masks, text_embeds) output = dict(pred_masks=pred_masks, sam_pred_masks=sam_pred_masks, labels=labels, mask_ids=mask_ids, hidden_states=hidden_states) return output @torch.no_grad() def predict(self, data_sample): return self._forward(data_sample)['sam_pred_masks'] def compute_loss(self, data): mask_cnts = 0 loss_dice = 0 loss_mask = 0 accuracy = 0 aiou = 0 sam_loss_dice = 0 sam_loss_mask = 0 sam_accuracy = 0 sam_aiou = 0 for data_sample in data: forward_output = self._forward(data_sample) pred_masks, sam_pred_masks = forward_output['pred_masks'], forward_output['sam_pred_masks'] masks = data_sample['masks'].to(self.llava.device) gt_masks = F.interpolate(masks[None].float(), size=pred_masks.shape[-2:])[0].to(pred_masks) sam_gt_masks = F.interpolate(masks[None].float(), size=sam_pred_masks.shape[-2:])[0].to(sam_pred_masks) mask_cnt = pred_masks.shape[0] assert pred_masks.shape == gt_masks.shape mask_cnts += mask_cnt loss_dice_, loss_mask_, accuracy_, aiou_ = self._compute( pred_masks, gt_masks) loss_dice += loss_dice_ * mask_cnt loss_mask += loss_mask_ * mask_cnt accuracy += accuracy_ * mask_cnt aiou += aiou_ * mask_cnt sam_loss_dice_, sam_loss_mask_, sam_accuracy_, sam_aiou_ = self._compute( sam_pred_masks, sam_gt_masks) sam_loss_dice += sam_loss_dice_ * mask_cnt sam_loss_mask += sam_loss_mask_ * mask_cnt sam_accuracy += sam_accuracy_ * mask_cnt sam_aiou += sam_aiou_ * mask_cnt assert mask_cnts > 0 loss_dict = {'loss_mask': loss_mask / mask_cnts, 'loss_dice': loss_dice / mask_cnts, 'accuracy': accuracy / mask_cnts, 'aiou': aiou / mask_cnts, 'sam_loss_mask': sam_loss_mask / mask_cnts, 'sam_loss_dice': sam_loss_dice / mask_cnts, 'sam_accuracy': sam_accuracy / mask_cnts, 'sam_aiou': sam_aiou / mask_cnts, } return loss_dict