RSP-ResNet-50 / modeling_rsp.py
BiliSakura's picture
Add files using upload-large-folder tool
844428c verified
"""Model classes for RSP models compatible with transformers"""
import sys
import os
from pathlib import Path
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import ImageClassifierOutputWithNoAttention
from safetensors.torch import load_file
# Import local modular model
from modular_resnet import ResNet, Bottleneck
# Import other models from sibling directories if needed
_parent_dir = Path(__file__).parent.parent
import importlib.util
# Import SwinTransformer from RSP-Swin-T
_swin_path = _parent_dir / "RSP-Swin-T" / "modular_swin.py"
if _swin_path.exists():
spec = importlib.util.spec_from_file_location("modular_swin_swin", _swin_path)
swin_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(swin_module)
SwinTransformer = swin_module.SwinTransformer
else:
SwinTransformer = None
# Import ViTAE from RSP-ViTAEv2-S
_vitae_path = _parent_dir / "RSP-ViTAEv2-S" / "modular_vitae_window_noshift.py"
if _vitae_path.exists():
spec = importlib.util.spec_from_file_location("modular_vitae_window_noshift_vitae", _vitae_path)
vitae_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(vitae_module)
ViTAE_Window_NoShift_12_basic_stages4_14 = vitae_module.ViTAE_Window_NoShift_12_basic_stages4_14
else:
ViTAE_Window_NoShift_12_basic_stages4_14 = None
# Import configuration - handle both relative and absolute imports
try:
from configuration_rsp import RSPResNetConfig, RSPSwinConfig, RSPViTAEConfig
except ImportError:
# Fallback: import from same directory
import importlib.util
config_path = Path(__file__).parent / "configuration_rsp.py"
spec = importlib.util.spec_from_file_location("configuration_rsp", config_path)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
RSPResNetConfig = config_module.RSPResNetConfig
RSPSwinConfig = config_module.RSPSwinConfig
RSPViTAEConfig = config_module.RSPViTAEConfig
class RSPResNetForImageClassification(PreTrainedModel):
"""RSP ResNet model for image classification"""
config_class = RSPResNetConfig
def __init__(self, config):
super().__init__(config)
# Build ResNet model from config
block = Bottleneck if config.block == "Bottleneck" else None
if block is None:
raise ValueError(f"Unsupported block type: {config.block}")
self.model = ResNet(
block=block,
layers=config.layers,
num_classes=config.num_labels
)
def forward(self, pixel_values=None, labels=None, return_dict=None, **kwargs):
"""
Args:
pixel_values: Input images (B, C, H, W)
labels: Optional labels for loss computation
return_dict: Whether to return a ModelOutput instead of a plain tuple
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("pixel_values must be provided")
logits = self.model(pixel_values)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
if not return_dict:
output = (logits,)
return (loss,) + output if loss is not None else output
return ImageClassifierOutputWithNoAttention(
loss=loss,
logits=logits,
hidden_states=None,
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"""Load model from pretrained checkpoint"""
config = kwargs.pop("config", None)
if config is None:
config = RSPResNetConfig.from_pretrained(pretrained_model_name_or_path)
model = cls(config)
# Load weights from safetensors
model_path = Path(pretrained_model_name_or_path)
safetensors_path = model_path / "model.safetensors"
if safetensors_path.exists():
state_dict = load_file(str(safetensors_path))
# Remove 'model.' prefix if present
state_dict_clean = {}
for k, v in state_dict.items():
if k.startswith("model."):
state_dict_clean[k[6:]] = v
else:
state_dict_clean[k] = v
model.model.load_state_dict(state_dict_clean, strict=False)
else:
raise FileNotFoundError(f"Model weights not found at {safetensors_path}")
return model
class RSPSwinForImageClassification(PreTrainedModel):
"""RSP Swin Transformer model for image classification"""
config_class = RSPSwinConfig
def __init__(self, config):
super().__init__(config)
# Build SwinTransformer model from config
self.model = SwinTransformer(
img_size=config.image_size,
patch_size=config.patch_size,
in_chans=config.num_channels,
num_classes=config.num_labels,
embed_dim=config.embed_dim,
depths=config.depths,
num_heads=config.num_heads,
window_size=config.window_size,
mlp_ratio=config.mlp_ratio,
qkv_bias=config.qkv_bias,
ape=config.ape,
patch_norm=config.patch_norm,
)
def forward(self, pixel_values=None, labels=None, return_dict=None, **kwargs):
"""
Args:
pixel_values: Input images (B, C, H, W)
labels: Optional labels for loss computation
return_dict: Whether to return a ModelOutput instead of a plain tuple
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("pixel_values must be provided")
logits = self.model(pixel_values)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
if not return_dict:
output = (logits,)
return (loss,) + output if loss is not None else output
return ImageClassifierOutputWithNoAttention(
loss=loss,
logits=logits,
hidden_states=None,
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"""Load model from pretrained checkpoint"""
config = kwargs.pop("config", None)
if config is None:
config = RSPSwinConfig.from_pretrained(pretrained_model_name_or_path)
model = cls(config)
# Load weights from safetensors
model_path = Path(pretrained_model_name_or_path)
safetensors_path = model_path / "model.safetensors"
if safetensors_path.exists():
state_dict = load_file(str(safetensors_path))
# Remove 'model.' prefix if present
state_dict_clean = {}
for k, v in state_dict.items():
if k.startswith("model."):
state_dict_clean[k[6:]] = v
else:
state_dict_clean[k] = v
model.model.load_state_dict(state_dict_clean, strict=False)
else:
raise FileNotFoundError(f"Model weights not found at {safetensors_path}")
return model
class RSPViTAEForImageClassification(PreTrainedModel):
"""RSP ViTAE model for image classification"""
config_class = RSPViTAEConfig
def __init__(self, config):
super().__init__(config)
# Build ViTAE model from config
# Note: ViTAE_Window_NoShift_12_basic_stages4_14 already sets most parameters as defaults:
# - stages=4, embed_dims=[64, 64, 128, 256], token_dims=[64, 128, 256, 512]
# - downsample_ratios=[4, 2, 2, 2], NC_depth=[2, 2, 8, 2], etc.
# We only pass parameters that need to be overridden (img_size, num_classes)
# The function accepts **kwargs, so we can pass window_size if needed
self.model = ViTAE_Window_NoShift_12_basic_stages4_14(
pretrained=False,
img_size=config.image_size,
num_classes=config.num_labels,
window_size=7,
)
def forward(self, pixel_values=None, labels=None, return_dict=None, **kwargs):
"""
Args:
pixel_values: Input images (B, C, H, W)
labels: Optional labels for loss computation
return_dict: Whether to return a ModelOutput instead of a plain tuple
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("pixel_values must be provided")
logits = self.model(pixel_values)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
if not return_dict:
output = (logits,)
return (loss,) + output if loss is not None else output
return ImageClassifierOutputWithNoAttention(
loss=loss,
logits=logits,
hidden_states=None,
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"""Load model from pretrained checkpoint"""
config = kwargs.pop("config", None)
if config is None:
config = RSPViTAEConfig.from_pretrained(pretrained_model_name_or_path)
model = cls(config)
# Load weights from safetensors
model_path = Path(pretrained_model_name_or_path)
safetensors_path = model_path / "model.safetensors"
if safetensors_path.exists():
state_dict = load_file(str(safetensors_path))
# Remove 'model.' prefix if present
state_dict_clean = {}
for k, v in state_dict.items():
if k.startswith("model."):
state_dict_clean[k[6:]] = v
else:
state_dict_clean[k] = v
model.model.load_state_dict(state_dict_clean, strict=False)
else:
raise FileNotFoundError(f"Model weights not found at {safetensors_path}")
return model