denoise model update
Browse files- app.py +215 -3
- draco.yaml +19 -0
- draco/__init__.py +2 -0
- draco/configuration/__init__.py +4 -0
- draco/configuration/base.yaml +57 -0
- draco/configuration/config.py +52 -0
- draco/configuration/configurable.py +138 -0
- draco/configuration/draco2d-b_triplet_pretrain.yaml +20 -0
- draco/configuration/draco2d-h_triplet_pretrain.yaml +5 -0
- draco/configuration/draco2d-l_triplet_pretrain.yaml +4 -0
- draco/model/__init__.py +6 -0
- draco/model/build.py +22 -0
- draco/model/checkpoint.py +61 -0
- draco/model/draco2d.py +663 -0
- draco/model/draco_base.py +35 -0
- draco/model/layer/__init__.py +3 -0
- draco/model/layer/normalization.py +22 -0
- draco/model/utils/constant.py +24 -0
- requirements.txt +11 -0
app.py
CHANGED
|
@@ -1,7 +1,219 @@
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
def greet(name):
|
| 4 |
-
return "Hello " + name + "!!"
|
| 5 |
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
demo.launch()
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import h5py
|
| 3 |
+
import mrcfile
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from omegaconf import DictConfig
|
| 7 |
+
import torch
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from torchvision.transforms import functional as F
|
| 10 |
+
import torchvision.transforms.v2 as v2
|
| 11 |
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
from draco.configuration import CfgNode
|
| 14 |
+
from draco.model import (
|
| 15 |
+
build_model,
|
| 16 |
+
load_pretrained
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
class DRACODenoiser(object):
|
| 20 |
+
def __init__(self,
|
| 21 |
+
cfg: DictConfig,
|
| 22 |
+
ckpt_path: Path,
|
| 23 |
+
) -> None:
|
| 24 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 25 |
+
|
| 26 |
+
self.transform = self.build_transform()
|
| 27 |
+
self.model = build_model(cfg).to(self.device).eval()
|
| 28 |
+
self.model = load_pretrained(self.model, ckpt_path, self.device)
|
| 29 |
+
self.patch_size = cfg.MODEL.PATCH_SIZE
|
| 30 |
+
|
| 31 |
+
def patchify(self, image: torch.Tensor) -> torch.Tensor:
|
| 32 |
+
B, C, H, W = image.shape
|
| 33 |
+
P = self.patch_size
|
| 34 |
+
if H % P != 0 or W % P != 0:
|
| 35 |
+
image = torch.nn.functional.pad(image, (0, (P - W % P) % P, 0, (P - H % P) % P), mode='constant', value=0)
|
| 36 |
+
|
| 37 |
+
patches = image.unfold(2, P, P).unfold(3, P, P)
|
| 38 |
+
patches = patches.permute(0, 2, 3, 4, 5, 1)
|
| 39 |
+
patches = patches.reshape(B, -1, P * P * C)
|
| 40 |
+
return patches
|
| 41 |
+
|
| 42 |
+
def unpatchify(self, patches: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
| 43 |
+
B = patches.shape[0]
|
| 44 |
+
P = self.patch_size
|
| 45 |
+
|
| 46 |
+
images = patches.reshape(B, (H + P - 1) // P, (W + P - 1) // P, P, P, -1)
|
| 47 |
+
images = images.permute(0, 5, 1, 3, 2, 4)
|
| 48 |
+
images = images.reshape(B, -1, (H + P - 1) // P * P, (W + P - 1) // P * P)
|
| 49 |
+
images = images[..., :H, :W]
|
| 50 |
+
return images
|
| 51 |
+
|
| 52 |
+
@classmethod
|
| 53 |
+
def build_transform(cls) -> v2.Compose:
|
| 54 |
+
return v2.Compose([
|
| 55 |
+
v2.ToImage(),
|
| 56 |
+
v2.ToDtype(torch.float32, scale=True)
|
| 57 |
+
])
|
| 58 |
+
|
| 59 |
+
@torch.inference_mode()
|
| 60 |
+
def inference(self, image: Image.Image) -> None:
|
| 61 |
+
W, H = image.size
|
| 62 |
+
|
| 63 |
+
x = self.transform(image).unsqueeze(0).to(self.device)
|
| 64 |
+
y = self.model(x)
|
| 65 |
+
|
| 66 |
+
x = self.patchify(x).detach().cpu().numpy()
|
| 67 |
+
denoised = self.unpatchify(y, H, W).squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
|
| 68 |
+
|
| 69 |
+
return denoised
|
| 70 |
+
|
| 71 |
+
# Model Initialization
|
| 72 |
+
cfg = CfgNode.load_yaml_with_base(Path("draco.yaml"))
|
| 73 |
+
CfgNode.merge_with_dotlist(cfg, [])
|
| 74 |
+
ckpt_path = Path("denoise.ckpt")
|
| 75 |
+
denoiser = DRACODenoiser(cfg, ckpt_path)
|
| 76 |
+
|
| 77 |
+
def Auto_contrast(image, t_mean=150.0/255.0, t_sd=40.0/255.0) -> np.ndarray:
|
| 78 |
+
|
| 79 |
+
image = (image - image.min()) / (image.max() - image.min())
|
| 80 |
+
mean = image.mean()
|
| 81 |
+
std = image.std()
|
| 82 |
+
|
| 83 |
+
f = std / t_sd
|
| 84 |
+
|
| 85 |
+
black = mean - t_mean * f
|
| 86 |
+
white = mean + (1 - t_mean) * f
|
| 87 |
+
|
| 88 |
+
new_image = np.clip(image, black, white)
|
| 89 |
+
new_image = (new_image - black) / (white - black)
|
| 90 |
+
return new_image
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def load_data(file_path) -> np.ndarray:
|
| 94 |
+
if file_path.endswith('.h5'):
|
| 95 |
+
with h5py.File(file_path, "r") as f:
|
| 96 |
+
full_micrograph = f["micrograph"] if "micrograph" in f else f["data"]
|
| 97 |
+
full_mean = full_micrograph.attrs["mean"] if "mean" in full_micrograph.attrs else full_micrograph[:].astype(np.float32).mean()
|
| 98 |
+
full_std = full_micrograph.attrs["std"] if "std" in full_micrograph.attrs else full_micrograph[:].astype(np.float32).std()
|
| 99 |
+
data = full_micrograph[:].astype(np.float32)
|
| 100 |
+
elif file_path.endswith('.mrc'):
|
| 101 |
+
with mrcfile.open(file_path, "r") as f:
|
| 102 |
+
data = f.data[:].astype(np.float32)
|
| 103 |
+
full_mean = data.mean()
|
| 104 |
+
full_std = data.std()
|
| 105 |
+
else:
|
| 106 |
+
raise ValueError("Unsupported file format. Please upload a .mrc or .h5 file.")
|
| 107 |
+
data = (data - full_mean) / full_std
|
| 108 |
+
return data
|
| 109 |
+
|
| 110 |
+
def display_crop(data, x_offset, y_offset, auto_contrast) -> Image:
|
| 111 |
+
|
| 112 |
+
crop = data[y_offset:y_offset + 1024, x_offset:x_offset + 1024]
|
| 113 |
+
original_image_normalized = Auto_contrast(crop) if auto_contrast else (crop - crop.min()) / (crop.max() - crop.min())
|
| 114 |
+
input_image = Image.fromarray((original_image_normalized * 255).astype(np.uint8))
|
| 115 |
+
|
| 116 |
+
return input_image
|
| 117 |
+
|
| 118 |
+
def process_and_denoise(data, x_offset, y_offset, auto_contrast) -> Image:
|
| 119 |
+
|
| 120 |
+
crop = data[y_offset:y_offset + 1024, x_offset:x_offset + 1024]
|
| 121 |
+
denoised_data = denoiser.inference(Image.fromarray(crop))
|
| 122 |
+
|
| 123 |
+
denoised_data = denoised_data.squeeze()
|
| 124 |
+
denoised_image_normalized = Auto_contrast(denoised_data) if auto_contrast else (denoised_data - denoised_data.min()) / (denoised_data.max() - denoised_data.min())
|
| 125 |
+
denoised_image = Image.fromarray((denoised_image_normalized * 255).astype(np.uint8))
|
| 126 |
+
|
| 127 |
+
return denoised_image
|
| 128 |
+
|
| 129 |
+
def clear_images() -> tuple:
|
| 130 |
+
return None, None, None, gr.update(maximum=512), gr.update(maximum=512)
|
| 131 |
+
|
| 132 |
+
with gr.Blocks(css="""
|
| 133 |
+
.gradio-container {
|
| 134 |
+
background-color: #f7f9fc;
|
| 135 |
+
font-family: Arial, sans-serif;
|
| 136 |
+
}
|
| 137 |
+
.title-text {
|
| 138 |
+
text-align: center;
|
| 139 |
+
font-size: 30px;
|
| 140 |
+
font-weight: bold;
|
| 141 |
+
margin-bottom: 10px;
|
| 142 |
+
}
|
| 143 |
+
.description-text {
|
| 144 |
+
text-align: center;
|
| 145 |
+
font-size: 18px;
|
| 146 |
+
margin-bottom: 20px;
|
| 147 |
+
}
|
| 148 |
+
""") as demo:
|
| 149 |
+
# Centered Title and Description
|
| 150 |
+
with gr.Column():
|
| 151 |
+
gr.Markdown(
|
| 152 |
+
"""
|
| 153 |
+
<div style="text-align: center; font-size: 30px; font-weight: bold; margin-bottom: 10px;">
|
| 154 |
+
Denoising Demo
|
| 155 |
+
</div>
|
| 156 |
+
<div style="text-align: center; font-size: 18px;">
|
| 157 |
+
Upload a Raw file or select an example to view the original and denoised images
|
| 158 |
+
</div>
|
| 159 |
+
"""
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
file_input = gr.File(label="Or upload a Micrograph File in .h5 or .mrc format")
|
| 163 |
+
auto_contrast = gr.Checkbox(label="Enable Auto Contrast", value=False)
|
| 164 |
+
|
| 165 |
+
x_slider = gr.Slider(0, 512, step=10, label="X Offset")
|
| 166 |
+
y_slider = gr.Slider(0, 512, step=10, label="Y Offset")
|
| 167 |
+
|
| 168 |
+
with gr.Row():
|
| 169 |
+
denoise_button = gr.Button("Denoise")
|
| 170 |
+
clear_button = gr.Button("Clear")
|
| 171 |
+
|
| 172 |
+
with gr.Row():
|
| 173 |
+
with gr.Column():
|
| 174 |
+
original_image = gr.Image(type="pil", label="Original Image")
|
| 175 |
+
with gr.Column():
|
| 176 |
+
denoised_image = gr.Image(type="pil", label="Denoised Image")
|
| 177 |
+
|
| 178 |
+
active_data = gr.State()
|
| 179 |
+
|
| 180 |
+
def load_image_and_update_sliders(file_path) -> tuple:
|
| 181 |
+
data = load_data(file_path)
|
| 182 |
+
h, w = data.shape[:2]
|
| 183 |
+
return data, gr.update(maximum=w-1024), gr.update(maximum=h-1024)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
file_input.clear(
|
| 187 |
+
clear_images,
|
| 188 |
+
inputs=None,
|
| 189 |
+
outputs=[original_image, denoised_image, active_data, x_slider, y_slider]
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
file_input.change(
|
| 193 |
+
lambda file: load_image_and_update_sliders(file.name) if file else (None, None, None, gr.update(maximum=512), gr.update(maximum=512)),
|
| 194 |
+
inputs=file_input,
|
| 195 |
+
outputs=[active_data, x_slider, y_slider]
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
x_slider.change(
|
| 199 |
+
display_crop,
|
| 200 |
+
inputs=[active_data, x_slider, y_slider, auto_contrast],
|
| 201 |
+
outputs=original_image
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
y_slider.change(
|
| 205 |
+
display_crop,
|
| 206 |
+
inputs=[active_data, x_slider, y_slider, auto_contrast],
|
| 207 |
+
outputs=original_image
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
denoise_button.click(
|
| 211 |
+
process_and_denoise,
|
| 212 |
+
inputs=[active_data, x_slider, y_slider, auto_contrast],
|
| 213 |
+
outputs=denoised_image
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
clear_button.click(clear_images, inputs=None, outputs=[original_image, denoised_image, active_data, x_slider, y_slider])
|
| 217 |
+
|
| 218 |
demo.launch()
|
| 219 |
+
|
draco.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL:
|
| 2 |
+
NAME: DracoDenoiseAutoencoder
|
| 3 |
+
DEVICE: cuda
|
| 4 |
+
|
| 5 |
+
IMG_SIZE: 1024
|
| 6 |
+
PATCH_SIZE: 32
|
| 7 |
+
IN_CHANS: 1
|
| 8 |
+
VIT_SCALE: base
|
| 9 |
+
DYNAMIC_IMG_SIZE: True
|
| 10 |
+
DYNAMIC_IMG_PAD: True
|
| 11 |
+
DECODER_EMBED_DIM: 512
|
| 12 |
+
DECODER_DEPTH: 8
|
| 13 |
+
DECODER_NUM_HEADS: 16
|
| 14 |
+
DECODER_USE_NECK: True
|
| 15 |
+
DECODER_NECK_DIM: 256
|
| 16 |
+
USE_ABS_POS: true
|
| 17 |
+
USE_DECODER_NECK: True
|
| 18 |
+
WINDOW_SIZE: 28
|
| 19 |
+
DECODER_GLOBAL_ATTN_INDEXES: [3, 7]
|
draco/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import draco.model
|
draco/configuration/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .config import CfgNode
|
| 2 |
+
from .configurable import configurable
|
| 3 |
+
|
| 4 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
draco/configuration/base.yaml
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DATALOADER:
|
| 2 |
+
BATCH_SIZE: 0
|
| 3 |
+
NUM_WORKERS: 0
|
| 4 |
+
PIN_MEMORY: False
|
| 5 |
+
DROP_LAST: False
|
| 6 |
+
PERSISTENT_WORKERS: False
|
| 7 |
+
|
| 8 |
+
DATASET:
|
| 9 |
+
NAME: null
|
| 10 |
+
|
| 11 |
+
TRANSFORM:
|
| 12 |
+
NAME: null
|
| 13 |
+
|
| 14 |
+
MODEL:
|
| 15 |
+
NAME: null
|
| 16 |
+
DEVICE: cuda
|
| 17 |
+
|
| 18 |
+
METRIC:
|
| 19 |
+
NAME: null
|
| 20 |
+
TYPE: null
|
| 21 |
+
|
| 22 |
+
MODULE:
|
| 23 |
+
NAME: null
|
| 24 |
+
COMPILE: False
|
| 25 |
+
|
| 26 |
+
OPTIMIZER:
|
| 27 |
+
NAME: null
|
| 28 |
+
|
| 29 |
+
SCHEDULER:
|
| 30 |
+
NAME: null
|
| 31 |
+
|
| 32 |
+
TRAINER:
|
| 33 |
+
STRATEGY: auto # Set to `auto`, `ddp`, `deepspeed_stage_2`, `deepspeed_stage_3` ...
|
| 34 |
+
MIXED_PRECISION: False
|
| 35 |
+
CHECKPOINT:
|
| 36 |
+
EVERY_N_EPOCHS: 10
|
| 37 |
+
|
| 38 |
+
SAVE_BEST: False # If True, monitor will be required
|
| 39 |
+
MONITOR: null
|
| 40 |
+
MONITOR_MODE: min # Set to `min` or `max`
|
| 41 |
+
|
| 42 |
+
MAX_EPOCHS: -1 # If profiler is enabled, this will be *automatically* set to 1
|
| 43 |
+
LOG_EVERY_N_STEPS: 1
|
| 44 |
+
ACCUMULATE_GRAD_BATCHES: 1
|
| 45 |
+
|
| 46 |
+
CLIP_GRAD:
|
| 47 |
+
ALGORITHM: null
|
| 48 |
+
VALUE: null
|
| 49 |
+
|
| 50 |
+
DETERMINISTIC: False # Set to True to enable cudnn.deterministic
|
| 51 |
+
BENCHMARK: False # Set to True to enable cudnn.benchmark
|
| 52 |
+
PROFILER: null # Set to `advanced` or `pytorch` to enable profiling
|
| 53 |
+
DETECT_ANOMALY: False # Set to True to enable anomaly detection
|
| 54 |
+
SYNC_BATCHNORM: False # Set to True to enable sync batchnorm
|
| 55 |
+
|
| 56 |
+
SEED: null
|
| 57 |
+
OUTPUT_DIR: null
|
draco/configuration/config.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
from omegaconf import DictConfig, OmegaConf
|
| 5 |
+
|
| 6 |
+
BASE_KEY = "_BASE_"
|
| 7 |
+
ROOT_KEY = "cfg"
|
| 8 |
+
|
| 9 |
+
class CfgNode(OmegaConf):
|
| 10 |
+
"""
|
| 11 |
+
A wrapper around OmegaConf that provides some additional functionality.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
@staticmethod
|
| 15 |
+
def load_yaml_with_base(filename: str) -> DictConfig:
|
| 16 |
+
cfg = OmegaConf.load(filename)
|
| 17 |
+
|
| 18 |
+
def _load_with_base(base_cfg_file: str) -> dict[str, Any]:
|
| 19 |
+
if base_cfg_file.startswith("~"):
|
| 20 |
+
base_cfg_file = os.path.expanduser(base_cfg_file)
|
| 21 |
+
if not any(map(base_cfg_file.startswith, ["/", "https://", "http://"])):
|
| 22 |
+
# the path to base cfg is relative to the config file itself.
|
| 23 |
+
base_cfg_file = os.path.join(os.path.dirname(filename), base_cfg_file)
|
| 24 |
+
return CfgNode.load_yaml_with_base(base_cfg_file)
|
| 25 |
+
|
| 26 |
+
if BASE_KEY in cfg:
|
| 27 |
+
if isinstance(cfg[BASE_KEY], list):
|
| 28 |
+
base_cfg: dict[str, Any] = {}
|
| 29 |
+
base_cfg_files = cfg[BASE_KEY]
|
| 30 |
+
for base_cfg_file in base_cfg_files:
|
| 31 |
+
base_cfg = CfgNode.merge(base_cfg, _load_with_base(base_cfg_file))
|
| 32 |
+
else:
|
| 33 |
+
base_cfg_file = cfg[BASE_KEY]
|
| 34 |
+
base_cfg = _load_with_base(base_cfg_file)
|
| 35 |
+
del cfg[BASE_KEY]
|
| 36 |
+
|
| 37 |
+
base_cfg = CfgNode.merge(base_cfg, cfg)
|
| 38 |
+
return base_cfg
|
| 39 |
+
|
| 40 |
+
if ROOT_KEY in cfg:
|
| 41 |
+
return cfg[ROOT_KEY]
|
| 42 |
+
return cfg
|
| 43 |
+
|
| 44 |
+
@staticmethod
|
| 45 |
+
def merge_with_dotlist(cfg: DictConfig, dotlist: list[str]) -> None:
|
| 46 |
+
if len(dotlist) == 0:
|
| 47 |
+
return
|
| 48 |
+
|
| 49 |
+
new_dotlist = []
|
| 50 |
+
for key, value in zip(dotlist[::2], dotlist[1::2]):
|
| 51 |
+
new_dotlist.append(f"{key}={value}")
|
| 52 |
+
cfg.merge_with_dotlist(new_dotlist)
|
draco/configuration/configurable.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import inspect
|
| 3 |
+
from typing import Any, Callable
|
| 4 |
+
|
| 5 |
+
from omegaconf import DictConfig
|
| 6 |
+
|
| 7 |
+
__all__ = ["configurable"]
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _called_with_cfg(*args, **kwargs) -> bool:
|
| 11 |
+
"""
|
| 12 |
+
Check if the function is called with a `DictConfig` as the first argument.
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
(bool): whether the function is called with a `DictConfig` as the first argument.
|
| 16 |
+
Or the `cfg` keyword argument is a `DictConfig`.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
if len(args) > 0 and isinstance(args[0], DictConfig):
|
| 20 |
+
return True
|
| 21 |
+
if isinstance(kwargs.get("cfg", None), DictConfig):
|
| 22 |
+
return True
|
| 23 |
+
return False
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _get_args_from_cfg(from_config_func: Callable[[Any], dict[str, Any]], *args, **kwargs) -> dict[str, Any]:
|
| 27 |
+
"""
|
| 28 |
+
Get the input arguments of the decorated function from a `DictConfig` object.
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
(dict): The input arguments of the class `__init__` method.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
signature = inspect.signature(from_config_func)
|
| 35 |
+
if list(signature.parameters.keys())[0] != "cfg":
|
| 36 |
+
raise ValueError("The first argument of `{}` must be named as `cfg`.".format(from_config_func.__name__))
|
| 37 |
+
|
| 38 |
+
# Forwarding all arguments to `from_config`, if the arguments of `from_config` are only `*args` or `*kwargs`.
|
| 39 |
+
if any(param.kind in [param.VAR_POSITIONAL or param.VAR_KEYWORD] for param in signature.parameters.values()):
|
| 40 |
+
result = from_config_func(*args, **kwargs)
|
| 41 |
+
|
| 42 |
+
# If there is any positional arguments.
|
| 43 |
+
else:
|
| 44 |
+
positional_args_name = set(signature.parameters.keys())
|
| 45 |
+
extra_kwargs = {}
|
| 46 |
+
for name in kwargs.keys():
|
| 47 |
+
if name not in positional_args_name:
|
| 48 |
+
extra_kwargs[name] = kwargs.pop(name)
|
| 49 |
+
result = from_config_func(*args, **kwargs)
|
| 50 |
+
# These args are forwarded directly to `__init__` method.
|
| 51 |
+
result.update(extra_kwargs)
|
| 52 |
+
|
| 53 |
+
return result
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def configurable(init_func: Callable = None, *, from_config: Callable[[Any], dict[str, Any]] | None = None) -> Callable:
|
| 57 |
+
"""
|
| 58 |
+
A decorator of a function or a class `__init__` method,
|
| 59 |
+
to make it configurable by a `DictConfig` object.
|
| 60 |
+
|
| 61 |
+
Example:
|
| 62 |
+
```python
|
| 63 |
+
# 1. Decorate a function.
|
| 64 |
+
@configurable(from_config=lambda cfg: { "x": cfg.x })
|
| 65 |
+
def func(x, y=2, z=3):
|
| 66 |
+
pass
|
| 67 |
+
|
| 68 |
+
a1 = func(x=1, y=2) # Call with regular args.
|
| 69 |
+
a2 = func(cfg) # Call with a `DictConfig` object.
|
| 70 |
+
a3 = func(cfg, y=2, z=3) # Call with a `DictConfig` object and regular arguments.
|
| 71 |
+
|
| 72 |
+
# 2. Decorate a class `__init__` method.
|
| 73 |
+
class A:
|
| 74 |
+
@configurable
|
| 75 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 76 |
+
pass
|
| 77 |
+
|
| 78 |
+
@classmethod
|
| 79 |
+
def from_config(cls, cfg) -> dict:
|
| 80 |
+
pass
|
| 81 |
+
|
| 82 |
+
a1 = A(x, y) # Call with regular constructor.
|
| 83 |
+
a2 = A(cfg) # Call with a `DictConfig` object.
|
| 84 |
+
a3 = A(cfg, x, y) # Call with a `DictConfig` object and regular arguments.
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
`init_func` (callable): a function or a class method.
|
| 89 |
+
`from_config` (callable): a function that converts a `DictConfig` to the
|
| 90 |
+
input arguments of the decorated function.
|
| 91 |
+
It is always required.
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
# Decorating a function
|
| 95 |
+
if init_func is None:
|
| 96 |
+
# Prevent common misuse: `@configurable()`.
|
| 97 |
+
if from_config is None:
|
| 98 |
+
return configurable
|
| 99 |
+
|
| 100 |
+
assert inspect.isfunction(from_config), "`from_config` must be a function."
|
| 101 |
+
|
| 102 |
+
def wrapper(func):
|
| 103 |
+
@functools.wraps(func)
|
| 104 |
+
def wrapped(*args, **kwargs):
|
| 105 |
+
if _called_with_cfg(*args, **kwargs):
|
| 106 |
+
explicit_args = _get_args_from_cfg(from_config, *args, **kwargs)
|
| 107 |
+
return func(**explicit_args)
|
| 108 |
+
else:
|
| 109 |
+
return func(*args, **kwargs)
|
| 110 |
+
|
| 111 |
+
wrapped.from_config = from_config
|
| 112 |
+
return wrapped
|
| 113 |
+
|
| 114 |
+
return wrapper
|
| 115 |
+
|
| 116 |
+
# Decorating a class `__init__` method
|
| 117 |
+
else:
|
| 118 |
+
assert(
|
| 119 |
+
inspect.isfunction(init_func) and from_config is None and init_func.__name__ == "__init__"
|
| 120 |
+
), "Invalid usage of @configurable."
|
| 121 |
+
|
| 122 |
+
@functools.wraps(init_func)
|
| 123 |
+
def wrapped(self, *args, **kwargs):
|
| 124 |
+
try:
|
| 125 |
+
from_config_func = getattr(self, "from_config")
|
| 126 |
+
except AttributeError as e:
|
| 127 |
+
raise AttributeError("Class with `@configurable` should have a `from_config` classmethod.") from e
|
| 128 |
+
|
| 129 |
+
if not inspect.ismethod(from_config_func):
|
| 130 |
+
raise AttributeError("Class with `@configurable` should have a `from_config` classmethod.")
|
| 131 |
+
|
| 132 |
+
if _called_with_cfg(*args, **kwargs):
|
| 133 |
+
explicit_args = _get_args_from_cfg(from_config_func, *args, **kwargs)
|
| 134 |
+
init_func(self, **explicit_args)
|
| 135 |
+
else:
|
| 136 |
+
init_func(self, *args, **kwargs)
|
| 137 |
+
|
| 138 |
+
return wrapped
|
draco/configuration/draco2d-b_triplet_pretrain.yaml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_BASE_: base.yaml
|
| 2 |
+
|
| 3 |
+
MODEL:
|
| 4 |
+
NAME: DenoisingReconstructionAutoencoderVisionTransformer2d
|
| 5 |
+
|
| 6 |
+
IMG_SIZE: 256
|
| 7 |
+
PATCH_SIZE: 16
|
| 8 |
+
IN_CHANS: 1
|
| 9 |
+
VIT_SCALE: base
|
| 10 |
+
DYNAMIC_IMG_SIZE: False
|
| 11 |
+
DYNAMIC_IMG_PAD: False
|
| 12 |
+
USE_ABS_POS: True
|
| 13 |
+
DECODER_EMBED_DIM: 512
|
| 14 |
+
DECODER_DEPTH: 8
|
| 15 |
+
DECODER_NUM_HEADS: 16
|
| 16 |
+
DECODER_USE_NECK: True
|
| 17 |
+
DECODER_NECK_DIM: 256
|
| 18 |
+
|
| 19 |
+
SEED: 0
|
| 20 |
+
OUTPUT_DIR: null
|
draco/configuration/draco2d-h_triplet_pretrain.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_BASE_: draco-b_imagenet_pretrain.yaml
|
| 2 |
+
|
| 3 |
+
MODEL:
|
| 4 |
+
PATCH_SIZE: 14
|
| 5 |
+
VIT_SCALE: huge
|
draco/configuration/draco2d-l_triplet_pretrain.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_BASE_: draco-b_imagenet_pretrain.yaml
|
| 2 |
+
|
| 3 |
+
MODEL:
|
| 4 |
+
VIT_SCALE: large
|
draco/model/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .build import MODEL_REGISTRY, build_model
|
| 2 |
+
from .checkpoint import load_pretrained
|
| 3 |
+
|
| 4 |
+
from .draco2d import DenoisingReconstructionAutoencoderVisionTransformer2d, DracoDenoiseAutoencoder
|
| 5 |
+
|
| 6 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
draco/model/build.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fvcore.common.registry import Registry
|
| 2 |
+
from omegaconf import DictConfig
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
__all__ = ["MODEL_REGISTRY", "build_model"]
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
MODEL_REGISTRY = Registry("MODEL")
|
| 9 |
+
MODEL_REGISTRY.__doc__ = "Registry for the model."
|
| 10 |
+
|
| 11 |
+
def build_model(cfg: DictConfig) -> torch.nn.Module:
|
| 12 |
+
"""
|
| 13 |
+
Build the model defined by `cfg.MODEL.NAME`.
|
| 14 |
+
It moves the model to the device defined by `cfg.MODEL.DEVICE`.
|
| 15 |
+
It does not load checkpoints from `cfg`.
|
| 16 |
+
"""
|
| 17 |
+
model_name = cfg.MODEL.NAME
|
| 18 |
+
try:
|
| 19 |
+
model = MODEL_REGISTRY.get(model_name)(cfg)
|
| 20 |
+
except KeyError as e:
|
| 21 |
+
raise KeyError(MODEL_REGISTRY) from e
|
| 22 |
+
return model
|
draco/model/checkpoint.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def _strip_prefix_if_present(state_dict: dict[str, Any], prefix: str) -> None:
|
| 8 |
+
"""
|
| 9 |
+
Strip the prefix in metadata, if any.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
state_dict (OrderedDict): a state-dict to be loaded to the model.
|
| 13 |
+
prefix (str): prefix.
|
| 14 |
+
"""
|
| 15 |
+
keys = sorted(state_dict.keys())
|
| 16 |
+
if not all(len(key) == 0 or key.startswith(prefix) for key in keys):
|
| 17 |
+
return
|
| 18 |
+
|
| 19 |
+
for key in keys:
|
| 20 |
+
newkey = key[len(prefix) :]
|
| 21 |
+
state_dict[newkey] = state_dict.pop(key)
|
| 22 |
+
|
| 23 |
+
# also strip the prefix in metadata, if any..
|
| 24 |
+
try:
|
| 25 |
+
metadata = state_dict._metadata # pyre-ignore
|
| 26 |
+
except AttributeError:
|
| 27 |
+
pass
|
| 28 |
+
else:
|
| 29 |
+
for key in list(metadata.keys()):
|
| 30 |
+
# for the metadata dict, the key can be:
|
| 31 |
+
# '': for the DDP module, which we want to remove.
|
| 32 |
+
# 'module': for the actual model.
|
| 33 |
+
# 'module.xx.xx': for the rest.
|
| 34 |
+
|
| 35 |
+
if len(key) == 0:
|
| 36 |
+
continue
|
| 37 |
+
newkey = key[len(prefix) :]
|
| 38 |
+
metadata[newkey] = metadata.pop(key)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def load_pretrained(model: torch.nn.Module, ckpt_path: Path, device: torch.device = "cuda") -> torch.nn.Module:
|
| 42 |
+
"""
|
| 43 |
+
Load the pre-trained model from the checkpoint file.
|
| 44 |
+
"""
|
| 45 |
+
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
|
| 46 |
+
|
| 47 |
+
if "state_dict" in ckpt:
|
| 48 |
+
checkpoint_state_dict = ckpt["state_dict"]
|
| 49 |
+
elif "model" in ckpt:
|
| 50 |
+
checkpoint_state_dict = ckpt["model"]
|
| 51 |
+
else:
|
| 52 |
+
checkpoint_state_dict = ckpt
|
| 53 |
+
|
| 54 |
+
_strip_prefix_if_present(checkpoint_state_dict, "module.") # for DistributedDataParallel
|
| 55 |
+
_strip_prefix_if_present(checkpoint_state_dict, "model.") # for PyTorch Lightning Module
|
| 56 |
+
_strip_prefix_if_present(checkpoint_state_dict, "_orig_mod.") # for torch.compile
|
| 57 |
+
|
| 58 |
+
msg = model.load_state_dict(checkpoint_state_dict, strict=False)
|
| 59 |
+
print(f"Loaded pre-trained model from {ckpt_path} with message: {msg}")
|
| 60 |
+
|
| 61 |
+
return model
|
draco/model/draco2d.py
ADDED
|
@@ -0,0 +1,663 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
from typing import Any, Callable
|
| 3 |
+
|
| 4 |
+
from omegaconf import DictConfig
|
| 5 |
+
from timm.layers import build_sincos2d_pos_embed, resample_abs_pos_embed_nhwc, PatchEmbed, Mlp, LayerType
|
| 6 |
+
from timm.models.vision_transformer import Block
|
| 7 |
+
from timm.models.vision_transformer_sam import Block as SAMBlock
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
from draco.configuration import configurable
|
| 12 |
+
from .build import MODEL_REGISTRY
|
| 13 |
+
from .layer import LayerNorm2d
|
| 14 |
+
from .draco_base import DenoisingReconstructionAutoencoderVisionTransformerBase
|
| 15 |
+
from .utils.constant import get_vit_scale, get_global_attn_indexes
|
| 16 |
+
|
| 17 |
+
__all__ = ["DenoisingReconstructionAutoencoderVisionTransformer2d", "DracoDenoiseAutoencoder"]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@MODEL_REGISTRY.register()
|
| 21 |
+
class DenoisingReconstructionAutoencoderVisionTransformer2d(DenoisingReconstructionAutoencoderVisionTransformerBase):
|
| 22 |
+
@configurable
|
| 23 |
+
def __init__(self, *,
|
| 24 |
+
img_size: int = 224,
|
| 25 |
+
patch_size: int = 16,
|
| 26 |
+
in_chans: int = 3,
|
| 27 |
+
embed_layer: Callable = PatchEmbed,
|
| 28 |
+
dynamic_img_size: bool = False,
|
| 29 |
+
dynamic_img_pad: bool = False,
|
| 30 |
+
use_abs_pos: bool = True,
|
| 31 |
+
block_fn: nn.Module = Block,
|
| 32 |
+
norm_layer: LayerType = partial(nn.LayerNorm, eps=1e-6),
|
| 33 |
+
act_layer: LayerType = nn.GELU,
|
| 34 |
+
mlp_layer: nn.Module = Mlp,
|
| 35 |
+
embed_dim: int = 768,
|
| 36 |
+
depth: int = 12,
|
| 37 |
+
num_heads: int = 12,
|
| 38 |
+
mlp_ratio: float = 4.0,
|
| 39 |
+
qkv_bias: bool = True,
|
| 40 |
+
qk_norm: bool = False,
|
| 41 |
+
decoder_block_fn: nn.Module = Block,
|
| 42 |
+
decoder_norm_layer: LayerType = partial(nn.LayerNorm, eps=1e-6),
|
| 43 |
+
decoder_act_layer: LayerType = nn.GELU,
|
| 44 |
+
decoder_mlp_layer: nn.Module = Mlp,
|
| 45 |
+
decoder_embed_dim: int = 512,
|
| 46 |
+
decoder_depth: int = 8,
|
| 47 |
+
decoder_num_heads: int = 16,
|
| 48 |
+
decoder_use_neck: bool = True,
|
| 49 |
+
decoder_neck_dim: int = 256,
|
| 50 |
+
) -> None:
|
| 51 |
+
super().__init__()
|
| 52 |
+
|
| 53 |
+
self.dynamic_img_size = dynamic_img_size
|
| 54 |
+
self.decoder_use_neck = decoder_use_neck
|
| 55 |
+
|
| 56 |
+
self.init_encoder(
|
| 57 |
+
img_size=img_size,
|
| 58 |
+
patch_size=patch_size,
|
| 59 |
+
in_chans=in_chans,
|
| 60 |
+
embed_layer=embed_layer,
|
| 61 |
+
dynamic_img_size=dynamic_img_size,
|
| 62 |
+
dynamic_img_pad=dynamic_img_pad,
|
| 63 |
+
use_abs_pos=use_abs_pos,
|
| 64 |
+
block_fn=block_fn,
|
| 65 |
+
norm_layer=norm_layer,
|
| 66 |
+
act_layer=act_layer,
|
| 67 |
+
mlp_layer=mlp_layer,
|
| 68 |
+
embed_dim=embed_dim,
|
| 69 |
+
depth=depth,
|
| 70 |
+
num_heads=num_heads,
|
| 71 |
+
mlp_ratio=mlp_ratio,
|
| 72 |
+
qkv_bias=qkv_bias,
|
| 73 |
+
qk_norm=qk_norm,
|
| 74 |
+
)
|
| 75 |
+
self.init_decoder(
|
| 76 |
+
patch_size=patch_size,
|
| 77 |
+
in_chans=in_chans,
|
| 78 |
+
embed_dim=embed_dim,
|
| 79 |
+
use_abs_pos=use_abs_pos,
|
| 80 |
+
decoder_block_fn=decoder_block_fn,
|
| 81 |
+
decoder_norm_layer=decoder_norm_layer,
|
| 82 |
+
decoder_act_layer=decoder_act_layer,
|
| 83 |
+
decoder_mlp_layer=decoder_mlp_layer,
|
| 84 |
+
decoder_embed_dim=decoder_embed_dim,
|
| 85 |
+
decoder_depth=decoder_depth,
|
| 86 |
+
decoder_num_heads=decoder_num_heads,
|
| 87 |
+
decoder_use_neck=decoder_use_neck,
|
| 88 |
+
decoder_neck_dim=decoder_neck_dim,
|
| 89 |
+
mlp_ratio=mlp_ratio,
|
| 90 |
+
qkv_bias=qkv_bias,
|
| 91 |
+
qk_norm=qk_norm,
|
| 92 |
+
)
|
| 93 |
+
self.init_weights(
|
| 94 |
+
grid_size=self.patch_embed.grid_size,
|
| 95 |
+
embed_dim=embed_dim,
|
| 96 |
+
decoder_embed_dim=decoder_embed_dim,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
@classmethod
|
| 100 |
+
def from_config(cls, cfg: DictConfig) -> dict[str, Any]:
|
| 101 |
+
embed_dim, depth, num_heads = get_vit_scale(cfg.MODEL.VIT_SCALE)
|
| 102 |
+
return {
|
| 103 |
+
"img_size": cfg.MODEL.IMG_SIZE,
|
| 104 |
+
"patch_size": cfg.MODEL.PATCH_SIZE,
|
| 105 |
+
"in_chans": cfg.MODEL.IN_CHANS,
|
| 106 |
+
"dynamic_img_size": cfg.MODEL.DYNAMIC_IMG_SIZE,
|
| 107 |
+
"dynamic_img_pad": cfg.MODEL.DYNAMIC_IMG_PAD,
|
| 108 |
+
"use_abs_pos": cfg.MODEL.USE_ABS_POS,
|
| 109 |
+
"embed_dim": embed_dim,
|
| 110 |
+
"depth": depth,
|
| 111 |
+
"num_heads": num_heads,
|
| 112 |
+
"decoder_embed_dim": cfg.MODEL.DECODER_EMBED_DIM,
|
| 113 |
+
"decoder_depth": cfg.MODEL.DECODER_DEPTH,
|
| 114 |
+
"decoder_num_heads": cfg.MODEL.DECODER_NUM_HEADS,
|
| 115 |
+
"decoder_use_neck": cfg.MODEL.DECODER_USE_NECK,
|
| 116 |
+
"decoder_neck_dim": cfg.MODEL.DECODER_NECK_DIM,
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
def init_encoder(self, *,
|
| 120 |
+
img_size: int,
|
| 121 |
+
patch_size: int,
|
| 122 |
+
in_chans: int,
|
| 123 |
+
embed_layer: Callable,
|
| 124 |
+
dynamic_img_size: bool,
|
| 125 |
+
dynamic_img_pad: bool,
|
| 126 |
+
use_abs_pos: bool,
|
| 127 |
+
block_fn: nn.Module,
|
| 128 |
+
norm_layer: LayerType | None,
|
| 129 |
+
act_layer: LayerType | None,
|
| 130 |
+
mlp_layer: nn.Module,
|
| 131 |
+
embed_dim: int,
|
| 132 |
+
depth: int,
|
| 133 |
+
num_heads: int,
|
| 134 |
+
mlp_ratio: float,
|
| 135 |
+
qkv_bias: bool,
|
| 136 |
+
qk_norm: bool,
|
| 137 |
+
) -> None:
|
| 138 |
+
embed_args = {}
|
| 139 |
+
if dynamic_img_size:
|
| 140 |
+
embed_args.update(dict(strict_img_size=False))
|
| 141 |
+
self.patch_embed = embed_layer(
|
| 142 |
+
img_size=img_size,
|
| 143 |
+
patch_size=patch_size,
|
| 144 |
+
in_chans=in_chans,
|
| 145 |
+
embed_dim=embed_dim,
|
| 146 |
+
dynamic_img_pad=dynamic_img_pad,
|
| 147 |
+
output_fmt="NHWC",
|
| 148 |
+
**embed_args
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, *self.patch_embed.grid_size, embed_dim)) if use_abs_pos else None
|
| 152 |
+
self.blocks = nn.ModuleList([
|
| 153 |
+
block_fn(
|
| 154 |
+
dim=embed_dim,
|
| 155 |
+
num_heads=num_heads,
|
| 156 |
+
mlp_ratio=mlp_ratio,
|
| 157 |
+
qkv_bias=qkv_bias,
|
| 158 |
+
qk_norm=qk_norm,
|
| 159 |
+
norm_layer=norm_layer,
|
| 160 |
+
act_layer=act_layer,
|
| 161 |
+
mlp_layer=mlp_layer,
|
| 162 |
+
) for _ in range(depth)
|
| 163 |
+
])
|
| 164 |
+
self.norm = norm_layer(embed_dim)
|
| 165 |
+
|
| 166 |
+
def init_decoder(self, *,
|
| 167 |
+
patch_size: int,
|
| 168 |
+
in_chans: int,
|
| 169 |
+
embed_dim: int,
|
| 170 |
+
use_abs_pos: bool,
|
| 171 |
+
decoder_block_fn: nn.Module,
|
| 172 |
+
decoder_norm_layer: LayerType | None,
|
| 173 |
+
decoder_act_layer: LayerType | None,
|
| 174 |
+
decoder_mlp_layer: nn.Module,
|
| 175 |
+
decoder_embed_dim: int,
|
| 176 |
+
decoder_depth: int,
|
| 177 |
+
decoder_num_heads: int,
|
| 178 |
+
decoder_use_neck: bool,
|
| 179 |
+
decoder_neck_dim: int,
|
| 180 |
+
mlp_ratio: float,
|
| 181 |
+
qkv_bias: bool,
|
| 182 |
+
qk_norm: bool,
|
| 183 |
+
) -> None:
|
| 184 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
| 185 |
+
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim)
|
| 186 |
+
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, *self.patch_embed.grid_size, decoder_embed_dim)) if use_abs_pos else None
|
| 187 |
+
self.decoder_blocks = nn.ModuleList([
|
| 188 |
+
decoder_block_fn(
|
| 189 |
+
dim=decoder_embed_dim,
|
| 190 |
+
num_heads=decoder_num_heads,
|
| 191 |
+
mlp_ratio=mlp_ratio,
|
| 192 |
+
qkv_bias=qkv_bias,
|
| 193 |
+
qk_norm=qk_norm,
|
| 194 |
+
norm_layer=decoder_norm_layer,
|
| 195 |
+
act_layer=decoder_act_layer,
|
| 196 |
+
mlp_layer=decoder_mlp_layer,
|
| 197 |
+
) for _ in range(decoder_depth)
|
| 198 |
+
])
|
| 199 |
+
self.decoder_norm = decoder_norm_layer(decoder_embed_dim)
|
| 200 |
+
if decoder_use_neck:
|
| 201 |
+
self.decoder_neck = nn.Sequential(
|
| 202 |
+
nn.Conv2d(
|
| 203 |
+
in_channels=decoder_embed_dim,
|
| 204 |
+
out_channels=decoder_neck_dim,
|
| 205 |
+
kernel_size=1,
|
| 206 |
+
bias=False,
|
| 207 |
+
),
|
| 208 |
+
LayerNorm2d(decoder_neck_dim),
|
| 209 |
+
decoder_act_layer(),
|
| 210 |
+
nn.Conv2d(
|
| 211 |
+
in_channels=decoder_neck_dim,
|
| 212 |
+
out_channels=decoder_neck_dim,
|
| 213 |
+
kernel_size=3,
|
| 214 |
+
padding=1,
|
| 215 |
+
bias=False,
|
| 216 |
+
),
|
| 217 |
+
LayerNorm2d(decoder_neck_dim),
|
| 218 |
+
decoder_act_layer(),
|
| 219 |
+
nn.Conv2d(
|
| 220 |
+
in_channels=decoder_neck_dim,
|
| 221 |
+
out_channels=decoder_embed_dim,
|
| 222 |
+
kernel_size=1,
|
| 223 |
+
bias=False,
|
| 224 |
+
),
|
| 225 |
+
LayerNorm2d(decoder_embed_dim),
|
| 226 |
+
)
|
| 227 |
+
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans)
|
| 228 |
+
|
| 229 |
+
def init_weights(self, *,
|
| 230 |
+
grid_size: tuple[int, int],
|
| 231 |
+
embed_dim: int,
|
| 232 |
+
decoder_embed_dim: int
|
| 233 |
+
) -> None:
|
| 234 |
+
w = self.patch_embed.proj.weight.data
|
| 235 |
+
torch.nn.init.xavier_uniform_(w.view(w.size(0), -1))
|
| 236 |
+
|
| 237 |
+
torch.nn.init.normal_(self.mask_token, std=0.02)
|
| 238 |
+
|
| 239 |
+
if self.pos_embed is not None:
|
| 240 |
+
self.pos_embed.data.copy_(build_sincos2d_pos_embed(
|
| 241 |
+
feat_shape=grid_size,
|
| 242 |
+
dim=embed_dim,
|
| 243 |
+
interleave_sin_cos=True
|
| 244 |
+
).reshape(1, *grid_size, -1).transpose(1, 2))
|
| 245 |
+
|
| 246 |
+
if self.decoder_pos_embed is not None:
|
| 247 |
+
self.decoder_pos_embed.data.copy_(build_sincos2d_pos_embed(
|
| 248 |
+
feat_shape=grid_size,
|
| 249 |
+
dim=decoder_embed_dim,
|
| 250 |
+
interleave_sin_cos=True
|
| 251 |
+
).reshape(1, *grid_size, -1).transpose(1, 2))
|
| 252 |
+
|
| 253 |
+
if self.decoder_use_neck:
|
| 254 |
+
for m in self.decoder_neck.modules():
|
| 255 |
+
if isinstance(m, nn.Conv2d):
|
| 256 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 257 |
+
if m.bias is not None:
|
| 258 |
+
nn.init.zeros_(m.bias)
|
| 259 |
+
nn.init.zeros_(self.decoder_neck[-1].weight)
|
| 260 |
+
nn.init.zeros_(self.decoder_neck[-1].bias)
|
| 261 |
+
|
| 262 |
+
self.apply(self._init_weights)
|
| 263 |
+
|
| 264 |
+
def _init_weights(self, module: nn.Module) -> None:
|
| 265 |
+
if isinstance(module, nn.Linear):
|
| 266 |
+
nn.init.xavier_uniform_(module.weight)
|
| 267 |
+
if module.bias is not None:
|
| 268 |
+
nn.init.constant_(module.bias, 0.0)
|
| 269 |
+
|
| 270 |
+
def forward_encoder(self, x: torch.Tensor, mask_ratio: float) -> tuple[torch.Tensor, torch.BoolTensor, int, int]:
|
| 271 |
+
x = self.patch_embed(x)
|
| 272 |
+
B, H, W, E = x.shape
|
| 273 |
+
if self.pos_embed is not None:
|
| 274 |
+
x = x + resample_abs_pos_embed_nhwc(self.pos_embed, (H, W))
|
| 275 |
+
x = x.view(B, -1, E)
|
| 276 |
+
|
| 277 |
+
mask = super().random_masking(x, mask_ratio)
|
| 278 |
+
x = x[~mask].reshape(B, -1, E)
|
| 279 |
+
|
| 280 |
+
for block in self.blocks:
|
| 281 |
+
x = block(x)
|
| 282 |
+
x = self.norm(x)
|
| 283 |
+
|
| 284 |
+
return x, mask, H, W
|
| 285 |
+
|
| 286 |
+
def forward_decoder(self, x: torch.Tensor, mask: torch.BoolTensor, H: int, W: int) -> torch.Tensor:
|
| 287 |
+
x = self.decoder_embed(x)
|
| 288 |
+
|
| 289 |
+
B, L = mask.shape
|
| 290 |
+
E = x.shape[-1]
|
| 291 |
+
mask_tokens = self.mask_token.repeat(B, L, 1).to(x.dtype)
|
| 292 |
+
mask_tokens[~mask] = x.reshape(-1, E)
|
| 293 |
+
x = mask_tokens
|
| 294 |
+
|
| 295 |
+
if self.decoder_pos_embed is not None:
|
| 296 |
+
x = x.view(B, H, W, E)
|
| 297 |
+
x = x + resample_abs_pos_embed_nhwc(self.decoder_pos_embed, (H, W))
|
| 298 |
+
x = x.view(B, -1, E)
|
| 299 |
+
|
| 300 |
+
for block in self.decoder_blocks:
|
| 301 |
+
x = block(x)
|
| 302 |
+
x = self.decoder_norm(x)
|
| 303 |
+
if self.decoder_use_neck:
|
| 304 |
+
x = x + self.decoder_neck(
|
| 305 |
+
x.permute(0, 2, 1).reshape(B, E, H, W).contiguous()
|
| 306 |
+
).permute(0, 2, 3, 1).reshape(B, L, -1).contiguous()
|
| 307 |
+
x = self.decoder_pred(x)
|
| 308 |
+
|
| 309 |
+
return x
|
| 310 |
+
|
| 311 |
+
def forward(self, x: torch.Tensor, mask_ratio: float) -> tuple[torch.Tensor, torch.BoolTensor]:
|
| 312 |
+
x, mask, H, W = self.forward_encoder(x, mask_ratio)
|
| 313 |
+
x = self.forward_decoder(x, mask, H, W)
|
| 314 |
+
return x, mask
|
| 315 |
+
|
| 316 |
+
@MODEL_REGISTRY.register()
|
| 317 |
+
class DracoDenoiseAutoencoder(DenoisingReconstructionAutoencoderVisionTransformerBase):
|
| 318 |
+
"""
|
| 319 |
+
Masked Autoencoder (MAE) with Vision Transformer backbone.
|
| 320 |
+
Note that `cls_token` is discarded.
|
| 321 |
+
"""
|
| 322 |
+
|
| 323 |
+
@configurable
|
| 324 |
+
def __init__(self, *,
|
| 325 |
+
img_size: int = 224,
|
| 326 |
+
patch_size: int = 16,
|
| 327 |
+
in_chans: int = 3,
|
| 328 |
+
embed_layer: Callable = PatchEmbed,
|
| 329 |
+
dynamic_img_size: bool = False,
|
| 330 |
+
dynamic_img_pad: bool = False,
|
| 331 |
+
use_abs_pos: bool = True,
|
| 332 |
+
block_fn: nn.Module = SAMBlock,
|
| 333 |
+
norm_layer: LayerType = partial(nn.LayerNorm, eps=1e-6),
|
| 334 |
+
act_layer: LayerType = nn.GELU,
|
| 335 |
+
mlp_layer: nn.Module = Mlp,
|
| 336 |
+
embed_dim: int = 768,
|
| 337 |
+
depth: int = 12,
|
| 338 |
+
num_heads: int = 12,
|
| 339 |
+
mlp_ratio: float = 4.0,
|
| 340 |
+
qkv_bias: bool = True,
|
| 341 |
+
qk_norm: bool = False,
|
| 342 |
+
window_size: int = 16,
|
| 343 |
+
global_attn_indexes: list[int] = [2, 5, 8, 11],
|
| 344 |
+
decoder_block_fn: nn.Module = SAMBlock,
|
| 345 |
+
decoder_norm_layer: LayerType = partial(nn.LayerNorm, eps=1e-6),
|
| 346 |
+
decoder_act_layer: LayerType = nn.GELU,
|
| 347 |
+
decoder_mlp_layer: nn.Module = Mlp,
|
| 348 |
+
decoder_embed_dim: int = 512,
|
| 349 |
+
decoder_depth: int = 8,
|
| 350 |
+
decoder_num_heads: int = 16,
|
| 351 |
+
decoder_use_neck: bool = True,
|
| 352 |
+
decoder_neck_dim: int = 256,
|
| 353 |
+
decoder_global_attn_indexes: list[int] = [3, 7],
|
| 354 |
+
) -> None:
|
| 355 |
+
super().__init__()
|
| 356 |
+
|
| 357 |
+
self.dynamic_img_size = dynamic_img_size
|
| 358 |
+
self.decoder_use_neck = decoder_use_neck
|
| 359 |
+
|
| 360 |
+
self.init_encoder(
|
| 361 |
+
img_size=img_size,
|
| 362 |
+
patch_size=patch_size,
|
| 363 |
+
in_chans=in_chans,
|
| 364 |
+
embed_layer=embed_layer,
|
| 365 |
+
dynamic_img_size=dynamic_img_size,
|
| 366 |
+
dynamic_img_pad=dynamic_img_pad,
|
| 367 |
+
use_abs_pos=use_abs_pos,
|
| 368 |
+
block_fn=block_fn,
|
| 369 |
+
norm_layer=norm_layer,
|
| 370 |
+
act_layer=act_layer,
|
| 371 |
+
mlp_layer=mlp_layer,
|
| 372 |
+
embed_dim=embed_dim,
|
| 373 |
+
depth=depth,
|
| 374 |
+
num_heads=num_heads,
|
| 375 |
+
mlp_ratio=mlp_ratio,
|
| 376 |
+
qkv_bias=qkv_bias,
|
| 377 |
+
qk_norm=qk_norm,
|
| 378 |
+
window_size=window_size,
|
| 379 |
+
global_attn_indexes=global_attn_indexes
|
| 380 |
+
)
|
| 381 |
+
self.init_decoder(
|
| 382 |
+
img_size=img_size,
|
| 383 |
+
patch_size=patch_size,
|
| 384 |
+
in_chans=in_chans,
|
| 385 |
+
embed_dim=embed_dim,
|
| 386 |
+
use_abs_pos=use_abs_pos,
|
| 387 |
+
decoder_block_fn=decoder_block_fn,
|
| 388 |
+
decoder_norm_layer=decoder_norm_layer,
|
| 389 |
+
decoder_act_layer=decoder_act_layer,
|
| 390 |
+
decoder_mlp_layer=decoder_mlp_layer,
|
| 391 |
+
decoder_embed_dim=decoder_embed_dim,
|
| 392 |
+
decoder_depth=decoder_depth,
|
| 393 |
+
decoder_num_heads=decoder_num_heads,
|
| 394 |
+
decoder_use_neck=decoder_use_neck,
|
| 395 |
+
decoder_neck_dim=decoder_neck_dim,
|
| 396 |
+
mlp_ratio=mlp_ratio,
|
| 397 |
+
qkv_bias=qkv_bias,
|
| 398 |
+
qk_norm=qk_norm,
|
| 399 |
+
window_size=window_size,
|
| 400 |
+
decoder_global_attn_indexes=decoder_global_attn_indexes
|
| 401 |
+
)
|
| 402 |
+
self.init_weights(
|
| 403 |
+
grid_size=self.patch_embed.grid_size,
|
| 404 |
+
embed_dim=embed_dim,
|
| 405 |
+
decoder_embed_dim=decoder_embed_dim,
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
@classmethod
|
| 409 |
+
def from_config(cls, cfg: DictConfig) -> dict[str, Any]:
|
| 410 |
+
embed_dim, depth, num_heads = get_vit_scale(cfg.MODEL.VIT_SCALE)
|
| 411 |
+
global_attn_indexes = get_global_attn_indexes(depth)
|
| 412 |
+
return {
|
| 413 |
+
"img_size": cfg.MODEL.IMG_SIZE,
|
| 414 |
+
"patch_size": cfg.MODEL.PATCH_SIZE,
|
| 415 |
+
"in_chans": cfg.MODEL.IN_CHANS,
|
| 416 |
+
"dynamic_img_size": cfg.MODEL.DYNAMIC_IMG_SIZE,
|
| 417 |
+
"dynamic_img_pad": cfg.MODEL.DYNAMIC_IMG_PAD,
|
| 418 |
+
"use_abs_pos": cfg.MODEL.USE_ABS_POS,
|
| 419 |
+
"embed_dim": embed_dim,
|
| 420 |
+
"depth": depth,
|
| 421 |
+
"num_heads": num_heads,
|
| 422 |
+
"window_size": cfg.MODEL.WINDOW_SIZE,
|
| 423 |
+
"global_attn_indexes": global_attn_indexes,
|
| 424 |
+
"decoder_embed_dim": cfg.MODEL.DECODER_EMBED_DIM,
|
| 425 |
+
"decoder_depth": cfg.MODEL.DECODER_DEPTH,
|
| 426 |
+
"decoder_num_heads": cfg.MODEL.DECODER_NUM_HEADS,
|
| 427 |
+
"decoder_use_neck": cfg.MODEL.DECODER_USE_NECK,
|
| 428 |
+
"decoder_neck_dim": cfg.MODEL.DECODER_NECK_DIM,
|
| 429 |
+
"decoder_global_attn_indexes": cfg.MODEL.DECODER_GLOBAL_ATTN_INDEXES,
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
def init_encoder(self, *,
|
| 433 |
+
img_size: int,
|
| 434 |
+
patch_size: int,
|
| 435 |
+
in_chans: int,
|
| 436 |
+
embed_layer: Callable,
|
| 437 |
+
dynamic_img_size: bool,
|
| 438 |
+
dynamic_img_pad: bool,
|
| 439 |
+
use_abs_pos: bool,
|
| 440 |
+
block_fn: nn.Module,
|
| 441 |
+
norm_layer: LayerType | None,
|
| 442 |
+
act_layer: LayerType | None,
|
| 443 |
+
mlp_layer: nn.Module,
|
| 444 |
+
embed_dim: int,
|
| 445 |
+
depth: int,
|
| 446 |
+
num_heads: int,
|
| 447 |
+
mlp_ratio: float,
|
| 448 |
+
qkv_bias: bool,
|
| 449 |
+
qk_norm: bool,
|
| 450 |
+
window_size: int,
|
| 451 |
+
global_attn_indexes: list,
|
| 452 |
+
) -> None:
|
| 453 |
+
embed_args = {}
|
| 454 |
+
if dynamic_img_size:
|
| 455 |
+
# flatten deferred until after pos embed
|
| 456 |
+
embed_args.update(dict(strict_img_size=False))
|
| 457 |
+
self.patch_embed = embed_layer(
|
| 458 |
+
img_size=img_size,
|
| 459 |
+
patch_size=patch_size,
|
| 460 |
+
in_chans=in_chans,
|
| 461 |
+
embed_dim=embed_dim,
|
| 462 |
+
dynamic_img_pad=dynamic_img_pad,
|
| 463 |
+
output_fmt="NHWC",
|
| 464 |
+
**embed_args
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, *self.patch_embed.grid_size, embed_dim)) if use_abs_pos else None
|
| 468 |
+
self.blocks = nn.ModuleList(
|
| 469 |
+
block_fn(
|
| 470 |
+
dim=embed_dim,
|
| 471 |
+
num_heads=num_heads,
|
| 472 |
+
mlp_ratio=mlp_ratio,
|
| 473 |
+
qkv_bias=qkv_bias,
|
| 474 |
+
qk_norm=qk_norm,
|
| 475 |
+
norm_layer=norm_layer,
|
| 476 |
+
act_layer=act_layer,
|
| 477 |
+
mlp_layer=mlp_layer,
|
| 478 |
+
use_rel_pos=True,
|
| 479 |
+
window_size=window_size if i not in global_attn_indexes else 0,
|
| 480 |
+
input_size=(img_size // patch_size, img_size // patch_size),
|
| 481 |
+
) for i in range(depth)
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
self.norm = norm_layer(embed_dim)
|
| 485 |
+
|
| 486 |
+
def init_decoder(self, *,
|
| 487 |
+
img_size: int,
|
| 488 |
+
patch_size: int,
|
| 489 |
+
in_chans: int,
|
| 490 |
+
embed_dim: int,
|
| 491 |
+
use_abs_pos: bool,
|
| 492 |
+
decoder_block_fn: nn.Module,
|
| 493 |
+
decoder_norm_layer: LayerType | None,
|
| 494 |
+
decoder_act_layer: LayerType | None,
|
| 495 |
+
decoder_mlp_layer: nn.Module,
|
| 496 |
+
decoder_embed_dim: int,
|
| 497 |
+
decoder_depth: int,
|
| 498 |
+
decoder_num_heads: int,
|
| 499 |
+
decoder_use_neck: bool,
|
| 500 |
+
decoder_neck_dim: int,
|
| 501 |
+
mlp_ratio: float,
|
| 502 |
+
qkv_bias: bool,
|
| 503 |
+
qk_norm: bool,
|
| 504 |
+
window_size: int,
|
| 505 |
+
decoder_global_attn_indexes: list[int]
|
| 506 |
+
) -> None:
|
| 507 |
+
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim)
|
| 508 |
+
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, *self.patch_embed.grid_size, decoder_embed_dim)) if use_abs_pos else None
|
| 509 |
+
self.decoder_blocks = nn.ModuleList(
|
| 510 |
+
decoder_block_fn(
|
| 511 |
+
dim=decoder_embed_dim,
|
| 512 |
+
num_heads=decoder_num_heads,
|
| 513 |
+
mlp_ratio=mlp_ratio,
|
| 514 |
+
qkv_bias=qkv_bias,
|
| 515 |
+
qk_norm=qk_norm,
|
| 516 |
+
norm_layer=decoder_norm_layer,
|
| 517 |
+
act_layer=decoder_act_layer,
|
| 518 |
+
mlp_layer=decoder_mlp_layer,
|
| 519 |
+
use_rel_pos=True,
|
| 520 |
+
window_size=window_size if i not in decoder_global_attn_indexes else 0,
|
| 521 |
+
input_size=(img_size // patch_size, img_size // patch_size),
|
| 522 |
+
) for i in range(decoder_depth)
|
| 523 |
+
)
|
| 524 |
+
self.decoder_norm = decoder_norm_layer(decoder_embed_dim)
|
| 525 |
+
if decoder_use_neck:
|
| 526 |
+
self.decoder_neck = nn.Sequential(
|
| 527 |
+
nn.Conv2d(
|
| 528 |
+
in_channels=decoder_embed_dim,
|
| 529 |
+
out_channels=decoder_neck_dim,
|
| 530 |
+
kernel_size=1,
|
| 531 |
+
bias=False,
|
| 532 |
+
),
|
| 533 |
+
LayerNorm2d(decoder_neck_dim),
|
| 534 |
+
decoder_act_layer(),
|
| 535 |
+
nn.Conv2d(
|
| 536 |
+
in_channels=decoder_neck_dim,
|
| 537 |
+
out_channels=decoder_neck_dim,
|
| 538 |
+
kernel_size=3,
|
| 539 |
+
padding=1,
|
| 540 |
+
bias=False,
|
| 541 |
+
),
|
| 542 |
+
LayerNorm2d(decoder_neck_dim),
|
| 543 |
+
decoder_act_layer(),
|
| 544 |
+
nn.Conv2d(
|
| 545 |
+
in_channels=decoder_neck_dim,
|
| 546 |
+
out_channels=decoder_embed_dim,
|
| 547 |
+
kernel_size=1,
|
| 548 |
+
bias=False,
|
| 549 |
+
),
|
| 550 |
+
LayerNorm2d(decoder_embed_dim),
|
| 551 |
+
)
|
| 552 |
+
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans)
|
| 553 |
+
|
| 554 |
+
def init_weights(self, *,
|
| 555 |
+
grid_size: tuple[int, int],
|
| 556 |
+
embed_dim: int,
|
| 557 |
+
decoder_embed_dim: int
|
| 558 |
+
) -> None:
|
| 559 |
+
w = self.patch_embed.proj.weight.data
|
| 560 |
+
torch.nn.init.xavier_uniform_(w.view(w.size(0), -1))
|
| 561 |
+
|
| 562 |
+
if self.pos_embed is not None:
|
| 563 |
+
self.pos_embed.data.copy_(build_sincos2d_pos_embed(
|
| 564 |
+
feat_shape=grid_size,
|
| 565 |
+
dim=embed_dim,
|
| 566 |
+
interleave_sin_cos=True
|
| 567 |
+
).reshape(1, *grid_size, -1).transpose(1, 2))
|
| 568 |
+
|
| 569 |
+
if self.decoder_pos_embed is not None:
|
| 570 |
+
self.decoder_pos_embed.data.copy_(build_sincos2d_pos_embed(
|
| 571 |
+
feat_shape=grid_size,
|
| 572 |
+
dim=decoder_embed_dim,
|
| 573 |
+
interleave_sin_cos=True
|
| 574 |
+
).reshape(1, *grid_size, -1).transpose(1, 2))
|
| 575 |
+
|
| 576 |
+
# Zero-initialize the neck
|
| 577 |
+
if self.decoder_use_neck:
|
| 578 |
+
for m in self.decoder_neck.modules():
|
| 579 |
+
if isinstance(m, nn.Conv2d):
|
| 580 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 581 |
+
if m.bias is not None:
|
| 582 |
+
nn.init.zeros_(m.bias)
|
| 583 |
+
nn.init.zeros_(self.decoder_neck[-1].weight)
|
| 584 |
+
nn.init.zeros_(self.decoder_neck[-1].bias)
|
| 585 |
+
|
| 586 |
+
self.apply(self._init_weights)
|
| 587 |
+
|
| 588 |
+
def _init_weights(self, module: nn.Module) -> None:
|
| 589 |
+
if isinstance(module, nn.Linear):
|
| 590 |
+
nn.init.xavier_uniform_(module.weight)
|
| 591 |
+
if module.bias is not None:
|
| 592 |
+
nn.init.constant_(module.bias, 0.0)
|
| 593 |
+
|
| 594 |
+
def forward_encoder(self, x: torch.Tensor) -> tuple[torch.Tensor, int, int]:
|
| 595 |
+
"""
|
| 596 |
+
Forward pass of the encoder.
|
| 597 |
+
|
| 598 |
+
Args:
|
| 599 |
+
`x` (torch.Tensor): Image of shape [B, C, H, W].
|
| 600 |
+
|
| 601 |
+
Returns:
|
| 602 |
+
(torch.Tensor): Encoded image of shape [B, num_kept, E].
|
| 603 |
+
(int): Height of the encoded tokens.
|
| 604 |
+
(int): Width of the encoded tokens.
|
| 605 |
+
"""
|
| 606 |
+
x = self.patch_embed(x)
|
| 607 |
+
B, H, W, E = x.shape
|
| 608 |
+
|
| 609 |
+
if self.pos_embed is not None:
|
| 610 |
+
x = x + resample_abs_pos_embed_nhwc(self.pos_embed, (H, W))
|
| 611 |
+
|
| 612 |
+
for block in self.blocks:
|
| 613 |
+
x = block(x)
|
| 614 |
+
|
| 615 |
+
x = x.view(B, -1, E)
|
| 616 |
+
x = self.norm(x)
|
| 617 |
+
|
| 618 |
+
return x, H, W
|
| 619 |
+
|
| 620 |
+
def forward_decoder(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
| 621 |
+
"""
|
| 622 |
+
Forward pass of the decoder.
|
| 623 |
+
|
| 624 |
+
Args:
|
| 625 |
+
`x` (torch.Tensor): Encoded image of shape [B, num_kept, E].
|
| 626 |
+
`H` (int): Height of the encoded tokens.
|
| 627 |
+
`W` (int): Width of the encoded tokens.
|
| 628 |
+
|
| 629 |
+
Returns:
|
| 630 |
+
(torch.Tensor): Decoded image of shape [B, L, E].
|
| 631 |
+
"""
|
| 632 |
+
x = self.decoder_embed(x) # [B, num_kept, E]
|
| 633 |
+
B, L, E = x.shape
|
| 634 |
+
|
| 635 |
+
if self.decoder_pos_embed is not None:
|
| 636 |
+
x = x.view(B, H, W, E)
|
| 637 |
+
x = x + resample_abs_pos_embed_nhwc(self.decoder_pos_embed, (H, W))
|
| 638 |
+
|
| 639 |
+
for block in self.decoder_blocks:
|
| 640 |
+
x = block(x)
|
| 641 |
+
x = x.view(B, -1, E)
|
| 642 |
+
|
| 643 |
+
x = self.decoder_norm(x)
|
| 644 |
+
if self.decoder_use_neck:
|
| 645 |
+
x = x + self.decoder_neck(
|
| 646 |
+
x.permute(0, 2, 1).reshape(B, E, H, W).contiguous()
|
| 647 |
+
).permute(0, 2, 3, 1).reshape(B, L, -1).contiguous()
|
| 648 |
+
x = self.decoder_pred(x)
|
| 649 |
+
|
| 650 |
+
return x
|
| 651 |
+
|
| 652 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 653 |
+
"""
|
| 654 |
+
Args:
|
| 655 |
+
`x` (torch.Tensor): Image of shape [B, C, H, W].
|
| 656 |
+
|
| 657 |
+
Returns:
|
| 658 |
+
(torch.Tensor): The prediction of shape [B, L, E].
|
| 659 |
+
"""
|
| 660 |
+
x, H, W = self.forward_encoder(x)
|
| 661 |
+
x = self.forward_decoder(x, H, W)
|
| 662 |
+
return x
|
| 663 |
+
|
draco/model/draco_base.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABCMeta, abstractmethod
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class DenoisingReconstructionAutoencoderVisionTransformerBase(nn.Module, metaclass=ABCMeta):
|
| 8 |
+
def __init__(self) -> None:
|
| 9 |
+
super().__init__()
|
| 10 |
+
|
| 11 |
+
@torch.jit.ignore
|
| 12 |
+
def no_weight_decay(self) -> set:
|
| 13 |
+
return {"cls_token"}
|
| 14 |
+
|
| 15 |
+
@torch.jit.ignore
|
| 16 |
+
def group_matcher(self, coarse: bool = False) -> dict:
|
| 17 |
+
return dict(
|
| 18 |
+
stem=r'^(?:_orig_mod\.)?cls_token|^(?:_orig_mod\.)?pos_embed|^(?:_orig_mod\.)?patch_embed',
|
| 19 |
+
blocks=[(r'^(?:_orig_mod\.)?blocks\.(\d+)', None), (r'^(?:_orig_mod\.)?norm', (99999,))]
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
@classmethod
|
| 23 |
+
def random_masking(cls, x: torch.Tensor, mask_ratio: float) -> torch.BoolTensor:
|
| 24 |
+
B, L = x.shape[:2]
|
| 25 |
+
num_masked = int(L * mask_ratio)
|
| 26 |
+
|
| 27 |
+
noise = torch.rand(B, L, device=x.device)
|
| 28 |
+
rank = noise.argsort(dim=1)
|
| 29 |
+
mask = rank < num_masked
|
| 30 |
+
|
| 31 |
+
return mask
|
| 32 |
+
|
| 33 |
+
@abstractmethod
|
| 34 |
+
def forward(self) -> None:
|
| 35 |
+
raise NotImplementedError
|
draco/model/layer/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .normalization import LayerNorm2d
|
| 2 |
+
|
| 3 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
draco/model/layer/normalization.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"LayerNorm2d",
|
| 6 |
+
]
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class LayerNorm2d(nn.Module):
|
| 10 |
+
def __init__(self, num_features: int, eps: float = 1e-6) -> None:
|
| 11 |
+
super().__init__()
|
| 12 |
+
|
| 13 |
+
self.weight = nn.Parameter(torch.ones(num_features))
|
| 14 |
+
self.bias = nn.Parameter(torch.zeros(num_features))
|
| 15 |
+
self.eps = eps
|
| 16 |
+
|
| 17 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 18 |
+
u = x.mean(1, keepdim=True)
|
| 19 |
+
s = (x - u).square().mean(1, keepdim=True)
|
| 20 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 21 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 22 |
+
return x
|
draco/model/utils/constant.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def get_vit_scale(scale: str) -> tuple[int, int, int]:
|
| 2 |
+
if scale == "tiny":
|
| 3 |
+
return 192, 12, 3
|
| 4 |
+
elif scale == "small":
|
| 5 |
+
return 384, 12, 6
|
| 6 |
+
elif scale == "base":
|
| 7 |
+
return 768, 12, 12
|
| 8 |
+
elif scale == "large":
|
| 9 |
+
return 1024, 24, 16
|
| 10 |
+
elif scale == "huge":
|
| 11 |
+
return 1280, 32, 16
|
| 12 |
+
else:
|
| 13 |
+
raise KeyError(f"Unknown Vision Transformer scale: {scale}")
|
| 14 |
+
|
| 15 |
+
def get_global_attn_indexes(num_layers: int) -> list[int]:
|
| 16 |
+
"""
|
| 17 |
+
Args:
|
| 18 |
+
num_layers (int): The number of layers.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
List[int]: The global attention indexes.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
return list(range(num_layers // 4 - 1, num_layers, num_layers // 4))
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.5.1
|
| 2 |
+
torchvision==0.20.1
|
| 3 |
+
h5py==3.12.1
|
| 4 |
+
numpy==1.26.4
|
| 5 |
+
pandas==2.2.2
|
| 6 |
+
mrcfile==1.5.3
|
| 7 |
+
scipy==1.13.1
|
| 8 |
+
pycocotools==2.0.8
|
| 9 |
+
omegaconf==2.3.0
|
| 10 |
+
pillow
|
| 11 |
+
fvcore
|