|
|
from torch import nn, Tensor |
|
|
import open_clip |
|
|
from peft import get_peft_model, LoraConfig |
|
|
|
|
|
from ..utils import ConvRefine, ConvAdapter |
|
|
from ..utils import ConvUpsample, _get_norm_layer, _get_activation |
|
|
|
|
|
|
|
|
convnext_names_and_weights = { |
|
|
"convnext_base": ["laion400m_s13b_b51k"], |
|
|
"convnext_base_w": ["laion2b_s13b_b82k", "laion2b_s13b_b82k_augreg", "laion_aesthetic_s13b_b82k"], |
|
|
"convnext_base_w_320": ["laion_aesthetic_s13b_b82k", "laion_aesthetic_s13b_b82k_augreg"], |
|
|
"convnext_large_d": ["laion2b_s26b_b102k_augreg"], |
|
|
"convnext_large_d_320": ["laion2b_s29b_b131k_ft", "laion2b_s29b_b131k_ft_soup"], |
|
|
"convnext_xxlarge": ["laion2b_s34b_b82k_augreg", "laion2b_s34b_b82k_augreg_rewind", "laion2b_s34b_b82k_augreg_soup"] |
|
|
} |
|
|
|
|
|
refiner_channels = { |
|
|
"convnext_base": 1024, |
|
|
"convnext_base_w": 1024, |
|
|
"convnext_base_w_320": 1024, |
|
|
"convnext_large_d": 1536, |
|
|
"convnext_large_d_320": 1536, |
|
|
"convnext_xxlarge": 3072, |
|
|
} |
|
|
|
|
|
refiner_groups = { |
|
|
"convnext_base": 1, |
|
|
"convnext_base_w": 1, |
|
|
"convnext_base_w_320": 1, |
|
|
"convnext_large_d": refiner_channels["convnext_large_d"] // 512, |
|
|
"convnext_large_d_320": refiner_channels["convnext_large_d_320"] // 512, |
|
|
"convnext_xxlarge": refiner_channels["convnext_xxlarge"] // 512, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
class ConvNeXt(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
model_name: str, |
|
|
weight_name: str, |
|
|
block_size: int = 16, |
|
|
adapter: bool = False, |
|
|
adapter_reduction: int = 4, |
|
|
norm: str = "none", |
|
|
act: str = "none" |
|
|
) -> None: |
|
|
super(ConvNeXt, self).__init__() |
|
|
assert model_name in convnext_names_and_weights, f"Model name should be one of {list(convnext_names_and_weights.keys())}, but got {model_name}." |
|
|
assert weight_name in convnext_names_and_weights[model_name], f"Pretrained should be one of {convnext_names_and_weights[model_name]}, but got {weight_name}." |
|
|
assert block_size in [32, 16, 8], f"block_size should be one of [32, 16, 8], got {block_size}" |
|
|
self.model_name, self.weight_name = model_name, weight_name |
|
|
self.block_size = block_size |
|
|
|
|
|
model = open_clip.create_model_from_pretrained(model_name, weight_name, return_transform=False).visual |
|
|
|
|
|
self.adapter = adapter |
|
|
if adapter: |
|
|
self.adapter_reduction = adapter_reduction |
|
|
for param in model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
self.stem = model.trunk.stem |
|
|
self.depth = len(model.trunk.stages) |
|
|
for idx, stage in enumerate(model.trunk.stages): |
|
|
setattr(self, f"stage{idx}", stage) |
|
|
if adapter: |
|
|
setattr(self, f"adapter{idx}", ConvAdapter( |
|
|
in_channels=stage.blocks[-1].mlp.fc2.out_features, |
|
|
bottleneck_channels=stage.blocks[-1].mlp.fc2.out_features // adapter_reduction, |
|
|
) if idx < self.depth - 1 else nn.Identity()) |
|
|
|
|
|
if self.model_name in ["convnext_base", "convnext_base_w", "convnext_base_w_320", "convnext_xxlarge"]: |
|
|
self.in_features, self.out_features = model.head.proj.in_features, model.head.proj.out_features |
|
|
else: |
|
|
self.in_features, self.out_features = model.head.mlp.fc1.in_features, model.head.mlp.fc2.out_features |
|
|
|
|
|
if norm == "bn": |
|
|
norm_layer = nn.BatchNorm2d |
|
|
elif norm == "ln": |
|
|
norm_layer = nn.LayerNorm |
|
|
else: |
|
|
norm_layer = _get_norm_layer(model) |
|
|
|
|
|
if act == "relu": |
|
|
activation = nn.ReLU(inplace=True) |
|
|
elif act == "gelu": |
|
|
activation = nn.GELU() |
|
|
else: |
|
|
activation = _get_activation(model) |
|
|
|
|
|
if block_size == 32: |
|
|
self.refiner = ConvRefine( |
|
|
in_channels=self.in_features, |
|
|
out_channels=self.in_features, |
|
|
norm_layer=norm_layer, |
|
|
activation=activation, |
|
|
groups=refiner_groups[self.model_name], |
|
|
) |
|
|
elif block_size == 16: |
|
|
self.refiner = ConvUpsample( |
|
|
in_channels=self.in_features, |
|
|
out_channels=self.in_features, |
|
|
norm_layer=norm_layer, |
|
|
activation=activation, |
|
|
groups=refiner_groups[self.model_name], |
|
|
) |
|
|
else: |
|
|
self.refiner = nn.Sequential( |
|
|
ConvUpsample( |
|
|
in_channels=self.in_features, |
|
|
out_channels=self.in_features, |
|
|
norm_layer=norm_layer, |
|
|
activation=activation, |
|
|
groups=refiner_groups[self.model_name], |
|
|
), |
|
|
ConvUpsample( |
|
|
in_channels=self.in_features, |
|
|
out_channels=self.in_features, |
|
|
norm_layer=norm_layer, |
|
|
activation=activation, |
|
|
groups=refiner_groups[self.model_name], |
|
|
), |
|
|
) |
|
|
|
|
|
def train(self, mode: bool = True): |
|
|
if self.adapter and mode: |
|
|
|
|
|
self.stem.eval() |
|
|
|
|
|
for idx in range(self.depth): |
|
|
getattr(self, f"stage{idx}").eval() |
|
|
getattr(self, f"adapter{idx}").train() |
|
|
|
|
|
self.refiner.train() |
|
|
|
|
|
else: |
|
|
|
|
|
for module in self.children(): |
|
|
module.train(mode) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
x = self.stem(x) |
|
|
|
|
|
for idx in range(self.depth): |
|
|
x = getattr(self, f"stage{idx}")(x) |
|
|
if self.adapter: |
|
|
x = getattr(self, f"adapter{idx}")(x) |
|
|
|
|
|
x = self.refiner(x) |
|
|
return x |
|
|
|
|
|
|
|
|
def _convnext( |
|
|
model_name: str, |
|
|
weight_name: str, |
|
|
block_size: int = 16, |
|
|
adapter: bool = False, |
|
|
adapter_reduction: int = 4, |
|
|
lora: bool = False, |
|
|
lora_rank: int = 16, |
|
|
lora_alpha: float = 32.0, |
|
|
lora_dropout: float = 0.1, |
|
|
norm: str = "none", |
|
|
act: str = "none" |
|
|
) -> ConvNeXt: |
|
|
assert not (lora and adapter), "Lora and adapter cannot be used together." |
|
|
model = ConvNeXt( |
|
|
model_name=model_name, |
|
|
weight_name=weight_name, |
|
|
block_size=block_size, |
|
|
adapter=adapter, |
|
|
adapter_reduction=adapter_reduction, |
|
|
norm=norm, |
|
|
act=act |
|
|
) |
|
|
|
|
|
if lora: |
|
|
target_modules = [] |
|
|
for name, module in model.named_modules(): |
|
|
if isinstance(module, (nn.Linear, nn.Conv2d)) and "refiner" not in name: |
|
|
target_modules.append(name) |
|
|
|
|
|
lora_config = LoraConfig( |
|
|
r=lora_rank, |
|
|
lora_alpha=lora_alpha, |
|
|
lora_dropout=lora_dropout, |
|
|
bias="none", |
|
|
target_modules=target_modules, |
|
|
) |
|
|
model = get_peft_model(model, lora_config) |
|
|
|
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
if "refiner" in name: |
|
|
module.requires_grad_(True) |
|
|
|
|
|
return model |