|
|
"""Model classes for RSP models compatible with transformers""" |
|
|
|
|
|
import sys |
|
|
import os |
|
|
from pathlib import Path |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import PreTrainedModel |
|
|
from transformers.modeling_outputs import ImageClassifierOutputWithNoAttention |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
|
|
|
from modular_resnet import ResNet, Bottleneck |
|
|
|
|
|
|
|
|
_parent_dir = Path(__file__).parent.parent |
|
|
import importlib.util |
|
|
|
|
|
|
|
|
_swin_path = _parent_dir / "RSP-Swin-T" / "modular_swin.py" |
|
|
if _swin_path.exists(): |
|
|
spec = importlib.util.spec_from_file_location("modular_swin_swin", _swin_path) |
|
|
swin_module = importlib.util.module_from_spec(spec) |
|
|
spec.loader.exec_module(swin_module) |
|
|
SwinTransformer = swin_module.SwinTransformer |
|
|
else: |
|
|
SwinTransformer = None |
|
|
|
|
|
|
|
|
_vitae_path = _parent_dir / "RSP-ViTAEv2-S" / "modular_vitae_window_noshift.py" |
|
|
if _vitae_path.exists(): |
|
|
spec = importlib.util.spec_from_file_location("modular_vitae_window_noshift_vitae", _vitae_path) |
|
|
vitae_module = importlib.util.module_from_spec(spec) |
|
|
spec.loader.exec_module(vitae_module) |
|
|
ViTAE_Window_NoShift_12_basic_stages4_14 = vitae_module.ViTAE_Window_NoShift_12_basic_stages4_14 |
|
|
else: |
|
|
ViTAE_Window_NoShift_12_basic_stages4_14 = None |
|
|
|
|
|
|
|
|
try: |
|
|
from configuration_rsp import RSPResNetConfig, RSPSwinConfig, RSPViTAEConfig |
|
|
except ImportError: |
|
|
|
|
|
import importlib.util |
|
|
config_path = Path(__file__).parent / "configuration_rsp.py" |
|
|
spec = importlib.util.spec_from_file_location("configuration_rsp", config_path) |
|
|
config_module = importlib.util.module_from_spec(spec) |
|
|
spec.loader.exec_module(config_module) |
|
|
RSPResNetConfig = config_module.RSPResNetConfig |
|
|
RSPSwinConfig = config_module.RSPSwinConfig |
|
|
RSPViTAEConfig = config_module.RSPViTAEConfig |
|
|
|
|
|
|
|
|
class RSPResNetForImageClassification(PreTrainedModel): |
|
|
"""RSP ResNet model for image classification""" |
|
|
|
|
|
config_class = RSPResNetConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
block = Bottleneck if config.block == "Bottleneck" else None |
|
|
if block is None: |
|
|
raise ValueError(f"Unsupported block type: {config.block}") |
|
|
|
|
|
self.model = ResNet( |
|
|
block=block, |
|
|
layers=config.layers, |
|
|
num_classes=config.num_labels |
|
|
) |
|
|
|
|
|
def forward(self, pixel_values=None, labels=None, return_dict=None, **kwargs): |
|
|
""" |
|
|
Args: |
|
|
pixel_values: Input images (B, C, H, W) |
|
|
labels: Optional labels for loss computation |
|
|
return_dict: Whether to return a ModelOutput instead of a plain tuple |
|
|
""" |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if pixel_values is None: |
|
|
raise ValueError("pixel_values must be provided") |
|
|
|
|
|
logits = self.model(pixel_values) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) |
|
|
return (loss,) + output if loss is not None else output |
|
|
|
|
|
return ImageClassifierOutputWithNoAttention( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
hidden_states=None, |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
|
"""Load model from pretrained checkpoint""" |
|
|
config = kwargs.pop("config", None) |
|
|
if config is None: |
|
|
config = RSPResNetConfig.from_pretrained(pretrained_model_name_or_path) |
|
|
|
|
|
model = cls(config) |
|
|
|
|
|
|
|
|
model_path = Path(pretrained_model_name_or_path) |
|
|
safetensors_path = model_path / "model.safetensors" |
|
|
|
|
|
if safetensors_path.exists(): |
|
|
state_dict = load_file(str(safetensors_path)) |
|
|
|
|
|
state_dict_clean = {} |
|
|
for k, v in state_dict.items(): |
|
|
if k.startswith("model."): |
|
|
state_dict_clean[k[6:]] = v |
|
|
else: |
|
|
state_dict_clean[k] = v |
|
|
model.model.load_state_dict(state_dict_clean, strict=False) |
|
|
else: |
|
|
raise FileNotFoundError(f"Model weights not found at {safetensors_path}") |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
class RSPSwinForImageClassification(PreTrainedModel): |
|
|
"""RSP Swin Transformer model for image classification""" |
|
|
|
|
|
config_class = RSPSwinConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
self.model = SwinTransformer( |
|
|
img_size=config.image_size, |
|
|
patch_size=config.patch_size, |
|
|
in_chans=config.num_channels, |
|
|
num_classes=config.num_labels, |
|
|
embed_dim=config.embed_dim, |
|
|
depths=config.depths, |
|
|
num_heads=config.num_heads, |
|
|
window_size=config.window_size, |
|
|
mlp_ratio=config.mlp_ratio, |
|
|
qkv_bias=config.qkv_bias, |
|
|
ape=config.ape, |
|
|
patch_norm=config.patch_norm, |
|
|
) |
|
|
|
|
|
def forward(self, pixel_values=None, labels=None, return_dict=None, **kwargs): |
|
|
""" |
|
|
Args: |
|
|
pixel_values: Input images (B, C, H, W) |
|
|
labels: Optional labels for loss computation |
|
|
return_dict: Whether to return a ModelOutput instead of a plain tuple |
|
|
""" |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if pixel_values is None: |
|
|
raise ValueError("pixel_values must be provided") |
|
|
|
|
|
logits = self.model(pixel_values) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) |
|
|
return (loss,) + output if loss is not None else output |
|
|
|
|
|
return ImageClassifierOutputWithNoAttention( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
hidden_states=None, |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
|
"""Load model from pretrained checkpoint""" |
|
|
config = kwargs.pop("config", None) |
|
|
if config is None: |
|
|
config = RSPSwinConfig.from_pretrained(pretrained_model_name_or_path) |
|
|
|
|
|
model = cls(config) |
|
|
|
|
|
|
|
|
model_path = Path(pretrained_model_name_or_path) |
|
|
safetensors_path = model_path / "model.safetensors" |
|
|
|
|
|
if safetensors_path.exists(): |
|
|
state_dict = load_file(str(safetensors_path)) |
|
|
|
|
|
state_dict_clean = {} |
|
|
for k, v in state_dict.items(): |
|
|
if k.startswith("model."): |
|
|
state_dict_clean[k[6:]] = v |
|
|
else: |
|
|
state_dict_clean[k] = v |
|
|
model.model.load_state_dict(state_dict_clean, strict=False) |
|
|
else: |
|
|
raise FileNotFoundError(f"Model weights not found at {safetensors_path}") |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
class RSPViTAEForImageClassification(PreTrainedModel): |
|
|
"""RSP ViTAE model for image classification""" |
|
|
|
|
|
config_class = RSPViTAEConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.model = ViTAE_Window_NoShift_12_basic_stages4_14( |
|
|
pretrained=False, |
|
|
img_size=config.image_size, |
|
|
num_classes=config.num_labels, |
|
|
window_size=7, |
|
|
) |
|
|
|
|
|
def forward(self, pixel_values=None, labels=None, return_dict=None, **kwargs): |
|
|
""" |
|
|
Args: |
|
|
pixel_values: Input images (B, C, H, W) |
|
|
labels: Optional labels for loss computation |
|
|
return_dict: Whether to return a ModelOutput instead of a plain tuple |
|
|
""" |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if pixel_values is None: |
|
|
raise ValueError("pixel_values must be provided") |
|
|
|
|
|
logits = self.model(pixel_values) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) |
|
|
return (loss,) + output if loss is not None else output |
|
|
|
|
|
return ImageClassifierOutputWithNoAttention( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
hidden_states=None, |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
|
"""Load model from pretrained checkpoint""" |
|
|
config = kwargs.pop("config", None) |
|
|
if config is None: |
|
|
config = RSPViTAEConfig.from_pretrained(pretrained_model_name_or_path) |
|
|
|
|
|
model = cls(config) |
|
|
|
|
|
|
|
|
model_path = Path(pretrained_model_name_or_path) |
|
|
safetensors_path = model_path / "model.safetensors" |
|
|
|
|
|
if safetensors_path.exists(): |
|
|
state_dict = load_file(str(safetensors_path)) |
|
|
|
|
|
state_dict_clean = {} |
|
|
for k, v in state_dict.items(): |
|
|
if k.startswith("model."): |
|
|
state_dict_clean[k[6:]] = v |
|
|
else: |
|
|
state_dict_clean[k] = v |
|
|
model.model.load_state_dict(state_dict_clean, strict=False) |
|
|
else: |
|
|
raise FileNotFoundError(f"Model weights not found at {safetensors_path}") |
|
|
|
|
|
return model |
|
|
|