Upload 13 files
Browse files- .gitattributes +1 -0
- LICENSE +21 -0
- README.md +47 -3
- docs/MCP-MedSAM.png +3 -0
- infer.py +738 -0
- modality_npz_dataset.py +317 -0
- models/__init__.py +4 -0
- models/common.py +44 -0
- models/lite_medsam.py +54 -0
- models/mask_decoder.py +465 -0
- models/prompt_encoder.py +306 -0
- models/tiny_vit.py +645 -0
- models/transformer.py +243 -0
- train.py +502 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
docs/MCP-MedSAM.png filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 Leo-Lyu
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,3 +1,47 @@
|
|
| 1 |
-
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MCP-MedSAM
|
| 2 |
+
|
| 3 |
+
Pytorch Implementation of the paper:
|
| 4 |
+
"[MCP-MedSAM: A Powerful Lightweight Medical Segment Anything Model Trained with a Single GPU in Just One Day](https://arxiv.org/abs/2412.05888)"
|
| 5 |
+
|
| 6 |
+

|
| 7 |
+
|
| 8 |
+
## 📄 Overview
|
| 9 |
+
|
| 10 |
+
This work proposes a lightweight variant of MedSAM by integrating:
|
| 11 |
+
|
| 12 |
+
- A **pre-trained Tiny ViT** as the vision backbone
|
| 13 |
+
- Two novel prompt types:
|
| 14 |
+
- **Modality Prompt**
|
| 15 |
+
- **Content Prompt**
|
| 16 |
+
- A **modified mask decoder** adapted to these prompts
|
| 17 |
+
|
| 18 |
+
To further improve performance across imaging modalities, we introduce a **modality-aware data sampling strategy** that ensures better balance and generalization.
|
| 19 |
+
|
| 20 |
+
With these enhancements, our model achieves strong multi-modality segmentation performance, and can be trained in approximately **1 day on a single A100 (40GB)** GPU.
|
| 21 |
+
|
| 22 |
+
<!--
|
| 23 |
+
We are currently releasing the inference code along with the model weight. You can download from [here](https://drive.google.com/drive/folders/1NW4aSNhk-dtiK-dicTAUp0g0eR2fryNi?usp=sharing).
|
| 24 |
+
|
| 25 |
+
The training code has been released and you can train your . -->
|
| 26 |
+
|
| 27 |
+
## Requirements
|
| 28 |
+
|
| 29 |
+
* Python==3.10.14
|
| 30 |
+
* torch==2.0.0
|
| 31 |
+
* torchvision==0.15.0
|
| 32 |
+
* transformers==4.49.0
|
| 33 |
+
|
| 34 |
+
## Training and Inference
|
| 35 |
+
|
| 36 |
+
Training and inference can be done by running train.py and infer.py. Additionally, we also release the model weight for inference, which can be downloaded from [here](https://drive.google.com/drive/folders/1NW4aSNhk-dtiK-dicTAUp0g0eR2fryNi?usp=sharing).
|
| 37 |
+
|
| 38 |
+
## Citation
|
| 39 |
+
|
| 40 |
+
```bash
|
| 41 |
+
@article{lyu2024mcp,
|
| 42 |
+
title={MCP-MedSAM: A Powerful Lightweight Medical Segment Anything Model Trained with a Single GPU in Just One Day},
|
| 43 |
+
author={Lyu, Donghang and Gao, Ruochen and Staring, Marius},
|
| 44 |
+
journal={arXiv preprint arXiv:2412.05888},
|
| 45 |
+
year={2024}
|
| 46 |
+
}
|
| 47 |
+
```
|
docs/MCP-MedSAM.png
ADDED
|
Git LFS Details
|
infer.py
ADDED
|
@@ -0,0 +1,738 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from os import makedirs
|
| 2 |
+
from os.path import join, basename
|
| 3 |
+
from glob import glob
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from time import time
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torchvision import transforms
|
| 11 |
+
from models import PromptEncoder, TwoWayTransformer, TinyViT, MaskDecoder_F4
|
| 12 |
+
from matplotlib import pyplot as plt
|
| 13 |
+
import cv2
|
| 14 |
+
import argparse
|
| 15 |
+
from collections import OrderedDict
|
| 16 |
+
import pandas as pd
|
| 17 |
+
from datetime import datetime
|
| 18 |
+
from transformers import CLIPModel, CLIPTokenizer
|
| 19 |
+
|
| 20 |
+
torch.set_float32_matmul_precision('high')
|
| 21 |
+
torch.manual_seed(42)
|
| 22 |
+
torch.cuda.manual_seed(42)
|
| 23 |
+
np.random.seed(42)
|
| 24 |
+
|
| 25 |
+
parser = argparse.ArgumentParser()
|
| 26 |
+
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
'-i',
|
| 29 |
+
'--input_dir',
|
| 30 |
+
type=str,
|
| 31 |
+
default='',
|
| 32 |
+
# required=True,
|
| 33 |
+
help='root directory of the data',
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
'-o',
|
| 37 |
+
'--output_dir',
|
| 38 |
+
type=str,
|
| 39 |
+
default='',
|
| 40 |
+
help='directory to save the prediction',
|
| 41 |
+
)
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
'-lite_medsam_checkpoint_path',
|
| 44 |
+
type=str,
|
| 45 |
+
default="",
|
| 46 |
+
help='path to the checkpoint of MedSAM-Lite',
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
'-device',
|
| 50 |
+
type=str,
|
| 51 |
+
default="cuda:0",
|
| 52 |
+
help='device to run the inference',
|
| 53 |
+
)
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
'-num_workers',
|
| 56 |
+
type=int,
|
| 57 |
+
default=4,
|
| 58 |
+
help='number of workers for inference with multiprocessing',
|
| 59 |
+
)
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
'--save_overlay',
|
| 62 |
+
default=False,
|
| 63 |
+
action='store_true',
|
| 64 |
+
help='whether to save the overlay image'
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
parser.add_argument(
|
| 68 |
+
'-png_save_dir',
|
| 69 |
+
type=str,
|
| 70 |
+
default=None,
|
| 71 |
+
help='directory to save the overlay image'
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
args = parser.parse_args()
|
| 75 |
+
|
| 76 |
+
data_root = args.input_dir
|
| 77 |
+
pred_save_dir = args.output_dir
|
| 78 |
+
save_overlay = args.save_overlay
|
| 79 |
+
num_workers = args.num_workers
|
| 80 |
+
|
| 81 |
+
if save_overlay:
|
| 82 |
+
assert args.png_save_dir is not None, "Please specify the directory to save the overlay image"
|
| 83 |
+
png_save_dir = args.png_save_dir
|
| 84 |
+
makedirs(png_save_dir, exist_ok=True)
|
| 85 |
+
|
| 86 |
+
lite_medsam_checkpoint_path = args.lite_medsam_checkpoint_path
|
| 87 |
+
makedirs(pred_save_dir, exist_ok=True)
|
| 88 |
+
device = torch.device(args.device)
|
| 89 |
+
image_size = 256
|
| 90 |
+
model1 = CLIPModel.from_pretrained("flaviagiammarino/pubmed-clip-vit-base-patch32", resume_download=True)
|
| 91 |
+
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch16", resume_download=True)
|
| 92 |
+
model1.requires_grad_(False)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def resize_longest_side(image, target_length=256):
|
| 96 |
+
"""
|
| 97 |
+
Resize image to target_length while keeping the aspect ratio
|
| 98 |
+
Expects a numpy array with shape HxWxC in uint8 format.
|
| 99 |
+
"""
|
| 100 |
+
oldh, oldw = image.shape[0], image.shape[1]
|
| 101 |
+
scale = target_length * 1.0 / max(oldh, oldw)
|
| 102 |
+
newh, neww = oldh * scale, oldw * scale
|
| 103 |
+
neww, newh = int(neww + 0.5), int(newh + 0.5)
|
| 104 |
+
target_size = (neww, newh)
|
| 105 |
+
|
| 106 |
+
return cv2.resize(image, target_size, interpolation=cv2.INTER_AREA)
|
| 107 |
+
|
| 108 |
+
def pad_image(image, target_size=256):
|
| 109 |
+
"""
|
| 110 |
+
Pad image to target_size
|
| 111 |
+
Expects a numpy array with shape HxWxC in uint8 format.
|
| 112 |
+
"""
|
| 113 |
+
# Pad
|
| 114 |
+
h, w = image.shape[0], image.shape[1]
|
| 115 |
+
padh = target_size - h
|
| 116 |
+
padw = target_size - w
|
| 117 |
+
if len(image.shape) == 3: ## Pad image
|
| 118 |
+
image_padded = np.pad(image, ((0, padh), (0, padw), (0, 0)))
|
| 119 |
+
else: ## Pad gt mask
|
| 120 |
+
image_padded = np.pad(image, ((0, padh), (0, padw)))
|
| 121 |
+
|
| 122 |
+
return image_padded
|
| 123 |
+
|
| 124 |
+
class MedSAM_Lite(nn.Module):
|
| 125 |
+
def __init__(
|
| 126 |
+
self,
|
| 127 |
+
image_encoder,
|
| 128 |
+
mask_decoder,
|
| 129 |
+
prompt_encoder
|
| 130 |
+
):
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.image_encoder = image_encoder
|
| 133 |
+
self.mask_decoder = mask_decoder
|
| 134 |
+
self.prompt_encoder = prompt_encoder
|
| 135 |
+
|
| 136 |
+
def forward(self, image, points, boxes, masks, features, crops, text_features, category_idx):
|
| 137 |
+
image_embedding = self.image_encoder(image)
|
| 138 |
+
with torch.no_grad():
|
| 139 |
+
boxes = torch.as_tensor(boxes, dtype=torch.float32, device=image.device)
|
| 140 |
+
if len(boxes.shape) == 2:
|
| 141 |
+
boxes = boxes[:, None, :] # (B, 1, 4)
|
| 142 |
+
|
| 143 |
+
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
| 144 |
+
points=points,
|
| 145 |
+
boxes=boxes,
|
| 146 |
+
masks=masks,
|
| 147 |
+
features=features,
|
| 148 |
+
crops=crops,
|
| 149 |
+
text_features = text_features,
|
| 150 |
+
category_idx=category_idx
|
| 151 |
+
)
|
| 152 |
+
low_res_masks, iou_predictions, category_predictions, clip_vec, img_vec = self.mask_decoder(
|
| 153 |
+
image_embeddings=image_embedding, # (B, 256, 64, 64)
|
| 154 |
+
image_pe=self.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
|
| 155 |
+
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
|
| 156 |
+
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
|
| 157 |
+
multimask_output=False,
|
| 158 |
+
) # (B, 1, 256, 256)
|
| 159 |
+
|
| 160 |
+
return low_res_masks
|
| 161 |
+
|
| 162 |
+
@torch.no_grad()
|
| 163 |
+
def postprocess_masks(self, masks, new_size, original_size):
|
| 164 |
+
"""
|
| 165 |
+
Do cropping and resizing
|
| 166 |
+
|
| 167 |
+
Parameters
|
| 168 |
+
----------
|
| 169 |
+
masks : torch.Tensor
|
| 170 |
+
masks predicted by the model
|
| 171 |
+
new_size : tuple
|
| 172 |
+
the shape of the image after resizing to the longest side of 256
|
| 173 |
+
original_size : tuple
|
| 174 |
+
the original shape of the image
|
| 175 |
+
|
| 176 |
+
Returns
|
| 177 |
+
-------
|
| 178 |
+
torch.Tensor
|
| 179 |
+
the upsampled mask to the original size
|
| 180 |
+
"""
|
| 181 |
+
# Crop
|
| 182 |
+
masks = masks[..., :new_size[0], :new_size[1]]
|
| 183 |
+
# Resize
|
| 184 |
+
masks = F.interpolate(
|
| 185 |
+
masks,
|
| 186 |
+
size=(original_size[0], original_size[1]),
|
| 187 |
+
mode="bilinear",
|
| 188 |
+
align_corners=False,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
return masks
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def show_mask(mask, ax, mask_color=None, alpha=0.5):
|
| 195 |
+
"""
|
| 196 |
+
show mask on the image
|
| 197 |
+
|
| 198 |
+
Parameters
|
| 199 |
+
----------
|
| 200 |
+
mask : numpy.ndarray
|
| 201 |
+
mask of the image
|
| 202 |
+
ax : matplotlib.axes.Axes
|
| 203 |
+
axes to plot the mask
|
| 204 |
+
mask_color : numpy.ndarray
|
| 205 |
+
color of the mask
|
| 206 |
+
alpha : float
|
| 207 |
+
transparency of the mask
|
| 208 |
+
"""
|
| 209 |
+
if mask_color is not None:
|
| 210 |
+
color = np.concatenate([mask_color, np.array([alpha])], axis=0)
|
| 211 |
+
else:
|
| 212 |
+
color = np.array([251/255, 252/255, 30/255, alpha])
|
| 213 |
+
h, w = mask.shape[-2:]
|
| 214 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
| 215 |
+
ax.imshow(mask_image)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def show_box(box, ax, edgecolor='blue'):
|
| 219 |
+
"""
|
| 220 |
+
show bounding box on the image
|
| 221 |
+
|
| 222 |
+
Parameters
|
| 223 |
+
----------
|
| 224 |
+
box : numpy.ndarray
|
| 225 |
+
bounding box coordinates in the original image
|
| 226 |
+
ax : matplotlib.axes.Axes
|
| 227 |
+
axes to plot the bounding box
|
| 228 |
+
edgecolor : str
|
| 229 |
+
color of the bounding box
|
| 230 |
+
"""
|
| 231 |
+
x0, y0 = box[0], box[1]
|
| 232 |
+
w, h = box[2] - box[0], box[3] - box[1]
|
| 233 |
+
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=edgecolor, facecolor=(0,0,0,0), lw=2))
|
| 234 |
+
|
| 235 |
+
def show_points(points, ax):
|
| 236 |
+
points = points.numpy()
|
| 237 |
+
for i, (x, y) in enumerate(points):
|
| 238 |
+
ax.scatter(x, y, color='yellow', s=15)
|
| 239 |
+
|
| 240 |
+
def get_bbox256(mask_256, bbox_shift=3):
|
| 241 |
+
"""
|
| 242 |
+
Get the bounding box coordinates from the mask (256x256)
|
| 243 |
+
|
| 244 |
+
Parameters
|
| 245 |
+
----------
|
| 246 |
+
mask_256 : numpy.ndarray
|
| 247 |
+
the mask of the resized image
|
| 248 |
+
|
| 249 |
+
bbox_shift : int
|
| 250 |
+
Add perturbation to the bounding box coordinates
|
| 251 |
+
|
| 252 |
+
Returns
|
| 253 |
+
-------
|
| 254 |
+
numpy.ndarray
|
| 255 |
+
bounding box coordinates in the resized image
|
| 256 |
+
"""
|
| 257 |
+
y_indices, x_indices = np.where(mask_256 > 0)
|
| 258 |
+
x_min, x_max = np.min(x_indices), np.max(x_indices)
|
| 259 |
+
y_min, y_max = np.min(y_indices), np.max(y_indices)
|
| 260 |
+
# add perturbation to bounding box coordinates and test the robustness
|
| 261 |
+
# this can be removed if you do not want to test the robustness
|
| 262 |
+
H, W = mask_256.shape
|
| 263 |
+
x_min = max(0, x_min - bbox_shift)
|
| 264 |
+
x_max = min(W, x_max + bbox_shift)
|
| 265 |
+
y_min = max(0, y_min - bbox_shift)
|
| 266 |
+
y_max = min(H, y_max + bbox_shift)
|
| 267 |
+
|
| 268 |
+
bboxes256 = np.array([x_min, y_min, x_max, y_max])
|
| 269 |
+
|
| 270 |
+
return bboxes256
|
| 271 |
+
|
| 272 |
+
def resize_box_to_256(box, original_size):
|
| 273 |
+
"""
|
| 274 |
+
the input bounding box is obtained from the original image
|
| 275 |
+
here, we rescale it to the coordinates of the resized image
|
| 276 |
+
|
| 277 |
+
Parameters
|
| 278 |
+
----------
|
| 279 |
+
box : numpy.ndarray
|
| 280 |
+
bounding box coordinates in the original image
|
| 281 |
+
original_size : tuple
|
| 282 |
+
the original size of the image
|
| 283 |
+
|
| 284 |
+
Returns
|
| 285 |
+
-------
|
| 286 |
+
numpy.ndarray
|
| 287 |
+
bounding box coordinates in the resized image
|
| 288 |
+
"""
|
| 289 |
+
new_box = np.zeros_like(box)
|
| 290 |
+
ratio = 256 / max(original_size)
|
| 291 |
+
for i in range(len(box)):
|
| 292 |
+
new_box[i] = int(box[i] * ratio)
|
| 293 |
+
|
| 294 |
+
return new_box, ratio
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def get_points_256(box, gt2D):
|
| 298 |
+
gt2D = np.mean(gt2D, axis=-1)
|
| 299 |
+
if len(box)==1:
|
| 300 |
+
x_min, y_min, x_max, y_max = box[0]
|
| 301 |
+
else:
|
| 302 |
+
x_min, y_min, x_max, y_max = box
|
| 303 |
+
|
| 304 |
+
try:
|
| 305 |
+
bounder_shiftx = np.random.randint(int((x_max-x_min)/5), int(2*(x_max-x_min)/5), (1,))
|
| 306 |
+
# bounder_shiftx = int((x_max-x_min)/5)
|
| 307 |
+
except:
|
| 308 |
+
bounder_shiftx = 0
|
| 309 |
+
try:
|
| 310 |
+
bounder_shifty = np.random.randint(int((y_max-y_min)/5), int(2*(y_max-y_min)/5), (1,))
|
| 311 |
+
# bounder_shifty = int((y_max-y_min)/5)
|
| 312 |
+
except:
|
| 313 |
+
bounder_shifty = 0
|
| 314 |
+
|
| 315 |
+
mid_x = int((x_min+x_max)//2)
|
| 316 |
+
mid_y = int((y_min+y_max)//2)
|
| 317 |
+
x_min = int(x_min+bounder_shiftx)
|
| 318 |
+
x_max = int(x_max-bounder_shiftx)
|
| 319 |
+
y_min = int(y_min+bounder_shifty)
|
| 320 |
+
y_max = int(y_max-bounder_shifty)
|
| 321 |
+
cl = [[y_min, mid_y, x_min, mid_x], [mid_y,y_max,x_min,mid_x], [mid_y,y_max, mid_x,x_max], [y_min,mid_y, mid_x,x_max]]
|
| 322 |
+
|
| 323 |
+
coords = []
|
| 324 |
+
for i in range(4):
|
| 325 |
+
gt2D_tmp = np.zeros((256, 256))
|
| 326 |
+
gt2D_tmp[cl[i][0]:cl[i][1], cl[i][2]:cl[i][3]] = gt2D[cl[i][0]:cl[i][1], cl[i][2]:cl[i][3]]
|
| 327 |
+
y_indices, x_indices = np.where(gt2D_tmp > 0)
|
| 328 |
+
if y_indices.size==0:
|
| 329 |
+
coords.append([mid_x, mid_y])
|
| 330 |
+
else:
|
| 331 |
+
x_point = np.random.choice(x_indices)
|
| 332 |
+
y_point = np.random.choice(y_indices)
|
| 333 |
+
coords.append([x_point, y_point])
|
| 334 |
+
coords = np.array(coords).reshape(4, 2)
|
| 335 |
+
coords = torch.tensor(coords).float()
|
| 336 |
+
return coords
|
| 337 |
+
|
| 338 |
+
def get_points_256_v0(box, gt2D):
|
| 339 |
+
gt2D = np.mean(gt2D, axis=-1)
|
| 340 |
+
if len(box)==1:
|
| 341 |
+
x_min, y_min, x_max, y_max = box[0]
|
| 342 |
+
else:
|
| 343 |
+
x_min, y_min, x_max, y_max = box
|
| 344 |
+
mid_x = int((x_min+x_max)//2)
|
| 345 |
+
mid_y = int((y_min+y_max)//2)
|
| 346 |
+
try:
|
| 347 |
+
bounder_shiftx = np.random.randint(int((x_max-x_min)/3), int(2*(x_max-x_min)/4)-1, (1,))
|
| 348 |
+
# bounder_shiftx = 0
|
| 349 |
+
except:
|
| 350 |
+
bounder_shiftx = 0
|
| 351 |
+
try:
|
| 352 |
+
bounder_shifty = np.random.randint(int((y_max-y_min)/3), int(2*(y_max-y_min)/4)-1, (1,))
|
| 353 |
+
# bounder_shifty = 0
|
| 354 |
+
except:
|
| 355 |
+
bounder_shifty = 0
|
| 356 |
+
x_min = int(x_min+bounder_shiftx)
|
| 357 |
+
x_max = int(x_max-bounder_shiftx)
|
| 358 |
+
y_min = int(y_min+bounder_shifty)
|
| 359 |
+
y_max = int(y_max-bounder_shifty)
|
| 360 |
+
# cl = [[y_min, mid_y, x_min, mid_x], [mid_y,y_max,x_min,mid_x], [mid_y,y_max, mid_x,x_max], [y_min,mid_y, mid_x,x_max]]
|
| 361 |
+
|
| 362 |
+
coords = []
|
| 363 |
+
gt2D_tmp = np.zeros((256, 256))
|
| 364 |
+
gt2D_tmp[y_min:y_max, x_min:x_max] = gt2D[y_min:y_max, x_min:x_max]
|
| 365 |
+
for i in range(4):
|
| 366 |
+
y_indices, x_indices = np.where(gt2D_tmp > 0)
|
| 367 |
+
if y_indices.size==0:
|
| 368 |
+
coords.append([mid_x, mid_y])
|
| 369 |
+
else:
|
| 370 |
+
x_point = np.random.choice(x_indices)
|
| 371 |
+
y_point = np.random.choice(y_indices)
|
| 372 |
+
coords.append([x_point, y_point])
|
| 373 |
+
coords = np.array(coords).reshape(4, 2)
|
| 374 |
+
coords = torch.tensor(coords).float()
|
| 375 |
+
return coords
|
| 376 |
+
|
| 377 |
+
@torch.no_grad()
|
| 378 |
+
def medsam_inference(medsam_model, img_embed, box_256, features, crops, text_features, category_idx, new_size, original_size):
|
| 379 |
+
"""
|
| 380 |
+
Perform inference using the LiteMedSAM model.
|
| 381 |
+
|
| 382 |
+
Args:
|
| 383 |
+
medsam_model (MedSAMModel): The MedSAM model.
|
| 384 |
+
img_embed (torch.Tensor): The image embeddings.
|
| 385 |
+
box_256 (numpy.ndarray): The bounding box coordinates.
|
| 386 |
+
new_size (tuple): The new size of the image.
|
| 387 |
+
original_size (tuple): The original size of the image.
|
| 388 |
+
Returns:
|
| 389 |
+
tuple: A tuple containing the segmented image and the intersection over union (IoU) score.
|
| 390 |
+
"""
|
| 391 |
+
box_torch = torch.as_tensor(box_256[None, None, ...], dtype=torch.float, device=img_embed.device)
|
| 392 |
+
features = features.unsqueeze(0).to(device)
|
| 393 |
+
crops = crops.unsqueeze(0).to(device)
|
| 394 |
+
category_idx = torch.tensor([category_idx]).to(device)
|
| 395 |
+
sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
|
| 396 |
+
points=None,
|
| 397 |
+
boxes=box_torch,
|
| 398 |
+
masks=None,
|
| 399 |
+
features=features,
|
| 400 |
+
crops=crops,
|
| 401 |
+
text_features = text_features,
|
| 402 |
+
category_idx=category_idx
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
low_res_logits, iou, _, _, _ = medsam_model.mask_decoder(
|
| 406 |
+
image_embeddings=img_embed, # (B, 256, 64, 64)
|
| 407 |
+
image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
|
| 408 |
+
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
|
| 409 |
+
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
|
| 410 |
+
multimask_output=False
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
low_res_pred = medsam_model.postprocess_masks(low_res_logits, new_size, original_size)
|
| 414 |
+
low_res_pred = torch.sigmoid(low_res_pred)
|
| 415 |
+
low_res_pred = low_res_pred.squeeze().cpu().numpy()
|
| 416 |
+
medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
|
| 417 |
+
return medsam_seg, iou
|
| 418 |
+
|
| 419 |
+
medsam_lite_image_encoder = TinyViT(
|
| 420 |
+
img_size=256,
|
| 421 |
+
in_chans=3,
|
| 422 |
+
embed_dims=[
|
| 423 |
+
64, ## (64, 256, 256)
|
| 424 |
+
128, ## (128, 128, 128)
|
| 425 |
+
160, ## (160, 64, 64)
|
| 426 |
+
320 ## (320, 64, 64)
|
| 427 |
+
],
|
| 428 |
+
depths=[2, 2, 6, 2],
|
| 429 |
+
num_heads=[2, 4, 5, 10],
|
| 430 |
+
window_sizes=[7, 7, 14, 7],
|
| 431 |
+
mlp_ratio=4.,
|
| 432 |
+
drop_rate=0.,
|
| 433 |
+
drop_path_rate=0.0,
|
| 434 |
+
use_checkpoint=False,
|
| 435 |
+
mbconv_expand_ratio=4.0,
|
| 436 |
+
local_conv_size=3,
|
| 437 |
+
layer_lr_decay=0.8
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
medsam_lite_prompt_encoder = PromptEncoder(
|
| 441 |
+
embed_dim=256,
|
| 442 |
+
image_embedding_size=(64, 64),
|
| 443 |
+
input_image_size=(256, 256),
|
| 444 |
+
mask_in_chans=16
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
medsam_lite_mask_decoder = MaskDecoder_F4(
|
| 448 |
+
num_multimask_outputs=3,
|
| 449 |
+
transformer=TwoWayTransformer(
|
| 450 |
+
depth=2,
|
| 451 |
+
embedding_dim=256,
|
| 452 |
+
mlp_dim=2048,
|
| 453 |
+
num_heads=8,
|
| 454 |
+
),
|
| 455 |
+
modality=True,
|
| 456 |
+
contents=True,
|
| 457 |
+
transformer_dim=256,
|
| 458 |
+
iou_head_depth=3,
|
| 459 |
+
iou_head_hidden_dim=256,
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
medsam_lite_model = MedSAM_Lite(
|
| 464 |
+
image_encoder = medsam_lite_image_encoder,
|
| 465 |
+
mask_decoder = medsam_lite_mask_decoder,
|
| 466 |
+
prompt_encoder = medsam_lite_prompt_encoder
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
lite_medsam_checkpoint = torch.load(lite_medsam_checkpoint_path, map_location='cpu')
|
| 470 |
+
medsam_lite_model.load_state_dict(lite_medsam_checkpoint["model"])
|
| 471 |
+
medsam_lite_model.to(device)
|
| 472 |
+
medsam_lite_model.eval()
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def m2_pre_img(image_data, image_size=224):
|
| 476 |
+
transform1 = transforms.Compose([
|
| 477 |
+
transforms.ToTensor(), # normalize to [0.0,1.0]
|
| 478 |
+
transforms.Resize([image_size, image_size], interpolation=transforms.InterpolationMode.BILINEAR, antialias=True)
|
| 479 |
+
]
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
resize_img_torch = transform1(image_data)
|
| 483 |
+
return resize_img_torch
|
| 484 |
+
|
| 485 |
+
def get_contents(img, box):
|
| 486 |
+
if len(box)==1:
|
| 487 |
+
x_mino, y_mino, x_maxo, y_maxo = box[0]
|
| 488 |
+
else:
|
| 489 |
+
x_mino, y_mino, x_maxo, y_maxo = box
|
| 490 |
+
crops = img[y_mino:y_maxo,x_mino:x_maxo,:]
|
| 491 |
+
crops_128 = m2_pre_img(crops, image_size=64)
|
| 492 |
+
crops_224 = m2_pre_img(crops)
|
| 493 |
+
crops_224 = crops_224.unsqueeze(0)
|
| 494 |
+
with torch.no_grad():
|
| 495 |
+
image_features = model1.get_image_features(crops_224)
|
| 496 |
+
return crops_128, image_features
|
| 497 |
+
|
| 498 |
+
def get_text_features(modality_text):
|
| 499 |
+
|
| 500 |
+
text_token = tokenizer(modality_text, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids
|
| 501 |
+
with torch.no_grad():
|
| 502 |
+
text_features = model1.get_text_features(text_token)
|
| 503 |
+
return text_features
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
def get_category(idx):
|
| 507 |
+
categories_map = {
|
| 508 |
+
"CT": 0,
|
| 509 |
+
"MR": 1,
|
| 510 |
+
"Endoscopy": 2,
|
| 511 |
+
"XRay": 3,
|
| 512 |
+
"X-Ray": 3,
|
| 513 |
+
"PET": 4,
|
| 514 |
+
"Dermoscopy": 5,
|
| 515 |
+
"Mammography": 6,
|
| 516 |
+
"Mammo": 6,
|
| 517 |
+
"US": 7,
|
| 518 |
+
"OCT": 8,
|
| 519 |
+
"Fundus": 9,
|
| 520 |
+
"Microscopy": 10,
|
| 521 |
+
"Microscope": 10
|
| 522 |
+
}
|
| 523 |
+
return categories_map[idx]
|
| 524 |
+
|
| 525 |
+
def change_name(name):
|
| 526 |
+
if name=="Microscope":
|
| 527 |
+
name = "Microscopy"
|
| 528 |
+
return name
|
| 529 |
+
|
| 530 |
+
def MedSAM_infer_npz_2D(img_npz_file):
|
| 531 |
+
npz_name = basename(img_npz_file)
|
| 532 |
+
c_name = change_name(npz_name.split('_')[1])
|
| 533 |
+
modality_text = f"{c_name} Image"
|
| 534 |
+
category_idx = get_category(c_name)
|
| 535 |
+
npz_data = np.load(img_npz_file, 'r', allow_pickle=True) # (H, W, 3)
|
| 536 |
+
img_3c = npz_data['imgs'] # (H, W, 3)
|
| 537 |
+
assert np.max(img_3c)<256, f'input data should be in range [0, 255], but got {np.unique(img_3c)}'
|
| 538 |
+
H, W = img_3c.shape[:2]
|
| 539 |
+
boxes = npz_data['boxes']
|
| 540 |
+
segs = np.zeros(img_3c.shape[:2], dtype=np.uint8)
|
| 541 |
+
text_features = get_text_features(modality_text)
|
| 542 |
+
text_features = torch.tensor(text_features).unsqueeze(0).to(device)
|
| 543 |
+
|
| 544 |
+
## preprocessing
|
| 545 |
+
img_256 = resize_longest_side(img_3c, 256)
|
| 546 |
+
newh, neww = img_256.shape[:2]
|
| 547 |
+
img_256_norm = (img_256 - img_256.min()) / np.clip(
|
| 548 |
+
img_256.max() - img_256.min(), a_min=1e-8, a_max=None
|
| 549 |
+
)
|
| 550 |
+
img_256_padded = pad_image(img_256_norm, 256)
|
| 551 |
+
img_256_tensor = torch.tensor(img_256_padded).float().permute(2, 0, 1).unsqueeze(0).to(device)
|
| 552 |
+
with torch.no_grad():
|
| 553 |
+
image_embedding = medsam_lite_model.image_encoder(img_256_tensor)
|
| 554 |
+
|
| 555 |
+
for idx, box in enumerate(boxes, start=1):
|
| 556 |
+
crops, features = get_contents(img_3c, box)
|
| 557 |
+
box256, ratio = resize_box_to_256(box, original_size=(H, W))
|
| 558 |
+
box256 = box256[None, ...] # (1, 4)
|
| 559 |
+
medsam_mask, iou_pred = medsam_inference(medsam_lite_model, image_embedding, box256, features, crops, text_features, category_idx, (newh, neww), (H, W))
|
| 560 |
+
segs[medsam_mask>0] = idx%256
|
| 561 |
+
# print(f'{npz_name}, box: {box}, predicted iou: {np.round(iou_pred.item(), 4)}')
|
| 562 |
+
|
| 563 |
+
np.savez_compressed(
|
| 564 |
+
join(pred_save_dir, npz_name),
|
| 565 |
+
segs=segs,
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
# visualize image, mask and bounding box
|
| 569 |
+
if save_overlay and "Microscope" not in npz_name:
|
| 570 |
+
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
|
| 571 |
+
ax[0].imshow(img_3c)
|
| 572 |
+
ax[1].imshow(img_3c)
|
| 573 |
+
ax[0].set_title("Image")
|
| 574 |
+
ax[1].set_title("LiteMedSAM Segmentation")
|
| 575 |
+
ax[0].axis('off')
|
| 576 |
+
ax[1].axis('off')
|
| 577 |
+
|
| 578 |
+
for i, box in enumerate(boxes):
|
| 579 |
+
color = np.random.rand(3)
|
| 580 |
+
box_viz = box
|
| 581 |
+
show_box(box_viz, ax[1], edgecolor=color)
|
| 582 |
+
# show_points(points[i], ax[1])
|
| 583 |
+
show_mask((segs == i+1).astype(np.uint8), ax[1], mask_color=color)
|
| 584 |
+
|
| 585 |
+
plt.tight_layout()
|
| 586 |
+
plt.savefig(join(png_save_dir, npz_name.split(".")[0] + '.png'), dpi=300)
|
| 587 |
+
plt.close()
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
def MedSAM_infer_npz_3D(img_npz_file):
|
| 591 |
+
npz_name = basename(img_npz_file)
|
| 592 |
+
c_name = change_name(npz_name.split('_')[1])
|
| 593 |
+
modality_text = f"{c_name} Image"
|
| 594 |
+
category_idx = get_category(c_name)
|
| 595 |
+
npz_data = np.load(img_npz_file, 'r', allow_pickle=True)
|
| 596 |
+
img_3D = npz_data['imgs'] # (D, H, W)
|
| 597 |
+
# not used in this demo because it treats each slice independently
|
| 598 |
+
# spacing = npz_data['spacing']
|
| 599 |
+
segs = np.zeros_like(img_3D, dtype=np.uint8)
|
| 600 |
+
boxes_3D = npz_data['boxes'] # [[x_min, y_min, z_min, x_max, y_max, z_max]]
|
| 601 |
+
text_features = get_text_features(modality_text)
|
| 602 |
+
text_features = torch.tensor(text_features).unsqueeze(0).to(device)
|
| 603 |
+
|
| 604 |
+
for idx, box3D in enumerate(boxes_3D, start=1):
|
| 605 |
+
segs_3d_temp = np.zeros_like(img_3D, dtype=np.uint8)
|
| 606 |
+
x_min, y_min, z_min, x_max, y_max, z_max = box3D
|
| 607 |
+
assert z_min < z_max, f"z_min should be smaller than z_max, but got {z_min=} and {z_max=}"
|
| 608 |
+
mid_slice_bbox_2d = np.array([x_min, y_min, x_max, y_max])
|
| 609 |
+
z_middle = int((z_max - z_min)/2 + z_min)
|
| 610 |
+
|
| 611 |
+
# infer from middle slice to the z_max
|
| 612 |
+
# print(npz_name, 'infer from middle slice to the z_max')
|
| 613 |
+
for z in range(z_middle, z_max):
|
| 614 |
+
img_2d = img_3D[z, :, :]
|
| 615 |
+
if len(img_2d.shape) == 2:
|
| 616 |
+
img_3c = np.repeat(img_2d[:, :, None], 3, axis=-1)
|
| 617 |
+
else:
|
| 618 |
+
img_3c = img_2d
|
| 619 |
+
H, W, _ = img_3c.shape
|
| 620 |
+
|
| 621 |
+
img_256 = resize_longest_side(img_3c, 256)
|
| 622 |
+
new_H, new_W = img_256.shape[:2]
|
| 623 |
+
|
| 624 |
+
img_256 = (img_256 - img_256.min()) / np.clip(
|
| 625 |
+
img_256.max() - img_256.min(), a_min=1e-8, a_max=None
|
| 626 |
+
) # normalize to [0, 1], (H, W, 3)
|
| 627 |
+
## Pad image to 256x256
|
| 628 |
+
img_256 = pad_image(img_256)
|
| 629 |
+
|
| 630 |
+
# convert the shape to (3, H, W)
|
| 631 |
+
img_256_tensor = torch.tensor(img_256).float().permute(2, 0, 1).unsqueeze(0).to(device)
|
| 632 |
+
# get the image embedding
|
| 633 |
+
with torch.no_grad():
|
| 634 |
+
image_embedding = medsam_lite_model.image_encoder(img_256_tensor) # (1, 256, 64, 64)
|
| 635 |
+
if z == z_middle:
|
| 636 |
+
crops, features = get_contents(img_3c, mid_slice_bbox_2d)
|
| 637 |
+
box_256, _ = resize_box_to_256(mid_slice_bbox_2d, original_size=(H, W))
|
| 638 |
+
else:
|
| 639 |
+
pre_seg = segs_3d_temp[z-1, :, :]
|
| 640 |
+
if np.max(pre_seg) > 0:
|
| 641 |
+
box_original = get_bbox256(pre_seg)
|
| 642 |
+
crops, features = get_contents(img_3c, box_original)
|
| 643 |
+
pre_seg256 = resize_longest_side(pre_seg)
|
| 644 |
+
pre_seg256 = pad_image(pre_seg256)
|
| 645 |
+
box_256 = get_bbox256(pre_seg256)
|
| 646 |
+
else:
|
| 647 |
+
crops, features = get_contents(img_3c, mid_slice_bbox_2d)
|
| 648 |
+
box_256, _ = resize_box_to_256(mid_slice_bbox_2d, original_size=(H, W))
|
| 649 |
+
img_2d_seg, iou_pred = medsam_inference(medsam_lite_model, image_embedding, box_256, features, crops, text_features, category_idx, [new_H, new_W], [H, W])
|
| 650 |
+
segs_3d_temp[z, img_2d_seg>0] = idx
|
| 651 |
+
|
| 652 |
+
# infer from middle slice to the z_max
|
| 653 |
+
# print(npz_name, 'infer from middle slice to the z_min')
|
| 654 |
+
for z in range(z_middle-1, z_min, -1):
|
| 655 |
+
img_2d = img_3D[z, :, :]
|
| 656 |
+
if len(img_2d.shape) == 2:
|
| 657 |
+
img_3c = np.repeat(img_2d[:, :, None], 3, axis=-1)
|
| 658 |
+
else:
|
| 659 |
+
img_3c = img_2d
|
| 660 |
+
H, W, _ = img_3c.shape
|
| 661 |
+
|
| 662 |
+
img_256 = resize_longest_side(img_3c)
|
| 663 |
+
new_H, new_W = img_256.shape[:2]
|
| 664 |
+
|
| 665 |
+
img_256 = (img_256 - img_256.min()) / np.clip(
|
| 666 |
+
img_256.max() - img_256.min(), a_min=1e-8, a_max=None
|
| 667 |
+
) # normalize to [0, 1], (H, W, 3)
|
| 668 |
+
## Pad image to 256x256
|
| 669 |
+
img_256 = pad_image(img_256)
|
| 670 |
+
|
| 671 |
+
img_256_tensor = torch.tensor(img_256).float().permute(2, 0, 1).unsqueeze(0).to(device)
|
| 672 |
+
# get the image embedding
|
| 673 |
+
with torch.no_grad():
|
| 674 |
+
image_embedding = medsam_lite_model.image_encoder(img_256_tensor) # (1, 256, 64, 64)
|
| 675 |
+
|
| 676 |
+
pre_seg = segs_3d_temp[z+1, :, :]
|
| 677 |
+
# pre_seg = segs[z+1, :, :]
|
| 678 |
+
if np.max(pre_seg) > 0:
|
| 679 |
+
box_original = get_bbox256(pre_seg)
|
| 680 |
+
crops, features = get_contents(img_3c, box_original)
|
| 681 |
+
pre_seg256 = resize_longest_side(pre_seg)
|
| 682 |
+
pre_seg256 = pad_image(pre_seg256)
|
| 683 |
+
box_256 = get_bbox256(pre_seg256)
|
| 684 |
+
else:
|
| 685 |
+
crops, features = get_contents(img_3c, mid_slice_bbox_2d)
|
| 686 |
+
scale_256 = 256 / max(H, W)
|
| 687 |
+
box_256 = mid_slice_bbox_2d * scale_256
|
| 688 |
+
img_2d_seg, iou_pred = medsam_inference(medsam_lite_model, image_embedding, box_256, features, crops, text_features, category_idx, [new_H, new_W], [H, W])
|
| 689 |
+
segs_3d_temp[z, img_2d_seg>0] = idx
|
| 690 |
+
segs[segs_3d_temp>0] = idx
|
| 691 |
+
np.savez_compressed(
|
| 692 |
+
join(pred_save_dir, npz_name),
|
| 693 |
+
segs=segs,
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
# visualize image, mask and bounding box
|
| 697 |
+
if save_overlay and "Microscope" not in npz_name:
|
| 698 |
+
idx = int(segs.shape[0] / 2)
|
| 699 |
+
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
|
| 700 |
+
ax[0].imshow(img_3D[idx], cmap='gray')
|
| 701 |
+
ax[1].imshow(img_3D[idx], cmap='gray')
|
| 702 |
+
ax[0].set_title("Image")
|
| 703 |
+
ax[1].set_title("LiteMedSAM Segmentation")
|
| 704 |
+
ax[0].axis('off')
|
| 705 |
+
ax[1].axis('off')
|
| 706 |
+
|
| 707 |
+
for i, box3D in enumerate(boxes_3D, start=1):
|
| 708 |
+
if np.sum(segs[idx]==i) > 0:
|
| 709 |
+
color = np.random.rand(3)
|
| 710 |
+
x_min, y_min, z_min, x_max, y_max, z_max = box3D
|
| 711 |
+
box_viz = np.array([x_min, y_min, x_max, y_max])
|
| 712 |
+
show_box(box_viz, ax[1], edgecolor=color)
|
| 713 |
+
show_mask(segs[idx]==i, ax[1], mask_color=color)
|
| 714 |
+
|
| 715 |
+
plt.tight_layout()
|
| 716 |
+
plt.savefig(join(png_save_dir, npz_name.split(".")[0] + '.png'), dpi=300)
|
| 717 |
+
plt.close()
|
| 718 |
+
|
| 719 |
+
|
| 720 |
+
if __name__ == '__main__':
|
| 721 |
+
|
| 722 |
+
img_npz_files = sorted(glob(join(data_root, '*.npz'), recursive=True))
|
| 723 |
+
efficiency = OrderedDict()
|
| 724 |
+
efficiency['case'] = []
|
| 725 |
+
efficiency['time'] = []
|
| 726 |
+
for img_npz_file in tqdm(img_npz_files):
|
| 727 |
+
start_time = time()
|
| 728 |
+
if basename(img_npz_file).startswith('3D'):
|
| 729 |
+
MedSAM_infer_npz_3D(img_npz_file)
|
| 730 |
+
else:
|
| 731 |
+
MedSAM_infer_npz_2D(img_npz_file)
|
| 732 |
+
end_time = time()
|
| 733 |
+
efficiency['case'].append(basename(img_npz_file))
|
| 734 |
+
efficiency['time'].append(end_time - start_time)
|
| 735 |
+
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 736 |
+
# print(current_time, 'file name:', basename(img_npz_file), 'time cost:', np.round(end_time - start_time, 4))
|
| 737 |
+
efficiency_df = pd.DataFrame(efficiency)
|
| 738 |
+
efficiency_df.to_csv(join(pred_save_dir, 'efficiency.csv'), index=False)
|
modality_npz_dataset.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import os
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
import torch
|
| 7 |
+
import cv2
|
| 8 |
+
from transformers import CLIPModel, CLIPTokenizer
|
| 9 |
+
from os.path import join, exists, isfile, isdir, basename
|
| 10 |
+
import random
|
| 11 |
+
|
| 12 |
+
join = os.path.join
|
| 13 |
+
import json
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def reshape_MR(img):
|
| 17 |
+
|
| 18 |
+
original_shape = img.shape
|
| 19 |
+
sorted_axes = np.argsort(original_shape)
|
| 20 |
+
new_img = img.transpose(sorted_axes)
|
| 21 |
+
|
| 22 |
+
return new_img
|
| 23 |
+
|
| 24 |
+
class ModalityNpzDataset(Dataset):
|
| 25 |
+
def __init__(self,
|
| 26 |
+
data_root,
|
| 27 |
+
points=True,
|
| 28 |
+
contents=True,
|
| 29 |
+
image_size=256,
|
| 30 |
+
bbox_shift=5,
|
| 31 |
+
data_aug=True):
|
| 32 |
+
|
| 33 |
+
self.data_root = data_root
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
json_data = json.load(open("case_data.json", "r"))
|
| 37 |
+
self.file_paths = json_data
|
| 38 |
+
|
| 39 |
+
assert len(self.file_paths) == 11
|
| 40 |
+
|
| 41 |
+
self.image_size = image_size
|
| 42 |
+
self.target_length = image_size
|
| 43 |
+
self.bbox_shift = bbox_shift
|
| 44 |
+
self.data_aug = data_aug
|
| 45 |
+
self.points = points
|
| 46 |
+
self.contents = contents
|
| 47 |
+
|
| 48 |
+
self.categories_map = {
|
| 49 |
+
"CT": 0,
|
| 50 |
+
"MR": 1,
|
| 51 |
+
"Endoscopy": 2,
|
| 52 |
+
"XRay": 3,
|
| 53 |
+
"X-Ray": 3,
|
| 54 |
+
"PET": 4,
|
| 55 |
+
"Dermoscopy": 5,
|
| 56 |
+
"Mammography": 6,
|
| 57 |
+
"Mammo": 6,
|
| 58 |
+
"US": 7,
|
| 59 |
+
"OCT": 8,
|
| 60 |
+
"Fundus": 9,
|
| 61 |
+
"Microscopy": 10,
|
| 62 |
+
"Microscope": 10
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
self.model1 = CLIPModel.from_pretrained("flaviagiammarino/pubmed-clip-vit-base-patch32")
|
| 66 |
+
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch16")
|
| 67 |
+
self.model1.requires_grad_(False)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def show_box(self, box, ax):
|
| 72 |
+
x0, y0 = box[0], box[1]
|
| 73 |
+
w, h = box[2] - box[0], box[3] - box[1]
|
| 74 |
+
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0,0,0,0), lw=2))
|
| 75 |
+
|
| 76 |
+
def vis(self, image, bboxes, title):
|
| 77 |
+
_, axs = plt.subplots(1, 2, figsize=(10, 10))
|
| 78 |
+
|
| 79 |
+
axs[0].imshow(image, cmap="gray")
|
| 80 |
+
self.show_box(bboxes, axs[0])
|
| 81 |
+
axs[0].axis('off')
|
| 82 |
+
axs[0].set_title(title)
|
| 83 |
+
|
| 84 |
+
plt.subplots_adjust(wspace=0.01, hspace=0)
|
| 85 |
+
plt.savefig(
|
| 86 |
+
"test.png",
|
| 87 |
+
bbox_inches='tight',
|
| 88 |
+
dpi=300
|
| 89 |
+
)
|
| 90 |
+
plt.close()
|
| 91 |
+
|
| 92 |
+
def vis_crop(self, image, title):
|
| 93 |
+
|
| 94 |
+
plt.imshow(np.transpose(image, (1,2,0)))
|
| 95 |
+
plt.axis('off')
|
| 96 |
+
plt.title(title)
|
| 97 |
+
|
| 98 |
+
plt.savefig(
|
| 99 |
+
"test.png",
|
| 100 |
+
bbox_inches='tight',
|
| 101 |
+
dpi=300
|
| 102 |
+
)
|
| 103 |
+
plt.close()
|
| 104 |
+
|
| 105 |
+
def __getitem__(self, index):
|
| 106 |
+
#! add the random index
|
| 107 |
+
|
| 108 |
+
modality_map = [
|
| 109 |
+
"CT",
|
| 110 |
+
"MR",
|
| 111 |
+
"Endoscopy",
|
| 112 |
+
"X-ray",
|
| 113 |
+
"PET",
|
| 114 |
+
"Dermoscopy",
|
| 115 |
+
"Mammography",
|
| 116 |
+
"US",
|
| 117 |
+
"OCT",
|
| 118 |
+
"Fundus",
|
| 119 |
+
"Microscopy"
|
| 120 |
+
]
|
| 121 |
+
modality_index = random.randint(0, 10)
|
| 122 |
+
index = random.randint(0, len(self.file_paths[modality_map[modality_index]])-1)
|
| 123 |
+
file_path = self.file_paths[modality_map[modality_index]][index][0]
|
| 124 |
+
temp = '/'.join(file_path.split('/')[7:])
|
| 125 |
+
file_path = self.data_root+'/'+temp
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
npz = np.load(file_path, 'r', allow_pickle=True)
|
| 129 |
+
img_name = basename(file_path)
|
| 130 |
+
|
| 131 |
+
mt = img_name.split("_")[0]
|
| 132 |
+
if mt=="2D" or mt=="3D":
|
| 133 |
+
mt = img_name.split("_")[1]
|
| 134 |
+
category_text = f"{mt} Image"
|
| 135 |
+
category_idx = self.categories_map[mt]
|
| 136 |
+
gts = npz["gts"]
|
| 137 |
+
img = npz["imgs"]
|
| 138 |
+
|
| 139 |
+
# special case for MR_totalseg
|
| 140 |
+
if "MR_totalseg" in img_name:
|
| 141 |
+
img = reshape_MR(img)
|
| 142 |
+
gts = reshape_MR(gts)
|
| 143 |
+
if img.shape[1] <=100:
|
| 144 |
+
return self.__getitem__(random.randint(0,len(self)-1))
|
| 145 |
+
|
| 146 |
+
if len(gts.shape) > 2: ## 3D image
|
| 147 |
+
i=random.randint(0,gts.shape[0]-1)
|
| 148 |
+
img = img[i, :, :]
|
| 149 |
+
gts = gts[i, :, :]
|
| 150 |
+
img_3c = np.repeat(img[:, :, None], 3, axis=-1) # (H, W, 3)
|
| 151 |
+
img_resized = self.resize_longest_side(img_3c)
|
| 152 |
+
else:
|
| 153 |
+
if len(img.shape) < 3:
|
| 154 |
+
img_3c = np.repeat(img[:, :, None], 3, axis=-1)
|
| 155 |
+
else:
|
| 156 |
+
img_3c = img
|
| 157 |
+
img_resized = self.resize_longest_side(img_3c)
|
| 158 |
+
gts = np.uint16(gts)
|
| 159 |
+
|
| 160 |
+
# Resizing
|
| 161 |
+
img_resized = (img_resized - img_resized.min()) / np.clip(img_resized.max() - img_resized.min(), a_min=1e-8, a_max=None) # normalize to [0, 1], (H, W, 3
|
| 162 |
+
img_padded = self.pad_image(img_resized) #self.pad_image(img_resize) # (256, 256, 3)
|
| 163 |
+
# convert the shape to (3, H, W)
|
| 164 |
+
img_padded = np.transpose(img_padded, (2, 0, 1)) # (3, 256, 256)
|
| 165 |
+
assert np.max(img_padded)<=1.0 and np.min(img_padded)>=0.0, 'image should be normalized to [0, 1]'
|
| 166 |
+
|
| 167 |
+
label_ids = np.unique(gts)
|
| 168 |
+
label_ids = label_ids.tolist()
|
| 169 |
+
|
| 170 |
+
try:
|
| 171 |
+
label_ids.remove(0)
|
| 172 |
+
label_id = random.choice(label_ids)
|
| 173 |
+
gt2D_original = np.uint8(gts == label_id)
|
| 174 |
+
gt = cv2.resize(
|
| 175 |
+
gt2D_original,
|
| 176 |
+
(img_resized.shape[1], img_resized.shape[0]),
|
| 177 |
+
interpolation=cv2.INTER_NEAREST
|
| 178 |
+
).astype(np.uint8)
|
| 179 |
+
gt2D = self.pad_image(gt)
|
| 180 |
+
|
| 181 |
+
except:
|
| 182 |
+
return self.__getitem__(random.randint(0,len(self)-1))
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
box_original = self.get_bbox(gt2D_original)
|
| 186 |
+
x_mino, y_mino, x_maxo, y_maxo = box_original
|
| 187 |
+
|
| 188 |
+
if self.data_aug:
|
| 189 |
+
if random.random() > 0.5:
|
| 190 |
+
img_padded = np.ascontiguousarray(np.flip(img_padded, axis=-1))
|
| 191 |
+
gt2D = np.ascontiguousarray(np.flip(gt2D, axis=-1))
|
| 192 |
+
if random.random() > 0.5:
|
| 193 |
+
img_padded = np.ascontiguousarray(np.flip(img_padded, axis=-2))
|
| 194 |
+
gt2D = np.ascontiguousarray(np.flip(gt2D, axis=-2))
|
| 195 |
+
|
| 196 |
+
try:
|
| 197 |
+
gt2D = np.uint8(gt2D > 0)
|
| 198 |
+
y_indices, x_indices = np.where(gt2D > 0)
|
| 199 |
+
x_min, x_max = np.min(x_indices), np.max(x_indices)
|
| 200 |
+
y_min, y_max = np.min(y_indices), np.max(y_indices)
|
| 201 |
+
H, W = gt2D.shape
|
| 202 |
+
x_min = max(0, x_min - random.randint(0, self.bbox_shift))
|
| 203 |
+
x_max = min(W, x_max + random.randint(0, self.bbox_shift))
|
| 204 |
+
y_min = max(0, y_min - random.randint(0, self.bbox_shift))
|
| 205 |
+
y_max = min(H, y_max + random.randint(0, self.bbox_shift))
|
| 206 |
+
bboxes = np.array([x_min, y_min, x_max, y_max])
|
| 207 |
+
except:
|
| 208 |
+
return self.__getitem__(random.randint(0,len(self)-1))
|
| 209 |
+
|
| 210 |
+
if self.points:
|
| 211 |
+
mid_x = (x_min+x_max)//2
|
| 212 |
+
mid_y = (y_min+y_max)//2
|
| 213 |
+
cl = [[y_min, mid_y, x_min, mid_x], [mid_y,y_max,x_min,mid_x], [mid_y,y_max, mid_x,x_max], [y_min,mid_y, mid_x,x_max]]
|
| 214 |
+
coords = []
|
| 215 |
+
for i in range(4):
|
| 216 |
+
gt2D_tmp = np.zeros((H, W))
|
| 217 |
+
gt2D_tmp[cl[i][0]:cl[i][1], cl[i][2]:cl[i][3]] = gt2D[cl[i][0]:cl[i][1], cl[i][2]:cl[i][3]]
|
| 218 |
+
y_indices, x_indices = np.where(gt2D_tmp > 0)
|
| 219 |
+
if y_indices.size==0:
|
| 220 |
+
coords.append([mid_x, mid_y])
|
| 221 |
+
else:
|
| 222 |
+
x_point = np.random.choice(x_indices)
|
| 223 |
+
y_point = np.random.choice(y_indices)
|
| 224 |
+
coords.append([x_point, y_point])
|
| 225 |
+
coords = np.array(coords).reshape(4, 2)
|
| 226 |
+
coords = torch.tensor(coords).float()
|
| 227 |
+
else:
|
| 228 |
+
coords = None
|
| 229 |
+
|
| 230 |
+
if self.contents:
|
| 231 |
+
try:
|
| 232 |
+
crops = img_3c[y_mino:y_maxo,x_mino:x_maxo,:]
|
| 233 |
+
crops_64 = self.m2_pre_img(crops, image_size=64) # change here for the size of cropped part
|
| 234 |
+
crops_224 = self.m2_pre_img(crops)
|
| 235 |
+
except:
|
| 236 |
+
crops_64 = torch.zeros((3, 64, 64))
|
| 237 |
+
crops_224 = torch.zeros((3, 224, 224))
|
| 238 |
+
crops_224 = crops_224.unsqueeze(0)
|
| 239 |
+
text_token = self.tokenizer(category_text, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids
|
| 240 |
+
with torch.no_grad():
|
| 241 |
+
image_features = self.model1.get_image_features(crops_224)
|
| 242 |
+
text_features = self.model1.get_text_features(text_token)
|
| 243 |
+
else:
|
| 244 |
+
crops_64 = None
|
| 245 |
+
image_features = None
|
| 246 |
+
text_features = None
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
return {
|
| 250 |
+
"image": torch.tensor(img_padded).float(),
|
| 251 |
+
"gt2D": torch.tensor(gt2D[None, :,:]).long(),
|
| 252 |
+
"coords": coords,
|
| 253 |
+
"bboxes": torch.tensor(bboxes[None, None, ...]).float(),
|
| 254 |
+
"image_crop": crops_64.float(),
|
| 255 |
+
"image_feature": image_features.float(),
|
| 256 |
+
"text_feature": text_features.float(),
|
| 257 |
+
"category_idx": category_idx,
|
| 258 |
+
"image_name": img_name,
|
| 259 |
+
"new_size": torch.tensor(np.array([img_padded.shape[0], img_padded.shape[1]])).long(),
|
| 260 |
+
"original_size": torch.tensor(np.array([img_3c.shape[0], img_3c.shape[1]])).long()
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
def __len__(self):
|
| 264 |
+
return 108714
|
| 265 |
+
|
| 266 |
+
def get_bbox(self, mask_256, bbox_shift=5):
|
| 267 |
+
y_indices, x_indices = np.where(mask_256 > 0)
|
| 268 |
+
x_min, x_max = np.min(x_indices), np.max(x_indices)
|
| 269 |
+
y_min, y_max = np.min(y_indices), np.max(y_indices)
|
| 270 |
+
H, W = mask_256.shape
|
| 271 |
+
x_min = max(0, x_min - random.randint(0, bbox_shift))
|
| 272 |
+
x_max = min(W, x_max + random.randint(0, bbox_shift))
|
| 273 |
+
y_min = max(0, y_min - random.randint(0, bbox_shift))
|
| 274 |
+
y_max = min(H, y_max + random.randint(0, bbox_shift))
|
| 275 |
+
|
| 276 |
+
bboxes256 = np.array([x_min, y_min, x_max, y_max])
|
| 277 |
+
|
| 278 |
+
return bboxes256
|
| 279 |
+
|
| 280 |
+
def m2_pre_img(self, image_data, image_size=224):
|
| 281 |
+
transform1 = transforms.Compose([
|
| 282 |
+
transforms.ToTensor(), # normalize to [0.0,1.0]
|
| 283 |
+
transforms.Resize([image_size, image_size], interpolation=transforms.InterpolationMode.BILINEAR, antialias=True)
|
| 284 |
+
]
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
resize_img_torch = transform1(image_data)
|
| 288 |
+
return resize_img_torch
|
| 289 |
+
|
| 290 |
+
def resize_longest_side(self, image):
|
| 291 |
+
"""
|
| 292 |
+
Expects a numpy array with shape HxWxC in uint8 format.
|
| 293 |
+
"""
|
| 294 |
+
long_side_length = self.target_length
|
| 295 |
+
oldh, oldw = image.shape[0], image.shape[1]
|
| 296 |
+
scale = long_side_length * 1.0 / max(oldh, oldw)
|
| 297 |
+
newh, neww = oldh * scale, oldw * scale
|
| 298 |
+
neww, newh = int(neww + 0.5), int(newh + 0.5)
|
| 299 |
+
target_size = (neww, newh)
|
| 300 |
+
|
| 301 |
+
return cv2.resize(image, target_size, interpolation=cv2.INTER_AREA)
|
| 302 |
+
|
| 303 |
+
def pad_image(self, image):
|
| 304 |
+
"""
|
| 305 |
+
Expects a numpy array with shape HxWxC in uint8 format.
|
| 306 |
+
"""
|
| 307 |
+
# Pad
|
| 308 |
+
h, w = image.shape[0], image.shape[1]
|
| 309 |
+
padh = self.image_size - h
|
| 310 |
+
padw = self.image_size - w
|
| 311 |
+
if len(image.shape) == 3: ## Pad image
|
| 312 |
+
image_padded = np.pad(image, ((0, padh), (0, padw), (0, 0)))
|
| 313 |
+
else: ## Pad gt mask
|
| 314 |
+
image_padded = np.pad(image, ((0, padh), (0, padw)))
|
| 315 |
+
|
| 316 |
+
return image_padded
|
| 317 |
+
|
models/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .mask_decoder import MaskDecoder, MaskDecoder_F4
|
| 2 |
+
from .prompt_encoder import PromptEncoder
|
| 3 |
+
from .transformer import TwoWayTransformer
|
| 4 |
+
from .tiny_vit import TinyViT
|
models/common.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
from typing import Type
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class MLPBlock(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
embedding_dim: int,
|
| 18 |
+
mlp_dim: int,
|
| 19 |
+
act: Type[nn.Module] = nn.GELU,
|
| 20 |
+
) -> None:
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
|
| 23 |
+
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
|
| 24 |
+
self.act = act()
|
| 25 |
+
|
| 26 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 27 |
+
return self.lin2(self.act(self.lin1(x)))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
|
| 31 |
+
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
|
| 32 |
+
class LayerNorm2d(nn.Module):
|
| 33 |
+
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.weight = nn.Parameter(torch.ones(num_channels))
|
| 36 |
+
self.bias = nn.Parameter(torch.zeros(num_channels))
|
| 37 |
+
self.eps = eps
|
| 38 |
+
|
| 39 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 40 |
+
u = x.mean(1, keepdim=True)
|
| 41 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 42 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 43 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 44 |
+
return x
|
models/lite_medsam.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from .mask_decoder import MaskDecoder
|
| 6 |
+
from .prompt_encoder import PromptEncoder
|
| 7 |
+
from .transform import TwoWayTransformer
|
| 8 |
+
|
| 9 |
+
class MedSAM_Lite(nn.Module):
|
| 10 |
+
def __init__(self,
|
| 11 |
+
image_encoder,
|
| 12 |
+
mask_decoder,
|
| 13 |
+
prompt_encoder
|
| 14 |
+
):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.image_encoder = image_encoder
|
| 17 |
+
self.mask_decoder = mask_decoder
|
| 18 |
+
self.prompt_encoder = prompt_encoder
|
| 19 |
+
|
| 20 |
+
def forward(self, image, boxes):
|
| 21 |
+
image_embedding = self.image_encoder(image) # (B, 256, 64, 64)
|
| 22 |
+
|
| 23 |
+
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
| 24 |
+
points=None,
|
| 25 |
+
boxes=boxes,
|
| 26 |
+
masks=None,
|
| 27 |
+
)
|
| 28 |
+
low_res_masks, iou_predictions = self.mask_decoder(
|
| 29 |
+
image_embeddings=image_embedding, # (B, 256, 64, 64)
|
| 30 |
+
image_pe=self.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
|
| 31 |
+
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
|
| 32 |
+
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
|
| 33 |
+
multimask_output=False,
|
| 34 |
+
) # (B, 1, 256, 256)
|
| 35 |
+
|
| 36 |
+
return low_res_masks, iou_predictions
|
| 37 |
+
|
| 38 |
+
@torch.no_grad()
|
| 39 |
+
def postprocess_masks(self, masks, new_size, original_size):
|
| 40 |
+
"""
|
| 41 |
+
Do cropping and resizing
|
| 42 |
+
"""
|
| 43 |
+
# Crop
|
| 44 |
+
masks = masks[:, :, :new_size[0], :new_size[1]]
|
| 45 |
+
# Resize
|
| 46 |
+
masks = F.interpolate(
|
| 47 |
+
masks,
|
| 48 |
+
size=(original_size[0], original_size[1]),
|
| 49 |
+
mode="bilinear",
|
| 50 |
+
align_corners=False,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
return masks
|
| 54 |
+
|
models/mask_decoder.py
ADDED
|
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
from torch.nn import functional as F
|
| 11 |
+
|
| 12 |
+
from typing import List, Tuple, Type
|
| 13 |
+
|
| 14 |
+
from .common import LayerNorm2d
|
| 15 |
+
from .transformer import TwoWayTransformer
|
| 16 |
+
|
| 17 |
+
class Classifier(nn.Module):
|
| 18 |
+
def __init__(self, in_dim, hid_dim=None, out_dim=None, act=nn.GELU, drop=0.):
|
| 19 |
+
super().__init__()
|
| 20 |
+
out_dim = out_dim or in_dim
|
| 21 |
+
hid_dim = hid_dim or in_dim
|
| 22 |
+
self.fc1 = nn.Linear(in_dim, hid_dim)
|
| 23 |
+
self.act = act()
|
| 24 |
+
self.fc2 = nn.Linear(hid_dim, out_dim)
|
| 25 |
+
self.drop = nn.Dropout(drop)
|
| 26 |
+
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
x = self.fc1(x)
|
| 29 |
+
x = self.act(x)
|
| 30 |
+
x = self.drop(x)
|
| 31 |
+
x = self.fc2(x)
|
| 32 |
+
return x
|
| 33 |
+
|
| 34 |
+
class Block(nn.Module):
|
| 35 |
+
def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
|
| 36 |
+
super(Block, self).__init__()
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False)
|
| 40 |
+
self.batch_norm1 = nn.BatchNorm2d(out_channels)
|
| 41 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False)
|
| 42 |
+
|
| 43 |
+
self.i_downsample = i_downsample
|
| 44 |
+
self.stride = stride
|
| 45 |
+
self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
| 46 |
+
|
| 47 |
+
def forward(self, x):
|
| 48 |
+
identity = x.clone()
|
| 49 |
+
|
| 50 |
+
x = self.relu(self.batch_norm1(self.conv1(x)))
|
| 51 |
+
x = self.conv2(x)
|
| 52 |
+
|
| 53 |
+
if self.i_downsample is not None:
|
| 54 |
+
identity = self.i_downsample(identity)
|
| 55 |
+
|
| 56 |
+
x += identity
|
| 57 |
+
return x
|
| 58 |
+
|
| 59 |
+
class MaskDecoder(nn.Module):
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
*,
|
| 63 |
+
transformer_dim: int,
|
| 64 |
+
transformer: nn.Module,
|
| 65 |
+
modality,
|
| 66 |
+
contents,
|
| 67 |
+
num_multimask_outputs: int = 3,
|
| 68 |
+
activation: Type[nn.Module] = nn.GELU,
|
| 69 |
+
iou_head_depth: int = 3,
|
| 70 |
+
iou_head_hidden_dim: int = 256,
|
| 71 |
+
category_num = 11
|
| 72 |
+
) -> None:
|
| 73 |
+
"""
|
| 74 |
+
Predicts masks given an image and prompt embeddings, using a
|
| 75 |
+
transformer architecture.
|
| 76 |
+
|
| 77 |
+
Arguments:
|
| 78 |
+
transformer_dim (int): the channel dimension of the transformer
|
| 79 |
+
transformer (nn.Module): the transformer used to predict masks
|
| 80 |
+
num_multimask_outputs (int): the number of masks to predict
|
| 81 |
+
when disambiguating masks
|
| 82 |
+
activation (nn.Module): the type of activation to use when
|
| 83 |
+
upscaling masks
|
| 84 |
+
iou_head_depth (int): the depth of the MLP used to predict
|
| 85 |
+
mask quality
|
| 86 |
+
iou_head_hidden_dim (int): the hidden dimension of the MLP
|
| 87 |
+
used to predict mask quality
|
| 88 |
+
"""
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.transformer_dim = transformer_dim
|
| 91 |
+
self.transformer = transformer
|
| 92 |
+
self.category_num = category_num
|
| 93 |
+
self.modality = modality
|
| 94 |
+
self.contents = contents
|
| 95 |
+
|
| 96 |
+
self.num_multimask_outputs = num_multimask_outputs
|
| 97 |
+
|
| 98 |
+
self.iou_token = nn.Embedding(1, transformer_dim)
|
| 99 |
+
self.num_mask_tokens = num_multimask_outputs + 1
|
| 100 |
+
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
|
| 101 |
+
|
| 102 |
+
self.convs = Block(transformer_dim, transformer_dim)
|
| 103 |
+
self.w_lin = nn.Linear(transformer_dim, transformer_dim)
|
| 104 |
+
self.b_lin = nn.Linear(transformer_dim, transformer_dim)
|
| 105 |
+
|
| 106 |
+
self.output_upscaling = nn.Sequential(
|
| 107 |
+
nn.ConvTranspose2d(
|
| 108 |
+
transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
|
| 109 |
+
),
|
| 110 |
+
LayerNorm2d(transformer_dim // 4),
|
| 111 |
+
activation(),
|
| 112 |
+
nn.ConvTranspose2d(
|
| 113 |
+
transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
|
| 114 |
+
),
|
| 115 |
+
activation(),
|
| 116 |
+
)
|
| 117 |
+
self.output_hypernetworks_mlps = nn.ModuleList(
|
| 118 |
+
[
|
| 119 |
+
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
|
| 120 |
+
for i in range(self.num_mask_tokens)
|
| 121 |
+
]
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
self.iou_prediction_head = MLP(
|
| 125 |
+
transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
self.category_prediction_head = Classifier(
|
| 129 |
+
transformer_dim, transformer_dim//4, category_num
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
def forward(
|
| 133 |
+
self,
|
| 134 |
+
image_embeddings: torch.Tensor,
|
| 135 |
+
image_pe: torch.Tensor,
|
| 136 |
+
sparse_prompt_embeddings: torch.Tensor,
|
| 137 |
+
dense_prompt_embeddings: torch.Tensor,
|
| 138 |
+
multimask_output: bool,
|
| 139 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 140 |
+
"""
|
| 141 |
+
Predict masks given image and prompt embeddings.
|
| 142 |
+
|
| 143 |
+
Arguments:
|
| 144 |
+
image_embeddings (torch.Tensor): the embeddings from the image encoder
|
| 145 |
+
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
|
| 146 |
+
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
|
| 147 |
+
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
|
| 148 |
+
multimask_output (bool): Whether to return multiple masks or a single
|
| 149 |
+
mask.
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
torch.Tensor: batched predicted masks
|
| 153 |
+
torch.Tensor: batched predictions of mask quality
|
| 154 |
+
"""
|
| 155 |
+
masks, iou_pred, category_pred, clip_tokens_out, image_tokens_out = self.predict_masks(
|
| 156 |
+
image_embeddings=image_embeddings,
|
| 157 |
+
image_pe=image_pe,
|
| 158 |
+
sparse_prompt_embeddings=sparse_prompt_embeddings,
|
| 159 |
+
dense_prompt_embeddings=dense_prompt_embeddings,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# Select the correct mask or masks for output
|
| 163 |
+
if multimask_output:
|
| 164 |
+
mask_slice = slice(1, None)
|
| 165 |
+
else:
|
| 166 |
+
mask_slice = slice(0, 1)
|
| 167 |
+
masks = masks[:, mask_slice, :, :]
|
| 168 |
+
iou_pred = iou_pred[:, mask_slice]
|
| 169 |
+
|
| 170 |
+
# Prepare output
|
| 171 |
+
return masks, iou_pred, category_pred, clip_tokens_out, image_tokens_out
|
| 172 |
+
|
| 173 |
+
def predict_masks(
|
| 174 |
+
self,
|
| 175 |
+
image_embeddings: torch.Tensor,
|
| 176 |
+
image_pe: torch.Tensor,
|
| 177 |
+
sparse_prompt_embeddings: torch.Tensor,
|
| 178 |
+
dense_prompt_embeddings: torch.Tensor,
|
| 179 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 180 |
+
"""Predicts masks. See 'forward' for more details."""
|
| 181 |
+
# Concatenate output tokens
|
| 182 |
+
output_tokens = torch.cat(
|
| 183 |
+
[self.iou_token.weight, self.mask_tokens.weight], dim=0
|
| 184 |
+
)
|
| 185 |
+
output_tokens = output_tokens.unsqueeze(0).expand(
|
| 186 |
+
sparse_prompt_embeddings.size(0), -1, -1
|
| 187 |
+
)
|
| 188 |
+
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
|
| 189 |
+
|
| 190 |
+
# Expand per-image data in batch direction to be per-mask
|
| 191 |
+
if image_embeddings.shape[0] != tokens.shape[0]:
|
| 192 |
+
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
|
| 193 |
+
else:
|
| 194 |
+
src = image_embeddings
|
| 195 |
+
src = src + dense_prompt_embeddings
|
| 196 |
+
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
| 197 |
+
b, c, h, w = src.shape
|
| 198 |
+
|
| 199 |
+
# Run the transformer
|
| 200 |
+
hs, src = self.transformer(src, pos_src, tokens)
|
| 201 |
+
iou_token_out = hs[:, 0, :]
|
| 202 |
+
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
|
| 203 |
+
|
| 204 |
+
# Upscale mask embeddings and predict masks using the mask tokens
|
| 205 |
+
src = src.transpose(1, 2).view(b, c, h, w)
|
| 206 |
+
if self.contents:
|
| 207 |
+
clip_tokens_out = tokens[:,-2,:]
|
| 208 |
+
image_tokens_out = F.adaptive_avg_pool2d(dense_prompt_embeddings, output_size=(1, 1)).squeeze(-1).squeeze(-1)
|
| 209 |
+
clip_new_out = hs[:,-2,:].unsqueeze(-1).unsqueeze(-1)
|
| 210 |
+
src = dense_prompt_embeddings+src+clip_new_out
|
| 211 |
+
src = self.convs(src)
|
| 212 |
+
else:
|
| 213 |
+
clip_tokens_out = None
|
| 214 |
+
image_tokens_out = None
|
| 215 |
+
|
| 216 |
+
if self.modality:
|
| 217 |
+
category_tokens_out = hs[:,-1,:]
|
| 218 |
+
wc = self.w_lin(category_tokens_out).unsqueeze(-1).unsqueeze(-1)
|
| 219 |
+
bc = self.b_lin(category_tokens_out).unsqueeze(-1).unsqueeze(-1)
|
| 220 |
+
src = wc*src+bc+src
|
| 221 |
+
category_pred = self.category_prediction_head(category_tokens_out)
|
| 222 |
+
else:
|
| 223 |
+
category_pred = None
|
| 224 |
+
|
| 225 |
+
upscaled_embedding = self.output_upscaling(src)
|
| 226 |
+
hyper_in_list: List[torch.Tensor] = []
|
| 227 |
+
for i in range(self.num_mask_tokens):
|
| 228 |
+
hyper_in_list.append(
|
| 229 |
+
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
|
| 230 |
+
)
|
| 231 |
+
hyper_in = torch.stack(hyper_in_list, dim=1)
|
| 232 |
+
b, c, h, w = upscaled_embedding.shape
|
| 233 |
+
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
|
| 234 |
+
|
| 235 |
+
# Generate mask quality predictions
|
| 236 |
+
iou_pred = self.iou_prediction_head(iou_token_out)
|
| 237 |
+
|
| 238 |
+
return masks, iou_pred, category_pred, clip_tokens_out, image_tokens_out
|
| 239 |
+
|
| 240 |
+
# Lightly adapted from
|
| 241 |
+
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
|
| 242 |
+
class MLP(nn.Module):
|
| 243 |
+
def __init__(
|
| 244 |
+
self,
|
| 245 |
+
input_dim: int,
|
| 246 |
+
hidden_dim: int,
|
| 247 |
+
output_dim: int,
|
| 248 |
+
num_layers: int,
|
| 249 |
+
sigmoid_output: bool = False,
|
| 250 |
+
) -> None:
|
| 251 |
+
super().__init__()
|
| 252 |
+
self.num_layers = num_layers
|
| 253 |
+
h = [hidden_dim] * (num_layers - 1)
|
| 254 |
+
self.layers = nn.ModuleList(
|
| 255 |
+
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
|
| 256 |
+
)
|
| 257 |
+
self.sigmoid_output = sigmoid_output
|
| 258 |
+
|
| 259 |
+
def forward(self, x):
|
| 260 |
+
for i, layer in enumerate(self.layers):
|
| 261 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
| 262 |
+
if self.sigmoid_output:
|
| 263 |
+
x = F.sigmoid(x)
|
| 264 |
+
return x
|
| 265 |
+
|
| 266 |
+
class MaskDecoder_F4(nn.Module):
|
| 267 |
+
def __init__(
|
| 268 |
+
self,
|
| 269 |
+
*,
|
| 270 |
+
transformer_dim: int,
|
| 271 |
+
transformer: nn.Module,
|
| 272 |
+
modality,
|
| 273 |
+
contents,
|
| 274 |
+
num_multimask_outputs: int = 3,
|
| 275 |
+
activation: Type[nn.Module] = nn.GELU,
|
| 276 |
+
iou_head_depth: int = 3,
|
| 277 |
+
iou_head_hidden_dim: int = 256,
|
| 278 |
+
category_num = 11
|
| 279 |
+
) -> None:
|
| 280 |
+
"""
|
| 281 |
+
Predicts masks given an image and prompt embeddings, using a
|
| 282 |
+
transformer architecture.
|
| 283 |
+
|
| 284 |
+
Arguments:
|
| 285 |
+
transformer_dim (int): the channel dimension of the transformer
|
| 286 |
+
transformer (nn.Module): the transformer used to predict masks
|
| 287 |
+
num_multimask_outputs (int): the number of masks to predict
|
| 288 |
+
when disambiguating masks
|
| 289 |
+
activation (nn.Module): the type of activation to use when
|
| 290 |
+
upscaling masks
|
| 291 |
+
iou_head_depth (int): the depth of the MLP used to predict
|
| 292 |
+
mask quality
|
| 293 |
+
iou_head_hidden_dim (int): the hidden dimension of the MLP
|
| 294 |
+
used to predict mask quality
|
| 295 |
+
"""
|
| 296 |
+
super().__init__()
|
| 297 |
+
self.transformer_dim = transformer_dim
|
| 298 |
+
self.transformer = transformer
|
| 299 |
+
self.category_num = category_num
|
| 300 |
+
self.modality = modality
|
| 301 |
+
self.contents = contents
|
| 302 |
+
|
| 303 |
+
self.num_multimask_outputs = num_multimask_outputs
|
| 304 |
+
|
| 305 |
+
self.iou_token = nn.Embedding(1, transformer_dim)
|
| 306 |
+
self.num_mask_tokens = num_multimask_outputs + 1
|
| 307 |
+
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
|
| 308 |
+
|
| 309 |
+
self.convs = Block(transformer_dim, transformer_dim)
|
| 310 |
+
self.conv1 = nn.Conv2d(transformer_dim*2, transformer_dim, 1)
|
| 311 |
+
self.c_conv = Block(transformer_dim, transformer_dim)
|
| 312 |
+
self.w_lin = nn.Linear(transformer_dim, transformer_dim)
|
| 313 |
+
self.b_lin = nn.Linear(transformer_dim, transformer_dim)
|
| 314 |
+
self.m_conv = Block(transformer_dim, transformer_dim)
|
| 315 |
+
|
| 316 |
+
self.output_upscaling = nn.Sequential(
|
| 317 |
+
nn.ConvTranspose2d(
|
| 318 |
+
transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
|
| 319 |
+
),
|
| 320 |
+
LayerNorm2d(transformer_dim // 4),
|
| 321 |
+
activation(),
|
| 322 |
+
nn.ConvTranspose2d(
|
| 323 |
+
transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
|
| 324 |
+
),
|
| 325 |
+
activation(),
|
| 326 |
+
)
|
| 327 |
+
self.output_hypernetworks_mlps = nn.ModuleList(
|
| 328 |
+
[
|
| 329 |
+
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
|
| 330 |
+
for i in range(self.num_mask_tokens)
|
| 331 |
+
]
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
self.iou_prediction_head = MLP(
|
| 335 |
+
transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
# self.category_prediction_head = Classifier(
|
| 339 |
+
# transformer_dim, transformer_dim//4, category_num
|
| 340 |
+
# )
|
| 341 |
+
self.category_prediction_head = Classifier(
|
| 342 |
+
transformer_dim, transformer_dim//4, category_num
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
def forward(
|
| 346 |
+
self,
|
| 347 |
+
image_embeddings: torch.Tensor,
|
| 348 |
+
image_pe: torch.Tensor,
|
| 349 |
+
sparse_prompt_embeddings: torch.Tensor,
|
| 350 |
+
dense_prompt_embeddings: torch.Tensor,
|
| 351 |
+
multimask_output: bool,
|
| 352 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 353 |
+
"""
|
| 354 |
+
Predict masks given image and prompt embeddings.
|
| 355 |
+
|
| 356 |
+
Arguments:
|
| 357 |
+
image_embeddings (torch.Tensor): the embeddings from the image encoder
|
| 358 |
+
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
|
| 359 |
+
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
|
| 360 |
+
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
|
| 361 |
+
multimask_output (bool): Whether to return multiple masks or a single
|
| 362 |
+
mask.
|
| 363 |
+
|
| 364 |
+
Returns:
|
| 365 |
+
torch.Tensor: batched predicted masks
|
| 366 |
+
torch.Tensor: batched predictions of mask quality
|
| 367 |
+
"""
|
| 368 |
+
masks, iou_pred, category_pred, clip_tokens_out, image_tokens_out = self.predict_masks(
|
| 369 |
+
image_embeddings=image_embeddings,
|
| 370 |
+
image_pe=image_pe,
|
| 371 |
+
sparse_prompt_embeddings=sparse_prompt_embeddings,
|
| 372 |
+
dense_prompt_embeddings=dense_prompt_embeddings,
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
# Select the correct mask or masks for output
|
| 376 |
+
if multimask_output:
|
| 377 |
+
mask_slice = slice(1, None)
|
| 378 |
+
else:
|
| 379 |
+
mask_slice = slice(0, 1)
|
| 380 |
+
masks = masks[:, mask_slice, :, :]
|
| 381 |
+
iou_pred = iou_pred[:, mask_slice]
|
| 382 |
+
|
| 383 |
+
# Prepare output
|
| 384 |
+
return masks, iou_pred, category_pred, clip_tokens_out, image_tokens_out
|
| 385 |
+
|
| 386 |
+
def predict_masks(
|
| 387 |
+
self,
|
| 388 |
+
image_embeddings: torch.Tensor,
|
| 389 |
+
image_pe: torch.Tensor,
|
| 390 |
+
sparse_prompt_embeddings: torch.Tensor,
|
| 391 |
+
dense_prompt_embeddings: torch.Tensor,
|
| 392 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 393 |
+
"""Predicts masks. See 'forward' for more details."""
|
| 394 |
+
# Concatenate output tokens
|
| 395 |
+
output_tokens = torch.cat(
|
| 396 |
+
[self.iou_token.weight, self.mask_tokens.weight], dim=0
|
| 397 |
+
)
|
| 398 |
+
output_tokens = output_tokens.unsqueeze(0).expand(
|
| 399 |
+
sparse_prompt_embeddings.size(0), -1, -1
|
| 400 |
+
)
|
| 401 |
+
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
|
| 402 |
+
m_token = tokens[:,-1,:]
|
| 403 |
+
|
| 404 |
+
# Expand per-image data in batch direction to be per-mask
|
| 405 |
+
if image_embeddings.shape[0] != tokens.shape[0]:
|
| 406 |
+
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
|
| 407 |
+
else:
|
| 408 |
+
src = image_embeddings
|
| 409 |
+
src = src + dense_prompt_embeddings
|
| 410 |
+
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
| 411 |
+
b, c, h, w = src.shape
|
| 412 |
+
|
| 413 |
+
# Run the transformer
|
| 414 |
+
hs, src = self.transformer(src, pos_src, tokens)
|
| 415 |
+
iou_token_out = hs[:, 0, :]
|
| 416 |
+
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
|
| 417 |
+
|
| 418 |
+
# Upscale mask embeddings and predict masks using the mask tokens
|
| 419 |
+
src = src.transpose(1, 2).view(b, c, h, w)
|
| 420 |
+
|
| 421 |
+
if self.modality:
|
| 422 |
+
category_tokens_out = hs[:,-1,:]
|
| 423 |
+
wc = self.w_lin(category_tokens_out).unsqueeze(-1).unsqueeze(-1)
|
| 424 |
+
bc = self.b_lin(category_tokens_out).unsqueeze(-1).unsqueeze(-1)
|
| 425 |
+
src_m = wc*src+bc+src
|
| 426 |
+
m_info = wc.squeeze(-1).squeeze(-1)+bc.squeeze(-1).squeeze(-1)+category_tokens_out
|
| 427 |
+
category_pred = self.category_prediction_head(m_info)
|
| 428 |
+
src_m = self.m_conv(src_m)
|
| 429 |
+
else:
|
| 430 |
+
category_pred = None
|
| 431 |
+
|
| 432 |
+
if self.contents:
|
| 433 |
+
clip_tokens_out = tokens[:,-2,:]
|
| 434 |
+
image_tokens_out = F.adaptive_avg_pool2d(dense_prompt_embeddings, output_size=(1, 1)).squeeze(-1).squeeze(-1)
|
| 435 |
+
clip_new_out = hs[:,-2,:].unsqueeze(-1).unsqueeze(-1)
|
| 436 |
+
src_vp = dense_prompt_embeddings+src+clip_new_out
|
| 437 |
+
src_vp = self.convs(src_vp)
|
| 438 |
+
else:
|
| 439 |
+
clip_tokens_out = None
|
| 440 |
+
image_tokens_out = None
|
| 441 |
+
|
| 442 |
+
if self.contents and self.modality:
|
| 443 |
+
src = torch.cat((src_m, src_vp), dim=1)
|
| 444 |
+
src = self.conv1(src)
|
| 445 |
+
src = self.c_conv(src)
|
| 446 |
+
elif self.contents:
|
| 447 |
+
src = src_vp
|
| 448 |
+
elif self.modality:
|
| 449 |
+
src = src_m
|
| 450 |
+
|
| 451 |
+
upscaled_embedding = self.output_upscaling(src)
|
| 452 |
+
hyper_in_list: List[torch.Tensor] = []
|
| 453 |
+
for i in range(self.num_mask_tokens):
|
| 454 |
+
hyper_in_list.append(
|
| 455 |
+
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
|
| 456 |
+
)
|
| 457 |
+
hyper_in = torch.stack(hyper_in_list, dim=1)
|
| 458 |
+
b, c, h, w = upscaled_embedding.shape
|
| 459 |
+
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
|
| 460 |
+
|
| 461 |
+
# Generate mask quality predictions
|
| 462 |
+
iou_pred = self.iou_prediction_head(iou_token_out)
|
| 463 |
+
|
| 464 |
+
return masks, iou_pred, category_pred, clip_tokens_out, image_tokens_out
|
| 465 |
+
|
models/prompt_encoder.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from typing import Any, Optional, Tuple, Type
|
| 13 |
+
|
| 14 |
+
from .common import LayerNorm2d
|
| 15 |
+
|
| 16 |
+
class PositionEmbeddingRandom(nn.Module):
|
| 17 |
+
"""
|
| 18 |
+
Positional encoding using random spatial frequencies.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
|
| 22 |
+
super().__init__()
|
| 23 |
+
if scale is None or scale <= 0.0:
|
| 24 |
+
scale = 1.0
|
| 25 |
+
self.register_buffer(
|
| 26 |
+
"positional_encoding_gaussian_matrix",
|
| 27 |
+
scale * torch.randn((2, num_pos_feats)),
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
|
| 31 |
+
"""Positionally encode points that are normalized to [0,1]."""
|
| 32 |
+
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
| 33 |
+
coords = 2 * coords - 1
|
| 34 |
+
coords = coords @ self.positional_encoding_gaussian_matrix
|
| 35 |
+
coords = 2 * np.pi * coords
|
| 36 |
+
# outputs d_1 x ... x d_n x C shape
|
| 37 |
+
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
|
| 38 |
+
|
| 39 |
+
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
|
| 40 |
+
"""Generate positional encoding for a grid of the specified size."""
|
| 41 |
+
h, w = size
|
| 42 |
+
device: Any = self.positional_encoding_gaussian_matrix.device
|
| 43 |
+
grid = torch.ones((h, w), device=device, dtype=torch.float32)
|
| 44 |
+
y_embed = grid.cumsum(dim=0) - 0.5
|
| 45 |
+
x_embed = grid.cumsum(dim=1) - 0.5
|
| 46 |
+
y_embed = y_embed / h
|
| 47 |
+
x_embed = x_embed / w
|
| 48 |
+
|
| 49 |
+
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
|
| 50 |
+
return pe.permute(2, 0, 1) # C x H x W
|
| 51 |
+
|
| 52 |
+
def forward_with_coords(
|
| 53 |
+
self, coords_input: torch.Tensor, image_size: Tuple[int, int]
|
| 54 |
+
) -> torch.Tensor:
|
| 55 |
+
"""Positionally encode points that are not normalized to [0,1]."""
|
| 56 |
+
coords = coords_input.clone()
|
| 57 |
+
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
|
| 58 |
+
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
|
| 59 |
+
return self._pe_encoding(coords.to(torch.float)) # B x N x C
|
| 60 |
+
|
| 61 |
+
class Block(nn.Module):
|
| 62 |
+
def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
|
| 63 |
+
super(Block, self).__init__()
|
| 64 |
+
|
| 65 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False)
|
| 66 |
+
self.batch_norm1 = nn.BatchNorm2d(out_channels)
|
| 67 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False)
|
| 68 |
+
self.batch_norm2 = nn.BatchNorm2d(out_channels)
|
| 69 |
+
|
| 70 |
+
self.i_downsample = i_downsample
|
| 71 |
+
self.stride = stride
|
| 72 |
+
self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
| 73 |
+
|
| 74 |
+
def forward(self, x):
|
| 75 |
+
identity = x.clone()
|
| 76 |
+
|
| 77 |
+
x = self.relu(self.batch_norm1(self.conv1(x)))
|
| 78 |
+
x = self.batch_norm2(self.conv2(x))
|
| 79 |
+
|
| 80 |
+
if self.i_downsample is not None:
|
| 81 |
+
identity = self.i_downsample(identity)
|
| 82 |
+
|
| 83 |
+
x += identity
|
| 84 |
+
x = self.relu(x)
|
| 85 |
+
return x
|
| 86 |
+
|
| 87 |
+
class Crop_Net_New(nn.Module):
|
| 88 |
+
def __init__(self, dim):
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.conv = nn.Conv2d(3, dim, 3, 1, 1)
|
| 91 |
+
|
| 92 |
+
self.conv1 = Block(dim, dim)
|
| 93 |
+
self.conv2 = Block(dim, dim)
|
| 94 |
+
self.conv3 = Block(dim, dim)
|
| 95 |
+
|
| 96 |
+
self.conv4 = nn.Conv2d(dim, dim, 5, 1, 2)
|
| 97 |
+
|
| 98 |
+
def forward(self, x):
|
| 99 |
+
x = self.conv(x)
|
| 100 |
+
x = self.conv1(x)
|
| 101 |
+
x = self.conv2(x)
|
| 102 |
+
x = self.conv3(x)
|
| 103 |
+
return self.conv4(x)
|
| 104 |
+
|
| 105 |
+
class Mlp(nn.Module):
|
| 106 |
+
def __init__(self, in_dim, hid_dim=None, out_dim=None, act=nn.GELU, drop=0.):
|
| 107 |
+
super().__init__()
|
| 108 |
+
out_dim = out_dim or in_dim
|
| 109 |
+
hid_dim = hid_dim or in_dim
|
| 110 |
+
self.fc1 = nn.Linear(in_dim, hid_dim)
|
| 111 |
+
self.act = act()
|
| 112 |
+
self.fc2 = nn.Linear(hid_dim, out_dim)
|
| 113 |
+
self.drop = nn.Dropout(drop)
|
| 114 |
+
|
| 115 |
+
def forward(self, x):
|
| 116 |
+
x = self.fc1(x)
|
| 117 |
+
x = self.act(x)
|
| 118 |
+
x = self.drop(x)
|
| 119 |
+
x = self.fc2(x)
|
| 120 |
+
x = self.drop(x)
|
| 121 |
+
return x
|
| 122 |
+
|
| 123 |
+
class PromptEncoder(nn.Module):
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
embed_dim: int,
|
| 127 |
+
image_embedding_size: Tuple[int, int],
|
| 128 |
+
input_image_size: Tuple[int, int],
|
| 129 |
+
mask_in_chans: int,
|
| 130 |
+
activation: Type[nn.Module] = nn.GELU,
|
| 131 |
+
) -> None:
|
| 132 |
+
"""
|
| 133 |
+
Encodes prompts for input to SAM's mask decoder.
|
| 134 |
+
|
| 135 |
+
Arguments:
|
| 136 |
+
embed_dim (int): The prompts' embedding dimension
|
| 137 |
+
image_embedding_size (tuple(int, int)): The spatial size of the
|
| 138 |
+
image embedding, as (H, W).
|
| 139 |
+
input_image_size (int): The padded size of the image as input
|
| 140 |
+
to the image encoder, as (H, W).
|
| 141 |
+
mask_in_chans (int): The number of hidden channels used for
|
| 142 |
+
encoding input masks.
|
| 143 |
+
activation (nn.Module): The activation to use when encoding
|
| 144 |
+
input masks.
|
| 145 |
+
"""
|
| 146 |
+
super().__init__()
|
| 147 |
+
self.embed_dim = embed_dim
|
| 148 |
+
self.input_image_size = input_image_size
|
| 149 |
+
self.image_embedding_size = image_embedding_size
|
| 150 |
+
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
|
| 151 |
+
|
| 152 |
+
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
|
| 153 |
+
point_embeddings = [
|
| 154 |
+
nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
|
| 155 |
+
]
|
| 156 |
+
self.point_embeddings = nn.ModuleList(point_embeddings)
|
| 157 |
+
self.not_a_point_embed = nn.Embedding(1, embed_dim)
|
| 158 |
+
|
| 159 |
+
self.mask_input_size = (
|
| 160 |
+
4 * image_embedding_size[0],
|
| 161 |
+
4 * image_embedding_size[1],
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
self.no_mask_embed = nn.Embedding(1, embed_dim)
|
| 165 |
+
|
| 166 |
+
self.crop_nets = Crop_Net_New(embed_dim)
|
| 167 |
+
|
| 168 |
+
self.clip_img_mlp = Mlp(in_dim=512, hid_dim=256, out_dim=256)
|
| 169 |
+
self.clip_text_mlp = Mlp(in_dim=512, hid_dim=256, out_dim=256)
|
| 170 |
+
self.mlps = Mlp(in_dim=512, hid_dim=512, out_dim=256)
|
| 171 |
+
|
| 172 |
+
self.categories = nn.Embedding(11, 256)
|
| 173 |
+
|
| 174 |
+
def get_dense_pe(self) -> torch.Tensor:
|
| 175 |
+
"""
|
| 176 |
+
Returns the positional encoding used to encode point prompts,
|
| 177 |
+
applied to a dense set of points the shape of the image encoding.
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
torch.Tensor: Positional encoding with shape
|
| 181 |
+
1x(embed_dim)x(embedding_h)x(embedding_w)
|
| 182 |
+
"""
|
| 183 |
+
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
|
| 184 |
+
|
| 185 |
+
def _embed_points(
|
| 186 |
+
self,
|
| 187 |
+
points: torch.Tensor,
|
| 188 |
+
labels: torch.Tensor,
|
| 189 |
+
pad: bool,
|
| 190 |
+
) -> torch.Tensor:
|
| 191 |
+
"""Embeds point prompts."""
|
| 192 |
+
points = points + 0.5 # Shift to center of pixel
|
| 193 |
+
if pad:
|
| 194 |
+
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
|
| 195 |
+
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
|
| 196 |
+
points = torch.cat([points, padding_point], dim=1)
|
| 197 |
+
labels = torch.cat([labels, padding_label], dim=1)
|
| 198 |
+
point_embedding = self.pe_layer.forward_with_coords(
|
| 199 |
+
points, self.input_image_size
|
| 200 |
+
)
|
| 201 |
+
point_embedding[labels == -1] = 0.0
|
| 202 |
+
point_embedding[labels == -1] += self.not_a_point_embed.weight
|
| 203 |
+
point_embedding[labels == 0] += self.point_embeddings[0].weight
|
| 204 |
+
point_embedding[labels == 1] += self.point_embeddings[1].weight
|
| 205 |
+
return point_embedding
|
| 206 |
+
|
| 207 |
+
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
| 208 |
+
"""Embeds box prompts."""
|
| 209 |
+
boxes = boxes + 0.5 # Shift to center of pixel
|
| 210 |
+
coords = boxes.reshape(-1, 2, 2)
|
| 211 |
+
corner_embedding = self.pe_layer.forward_with_coords(
|
| 212 |
+
coords, self.input_image_size
|
| 213 |
+
)
|
| 214 |
+
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
|
| 215 |
+
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
|
| 216 |
+
return corner_embedding
|
| 217 |
+
|
| 218 |
+
# def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
|
| 219 |
+
# """Embeds mask inputs."""
|
| 220 |
+
# mask_embedding = self.mask_downscaling(masks)
|
| 221 |
+
# return mask_embedding
|
| 222 |
+
|
| 223 |
+
def _get_batch_size(
|
| 224 |
+
self,
|
| 225 |
+
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
| 226 |
+
boxes: Optional[torch.Tensor],
|
| 227 |
+
masks: Optional[torch.Tensor],
|
| 228 |
+
) -> int:
|
| 229 |
+
"""
|
| 230 |
+
Gets the batch size of the output given the batch size of the input prompts.
|
| 231 |
+
"""
|
| 232 |
+
if points is not None:
|
| 233 |
+
return points[0].shape[0]
|
| 234 |
+
elif boxes is not None:
|
| 235 |
+
return boxes.shape[0]
|
| 236 |
+
# elif tokens is not None:
|
| 237 |
+
# return tokens.shape[0]
|
| 238 |
+
elif masks is not None:
|
| 239 |
+
return masks.shape[0]
|
| 240 |
+
else:
|
| 241 |
+
return 1
|
| 242 |
+
|
| 243 |
+
def _get_device(self) -> torch.device:
|
| 244 |
+
return self.point_embeddings[0].weight.device
|
| 245 |
+
|
| 246 |
+
def forward(
|
| 247 |
+
self,
|
| 248 |
+
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
| 249 |
+
boxes: Optional[torch.Tensor],
|
| 250 |
+
masks,
|
| 251 |
+
features,
|
| 252 |
+
crops,
|
| 253 |
+
text_features,
|
| 254 |
+
category_idx
|
| 255 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 256 |
+
"""
|
| 257 |
+
Embeds different types of prompts, returning both sparse and dense
|
| 258 |
+
embeddings.
|
| 259 |
+
|
| 260 |
+
Arguments:
|
| 261 |
+
points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
|
| 262 |
+
and labels to embed.
|
| 263 |
+
boxes (torch.Tensor or none): boxes to embed
|
| 264 |
+
masks (torch.Tensor or none): masks to embed
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
torch.Tensor: sparse embeddings for the points and boxes, with shape
|
| 268 |
+
BxNx(embed_dim), where N is determined by the number of input points
|
| 269 |
+
and boxes.
|
| 270 |
+
torch.Tensor: dense embeddings for the masks, in the shape
|
| 271 |
+
Bx(embed_dim)x(embed_H)x(embed_W)
|
| 272 |
+
"""
|
| 273 |
+
bs = self._get_batch_size(points, boxes, masks)
|
| 274 |
+
sparse_embeddings = torch.empty(
|
| 275 |
+
(bs, 0, self.embed_dim), device=self._get_device()
|
| 276 |
+
)
|
| 277 |
+
if points is not None:
|
| 278 |
+
coords, labels = points
|
| 279 |
+
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
|
| 280 |
+
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
|
| 281 |
+
|
| 282 |
+
if boxes is not None:
|
| 283 |
+
box_embeddings = self._embed_boxes(boxes)
|
| 284 |
+
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
|
| 285 |
+
|
| 286 |
+
if features is not None:
|
| 287 |
+
clip_embeddings = self.clip_img_mlp(features)
|
| 288 |
+
sparse_embeddings = torch.cat([sparse_embeddings, clip_embeddings], dim=1)
|
| 289 |
+
|
| 290 |
+
if category_idx is not None:
|
| 291 |
+
text_embeddings = self.clip_text_mlp(text_features)
|
| 292 |
+
category_embeddings = torch.zeros((bs, 1, 256)).to(boxes.device)
|
| 293 |
+
for i in range(bs):
|
| 294 |
+
category_embeddings[i,0,:] = self.categories(category_idx[i].long())
|
| 295 |
+
modality_embeddings = torch.cat((text_embeddings, category_embeddings), dim=-1)
|
| 296 |
+
text_embeddings = self.mlps(modality_embeddings)
|
| 297 |
+
sparse_embeddings = torch.cat([sparse_embeddings, text_embeddings], dim=1)
|
| 298 |
+
|
| 299 |
+
if crops is not None:
|
| 300 |
+
dense_embeddings = self.crop_nets(crops)
|
| 301 |
+
else:
|
| 302 |
+
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
|
| 303 |
+
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
return sparse_embeddings, dense_embeddings
|
models/tiny_vit.py
ADDED
|
@@ -0,0 +1,645 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# TinyViT Model Architecture
|
| 3 |
+
# Copyright (c) 2022 Microsoft
|
| 4 |
+
# Adapted from LeViT and Swin Transformer
|
| 5 |
+
# LeViT: (https://github.com/facebookresearch/levit)
|
| 6 |
+
# Swin: (https://github.com/microsoft/swin-transformer)
|
| 7 |
+
# Build the TinyViT Model
|
| 8 |
+
# --------------------------------------------------------
|
| 9 |
+
# The TinyViT model is adapted from MobileSAM's variant.
|
| 10 |
+
# --------------------------------------------------------
|
| 11 |
+
|
| 12 |
+
import itertools
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
import torch.utils.checkpoint as checkpoint
|
| 17 |
+
from timm.models.layers import DropPath as TimmDropPath,\
|
| 18 |
+
to_2tuple, trunc_normal_
|
| 19 |
+
from typing import Tuple
|
| 20 |
+
|
| 21 |
+
class Conv2d_BN(torch.nn.Sequential):
|
| 22 |
+
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
|
| 23 |
+
groups=1, bn_weight_init=1):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.add_module('c', torch.nn.Conv2d(
|
| 26 |
+
a, b, ks, stride, pad, dilation, groups, bias=False))
|
| 27 |
+
bn = torch.nn.BatchNorm2d(b)
|
| 28 |
+
torch.nn.init.constant_(bn.weight, bn_weight_init)
|
| 29 |
+
torch.nn.init.constant_(bn.bias, 0)
|
| 30 |
+
self.add_module('bn', bn)
|
| 31 |
+
|
| 32 |
+
@torch.no_grad()
|
| 33 |
+
def fuse(self):
|
| 34 |
+
c, bn = self._modules.values()
|
| 35 |
+
w = bn.weight / (bn.running_var + bn.eps)**0.5
|
| 36 |
+
w = c.weight * w[:, None, None, None]
|
| 37 |
+
b = bn.bias - bn.running_mean * bn.weight / \
|
| 38 |
+
(bn.running_var + bn.eps)**0.5
|
| 39 |
+
m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
|
| 40 |
+
0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
|
| 41 |
+
m.weight.data.copy_(w)
|
| 42 |
+
m.bias.data.copy_(b)
|
| 43 |
+
return m
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class DropPath(TimmDropPath):
|
| 47 |
+
def __init__(self, drop_prob=None):
|
| 48 |
+
super().__init__(drop_prob=drop_prob)
|
| 49 |
+
self.drop_prob = drop_prob
|
| 50 |
+
|
| 51 |
+
def __repr__(self):
|
| 52 |
+
msg = super().__repr__()
|
| 53 |
+
msg += f'(drop_prob={self.drop_prob})'
|
| 54 |
+
return msg
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class PatchEmbed(nn.Module):
|
| 58 |
+
def __init__(self, in_chans, embed_dim, resolution, activation):
|
| 59 |
+
super().__init__()
|
| 60 |
+
img_size: Tuple[int, int] = to_2tuple(resolution)
|
| 61 |
+
#self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
|
| 62 |
+
self.patches_resolution = img_size
|
| 63 |
+
self.num_patches = self.patches_resolution[0] * \
|
| 64 |
+
self.patches_resolution[1]
|
| 65 |
+
self.in_chans = in_chans
|
| 66 |
+
self.embed_dim = embed_dim
|
| 67 |
+
n = embed_dim
|
| 68 |
+
#self.seq = nn.Sequential(
|
| 69 |
+
# Conv2d_BN(in_chans, n // 2, 3, 2, 1),
|
| 70 |
+
# activation(),
|
| 71 |
+
# Conv2d_BN(n // 2, n, 3, 2, 1),
|
| 72 |
+
#)
|
| 73 |
+
self.seq = nn.Sequential(
|
| 74 |
+
Conv2d_BN(in_chans, n // 2, 1, 1, 0),
|
| 75 |
+
activation(),
|
| 76 |
+
Conv2d_BN(n // 2, n, 1, 1, 0),
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def forward(self, x):
|
| 80 |
+
return self.seq(x)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class MBConv(nn.Module):
|
| 84 |
+
def __init__(self, in_chans, out_chans, expand_ratio,
|
| 85 |
+
activation, drop_path):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.in_chans = in_chans
|
| 88 |
+
self.hidden_chans = int(in_chans * expand_ratio)
|
| 89 |
+
self.out_chans = out_chans
|
| 90 |
+
|
| 91 |
+
self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1)
|
| 92 |
+
self.act1 = activation()
|
| 93 |
+
|
| 94 |
+
self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans,
|
| 95 |
+
ks=3, stride=1, pad=1, groups=self.hidden_chans)
|
| 96 |
+
self.act2 = activation()
|
| 97 |
+
|
| 98 |
+
self.conv3 = Conv2d_BN(
|
| 99 |
+
self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0)
|
| 100 |
+
self.act3 = activation()
|
| 101 |
+
|
| 102 |
+
self.drop_path = DropPath(
|
| 103 |
+
drop_path) if drop_path > 0. else nn.Identity()
|
| 104 |
+
|
| 105 |
+
def forward(self, x):
|
| 106 |
+
shortcut = x
|
| 107 |
+
|
| 108 |
+
x = self.conv1(x)
|
| 109 |
+
x = self.act1(x)
|
| 110 |
+
|
| 111 |
+
x = self.conv2(x)
|
| 112 |
+
x = self.act2(x)
|
| 113 |
+
|
| 114 |
+
x = self.conv3(x)
|
| 115 |
+
|
| 116 |
+
x = self.drop_path(x)
|
| 117 |
+
|
| 118 |
+
x += shortcut
|
| 119 |
+
x = self.act3(x)
|
| 120 |
+
|
| 121 |
+
return x
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class PatchMerging(nn.Module):
|
| 125 |
+
def __init__(self, input_resolution, dim, out_dim, activation):
|
| 126 |
+
super().__init__()
|
| 127 |
+
|
| 128 |
+
self.input_resolution = input_resolution
|
| 129 |
+
self.dim = dim
|
| 130 |
+
self.out_dim = out_dim
|
| 131 |
+
self.act = activation()
|
| 132 |
+
self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
|
| 133 |
+
stride_c=2
|
| 134 |
+
if(out_dim==320 or out_dim==448 or out_dim==576):
|
| 135 |
+
stride_c=1
|
| 136 |
+
self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
|
| 137 |
+
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
|
| 138 |
+
|
| 139 |
+
def forward(self, x):
|
| 140 |
+
if x.ndim == 3:
|
| 141 |
+
H, W = self.input_resolution
|
| 142 |
+
B = len(x)
|
| 143 |
+
# (B, C, H, W)
|
| 144 |
+
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
|
| 145 |
+
|
| 146 |
+
x = self.conv1(x)
|
| 147 |
+
x = self.act(x)
|
| 148 |
+
|
| 149 |
+
x = self.conv2(x)
|
| 150 |
+
x = self.act(x)
|
| 151 |
+
x = self.conv3(x)
|
| 152 |
+
x = x.flatten(2).transpose(1, 2)
|
| 153 |
+
return x
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class ConvLayer(nn.Module):
|
| 157 |
+
def __init__(self, dim, input_resolution, depth,
|
| 158 |
+
activation,
|
| 159 |
+
drop_path=0., downsample=None, use_checkpoint=False,
|
| 160 |
+
out_dim=None,
|
| 161 |
+
conv_expand_ratio=4.,
|
| 162 |
+
):
|
| 163 |
+
|
| 164 |
+
super().__init__()
|
| 165 |
+
self.dim = dim
|
| 166 |
+
self.input_resolution = input_resolution
|
| 167 |
+
self.depth = depth
|
| 168 |
+
self.use_checkpoint = use_checkpoint
|
| 169 |
+
|
| 170 |
+
# build blocks
|
| 171 |
+
self.blocks = nn.ModuleList([
|
| 172 |
+
MBConv(dim, dim, conv_expand_ratio, activation,
|
| 173 |
+
drop_path[i] if isinstance(drop_path, list) else drop_path,
|
| 174 |
+
)
|
| 175 |
+
for i in range(depth)])
|
| 176 |
+
|
| 177 |
+
# patch merging layer
|
| 178 |
+
if downsample is not None:
|
| 179 |
+
self.downsample = downsample(
|
| 180 |
+
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
| 181 |
+
else:
|
| 182 |
+
self.downsample = None
|
| 183 |
+
|
| 184 |
+
def forward(self, x):
|
| 185 |
+
for blk in self.blocks:
|
| 186 |
+
if self.use_checkpoint:
|
| 187 |
+
x = checkpoint.checkpoint(blk, x)
|
| 188 |
+
else:
|
| 189 |
+
x = blk(x)
|
| 190 |
+
if self.downsample is not None:
|
| 191 |
+
x = self.downsample(x)
|
| 192 |
+
return x
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class Mlp(nn.Module):
|
| 196 |
+
def __init__(self, in_features, hidden_features=None,
|
| 197 |
+
out_features=None, act_layer=nn.GELU, drop=0.):
|
| 198 |
+
super().__init__()
|
| 199 |
+
out_features = out_features or in_features
|
| 200 |
+
hidden_features = hidden_features or in_features
|
| 201 |
+
self.norm = nn.LayerNorm(in_features)
|
| 202 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 203 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 204 |
+
self.act = act_layer()
|
| 205 |
+
self.drop = nn.Dropout(drop)
|
| 206 |
+
|
| 207 |
+
def forward(self, x):
|
| 208 |
+
x = self.norm(x)
|
| 209 |
+
|
| 210 |
+
x = self.fc1(x)
|
| 211 |
+
x = self.act(x)
|
| 212 |
+
x = self.drop(x)
|
| 213 |
+
x = self.fc2(x)
|
| 214 |
+
x = self.drop(x)
|
| 215 |
+
return x
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class Attention(torch.nn.Module):
|
| 219 |
+
def __init__(self, dim, key_dim, num_heads=8,
|
| 220 |
+
attn_ratio=4,
|
| 221 |
+
resolution=(14, 14),
|
| 222 |
+
):
|
| 223 |
+
super().__init__()
|
| 224 |
+
# (h, w)
|
| 225 |
+
assert isinstance(resolution, tuple) and len(resolution) == 2
|
| 226 |
+
self.num_heads = num_heads
|
| 227 |
+
self.scale = key_dim ** -0.5
|
| 228 |
+
self.key_dim = key_dim
|
| 229 |
+
self.nh_kd = nh_kd = key_dim * num_heads
|
| 230 |
+
self.d = int(attn_ratio * key_dim)
|
| 231 |
+
self.dh = int(attn_ratio * key_dim) * num_heads
|
| 232 |
+
self.attn_ratio = attn_ratio
|
| 233 |
+
h = self.dh + nh_kd * 2
|
| 234 |
+
|
| 235 |
+
self.norm = nn.LayerNorm(dim)
|
| 236 |
+
self.qkv = nn.Linear(dim, h)
|
| 237 |
+
self.proj = nn.Linear(self.dh, dim)
|
| 238 |
+
|
| 239 |
+
points = list(itertools.product(
|
| 240 |
+
range(resolution[0]), range(resolution[1])))
|
| 241 |
+
N = len(points)
|
| 242 |
+
attention_offsets = {}
|
| 243 |
+
idxs = []
|
| 244 |
+
for p1 in points:
|
| 245 |
+
for p2 in points:
|
| 246 |
+
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
|
| 247 |
+
if offset not in attention_offsets:
|
| 248 |
+
attention_offsets[offset] = len(attention_offsets)
|
| 249 |
+
idxs.append(attention_offsets[offset])
|
| 250 |
+
self.attention_biases = torch.nn.Parameter(
|
| 251 |
+
torch.zeros(num_heads, len(attention_offsets)))
|
| 252 |
+
self.register_buffer('attention_bias_idxs',
|
| 253 |
+
torch.LongTensor(idxs).view(N, N),
|
| 254 |
+
persistent=False)
|
| 255 |
+
|
| 256 |
+
@torch.no_grad()
|
| 257 |
+
def train(self, mode=True):
|
| 258 |
+
super().train(mode)
|
| 259 |
+
if mode and hasattr(self, 'ab'):
|
| 260 |
+
del self.ab
|
| 261 |
+
else:
|
| 262 |
+
self.register_buffer('ab',
|
| 263 |
+
self.attention_biases[:, self.attention_bias_idxs],
|
| 264 |
+
persistent=False)
|
| 265 |
+
|
| 266 |
+
def forward(self, x): # x (B,N,C)
|
| 267 |
+
B, N, _ = x.shape
|
| 268 |
+
|
| 269 |
+
# Normalization
|
| 270 |
+
x = self.norm(x)
|
| 271 |
+
|
| 272 |
+
qkv = self.qkv(x)
|
| 273 |
+
# (B, N, num_heads, d)
|
| 274 |
+
q, k, v = qkv.view(B, N, self.num_heads, -
|
| 275 |
+
1).split([self.key_dim, self.key_dim, self.d], dim=3)
|
| 276 |
+
# (B, num_heads, N, d)
|
| 277 |
+
q = q.permute(0, 2, 1, 3)
|
| 278 |
+
k = k.permute(0, 2, 1, 3)
|
| 279 |
+
v = v.permute(0, 2, 1, 3)
|
| 280 |
+
|
| 281 |
+
attn = (
|
| 282 |
+
(q @ k.transpose(-2, -1)) * self.scale
|
| 283 |
+
+
|
| 284 |
+
(self.attention_biases[:, self.attention_bias_idxs]
|
| 285 |
+
if self.training else self.ab)
|
| 286 |
+
)
|
| 287 |
+
attn = attn.softmax(dim=-1)
|
| 288 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
|
| 289 |
+
x = self.proj(x)
|
| 290 |
+
return x
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class TinyViTBlock(nn.Module):
|
| 294 |
+
r""" TinyViT Block.
|
| 295 |
+
|
| 296 |
+
Args:
|
| 297 |
+
dim (int): Number of input channels.
|
| 298 |
+
input_resolution (tuple[int, int]): Input resolution.
|
| 299 |
+
num_heads (int): Number of attention heads.
|
| 300 |
+
window_size (int): Window size.
|
| 301 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 302 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 303 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
| 304 |
+
local_conv_size (int): the kernel size of the convolution between
|
| 305 |
+
Attention and MLP. Default: 3
|
| 306 |
+
activation: the activation function. Default: nn.GELU
|
| 307 |
+
"""
|
| 308 |
+
|
| 309 |
+
def __init__(self, dim, input_resolution, num_heads, window_size=7,
|
| 310 |
+
mlp_ratio=4., drop=0., drop_path=0.,
|
| 311 |
+
local_conv_size=3,
|
| 312 |
+
activation=nn.GELU,
|
| 313 |
+
):
|
| 314 |
+
super().__init__()
|
| 315 |
+
self.dim = dim
|
| 316 |
+
self.input_resolution = input_resolution
|
| 317 |
+
self.num_heads = num_heads
|
| 318 |
+
assert window_size > 0, 'window_size must be greater than 0'
|
| 319 |
+
self.window_size = window_size
|
| 320 |
+
self.mlp_ratio = mlp_ratio
|
| 321 |
+
|
| 322 |
+
self.drop_path = DropPath(
|
| 323 |
+
drop_path) if drop_path > 0. else nn.Identity()
|
| 324 |
+
|
| 325 |
+
assert dim % num_heads == 0, 'dim must be divisible by num_heads'
|
| 326 |
+
head_dim = dim // num_heads
|
| 327 |
+
|
| 328 |
+
window_resolution = (window_size, window_size)
|
| 329 |
+
self.attn = Attention(dim, head_dim, num_heads,
|
| 330 |
+
attn_ratio=1, resolution=window_resolution)
|
| 331 |
+
|
| 332 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 333 |
+
mlp_activation = activation
|
| 334 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
| 335 |
+
act_layer=mlp_activation, drop=drop)
|
| 336 |
+
|
| 337 |
+
pad = local_conv_size // 2
|
| 338 |
+
self.local_conv = Conv2d_BN(
|
| 339 |
+
dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
|
| 340 |
+
|
| 341 |
+
def forward(self, x):
|
| 342 |
+
H, W = self.input_resolution
|
| 343 |
+
B, L, C = x.shape
|
| 344 |
+
assert L == H * W, "input feature has wrong size"
|
| 345 |
+
res_x = x
|
| 346 |
+
if H == self.window_size and W == self.window_size:
|
| 347 |
+
x = self.attn(x)
|
| 348 |
+
else:
|
| 349 |
+
x = x.view(B, H, W, C)
|
| 350 |
+
pad_b = (self.window_size - H %
|
| 351 |
+
self.window_size) % self.window_size
|
| 352 |
+
pad_r = (self.window_size - W %
|
| 353 |
+
self.window_size) % self.window_size
|
| 354 |
+
padding = pad_b > 0 or pad_r > 0
|
| 355 |
+
|
| 356 |
+
if padding:
|
| 357 |
+
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
|
| 358 |
+
|
| 359 |
+
pH, pW = H + pad_b, W + pad_r
|
| 360 |
+
nH = pH // self.window_size
|
| 361 |
+
nW = pW // self.window_size
|
| 362 |
+
# window partition
|
| 363 |
+
x = x.view(B, nH, self.window_size, nW, self.window_size, C).transpose(2, 3).reshape(
|
| 364 |
+
B * nH * nW, self.window_size * self.window_size, C)
|
| 365 |
+
x = self.attn(x)
|
| 366 |
+
# window reverse
|
| 367 |
+
x = x.view(B, nH, nW, self.window_size, self.window_size,
|
| 368 |
+
C).transpose(2, 3).reshape(B, pH, pW, C)
|
| 369 |
+
|
| 370 |
+
if padding:
|
| 371 |
+
x = x[:, :H, :W].contiguous()
|
| 372 |
+
|
| 373 |
+
x = x.view(B, L, C)
|
| 374 |
+
|
| 375 |
+
x = res_x + self.drop_path(x)
|
| 376 |
+
|
| 377 |
+
x = x.transpose(1, 2).reshape(B, C, H, W)
|
| 378 |
+
x = self.local_conv(x)
|
| 379 |
+
x = x.view(B, C, L).transpose(1, 2)
|
| 380 |
+
|
| 381 |
+
x = x + self.drop_path(self.mlp(x))
|
| 382 |
+
return x
|
| 383 |
+
|
| 384 |
+
def extra_repr(self) -> str:
|
| 385 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
| 386 |
+
f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
class BasicLayer(nn.Module):
|
| 390 |
+
""" A basic TinyViT layer for one stage.
|
| 391 |
+
|
| 392 |
+
Args:
|
| 393 |
+
dim (int): Number of input channels.
|
| 394 |
+
input_resolution (tuple[int]): Input resolution.
|
| 395 |
+
depth (int): Number of blocks.
|
| 396 |
+
num_heads (int): Number of attention heads.
|
| 397 |
+
window_size (int): Local window size.
|
| 398 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 399 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 400 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
| 401 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
| 402 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 403 |
+
local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3
|
| 404 |
+
activation: the activation function. Default: nn.GELU
|
| 405 |
+
out_dim: the output dimension of the layer. Default: dim
|
| 406 |
+
"""
|
| 407 |
+
|
| 408 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
| 409 |
+
mlp_ratio=4., drop=0.,
|
| 410 |
+
drop_path=0., downsample=None, use_checkpoint=False,
|
| 411 |
+
local_conv_size=3,
|
| 412 |
+
activation=nn.GELU,
|
| 413 |
+
out_dim=None,
|
| 414 |
+
):
|
| 415 |
+
|
| 416 |
+
super().__init__()
|
| 417 |
+
self.dim = dim
|
| 418 |
+
self.input_resolution = input_resolution
|
| 419 |
+
self.depth = depth
|
| 420 |
+
self.use_checkpoint = use_checkpoint
|
| 421 |
+
|
| 422 |
+
# build blocks
|
| 423 |
+
self.blocks = nn.ModuleList([
|
| 424 |
+
TinyViTBlock(dim=dim, input_resolution=input_resolution,
|
| 425 |
+
num_heads=num_heads, window_size=window_size,
|
| 426 |
+
mlp_ratio=mlp_ratio,
|
| 427 |
+
drop=drop,
|
| 428 |
+
drop_path=drop_path[i] if isinstance(
|
| 429 |
+
drop_path, list) else drop_path,
|
| 430 |
+
local_conv_size=local_conv_size,
|
| 431 |
+
activation=activation,
|
| 432 |
+
)
|
| 433 |
+
for i in range(depth)])
|
| 434 |
+
|
| 435 |
+
# patch merging layer
|
| 436 |
+
if downsample is not None:
|
| 437 |
+
self.downsample = downsample(
|
| 438 |
+
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
| 439 |
+
else:
|
| 440 |
+
self.downsample = None
|
| 441 |
+
|
| 442 |
+
def forward(self, x):
|
| 443 |
+
for blk in self.blocks:
|
| 444 |
+
if self.use_checkpoint:
|
| 445 |
+
x = checkpoint.checkpoint(blk, x)
|
| 446 |
+
else:
|
| 447 |
+
x = blk(x)
|
| 448 |
+
if self.downsample is not None:
|
| 449 |
+
x = self.downsample(x)
|
| 450 |
+
return x
|
| 451 |
+
|
| 452 |
+
def extra_repr(self) -> str:
|
| 453 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
| 454 |
+
|
| 455 |
+
class LayerNorm2d(nn.Module):
|
| 456 |
+
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
| 457 |
+
super().__init__()
|
| 458 |
+
self.weight = nn.Parameter(torch.ones(num_channels))
|
| 459 |
+
self.bias = nn.Parameter(torch.zeros(num_channels))
|
| 460 |
+
self.eps = eps
|
| 461 |
+
|
| 462 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 463 |
+
u = x.mean(1, keepdim=True)
|
| 464 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 465 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 466 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 467 |
+
return x
|
| 468 |
+
|
| 469 |
+
class TinyViT(nn.Module):
|
| 470 |
+
def __init__(self,
|
| 471 |
+
img_size=224,
|
| 472 |
+
in_chans=3,
|
| 473 |
+
#num_classes=1000,
|
| 474 |
+
embed_dims=[96, 192, 384, 768], depths=[2, 2, 6, 2],
|
| 475 |
+
num_heads=[3, 6, 12, 24],
|
| 476 |
+
window_sizes=[7, 7, 14, 7],
|
| 477 |
+
mlp_ratio=4.,
|
| 478 |
+
drop_rate=0.,
|
| 479 |
+
drop_path_rate=0.1,
|
| 480 |
+
use_checkpoint=False,
|
| 481 |
+
mbconv_expand_ratio=4.0,
|
| 482 |
+
local_conv_size=3,
|
| 483 |
+
layer_lr_decay=1.0,
|
| 484 |
+
):
|
| 485 |
+
super().__init__()
|
| 486 |
+
self.img_size=img_size
|
| 487 |
+
#self.num_classes = num_classes
|
| 488 |
+
self.depths = depths
|
| 489 |
+
self.num_layers = len(depths)
|
| 490 |
+
self.mlp_ratio = mlp_ratio
|
| 491 |
+
|
| 492 |
+
activation = nn.GELU
|
| 493 |
+
|
| 494 |
+
self.patch_embed = PatchEmbed(in_chans=in_chans,
|
| 495 |
+
embed_dim=embed_dims[0],
|
| 496 |
+
resolution=img_size,
|
| 497 |
+
activation=activation)
|
| 498 |
+
|
| 499 |
+
patches_resolution = self.patch_embed.patches_resolution
|
| 500 |
+
self.patches_resolution = patches_resolution
|
| 501 |
+
|
| 502 |
+
# stochastic depth
|
| 503 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate,
|
| 504 |
+
sum(depths))] # stochastic depth decay rule
|
| 505 |
+
|
| 506 |
+
# build layers
|
| 507 |
+
self.layers = nn.ModuleList()
|
| 508 |
+
for i_layer in range(self.num_layers):
|
| 509 |
+
kwargs = dict(dim=embed_dims[i_layer],
|
| 510 |
+
input_resolution=(
|
| 511 |
+
patches_resolution[0] // (2 ** (i_layer-1 if i_layer == 3 else i_layer)),
|
| 512 |
+
patches_resolution[1] // (2 ** (i_layer-1 if i_layer == 3 else i_layer))
|
| 513 |
+
),
|
| 514 |
+
# input_resolution=(patches_resolution[0] // (2 ** i_layer),
|
| 515 |
+
# patches_resolution[1] // (2 ** i_layer)),
|
| 516 |
+
depth=depths[i_layer],
|
| 517 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
| 518 |
+
downsample=PatchMerging if (
|
| 519 |
+
i_layer < self.num_layers - 1) else None,
|
| 520 |
+
use_checkpoint=use_checkpoint,
|
| 521 |
+
out_dim=embed_dims[min(
|
| 522 |
+
i_layer + 1, len(embed_dims) - 1)],
|
| 523 |
+
activation=activation,
|
| 524 |
+
)
|
| 525 |
+
if i_layer == 0:
|
| 526 |
+
layer = ConvLayer(
|
| 527 |
+
conv_expand_ratio=mbconv_expand_ratio,
|
| 528 |
+
**kwargs,
|
| 529 |
+
)
|
| 530 |
+
else:
|
| 531 |
+
layer = BasicLayer(
|
| 532 |
+
num_heads=num_heads[i_layer],
|
| 533 |
+
window_size=window_sizes[i_layer],
|
| 534 |
+
mlp_ratio=self.mlp_ratio,
|
| 535 |
+
drop=drop_rate,
|
| 536 |
+
local_conv_size=local_conv_size,
|
| 537 |
+
**kwargs)
|
| 538 |
+
self.layers.append(layer)
|
| 539 |
+
|
| 540 |
+
# init weights
|
| 541 |
+
self.apply(self._init_weights)
|
| 542 |
+
self.set_layer_lr_decay(layer_lr_decay)
|
| 543 |
+
|
| 544 |
+
self.neck = nn.Sequential(
|
| 545 |
+
nn.Conv2d(
|
| 546 |
+
embed_dims[-1],
|
| 547 |
+
256,
|
| 548 |
+
kernel_size=1,
|
| 549 |
+
bias=False,
|
| 550 |
+
),
|
| 551 |
+
LayerNorm2d(256),
|
| 552 |
+
nn.Conv2d(
|
| 553 |
+
256,
|
| 554 |
+
256,
|
| 555 |
+
kernel_size=3,
|
| 556 |
+
padding=1,
|
| 557 |
+
bias=False,
|
| 558 |
+
),
|
| 559 |
+
LayerNorm2d(256),
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
def set_layer_lr_decay(self, layer_lr_decay):
|
| 563 |
+
decay_rate = layer_lr_decay
|
| 564 |
+
|
| 565 |
+
# layers -> blocks (depth)
|
| 566 |
+
depth = sum(self.depths)
|
| 567 |
+
lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
|
| 568 |
+
|
| 569 |
+
def _set_lr_scale(m, scale):
|
| 570 |
+
for p in m.parameters():
|
| 571 |
+
p.lr_scale = scale
|
| 572 |
+
|
| 573 |
+
self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0]))
|
| 574 |
+
i = 0
|
| 575 |
+
for layer in self.layers:
|
| 576 |
+
for block in layer.blocks:
|
| 577 |
+
block.apply(lambda x: _set_lr_scale(x, lr_scales[i]))
|
| 578 |
+
i += 1
|
| 579 |
+
if layer.downsample is not None:
|
| 580 |
+
layer.downsample.apply(
|
| 581 |
+
lambda x: _set_lr_scale(x, lr_scales[i - 1]))
|
| 582 |
+
assert i == depth
|
| 583 |
+
|
| 584 |
+
for k, p in self.named_parameters():
|
| 585 |
+
p.param_name = k
|
| 586 |
+
|
| 587 |
+
def _check_lr_scale(m):
|
| 588 |
+
for p in m.parameters():
|
| 589 |
+
assert hasattr(p, 'lr_scale'), p.param_name
|
| 590 |
+
|
| 591 |
+
self.apply(_check_lr_scale)
|
| 592 |
+
|
| 593 |
+
def _init_weights(self, m):
|
| 594 |
+
if isinstance(m, nn.Linear):
|
| 595 |
+
trunc_normal_(m.weight, std=.02)
|
| 596 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 597 |
+
nn.init.constant_(m.bias, 0)
|
| 598 |
+
elif isinstance(m, nn.LayerNorm):
|
| 599 |
+
nn.init.constant_(m.bias, 0)
|
| 600 |
+
nn.init.constant_(m.weight, 1.0)
|
| 601 |
+
|
| 602 |
+
@torch.jit.ignore
|
| 603 |
+
def no_weight_decay_keywords(self):
|
| 604 |
+
return {'attention_biases'}
|
| 605 |
+
|
| 606 |
+
def forward_features(self, x):
|
| 607 |
+
# x: (N, C, H, W)
|
| 608 |
+
x = self.patch_embed(x)
|
| 609 |
+
x = self.layers[0](x)
|
| 610 |
+
start_i = 1
|
| 611 |
+
|
| 612 |
+
for i in range(start_i, len(self.layers)):
|
| 613 |
+
layer = self.layers[i]
|
| 614 |
+
x = layer(x)
|
| 615 |
+
B, _, C = x.size()
|
| 616 |
+
x = x.view(B, 64, 64, C)
|
| 617 |
+
x = x.permute(0, 3, 1, 2)
|
| 618 |
+
x = self.neck(x)
|
| 619 |
+
|
| 620 |
+
return x
|
| 621 |
+
|
| 622 |
+
def forward(self, x):
|
| 623 |
+
x = self.forward_features(x)
|
| 624 |
+
return x
|
| 625 |
+
|
| 626 |
+
# model = TinyViT(
|
| 627 |
+
# img_size=256,
|
| 628 |
+
# in_chans=3,
|
| 629 |
+
# embed_dims=[
|
| 630 |
+
# 64, ## (64, 256, 256)
|
| 631 |
+
# 128, ## (128, 128, 128)
|
| 632 |
+
# 160, ## (160, 64, 64)
|
| 633 |
+
# 320 ## (320, 64, 64)
|
| 634 |
+
# ],
|
| 635 |
+
# depths=[2, 2, 6, 2],
|
| 636 |
+
# num_heads=[2, 4, 5, 10],
|
| 637 |
+
# window_sizes=[7, 7, 14, 7],
|
| 638 |
+
# mlp_ratio=4.,
|
| 639 |
+
# drop_rate=0.,
|
| 640 |
+
# drop_path_rate=0.0,
|
| 641 |
+
# use_checkpoint=False,
|
| 642 |
+
# mbconv_expand_ratio=4.0,
|
| 643 |
+
# local_conv_size=3,
|
| 644 |
+
# layer_lr_decay=0.8
|
| 645 |
+
# )
|
models/transformer.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import Tensor, nn
|
| 10 |
+
|
| 11 |
+
import math
|
| 12 |
+
from typing import Tuple, Type
|
| 13 |
+
|
| 14 |
+
from .common import MLPBlock
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class TwoWayTransformer(nn.Module):
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
depth: int,
|
| 21 |
+
embedding_dim: int,
|
| 22 |
+
num_heads: int,
|
| 23 |
+
mlp_dim: int,
|
| 24 |
+
activation: Type[nn.Module] = nn.ReLU,
|
| 25 |
+
attention_downsample_rate: int = 2,
|
| 26 |
+
) -> None:
|
| 27 |
+
"""
|
| 28 |
+
A transformer decoder that attends to an input image using
|
| 29 |
+
queries whose positional embedding is supplied.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
depth (int): number of layers in the transformer
|
| 33 |
+
embedding_dim (int): the channel dimension for the input embeddings
|
| 34 |
+
num_heads (int): the number of heads for multihead attention. Must
|
| 35 |
+
divide embedding_dim
|
| 36 |
+
mlp_dim (int): the channel dimension internal to the MLP block
|
| 37 |
+
activation (nn.Module): the activation to use in the MLP block
|
| 38 |
+
"""
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.depth = depth
|
| 41 |
+
self.embedding_dim = embedding_dim
|
| 42 |
+
self.num_heads = num_heads
|
| 43 |
+
self.mlp_dim = mlp_dim
|
| 44 |
+
self.layers = nn.ModuleList()
|
| 45 |
+
|
| 46 |
+
for i in range(depth):
|
| 47 |
+
self.layers.append(
|
| 48 |
+
TwoWayAttentionBlock(
|
| 49 |
+
embedding_dim=embedding_dim,
|
| 50 |
+
num_heads=num_heads,
|
| 51 |
+
mlp_dim=mlp_dim,
|
| 52 |
+
activation=activation,
|
| 53 |
+
attention_downsample_rate=attention_downsample_rate,
|
| 54 |
+
skip_first_layer_pe=(i == 0),
|
| 55 |
+
)
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
self.final_attn_token_to_image = Attention(
|
| 59 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
| 60 |
+
)
|
| 61 |
+
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
| 62 |
+
|
| 63 |
+
def forward(
|
| 64 |
+
self,
|
| 65 |
+
image_embedding: Tensor,
|
| 66 |
+
image_pe: Tensor,
|
| 67 |
+
point_embedding: Tensor,
|
| 68 |
+
) -> Tuple[Tensor, Tensor]:
|
| 69 |
+
"""
|
| 70 |
+
Args:
|
| 71 |
+
image_embedding (torch.Tensor): image to attend to. Should be shape
|
| 72 |
+
B x embedding_dim x h x w for any h and w.
|
| 73 |
+
image_pe (torch.Tensor): the positional encoding to add to the image. Must
|
| 74 |
+
have the same shape as image_embedding.
|
| 75 |
+
point_embedding (torch.Tensor): the embedding to add to the query points.
|
| 76 |
+
Must have shape B x N_points x embedding_dim for any N_points.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
torch.Tensor: the processed point_embedding
|
| 80 |
+
torch.Tensor: the processed image_embedding
|
| 81 |
+
"""
|
| 82 |
+
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
| 83 |
+
bs, c, h, w = image_embedding.shape
|
| 84 |
+
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
| 85 |
+
image_pe = image_pe.flatten(2).permute(0, 2, 1)
|
| 86 |
+
|
| 87 |
+
# Prepare queries
|
| 88 |
+
queries = point_embedding
|
| 89 |
+
keys = image_embedding
|
| 90 |
+
|
| 91 |
+
# Apply transformer blocks and final layernorm
|
| 92 |
+
for layer in self.layers:
|
| 93 |
+
queries, keys = layer(
|
| 94 |
+
queries=queries,
|
| 95 |
+
keys=keys,
|
| 96 |
+
query_pe=point_embedding,
|
| 97 |
+
key_pe=image_pe,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Apply the final attention layer from the points to the image
|
| 101 |
+
q = queries + point_embedding
|
| 102 |
+
k = keys + image_pe
|
| 103 |
+
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
|
| 104 |
+
queries = queries + attn_out
|
| 105 |
+
queries = self.norm_final_attn(queries)
|
| 106 |
+
|
| 107 |
+
return queries, keys
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class TwoWayAttentionBlock(nn.Module):
|
| 111 |
+
def __init__(
|
| 112 |
+
self,
|
| 113 |
+
embedding_dim: int,
|
| 114 |
+
num_heads: int,
|
| 115 |
+
mlp_dim: int = 2048,
|
| 116 |
+
activation: Type[nn.Module] = nn.ReLU,
|
| 117 |
+
attention_downsample_rate: int = 2,
|
| 118 |
+
skip_first_layer_pe: bool = False,
|
| 119 |
+
) -> None:
|
| 120 |
+
"""
|
| 121 |
+
A transformer block with four layers: (1) self-attention of sparse
|
| 122 |
+
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
|
| 123 |
+
block on sparse inputs, and (4) cross attention of dense inputs to sparse
|
| 124 |
+
inputs.
|
| 125 |
+
|
| 126 |
+
Arguments:
|
| 127 |
+
embedding_dim (int): the channel dimension of the embeddings
|
| 128 |
+
num_heads (int): the number of heads in the attention layers
|
| 129 |
+
mlp_dim (int): the hidden dimension of the mlp block
|
| 130 |
+
activation (nn.Module): the activation of the mlp block
|
| 131 |
+
skip_first_layer_pe (bool): skip the PE on the first layer
|
| 132 |
+
"""
|
| 133 |
+
super().__init__()
|
| 134 |
+
self.self_attn = Attention(embedding_dim, num_heads)
|
| 135 |
+
self.norm1 = nn.LayerNorm(embedding_dim)
|
| 136 |
+
|
| 137 |
+
self.cross_attn_token_to_image = Attention(
|
| 138 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
| 139 |
+
)
|
| 140 |
+
self.norm2 = nn.LayerNorm(embedding_dim)
|
| 141 |
+
|
| 142 |
+
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
|
| 143 |
+
self.norm3 = nn.LayerNorm(embedding_dim)
|
| 144 |
+
|
| 145 |
+
self.norm4 = nn.LayerNorm(embedding_dim)
|
| 146 |
+
self.cross_attn_image_to_token = Attention(
|
| 147 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
self.skip_first_layer_pe = skip_first_layer_pe
|
| 151 |
+
|
| 152 |
+
def forward(
|
| 153 |
+
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
|
| 154 |
+
) -> Tuple[Tensor, Tensor]:
|
| 155 |
+
# Self attention block
|
| 156 |
+
if self.skip_first_layer_pe:
|
| 157 |
+
queries = self.self_attn(q=queries, k=queries, v=queries)
|
| 158 |
+
else:
|
| 159 |
+
q = queries + query_pe
|
| 160 |
+
attn_out = self.self_attn(q=q, k=q, v=queries)
|
| 161 |
+
queries = queries + attn_out
|
| 162 |
+
queries = self.norm1(queries)
|
| 163 |
+
|
| 164 |
+
# Cross attention block, tokens attending to image embedding
|
| 165 |
+
q = queries + query_pe
|
| 166 |
+
k = keys + key_pe
|
| 167 |
+
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
|
| 168 |
+
queries = queries + attn_out
|
| 169 |
+
queries = self.norm2(queries)
|
| 170 |
+
|
| 171 |
+
# MLP block
|
| 172 |
+
mlp_out = self.mlp(queries)
|
| 173 |
+
queries = queries + mlp_out
|
| 174 |
+
queries = self.norm3(queries)
|
| 175 |
+
|
| 176 |
+
# Cross attention block, image embedding attending to tokens
|
| 177 |
+
q = queries + query_pe
|
| 178 |
+
k = keys + key_pe
|
| 179 |
+
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
|
| 180 |
+
keys = keys + attn_out
|
| 181 |
+
keys = self.norm4(keys)
|
| 182 |
+
|
| 183 |
+
return queries, keys
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class Attention(nn.Module):
|
| 187 |
+
"""
|
| 188 |
+
An attention layer that allows for downscaling the size of the embedding
|
| 189 |
+
after projection to queries, keys, and values.
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
def __init__(
|
| 193 |
+
self,
|
| 194 |
+
embedding_dim: int,
|
| 195 |
+
num_heads: int,
|
| 196 |
+
downsample_rate: int = 1,
|
| 197 |
+
) -> None:
|
| 198 |
+
super().__init__()
|
| 199 |
+
self.embedding_dim = embedding_dim
|
| 200 |
+
self.internal_dim = embedding_dim // downsample_rate
|
| 201 |
+
self.num_heads = num_heads
|
| 202 |
+
assert (
|
| 203 |
+
self.internal_dim % num_heads == 0
|
| 204 |
+
), "num_heads must divide embedding_dim."
|
| 205 |
+
|
| 206 |
+
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
| 207 |
+
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
|
| 208 |
+
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
|
| 209 |
+
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
| 210 |
+
|
| 211 |
+
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
|
| 212 |
+
b, n, c = x.shape
|
| 213 |
+
x = x.reshape(b, n, num_heads, c // num_heads)
|
| 214 |
+
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
| 215 |
+
|
| 216 |
+
def _recombine_heads(self, x: Tensor) -> Tensor:
|
| 217 |
+
b, n_heads, n_tokens, c_per_head = x.shape
|
| 218 |
+
x = x.transpose(1, 2)
|
| 219 |
+
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
| 220 |
+
|
| 221 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
| 222 |
+
# Input projections
|
| 223 |
+
q = self.q_proj(q)
|
| 224 |
+
k = self.k_proj(k)
|
| 225 |
+
v = self.v_proj(v)
|
| 226 |
+
|
| 227 |
+
# Separate into heads
|
| 228 |
+
q = self._separate_heads(q, self.num_heads)
|
| 229 |
+
k = self._separate_heads(k, self.num_heads)
|
| 230 |
+
v = self._separate_heads(v, self.num_heads)
|
| 231 |
+
|
| 232 |
+
# Attention
|
| 233 |
+
_, _, _, c_per_head = q.shape
|
| 234 |
+
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
|
| 235 |
+
attn = attn / math.sqrt(c_per_head)
|
| 236 |
+
attn = torch.softmax(attn, dim=-1)
|
| 237 |
+
|
| 238 |
+
# Get output
|
| 239 |
+
out = attn @ v
|
| 240 |
+
out = self._recombine_heads(out)
|
| 241 |
+
out = self.out_proj(out)
|
| 242 |
+
|
| 243 |
+
return out
|
train.py
ADDED
|
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import monai
|
| 4 |
+
from os import listdir, makedirs
|
| 5 |
+
from os.path import join, exists, isfile, isdir, basename
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from time import time
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.optim as optim
|
| 12 |
+
from torch.utils.data import DataLoader
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
from shutil import copyfile
|
| 15 |
+
from models import PromptEncoder, TwoWayTransformer, TinyViT, MaskDecoder_F4
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
import gc
|
| 18 |
+
from matplotlib import pyplot as plt
|
| 19 |
+
import argparse
|
| 20 |
+
from modality_npz_dataset import ModalityNpzDataset
|
| 21 |
+
|
| 22 |
+
torch.cuda.empty_cache()
|
| 23 |
+
os.environ["OMP_NUM_THREADS"] = "4" # export OMP_NUM_THREADS=4
|
| 24 |
+
os.environ["OPENBLAS_NUM_THREADS"] = "4" # export OPENBLAS_NUM_THREADS=4
|
| 25 |
+
os.environ["MKL_NUM_THREADS"] = "6" # export MKL_NUM_THREADS=6
|
| 26 |
+
os.environ["VECLIB_MAXIMUM_THREADS"] = "4" # export VECLIB_MAXIMUM_THREADS=4
|
| 27 |
+
os.environ["NUMEXPR_NUM_THREADS"] = "6" # export NUMEXPR_NUM_THREADS=6
|
| 28 |
+
|
| 29 |
+
def setup_seed(seed):
|
| 30 |
+
torch.manual_seed(seed)
|
| 31 |
+
torch.cuda.manual_seed_all(seed)
|
| 32 |
+
np.random.seed(seed)
|
| 33 |
+
random.seed(seed)
|
| 34 |
+
setup_seed(2024)
|
| 35 |
+
|
| 36 |
+
def get_args():
|
| 37 |
+
parser = argparse.ArgumentParser()
|
| 38 |
+
parser.add_argument("--data_root",
|
| 39 |
+
type=str,
|
| 40 |
+
default="",
|
| 41 |
+
help="Path to the npy data root.")
|
| 42 |
+
|
| 43 |
+
parser.add_argument('--task_name', type=str, default='MedSAM-Lite-All')
|
| 44 |
+
|
| 45 |
+
parser.add_argument("--pretrained_checkpoint",
|
| 46 |
+
type=str,
|
| 47 |
+
default=None,
|
| 48 |
+
help="Path to the pretrained Lite-MedSAM checkpoint.")
|
| 49 |
+
|
| 50 |
+
parser.add_argument("--resume",
|
| 51 |
+
type=str,
|
| 52 |
+
default=None,
|
| 53 |
+
help="Path to the checkpoint to continue training.")
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--work_dir",
|
| 56 |
+
type=str,
|
| 57 |
+
default="./work_dir",
|
| 58 |
+
help=
|
| 59 |
+
"Path to the working directory where checkpoints and logs will be saved."
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
parser.add_argument('--data_aug',
|
| 63 |
+
action='store_true',
|
| 64 |
+
default=False,
|
| 65 |
+
help='use data augmentation during training')
|
| 66 |
+
|
| 67 |
+
parser.add_argument("--num_epochs",
|
| 68 |
+
type=int,
|
| 69 |
+
default=25,
|
| 70 |
+
help="Number of epochs to train.")
|
| 71 |
+
parser.add_argument("--batch_size",
|
| 72 |
+
type=int,
|
| 73 |
+
default=16,
|
| 74 |
+
help="Batch size.")
|
| 75 |
+
parser.add_argument("--num_workers",
|
| 76 |
+
type=int,
|
| 77 |
+
default=8,
|
| 78 |
+
help="Number of workers for dataloader.")
|
| 79 |
+
|
| 80 |
+
parser.add_argument(
|
| 81 |
+
"--bbox_shift",
|
| 82 |
+
type=int,
|
| 83 |
+
default=5,
|
| 84 |
+
help="Perturbation to bounding box coordinates during training.")
|
| 85 |
+
|
| 86 |
+
parser.add_argument("-lr", type=float, default=2e-4, help="Learning rate.")
|
| 87 |
+
|
| 88 |
+
parser.add_argument("-weight_decay",
|
| 89 |
+
type=float,
|
| 90 |
+
default=0.001,
|
| 91 |
+
help="Weight decay.")
|
| 92 |
+
|
| 93 |
+
parser.add_argument("-iou_loss_weight",
|
| 94 |
+
type=float,
|
| 95 |
+
default=1.0,
|
| 96 |
+
help="Weight of IoU loss.")
|
| 97 |
+
|
| 98 |
+
parser.add_argument("-seg_loss_weight",
|
| 99 |
+
type=float,
|
| 100 |
+
default=1.0,
|
| 101 |
+
help="Weight of segmentation loss.")
|
| 102 |
+
parser.add_argument("-ce_loss_weight",
|
| 103 |
+
type=float,
|
| 104 |
+
default=1.0,
|
| 105 |
+
help="Weight of cross entropy loss.")
|
| 106 |
+
|
| 107 |
+
parser.add_argument("--sanity_check",
|
| 108 |
+
action="store_true",
|
| 109 |
+
default=True,
|
| 110 |
+
help="Whether to do sanity check for dataloading.")
|
| 111 |
+
|
| 112 |
+
args = parser.parse_args()
|
| 113 |
+
return args
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def show_mask(mask, ax, random_color=True):
|
| 117 |
+
if random_color:
|
| 118 |
+
color = np.concatenate([np.random.random(3), np.array([0.45])], axis=0)
|
| 119 |
+
else:
|
| 120 |
+
color = np.array([251 / 255, 252 / 255, 30 / 255, 0.45])
|
| 121 |
+
h, w = mask.shape[-2:]
|
| 122 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
| 123 |
+
ax.imshow(mask_image)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def show_box(box, ax):
|
| 127 |
+
x0, y0 = box[0], box[1]
|
| 128 |
+
w, h = box[2] - box[0], box[3] - box[1]
|
| 129 |
+
ax.add_patch(
|
| 130 |
+
plt.Rectangle((x0, y0),
|
| 131 |
+
w,
|
| 132 |
+
h,
|
| 133 |
+
edgecolor='blue',
|
| 134 |
+
facecolor=(0, 0, 0, 0),
|
| 135 |
+
lw=2))
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def show_points(points, ax):
|
| 139 |
+
for i, (x, y) in enumerate(points):
|
| 140 |
+
ax.scatter(x, y, color='red', s=10)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def cal_iou(result, reference):
|
| 144 |
+
|
| 145 |
+
intersection = torch.count_nonzero(torch.logical_and(result, reference),
|
| 146 |
+
dim=[i for i in range(1, result.ndim)])
|
| 147 |
+
union = torch.count_nonzero(torch.logical_or(result, reference),
|
| 148 |
+
dim=[i for i in range(1, result.ndim)])
|
| 149 |
+
|
| 150 |
+
iou = intersection.float() / union.float()
|
| 151 |
+
|
| 152 |
+
return iou.unsqueeze(1)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def sanity_check_dataset(args):
|
| 156 |
+
|
| 157 |
+
tr_dataset = ModalityNpzDataset(args.data_root, data_aug=True)
|
| 158 |
+
tr_dataloader = DataLoader(tr_dataset, batch_size=8, shuffle=True)
|
| 159 |
+
|
| 160 |
+
for step, batch in enumerate(tr_dataloader):
|
| 161 |
+
# show the example
|
| 162 |
+
_, axs = plt.subplots(1, 2, figsize=(10, 10))
|
| 163 |
+
idx = random.randint(0, 4)
|
| 164 |
+
|
| 165 |
+
image = batch["image"]
|
| 166 |
+
gt = batch["gt2D"]
|
| 167 |
+
bboxes = batch["bboxes"]
|
| 168 |
+
names_temp = batch["image_name"]
|
| 169 |
+
|
| 170 |
+
axs[0].imshow(image[idx].cpu().permute(1, 2, 0).numpy())
|
| 171 |
+
show_mask(gt[idx].cpu().squeeze().numpy(), axs[0])
|
| 172 |
+
show_box(bboxes[idx].numpy().squeeze(), axs[0])
|
| 173 |
+
axs[0].axis('off')
|
| 174 |
+
# set title
|
| 175 |
+
axs[0].set_title(names_temp[idx])
|
| 176 |
+
idx = random.randint(4, 7)
|
| 177 |
+
axs[1].imshow(image[idx].cpu().permute(1, 2, 0).numpy())
|
| 178 |
+
show_mask(gt[idx].cpu().squeeze().numpy(), axs[1])
|
| 179 |
+
show_box(bboxes[idx].numpy().squeeze(), axs[1])
|
| 180 |
+
axs[1].axis('off')
|
| 181 |
+
# set title
|
| 182 |
+
axs[1].set_title(names_temp[idx])
|
| 183 |
+
plt.subplots_adjust(wspace=0.01, hspace=0)
|
| 184 |
+
plt.savefig(join(args.work_dir, 'Sanitycheck_DA.png'),
|
| 185 |
+
bbox_inches='tight',
|
| 186 |
+
dpi=300)
|
| 187 |
+
plt.close()
|
| 188 |
+
break
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class MedSAM_Lite(nn.Module):
|
| 192 |
+
|
| 193 |
+
def __init__(
|
| 194 |
+
self,
|
| 195 |
+
image_encoder,
|
| 196 |
+
mask_decoder,
|
| 197 |
+
prompt_encoder,
|
| 198 |
+
):
|
| 199 |
+
super().__init__()
|
| 200 |
+
self.image_encoder = image_encoder
|
| 201 |
+
self.mask_decoder = mask_decoder
|
| 202 |
+
self.prompt_encoder = prompt_encoder
|
| 203 |
+
encoder_weight_file = "" # path for vision encoder (tiny vit) weights
|
| 204 |
+
|
| 205 |
+
self.image_encoder.load_state_dict(torch.load(encoder_weight_file))
|
| 206 |
+
|
| 207 |
+
def forward(self, image, points, boxes, masks, features, crops,
|
| 208 |
+
text_features, category_idx):
|
| 209 |
+
image_embedding = self.image_encoder(image)
|
| 210 |
+
|
| 211 |
+
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
| 212 |
+
points=points,
|
| 213 |
+
boxes=boxes,
|
| 214 |
+
masks=masks,
|
| 215 |
+
features=features,
|
| 216 |
+
crops=crops,
|
| 217 |
+
text_features=text_features,
|
| 218 |
+
category_idx=category_idx)
|
| 219 |
+
|
| 220 |
+
low_res_masks, iou_predictions, category_predictions, clip_vec, img_vec = self.mask_decoder(
|
| 221 |
+
image_embeddings=image_embedding, # (B, 256, 64, 64)
|
| 222 |
+
image_pe=self.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
|
| 223 |
+
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
|
| 224 |
+
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
|
| 225 |
+
multimask_output=False,
|
| 226 |
+
) # (B, 1, 256, 256)
|
| 227 |
+
|
| 228 |
+
return low_res_masks, iou_predictions, category_predictions, clip_vec, img_vec
|
| 229 |
+
|
| 230 |
+
@torch.no_grad()
|
| 231 |
+
def postprocess_masks(self, masks, new_size, original_size):
|
| 232 |
+
"""
|
| 233 |
+
Do cropping and resizing
|
| 234 |
+
"""
|
| 235 |
+
# Crop
|
| 236 |
+
masks = masks[:, :, :new_size[0], :new_size[1]]
|
| 237 |
+
# Resize
|
| 238 |
+
masks = F.interpolate(
|
| 239 |
+
masks,
|
| 240 |
+
size=(original_size[0], original_size[1]),
|
| 241 |
+
mode="bilinear",
|
| 242 |
+
align_corners=False,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
return masks
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def collate_fn(batch):
|
| 249 |
+
"""
|
| 250 |
+
Collate function for PyTorch DataLoader.
|
| 251 |
+
"""
|
| 252 |
+
batch_dict = {}
|
| 253 |
+
for key in batch[0].keys():
|
| 254 |
+
if key == "image_name" or key == "category_idx":
|
| 255 |
+
batch_dict[key] = [sample[key] for sample in batch]
|
| 256 |
+
else:
|
| 257 |
+
batch_dict[key] = torch.stack([sample[key] for sample in batch],
|
| 258 |
+
dim=0)
|
| 259 |
+
|
| 260 |
+
return batch_dict
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
if __name__ == "__main__":
|
| 264 |
+
|
| 265 |
+
args = get_args()
|
| 266 |
+
sanity_check_dataset(args)
|
| 267 |
+
|
| 268 |
+
run_id = datetime.now().strftime("%Y%m%d-%H%M")
|
| 269 |
+
print(f"Run ID: {run_id}")
|
| 270 |
+
|
| 271 |
+
model_save_path = join(args.work_dir, args.task_name + "-" + run_id)
|
| 272 |
+
makedirs(model_save_path, exist_ok=True)
|
| 273 |
+
copyfile(__file__,
|
| 274 |
+
join(model_save_path, run_id + "_" + os.path.basename(__file__)))
|
| 275 |
+
|
| 276 |
+
device = torch.device("cuda")
|
| 277 |
+
|
| 278 |
+
num_epochs = args.num_epochs
|
| 279 |
+
batch_size = args.batch_size
|
| 280 |
+
num_workers = args.num_workers
|
| 281 |
+
|
| 282 |
+
medsam_lite_image_encoder = TinyViT(
|
| 283 |
+
img_size=256,
|
| 284 |
+
in_chans=3,
|
| 285 |
+
embed_dims=[
|
| 286 |
+
64, ## (64, 256, 256)
|
| 287 |
+
128, ## (128, 128, 128)
|
| 288 |
+
160, ## (160, 64, 64)
|
| 289 |
+
320 ## (320, 64, 64)
|
| 290 |
+
],
|
| 291 |
+
depths=[2, 2, 6, 2],
|
| 292 |
+
num_heads=[2, 4, 5, 10],
|
| 293 |
+
window_sizes=[7, 7, 14, 7],
|
| 294 |
+
mlp_ratio=4.,
|
| 295 |
+
drop_rate=0.,
|
| 296 |
+
drop_path_rate=0.0,
|
| 297 |
+
use_checkpoint=False,
|
| 298 |
+
mbconv_expand_ratio=4.0,
|
| 299 |
+
local_conv_size=3,
|
| 300 |
+
layer_lr_decay=0.8)
|
| 301 |
+
|
| 302 |
+
medsam_lite_prompt_encoder = PromptEncoder(embed_dim=256,
|
| 303 |
+
image_embedding_size=(64, 64),
|
| 304 |
+
input_image_size=(256, 256),
|
| 305 |
+
mask_in_chans=16)
|
| 306 |
+
|
| 307 |
+
medsam_lite_mask_decoder = MaskDecoder_F4(
|
| 308 |
+
num_multimask_outputs=3,
|
| 309 |
+
transformer=TwoWayTransformer(
|
| 310 |
+
depth=2,
|
| 311 |
+
embedding_dim=256,
|
| 312 |
+
mlp_dim=2048,
|
| 313 |
+
num_heads=8,
|
| 314 |
+
),
|
| 315 |
+
modality=True,
|
| 316 |
+
contents=True,
|
| 317 |
+
transformer_dim=256,
|
| 318 |
+
iou_head_depth=3,
|
| 319 |
+
iou_head_hidden_dim=256,
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
medsam_lite_model = MedSAM_Lite(image_encoder=medsam_lite_image_encoder,
|
| 323 |
+
mask_decoder=medsam_lite_mask_decoder,
|
| 324 |
+
prompt_encoder=medsam_lite_prompt_encoder)
|
| 325 |
+
|
| 326 |
+
if args.resume is None and args.pretrained_checkpoint is not None:
|
| 327 |
+
## Load pretrained checkpoint if there's no checkpoint to resume from and there's a pretrained checkpoint
|
| 328 |
+
print(
|
| 329 |
+
f"Loading pretrained checkpoint from {args.pretrained_checkpoint}")
|
| 330 |
+
medsam_lite_checkpoint = torch.load(args.pretrained_checkpoint,
|
| 331 |
+
map_location="cpu")
|
| 332 |
+
medsam_lite_model.load_state_dict(medsam_lite_checkpoint["model"],
|
| 333 |
+
strict=True)
|
| 334 |
+
|
| 335 |
+
medsam_lite_model = medsam_lite_model.to(device)
|
| 336 |
+
|
| 337 |
+
medsam_lite_model.train()
|
| 338 |
+
|
| 339 |
+
print(
|
| 340 |
+
f"MedSAM Lite size: {sum(p.numel() for p in medsam_lite_model.parameters())}"
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
print('lr:', args.lr)
|
| 344 |
+
|
| 345 |
+
optimizer = optim.AdamW(
|
| 346 |
+
medsam_lite_model.parameters(),
|
| 347 |
+
lr=args.lr,
|
| 348 |
+
betas=(0.9, 0.999),
|
| 349 |
+
eps=1e-08,
|
| 350 |
+
weight_decay=args.weight_decay,
|
| 351 |
+
)
|
| 352 |
+
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
|
| 353 |
+
mode='min',
|
| 354 |
+
factor=0.9,
|
| 355 |
+
patience=5,
|
| 356 |
+
cooldown=0)
|
| 357 |
+
seg_loss = monai.losses.DiceLoss(sigmoid=True,
|
| 358 |
+
squared_pred=True,
|
| 359 |
+
reduction='mean')
|
| 360 |
+
bce_loss = nn.BCEWithLogitsLoss(reduction='mean')
|
| 361 |
+
iou_loss = nn.MSELoss(reduction='mean')
|
| 362 |
+
ce_loss = nn.CrossEntropyLoss(reduction='mean')
|
| 363 |
+
|
| 364 |
+
train_dataset = ModalityNpzDataset(data_root=args.data_root, data_aug=True)
|
| 365 |
+
|
| 366 |
+
train_loader = DataLoader(train_dataset,
|
| 367 |
+
batch_size=batch_size,
|
| 368 |
+
shuffle=True,
|
| 369 |
+
num_workers=num_workers,
|
| 370 |
+
pin_memory=True)
|
| 371 |
+
|
| 372 |
+
if args.resume is not None:
|
| 373 |
+
ckpt_folders = sorted(listdir(args.resume))
|
| 374 |
+
ckpt_folders = [
|
| 375 |
+
f for f in ckpt_folders
|
| 376 |
+
if (f.startswith(args.task_name)
|
| 377 |
+
and isfile(join(args.resume, f, 'medsam_lite_latest.pth')))
|
| 378 |
+
]
|
| 379 |
+
print('*' * 20)
|
| 380 |
+
print('existing ckpts in', args.resume, ckpt_folders)
|
| 381 |
+
# find the latest ckpt folders
|
| 382 |
+
time_strings = [
|
| 383 |
+
f.split(args.task_name + '-')[-1] for f in ckpt_folders
|
| 384 |
+
]
|
| 385 |
+
dates = [datetime.strptime(f, '%Y%m%d-%H%M') for f in time_strings]
|
| 386 |
+
latest_date = max(dates)
|
| 387 |
+
latest_ckpt = join(
|
| 388 |
+
args.work_dir,
|
| 389 |
+
args.task_name + '-' + latest_date.strftime('%Y%m%d-%H%M'),
|
| 390 |
+
'medsam_lite_latest.pth')
|
| 391 |
+
print('Loading from', latest_ckpt)
|
| 392 |
+
checkpoint = torch.load(latest_ckpt, map_location=device)
|
| 393 |
+
medsam_lite_model.module.load_state_dict(checkpoint["model"])
|
| 394 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
| 395 |
+
start_epoch = checkpoint["epoch"] + 1
|
| 396 |
+
best_loss = checkpoint["loss"]
|
| 397 |
+
print(f"Loaded checkpoint from epoch {start_epoch}")
|
| 398 |
+
else:
|
| 399 |
+
start_epoch = 0
|
| 400 |
+
best_loss = 1e10
|
| 401 |
+
|
| 402 |
+
train_losses = []
|
| 403 |
+
epoch_times = []
|
| 404 |
+
|
| 405 |
+
print("Training")
|
| 406 |
+
for epoch in range(start_epoch, num_epochs):
|
| 407 |
+
if epoch == num_epochs - 1:
|
| 408 |
+
for param_group in optimizer.param_groups:
|
| 409 |
+
param_group['lr'] = 5e-5
|
| 410 |
+
|
| 411 |
+
epoch_loss = [1e10 for _ in range(len(train_loader))]
|
| 412 |
+
epoch_start_time = time()
|
| 413 |
+
pbar = tqdm(train_loader)
|
| 414 |
+
for step, batch in enumerate(pbar):
|
| 415 |
+
gc.collect()
|
| 416 |
+
torch.cuda.empty_cache()
|
| 417 |
+
image = batch["image"]
|
| 418 |
+
gt2D = batch["gt2D"]
|
| 419 |
+
boxes = batch["bboxes"]
|
| 420 |
+
coords = batch["coords"]
|
| 421 |
+
crops = batch["image_crop"]
|
| 422 |
+
features = batch["image_feature"]
|
| 423 |
+
text_features = batch["text_feature"]
|
| 424 |
+
class_idx = batch["category_idx"]
|
| 425 |
+
class_idx = torch.tensor(class_idx)
|
| 426 |
+
|
| 427 |
+
optimizer.zero_grad()
|
| 428 |
+
image, gt2D, boxes, coords, crops, features, text_features, class_idx = image.to(
|
| 429 |
+
device), gt2D.to(device), boxes.to(device), coords.to(
|
| 430 |
+
device), crops.to(device), features.to(
|
| 431 |
+
device), text_features.to(device), class_idx.to(device)
|
| 432 |
+
labels_torch = torch.ones(coords.shape[0]).long()
|
| 433 |
+
labels_torch = labels_torch.unsqueeze(1).expand(-1, 4)
|
| 434 |
+
labels_torch = labels_torch.to(device)
|
| 435 |
+
point_prompt = (coords, labels_torch)
|
| 436 |
+
logits_pred, iou_pred, category_predictions, clip_vec, img_vec = medsam_lite_model(
|
| 437 |
+
image, None, boxes, None, features, crops, text_features, class_idx)
|
| 438 |
+
|
| 439 |
+
clip_img_features = clip_vec / clip_vec.norm(dim=-1, keepdim=True)
|
| 440 |
+
img_features = img_vec / img_vec.norm(dim=-1, keepdim=True)
|
| 441 |
+
similarity1 = torch.matmul(clip_img_features, img_features.T)
|
| 442 |
+
similarity2 = torch.matmul(img_features, clip_img_features.T)
|
| 443 |
+
sim_labels = torch.arange(similarity1.shape[0]).to(image.device)
|
| 444 |
+
|
| 445 |
+
l_seg = seg_loss(logits_pred, gt2D)
|
| 446 |
+
l_bce = bce_loss(logits_pred, gt2D.float())
|
| 447 |
+
l_ce_sim = 0.5 * (ce_loss(similarity1, sim_labels.long()) +
|
| 448 |
+
ce_loss(similarity2, sim_labels.long()))
|
| 449 |
+
l_ce = ce_loss(category_predictions, class_idx.long())
|
| 450 |
+
mask_loss = l_seg + l_bce
|
| 451 |
+
with torch.no_grad():
|
| 452 |
+
iou_gt = cal_iou(torch.sigmoid(logits_pred) > 0.5, gt2D.bool())
|
| 453 |
+
l_iou = iou_loss(iou_pred, iou_gt)
|
| 454 |
+
loss = mask_loss + l_iou + 0.01 * l_ce_sim + 0.01 * l_ce
|
| 455 |
+
epoch_loss[step] = loss.item()
|
| 456 |
+
loss.backward()
|
| 457 |
+
optimizer.step()
|
| 458 |
+
optimizer.zero_grad()
|
| 459 |
+
pbar.set_description(
|
| 460 |
+
f"Epoch {epoch} at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}, loss: {loss.item():.4f}"
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
epoch_end_time = time()
|
| 464 |
+
epoch_duration = epoch_end_time - epoch_start_time
|
| 465 |
+
epoch_times.append(epoch_duration)
|
| 466 |
+
|
| 467 |
+
epoch_loss_reduced = sum(epoch_loss) / len(epoch_loss)
|
| 468 |
+
|
| 469 |
+
train_losses.append(epoch_loss_reduced)
|
| 470 |
+
lr_scheduler.step(epoch_loss_reduced)
|
| 471 |
+
|
| 472 |
+
model_weights = medsam_lite_model.state_dict()
|
| 473 |
+
|
| 474 |
+
checkpoint = {
|
| 475 |
+
"model": model_weights,
|
| 476 |
+
"epoch": epoch,
|
| 477 |
+
"optimizer": optimizer.state_dict(),
|
| 478 |
+
"loss": epoch_loss_reduced,
|
| 479 |
+
"best_loss": best_loss,
|
| 480 |
+
}
|
| 481 |
+
torch.save(checkpoint, join(model_save_path, "medsam_lite_latest.pth"))
|
| 482 |
+
|
| 483 |
+
if epoch_loss_reduced < best_loss:
|
| 484 |
+
print(
|
| 485 |
+
f"New best loss: {best_loss:.4f} -> {epoch_loss_reduced:.4f}")
|
| 486 |
+
best_loss = epoch_loss_reduced
|
| 487 |
+
checkpoint["best_loss"] = best_loss
|
| 488 |
+
torch.save(checkpoint, join(model_save_path,
|
| 489 |
+
"medsam_lite_best.pth"))
|
| 490 |
+
epoch_loss_reduced = 1e10
|
| 491 |
+
|
| 492 |
+
fig, axes = plt.subplots(2, 1, figsize=(10, 8))
|
| 493 |
+
axes[0].title.set_text("Dice + Binary Cross Entropy + IoU Loss")
|
| 494 |
+
axes[0].plot(train_losses)
|
| 495 |
+
axes[0].set_ylabel("Loss")
|
| 496 |
+
axes[1].plot(epoch_times)
|
| 497 |
+
axes[1].title.set_text("Epoch Duration")
|
| 498 |
+
axes[1].set_ylabel("Duration (s)")
|
| 499 |
+
axes[1].set_xlabel("Epoch")
|
| 500 |
+
plt.tight_layout()
|
| 501 |
+
plt.savefig(join(model_save_path, "log.png"))
|
| 502 |
+
plt.close()
|