"""Model classes for RSP models compatible with transformers""" import sys import os from pathlib import Path import torch import torch.nn as nn from transformers import PreTrainedModel from transformers.modeling_outputs import ImageClassifierOutputWithNoAttention from safetensors.torch import load_file # Import local modular model from modular_resnet import ResNet, Bottleneck # Import other models from sibling directories if needed _parent_dir = Path(__file__).parent.parent import importlib.util # Import SwinTransformer from RSP-Swin-T _swin_path = _parent_dir / "RSP-Swin-T" / "modular_swin.py" if _swin_path.exists(): spec = importlib.util.spec_from_file_location("modular_swin_swin", _swin_path) swin_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(swin_module) SwinTransformer = swin_module.SwinTransformer else: SwinTransformer = None # Import ViTAE from RSP-ViTAEv2-S _vitae_path = _parent_dir / "RSP-ViTAEv2-S" / "modular_vitae_window_noshift.py" if _vitae_path.exists(): spec = importlib.util.spec_from_file_location("modular_vitae_window_noshift_vitae", _vitae_path) vitae_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(vitae_module) ViTAE_Window_NoShift_12_basic_stages4_14 = vitae_module.ViTAE_Window_NoShift_12_basic_stages4_14 else: ViTAE_Window_NoShift_12_basic_stages4_14 = None # Import configuration - handle both relative and absolute imports try: from configuration_rsp import RSPResNetConfig, RSPSwinConfig, RSPViTAEConfig except ImportError: # Fallback: import from same directory import importlib.util config_path = Path(__file__).parent / "configuration_rsp.py" spec = importlib.util.spec_from_file_location("configuration_rsp", config_path) config_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(config_module) RSPResNetConfig = config_module.RSPResNetConfig RSPSwinConfig = config_module.RSPSwinConfig RSPViTAEConfig = config_module.RSPViTAEConfig class RSPResNetForImageClassification(PreTrainedModel): """RSP ResNet model for image classification""" config_class = RSPResNetConfig def __init__(self, config): super().__init__(config) # Build ResNet model from config block = Bottleneck if config.block == "Bottleneck" else None if block is None: raise ValueError(f"Unsupported block type: {config.block}") self.model = ResNet( block=block, layers=config.layers, num_classes=config.num_labels ) def forward(self, pixel_values=None, labels=None, return_dict=None, **kwargs): """ Args: pixel_values: Input images (B, C, H, W) labels: Optional labels for loss computation return_dict: Whether to return a ModelOutput instead of a plain tuple """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None: raise ValueError("pixel_values must be provided") logits = self.model(pixel_values) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) if not return_dict: output = (logits,) return (loss,) + output if loss is not None else output return ImageClassifierOutputWithNoAttention( loss=loss, logits=logits, hidden_states=None, ) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): """Load model from pretrained checkpoint""" config = kwargs.pop("config", None) if config is None: config = RSPResNetConfig.from_pretrained(pretrained_model_name_or_path) model = cls(config) # Load weights from safetensors model_path = Path(pretrained_model_name_or_path) safetensors_path = model_path / "model.safetensors" if safetensors_path.exists(): state_dict = load_file(str(safetensors_path)) # Remove 'model.' prefix if present state_dict_clean = {} for k, v in state_dict.items(): if k.startswith("model."): state_dict_clean[k[6:]] = v else: state_dict_clean[k] = v model.model.load_state_dict(state_dict_clean, strict=False) else: raise FileNotFoundError(f"Model weights not found at {safetensors_path}") return model class RSPSwinForImageClassification(PreTrainedModel): """RSP Swin Transformer model for image classification""" config_class = RSPSwinConfig def __init__(self, config): super().__init__(config) # Build SwinTransformer model from config self.model = SwinTransformer( img_size=config.image_size, patch_size=config.patch_size, in_chans=config.num_channels, num_classes=config.num_labels, embed_dim=config.embed_dim, depths=config.depths, num_heads=config.num_heads, window_size=config.window_size, mlp_ratio=config.mlp_ratio, qkv_bias=config.qkv_bias, ape=config.ape, patch_norm=config.patch_norm, ) def forward(self, pixel_values=None, labels=None, return_dict=None, **kwargs): """ Args: pixel_values: Input images (B, C, H, W) labels: Optional labels for loss computation return_dict: Whether to return a ModelOutput instead of a plain tuple """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None: raise ValueError("pixel_values must be provided") logits = self.model(pixel_values) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) if not return_dict: output = (logits,) return (loss,) + output if loss is not None else output return ImageClassifierOutputWithNoAttention( loss=loss, logits=logits, hidden_states=None, ) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): """Load model from pretrained checkpoint""" config = kwargs.pop("config", None) if config is None: config = RSPSwinConfig.from_pretrained(pretrained_model_name_or_path) model = cls(config) # Load weights from safetensors model_path = Path(pretrained_model_name_or_path) safetensors_path = model_path / "model.safetensors" if safetensors_path.exists(): state_dict = load_file(str(safetensors_path)) # Remove 'model.' prefix if present state_dict_clean = {} for k, v in state_dict.items(): if k.startswith("model."): state_dict_clean[k[6:]] = v else: state_dict_clean[k] = v model.model.load_state_dict(state_dict_clean, strict=False) else: raise FileNotFoundError(f"Model weights not found at {safetensors_path}") return model class RSPViTAEForImageClassification(PreTrainedModel): """RSP ViTAE model for image classification""" config_class = RSPViTAEConfig def __init__(self, config): super().__init__(config) # Build ViTAE model from config # Note: ViTAE_Window_NoShift_12_basic_stages4_14 already sets most parameters as defaults: # - stages=4, embed_dims=[64, 64, 128, 256], token_dims=[64, 128, 256, 512] # - downsample_ratios=[4, 2, 2, 2], NC_depth=[2, 2, 8, 2], etc. # We only pass parameters that need to be overridden (img_size, num_classes) # The function accepts **kwargs, so we can pass window_size if needed self.model = ViTAE_Window_NoShift_12_basic_stages4_14( pretrained=False, img_size=config.image_size, num_classes=config.num_labels, window_size=7, ) def forward(self, pixel_values=None, labels=None, return_dict=None, **kwargs): """ Args: pixel_values: Input images (B, C, H, W) labels: Optional labels for loss computation return_dict: Whether to return a ModelOutput instead of a plain tuple """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None: raise ValueError("pixel_values must be provided") logits = self.model(pixel_values) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) if not return_dict: output = (logits,) return (loss,) + output if loss is not None else output return ImageClassifierOutputWithNoAttention( loss=loss, logits=logits, hidden_states=None, ) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): """Load model from pretrained checkpoint""" config = kwargs.pop("config", None) if config is None: config = RSPViTAEConfig.from_pretrained(pretrained_model_name_or_path) model = cls(config) # Load weights from safetensors model_path = Path(pretrained_model_name_or_path) safetensors_path = model_path / "model.safetensors" if safetensors_path.exists(): state_dict = load_file(str(safetensors_path)) # Remove 'model.' prefix if present state_dict_clean = {} for k, v in state_dict.items(): if k.startswith("model."): state_dict_clean[k[6:]] = v else: state_dict_clean[k] = v model.model.load_state_dict(state_dict_clean, strict=False) else: raise FileNotFoundError(f"Model weights not found at {safetensors_path}") return model