Spaces:
Build error
Build error
| import torch | |
| import numpy.typing as npt | |
| import torch.nn.functional as F | |
| from app.configs import DEVICE | |
| from app.mobile_sam import SamPredictor | |
| from .model import point_selection, MaskWeights | |
| from .loss import calculate_dice_loss, calculate_sigmoid_focal_loss | |
| def train( | |
| predictor: SamPredictor, | |
| ref_images: list[npt.NDArray], | |
| ref_masks: list[npt.NDArray], | |
| lr: float = 1e-3, | |
| epochs: int = 200, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| gt_masks = [] | |
| points = [] | |
| target_feats = [] | |
| for ref_image, ref_mask in zip(ref_images, ref_masks): | |
| gt_mask = torch.from_numpy(ref_mask)[:, :] > 0 | |
| gt_mask = gt_mask.float().unsqueeze(0).flatten(1).to(DEVICE) | |
| gt_masks.append(gt_mask) | |
| # Image features encoding | |
| predictor.set_image(ref_image) | |
| ref_mask = predictor.get_mask(ref_mask[:, :, None]) | |
| ref_feat = predictor.features.squeeze().permute(1, 2, 0) | |
| ref_mask = F.interpolate(ref_mask, size=ref_feat.shape[0:2], mode="bilinear") | |
| ref_mask = ref_mask.squeeze() | |
| # Target feature extraction | |
| target_feat = ref_feat[ref_mask > 0] | |
| target_feat_mean = target_feat.mean(0) | |
| target_feat_max = torch.max(target_feat, dim=0)[0] | |
| target_feat = (target_feat_max / 2 + target_feat_mean / 2).unsqueeze(0) | |
| # Cosine similarity | |
| h, w, C = ref_feat.shape | |
| target_feat = target_feat / target_feat.norm(dim=-1, keepdim=True) | |
| target_feats.append(target_feat) | |
| ref_feat = ref_feat / ref_feat.norm(dim=-1, keepdim=True) | |
| ref_feat = ref_feat.permute(2, 0, 1).reshape(C, h * w) | |
| sim = target_feat @ ref_feat | |
| sim = sim.reshape(1, 1, h, w) | |
| sim = F.interpolate(sim, scale_factor=4, mode="bilinear") | |
| sim = predictor.model.postprocess_masks( | |
| sim, input_size=predictor.input_size, original_size=predictor.original_size | |
| ).squeeze() | |
| # Positive location prior | |
| topk_xy, topk_label = point_selection(sim, topk=1) | |
| points.append((topk_xy, topk_label)) | |
| target_feat = torch.concat(target_feats, axis=0).mean(axis=0) | |
| # Learnable mask weights | |
| mask_weights = MaskWeights().to(DEVICE) | |
| mask_weights.train() | |
| optimizer = torch.optim.AdamW(mask_weights.parameters(), lr=lr, eps=1e-4) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs) | |
| for _ in range(epochs): | |
| for i in range(len(gt_masks)): | |
| gt_mask = gt_masks[i] | |
| topk_xy, topk_label = points[i] | |
| # Run the decoder | |
| ( | |
| logits_high, | |
| _, | |
| _, | |
| ) = predictor.predict( | |
| point_coords=topk_xy, | |
| point_labels=topk_label, | |
| multimask_output=True, | |
| return_logits=True, | |
| return_numpy=False, | |
| ) | |
| logits_high = logits_high.flatten(1) | |
| # Weighted sum three-scale masks | |
| weights = torch.cat( | |
| (1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), | |
| dim=0, | |
| ) | |
| logits_high = logits_high * weights | |
| logits_high = logits_high.sum(0).unsqueeze(0) | |
| dice_loss = calculate_dice_loss(logits_high, gt_mask) | |
| focal_loss = calculate_sigmoid_focal_loss(logits_high, gt_mask, alpha=1.0) | |
| loss = dice_loss + focal_loss | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| scheduler.step() | |
| # current_lr = scheduler.get_last_lr()[0] | |
| mask_weights.eval() | |
| weights = torch.cat( | |
| (1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0 | |
| ) | |
| return weights, target_feat | |