RSP-Swin-T / modeling_rsp.py
BiliSakura's picture
Add files using upload-large-folder tool
bc2f004 verified
"""Model classes for RSP models compatible with transformers"""
import sys
import os
from pathlib import Path
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from safetensors.torch import load_file
# Import local modular model
from modular_swin import SwinTransformer
# Import other models from sibling directories if needed
_parent_dir = Path(__file__).parent.parent
import importlib.util
# Import ResNet from RSP-ResNet-50
_resnet_path = _parent_dir / "RSP-ResNet-50" / "modular_resnet.py"
if _resnet_path.exists():
spec = importlib.util.spec_from_file_location("modular_resnet_resnet", _resnet_path)
resnet_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(resnet_module)
ResNet = resnet_module.ResNet
Bottleneck = resnet_module.Bottleneck
else:
ResNet = None
Bottleneck = None
# Import ViTAE from RSP-ViTAEv2-S
_vitae_path = _parent_dir / "RSP-ViTAEv2-S" / "modular_vitae_window_noshift.py"
if _vitae_path.exists():
spec = importlib.util.spec_from_file_location("modular_vitae_window_noshift_vitae", _vitae_path)
vitae_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(vitae_module)
ViTAE_Window_NoShift_12_basic_stages4_14 = vitae_module.ViTAE_Window_NoShift_12_basic_stages4_14
else:
ViTAE_Window_NoShift_12_basic_stages4_14 = None
# Import configuration - handle both relative and absolute imports
try:
from configuration_rsp import RSPResNetConfig, RSPSwinConfig, RSPViTAEConfig
except ImportError:
# Fallback: import from same directory
import importlib.util
config_path = Path(__file__).parent / "configuration_rsp.py"
spec = importlib.util.spec_from_file_location("configuration_rsp", config_path)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
RSPResNetConfig = config_module.RSPResNetConfig
RSPSwinConfig = config_module.RSPSwinConfig
RSPViTAEConfig = config_module.RSPViTAEConfig
class RSPResNetForImageClassification(PreTrainedModel):
"""RSP ResNet model for image classification"""
config_class = RSPResNetConfig
def __init__(self, config):
super().__init__(config)
# Build ResNet model from config
block = Bottleneck if config.block == "Bottleneck" else None
if block is None:
raise ValueError(f"Unsupported block type: {config.block}")
self.model = ResNet(
block=block,
layers=config.layers,
num_classes=config.num_labels
)
def forward(self, pixel_values=None, labels=None, **kwargs):
"""
Args:
pixel_values: Input images (B, C, H, W)
labels: Optional labels for loss computation
"""
if pixel_values is None:
raise ValueError("pixel_values must be provided")
logits = self.model(pixel_values)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
return {
"logits": logits,
"loss": loss
} if loss is not None else {"logits": logits}
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"""Load model from pretrained checkpoint"""
config = kwargs.pop("config", None)
if config is None:
config = RSPResNetConfig.from_pretrained(pretrained_model_name_or_path)
model = cls(config)
# Load weights from safetensors
model_path = Path(pretrained_model_name_or_path)
safetensors_path = model_path / "model.safetensors"
if safetensors_path.exists():
state_dict = load_file(str(safetensors_path))
# Remove 'model.' prefix if present
state_dict_clean = {}
for k, v in state_dict.items():
if k.startswith("model."):
state_dict_clean[k[6:]] = v
else:
state_dict_clean[k] = v
model.model.load_state_dict(state_dict_clean, strict=False)
else:
raise FileNotFoundError(f"Model weights not found at {safetensors_path}")
return model
class RSPSwinForImageClassification(PreTrainedModel):
"""RSP Swin Transformer model for image classification"""
config_class = RSPSwinConfig
def __init__(self, config):
super().__init__(config)
# Build SwinTransformer model from config
self.model = SwinTransformer(
img_size=config.image_size,
patch_size=config.patch_size,
in_chans=config.num_channels,
num_classes=config.num_labels,
embed_dim=config.embed_dim,
depths=config.depths,
num_heads=config.num_heads,
window_size=config.window_size,
mlp_ratio=config.mlp_ratio,
qkv_bias=config.qkv_bias,
ape=config.ape,
patch_norm=config.patch_norm,
)
def forward(self, pixel_values=None, labels=None, **kwargs):
"""
Args:
pixel_values: Input images (B, C, H, W)
labels: Optional labels for loss computation
"""
if pixel_values is None:
raise ValueError("pixel_values must be provided")
logits = self.model(pixel_values)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
return {
"logits": logits,
"loss": loss
} if loss is not None else {"logits": logits}
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"""Load model from pretrained checkpoint"""
config = kwargs.pop("config", None)
if config is None:
config = RSPSwinConfig.from_pretrained(pretrained_model_name_or_path)
model = cls(config)
# Load weights from safetensors
model_path = Path(pretrained_model_name_or_path)
safetensors_path = model_path / "model.safetensors"
if safetensors_path.exists():
state_dict = load_file(str(safetensors_path))
# Remove 'model.' prefix if present
state_dict_clean = {}
for k, v in state_dict.items():
if k.startswith("model."):
state_dict_clean[k[6:]] = v
else:
state_dict_clean[k] = v
model.model.load_state_dict(state_dict_clean, strict=False)
else:
raise FileNotFoundError(f"Model weights not found at {safetensors_path}")
return model
class RSPViTAEForImageClassification(PreTrainedModel):
"""RSP ViTAE model for image classification"""
config_class = RSPViTAEConfig
def __init__(self, config):
super().__init__(config)
# Build ViTAE model from config
# Note: ViTAE_Window_NoShift_12_basic_stages4_14 already sets most parameters as defaults:
# - stages=4, embed_dims=[64, 64, 128, 256], token_dims=[64, 128, 256, 512]
# - downsample_ratios=[4, 2, 2, 2], NC_depth=[2, 2, 8, 2], etc.
# We only pass parameters that need to be overridden (img_size, num_classes)
# The function accepts **kwargs, so we can pass window_size if needed
self.model = ViTAE_Window_NoShift_12_basic_stages4_14(
pretrained=False,
img_size=config.image_size,
num_classes=config.num_labels,
window_size=7,
)
def forward(self, pixel_values=None, labels=None, **kwargs):
"""
Args:
pixel_values: Input images (B, C, H, W)
labels: Optional labels for loss computation
"""
if pixel_values is None:
raise ValueError("pixel_values must be provided")
logits = self.model(pixel_values)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
return {
"logits": logits,
"loss": loss
} if loss is not None else {"logits": logits}
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"""Load model from pretrained checkpoint"""
config = kwargs.pop("config", None)
if config is None:
config = RSPViTAEConfig.from_pretrained(pretrained_model_name_or_path)
model = cls(config)
# Load weights from safetensors
model_path = Path(pretrained_model_name_or_path)
safetensors_path = model_path / "model.safetensors"
if safetensors_path.exists():
state_dict = load_file(str(safetensors_path))
# Remove 'model.' prefix if present
state_dict_clean = {}
for k, v in state_dict.items():
if k.startswith("model."):
state_dict_clean[k[6:]] = v
else:
state_dict_clean[k] = v
model.model.load_state_dict(state_dict_clean, strict=False)
else:
raise FileNotFoundError(f"Model weights not found at {safetensors_path}")
return model