| """Model classes for RSP models compatible with transformers""" |
|
|
| import sys |
| import os |
| from pathlib import Path |
| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel |
| from safetensors.torch import load_file |
|
|
| |
| from modular_swin import SwinTransformer |
|
|
| |
| _parent_dir = Path(__file__).parent.parent |
| import importlib.util |
|
|
| |
| _resnet_path = _parent_dir / "RSP-ResNet-50" / "modular_resnet.py" |
| if _resnet_path.exists(): |
| spec = importlib.util.spec_from_file_location("modular_resnet_resnet", _resnet_path) |
| resnet_module = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(resnet_module) |
| ResNet = resnet_module.ResNet |
| Bottleneck = resnet_module.Bottleneck |
| else: |
| ResNet = None |
| Bottleneck = None |
|
|
| |
| _vitae_path = _parent_dir / "RSP-ViTAEv2-S" / "modular_vitae_window_noshift.py" |
| if _vitae_path.exists(): |
| spec = importlib.util.spec_from_file_location("modular_vitae_window_noshift_vitae", _vitae_path) |
| vitae_module = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(vitae_module) |
| ViTAE_Window_NoShift_12_basic_stages4_14 = vitae_module.ViTAE_Window_NoShift_12_basic_stages4_14 |
| else: |
| ViTAE_Window_NoShift_12_basic_stages4_14 = None |
|
|
| |
| try: |
| from configuration_rsp import RSPResNetConfig, RSPSwinConfig, RSPViTAEConfig |
| except ImportError: |
| |
| import importlib.util |
| config_path = Path(__file__).parent / "configuration_rsp.py" |
| spec = importlib.util.spec_from_file_location("configuration_rsp", config_path) |
| config_module = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(config_module) |
| RSPResNetConfig = config_module.RSPResNetConfig |
| RSPSwinConfig = config_module.RSPSwinConfig |
| RSPViTAEConfig = config_module.RSPViTAEConfig |
|
|
|
|
| class RSPResNetForImageClassification(PreTrainedModel): |
| """RSP ResNet model for image classification""" |
| |
| config_class = RSPResNetConfig |
| |
| def __init__(self, config): |
| super().__init__(config) |
| |
| |
| block = Bottleneck if config.block == "Bottleneck" else None |
| if block is None: |
| raise ValueError(f"Unsupported block type: {config.block}") |
| |
| self.model = ResNet( |
| block=block, |
| layers=config.layers, |
| num_classes=config.num_labels |
| ) |
| |
| def forward(self, pixel_values=None, labels=None, **kwargs): |
| """ |
| Args: |
| pixel_values: Input images (B, C, H, W) |
| labels: Optional labels for loss computation |
| """ |
| if pixel_values is None: |
| raise ValueError("pixel_values must be provided") |
| |
| logits = self.model(pixel_values) |
| |
| loss = None |
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) |
| |
| return { |
| "logits": logits, |
| "loss": loss |
| } if loss is not None else {"logits": logits} |
| |
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
| """Load model from pretrained checkpoint""" |
| config = kwargs.pop("config", None) |
| if config is None: |
| config = RSPResNetConfig.from_pretrained(pretrained_model_name_or_path) |
| |
| model = cls(config) |
| |
| |
| model_path = Path(pretrained_model_name_or_path) |
| safetensors_path = model_path / "model.safetensors" |
| |
| if safetensors_path.exists(): |
| state_dict = load_file(str(safetensors_path)) |
| |
| state_dict_clean = {} |
| for k, v in state_dict.items(): |
| if k.startswith("model."): |
| state_dict_clean[k[6:]] = v |
| else: |
| state_dict_clean[k] = v |
| model.model.load_state_dict(state_dict_clean, strict=False) |
| else: |
| raise FileNotFoundError(f"Model weights not found at {safetensors_path}") |
| |
| return model |
|
|
|
|
| class RSPSwinForImageClassification(PreTrainedModel): |
| """RSP Swin Transformer model for image classification""" |
| |
| config_class = RSPSwinConfig |
| |
| def __init__(self, config): |
| super().__init__(config) |
| |
| |
| self.model = SwinTransformer( |
| img_size=config.image_size, |
| patch_size=config.patch_size, |
| in_chans=config.num_channels, |
| num_classes=config.num_labels, |
| embed_dim=config.embed_dim, |
| depths=config.depths, |
| num_heads=config.num_heads, |
| window_size=config.window_size, |
| mlp_ratio=config.mlp_ratio, |
| qkv_bias=config.qkv_bias, |
| ape=config.ape, |
| patch_norm=config.patch_norm, |
| ) |
| |
| def forward(self, pixel_values=None, labels=None, **kwargs): |
| """ |
| Args: |
| pixel_values: Input images (B, C, H, W) |
| labels: Optional labels for loss computation |
| """ |
| if pixel_values is None: |
| raise ValueError("pixel_values must be provided") |
| |
| logits = self.model(pixel_values) |
| |
| loss = None |
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) |
| |
| return { |
| "logits": logits, |
| "loss": loss |
| } if loss is not None else {"logits": logits} |
| |
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
| """Load model from pretrained checkpoint""" |
| config = kwargs.pop("config", None) |
| if config is None: |
| config = RSPSwinConfig.from_pretrained(pretrained_model_name_or_path) |
| |
| model = cls(config) |
| |
| |
| model_path = Path(pretrained_model_name_or_path) |
| safetensors_path = model_path / "model.safetensors" |
| |
| if safetensors_path.exists(): |
| state_dict = load_file(str(safetensors_path)) |
| |
| state_dict_clean = {} |
| for k, v in state_dict.items(): |
| if k.startswith("model."): |
| state_dict_clean[k[6:]] = v |
| else: |
| state_dict_clean[k] = v |
| model.model.load_state_dict(state_dict_clean, strict=False) |
| else: |
| raise FileNotFoundError(f"Model weights not found at {safetensors_path}") |
| |
| return model |
|
|
|
|
| class RSPViTAEForImageClassification(PreTrainedModel): |
| """RSP ViTAE model for image classification""" |
| |
| config_class = RSPViTAEConfig |
| |
| def __init__(self, config): |
| super().__init__(config) |
| |
| |
| |
| |
| |
| |
| |
| self.model = ViTAE_Window_NoShift_12_basic_stages4_14( |
| pretrained=False, |
| img_size=config.image_size, |
| num_classes=config.num_labels, |
| window_size=7, |
| ) |
| |
| def forward(self, pixel_values=None, labels=None, **kwargs): |
| """ |
| Args: |
| pixel_values: Input images (B, C, H, W) |
| labels: Optional labels for loss computation |
| """ |
| if pixel_values is None: |
| raise ValueError("pixel_values must be provided") |
| |
| logits = self.model(pixel_values) |
| |
| loss = None |
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) |
| |
| return { |
| "logits": logits, |
| "loss": loss |
| } if loss is not None else {"logits": logits} |
| |
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
| """Load model from pretrained checkpoint""" |
| config = kwargs.pop("config", None) |
| if config is None: |
| config = RSPViTAEConfig.from_pretrained(pretrained_model_name_or_path) |
| |
| model = cls(config) |
| |
| |
| model_path = Path(pretrained_model_name_or_path) |
| safetensors_path = model_path / "model.safetensors" |
| |
| if safetensors_path.exists(): |
| state_dict = load_file(str(safetensors_path)) |
| |
| state_dict_clean = {} |
| for k, v in state_dict.items(): |
| if k.startswith("model."): |
| state_dict_clean[k[6:]] = v |
| else: |
| state_dict_clean[k] = v |
| model.model.load_state_dict(state_dict_clean, strict=False) |
| else: |
| raise FileNotFoundError(f"Model weights not found at {safetensors_path}") |
| |
| return model |
|
|