|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from xtuner.registry import BUILDER |
|
|
from mmengine.model import BaseModel |
|
|
from xtuner.model.utils import guess_load_checkpoint |
|
|
from .utils import compute_mask_IoU |
|
|
|
|
|
|
|
|
class FrozenLlava(BaseModel): |
|
|
|
|
|
def __init__(self, |
|
|
model, |
|
|
mask_head, |
|
|
merge='mean', |
|
|
loss_mask=None, |
|
|
loss_dice=None, |
|
|
pretrained=None, |
|
|
**kwargs): |
|
|
super().__init__() |
|
|
self.llava = BUILDER.build(model) |
|
|
self.llava.requires_grad_(False) |
|
|
in_channels = (self.llava.config.text_config.num_attention_heads * |
|
|
self.llava.config.text_config.num_hidden_layers) |
|
|
mask_head.update(in_channels=in_channels) |
|
|
self.mask_head = BUILDER.build(mask_head) |
|
|
self.patch_size = self.llava.config.vision_config.patch_size |
|
|
self.merge = merge |
|
|
assert merge in ['mean', 'max'] |
|
|
|
|
|
self.loss_mask = BUILDER.build(loss_mask) |
|
|
self.loss_dice = BUILDER.build(loss_dice) |
|
|
|
|
|
self.text_layer_weights = nn.Parameter( |
|
|
torch.ones(self.llava.config.text_config.num_hidden_layers)) |
|
|
|
|
|
if pretrained is not None: |
|
|
_ = self.load_state_dict( |
|
|
guess_load_checkpoint(pretrained), strict=False) |
|
|
|
|
|
def get_text_layer_weights(self): |
|
|
return torch.softmax(self.text_layer_weights, dim=0) |
|
|
|
|
|
def apply_merge(self, x, dim=1): |
|
|
if self.merge == 'mean': |
|
|
return x.mean(dim=dim) |
|
|
elif self.merge == 'max': |
|
|
return x.max(dim=dim).values |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
def init_weights(self): |
|
|
pass |
|
|
|
|
|
def train(self, mode=True): |
|
|
super().train(mode=mode) |
|
|
self.llava.train(mode=False) |
|
|
self.training = mode |
|
|
return self |
|
|
|
|
|
def forward(self, data, data_samples=None, mode='loss'): |
|
|
if mode == 'loss': |
|
|
return self.compute_loss(data) |
|
|
elif mode == 'predict': |
|
|
return self.predict(data) |
|
|
elif mode == 'tensor': |
|
|
return self._forward(data) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
def _compute(self, pred_masks, gt_masks): |
|
|
mask_cnt = pred_masks.shape[0] |
|
|
loss_dice = self.loss_dice( |
|
|
pred_masks.view(mask_cnt, -1), gt_masks.view(mask_cnt, -1), |
|
|
avg_factor=mask_cnt) |
|
|
loss_mask = self.loss_mask( |
|
|
pred_masks.view(-1), |
|
|
gt_masks.view(-1), |
|
|
avg_factor=pred_masks.numel()) |
|
|
accuracy = torch.eq((pred_masks.detach().sigmoid() > 0.5).to(gt_masks), |
|
|
gt_masks).to(gt_masks).mean() |
|
|
aiou = compute_mask_IoU((pred_masks.detach().sigmoid() > 0.5).to(gt_masks).view(mask_cnt, -1), |
|
|
gt_masks.view(mask_cnt, -1)).mean() |
|
|
|
|
|
return loss_dice, loss_mask, accuracy, aiou |
|
|
|
|
|
|
|
|
class FrozenLlavaSAM(FrozenLlava): |
|
|
def __init__(self, sam, *args, **kwargs): |
|
|
pretrained = kwargs.pop('pretrained', None) |
|
|
super().__init__(*args, **kwargs) |
|
|
self.sam = BUILDER.build(sam) |
|
|
self.text_proj = nn.Linear(self.llava.config.text_config.hidden_size, |
|
|
self.sam.model.prompt_encoder.embed_dim) |
|
|
|
|
|
if pretrained is not None: |
|
|
_ = self.load_state_dict( |
|
|
guess_load_checkpoint(pretrained), strict=False) |
|
|
|
|
|
def _forward(self, data_sample): |
|
|
text_layer_weights = self.get_text_layer_weights() |
|
|
inputs = dict(input_ids=data_sample['input_ids'][None].to(self.llava.device), |
|
|
mask_ids=data_sample['mask_ids'][None].to( |
|
|
self.llava.device), |
|
|
pixel_values=data_sample['pixel_values'][None].to(device=self.llava.device, |
|
|
dtype=self.llava.dtype), |
|
|
labels=data_sample['labels'][None].to(self.llava.device) |
|
|
) |
|
|
attention_mask = torch.ones(inputs['input_ids'].shape, device=self.llava.device, |
|
|
dtype=torch.bool) |
|
|
meta_data = data_sample['meta_data'] |
|
|
with torch.no_grad(): |
|
|
outputs = self.llava(**inputs, |
|
|
attention_mask=attention_mask, |
|
|
output_hidden_states=True, |
|
|
output_attentions=True) |
|
|
mask_ids = outputs['mask_ids'][0] |
|
|
attentions = [attn[0, ..., outputs['image_to_overwrite'][0]] |
|
|
for attn in outputs.attentions] |
|
|
hidden_states = outputs.hidden_states[-self.llava.config.text_config.num_hidden_layers:] |
|
|
|
|
|
labels = outputs.labels[0] |
|
|
|
|
|
|
|
|
hidden_states = torch.stack([hs[0] for hs in hidden_states]) |
|
|
|
|
|
hidden_states = ( |
|
|
hidden_states * text_layer_weights.view(-1, 1, 1)).sum(0) |
|
|
|
|
|
del outputs |
|
|
|
|
|
padded_h, padded_w = meta_data['padded_shape']['height'], meta_data['padded_shape']['width'] |
|
|
llava_h, llava_w = padded_h // self.patch_size, padded_w // self.patch_size |
|
|
|
|
|
attentions = [attn.view(*attn.shape[:-1], llava_h, llava_w) |
|
|
for attn in attentions] |
|
|
masks = data_sample['masks'] |
|
|
mask_attentions = [] |
|
|
text_embeds = [] |
|
|
for mask_id in range(len(masks)): |
|
|
matched = mask_ids == mask_id |
|
|
assert matched.sum() > 0 |
|
|
mask_attentions.append(torch.cat( |
|
|
[self.apply_merge(attn[:, matched], dim=1) for attn in attentions])) |
|
|
text_embeds.append(self.text_proj(hidden_states[matched])) |
|
|
|
|
|
del attentions |
|
|
mask_attentions = torch.stack(mask_attentions).to(self.mask_head.dtype) |
|
|
|
|
|
|
|
|
pred_masks = self.mask_head(mask_attentions)[:, 0] |
|
|
|
|
|
padded_mask_h, padded_mask_w = pred_masks.shape[-2:] |
|
|
|
|
|
before_height = int( |
|
|
meta_data['padding']['before_height'] * padded_mask_h / padded_h) |
|
|
before_width = int( |
|
|
meta_data['padding']['before_width'] * padded_mask_w / padded_w) |
|
|
|
|
|
mask_h = int(meta_data['image_shape']['height'] |
|
|
* padded_mask_h / padded_h + 0.5) |
|
|
mask_w = int(meta_data['image_shape']['width'] |
|
|
* padded_mask_w / padded_w + 0.5) |
|
|
pred_masks \ |
|
|
= pred_masks[:, before_height:before_height + mask_h, before_width:before_width + mask_w].contiguous() |
|
|
sam_pred_masks = self.sam( |
|
|
data_sample['image'], pred_masks, text_embeds) |
|
|
|
|
|
output = dict(pred_masks=pred_masks, sam_pred_masks=sam_pred_masks, |
|
|
labels=labels, mask_ids=mask_ids, hidden_states=hidden_states) |
|
|
|
|
|
return output |
|
|
|
|
|
@torch.no_grad() |
|
|
def predict(self, data_sample): |
|
|
return self._forward(data_sample)['sam_pred_masks'] |
|
|
|
|
|
def compute_loss(self, data): |
|
|
mask_cnts = 0 |
|
|
|
|
|
loss_dice = 0 |
|
|
loss_mask = 0 |
|
|
accuracy = 0 |
|
|
aiou = 0 |
|
|
|
|
|
sam_loss_dice = 0 |
|
|
sam_loss_mask = 0 |
|
|
sam_accuracy = 0 |
|
|
sam_aiou = 0 |
|
|
|
|
|
for data_sample in data: |
|
|
forward_output = self._forward(data_sample) |
|
|
pred_masks, sam_pred_masks = forward_output['pred_masks'], forward_output['sam_pred_masks'] |
|
|
masks = data_sample['masks'].to(self.llava.device) |
|
|
gt_masks = F.interpolate(masks[None].float(), |
|
|
size=pred_masks.shape[-2:])[0].to(pred_masks) |
|
|
sam_gt_masks = F.interpolate(masks[None].float(), |
|
|
size=sam_pred_masks.shape[-2:])[0].to(sam_pred_masks) |
|
|
|
|
|
mask_cnt = pred_masks.shape[0] |
|
|
assert pred_masks.shape == gt_masks.shape |
|
|
mask_cnts += mask_cnt |
|
|
|
|
|
loss_dice_, loss_mask_, accuracy_, aiou_ = self._compute( |
|
|
pred_masks, gt_masks) |
|
|
loss_dice += loss_dice_ * mask_cnt |
|
|
loss_mask += loss_mask_ * mask_cnt |
|
|
accuracy += accuracy_ * mask_cnt |
|
|
aiou += aiou_ * mask_cnt |
|
|
|
|
|
sam_loss_dice_, sam_loss_mask_, sam_accuracy_, sam_aiou_ = self._compute( |
|
|
sam_pred_masks, sam_gt_masks) |
|
|
sam_loss_dice += sam_loss_dice_ * mask_cnt |
|
|
sam_loss_mask += sam_loss_mask_ * mask_cnt |
|
|
sam_accuracy += sam_accuracy_ * mask_cnt |
|
|
sam_aiou += sam_aiou_ * mask_cnt |
|
|
|
|
|
assert mask_cnts > 0 |
|
|
|
|
|
loss_dict = {'loss_mask': loss_mask / mask_cnts, |
|
|
'loss_dice': loss_dice / mask_cnts, |
|
|
'accuracy': accuracy / mask_cnts, |
|
|
'aiou': aiou / mask_cnts, |
|
|
'sam_loss_mask': sam_loss_mask / mask_cnts, |
|
|
'sam_loss_dice': sam_loss_dice / mask_cnts, |
|
|
'sam_accuracy': sam_accuracy / mask_cnts, |
|
|
'sam_aiou': sam_aiou / mask_cnts, |
|
|
} |
|
|
|
|
|
return loss_dict |
|
|
|