| """
|
| ISNet model for transformers library
|
| This file is automatically loaded when trust_remote_code=True is used
|
| """
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import numpy as np
|
| from transformers import PreTrainedModel, PretrainedConfig
|
|
|
|
|
| import sys
|
| import os
|
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| from models.isnet import ISNetDIS
|
|
|
| class ISNetConfig(PretrainedConfig):
|
| """Configuration for ISNet model"""
|
| model_type = "isnet"
|
|
|
| def __init__(self, in_ch=3, out_ch=1, **kwargs):
|
| super().__init__(**kwargs)
|
| self.in_ch = in_ch
|
| self.out_ch = out_ch
|
| self.num_labels = out_ch
|
| self.architectures = ["ISNetForImageSegmentation"]
|
|
|
| class ISNetForImageSegmentation(PreTrainedModel):
|
| """Transformers-compatible ISNet model for image segmentation"""
|
|
|
| config_class = ISNetConfig
|
| base_model_prefix = "isnet"
|
|
|
| def __init__(self, config):
|
| super().__init__(config)
|
| self.isnet = ISNetDIS(in_ch=config.in_ch, out_ch=config.out_ch)
|
|
|
| def forward(self, pixel_values, labels=None, threshold=0.5):
|
| """Forward pass"""
|
| outputs = self.isnet(pixel_values)
|
|
|
|
|
| if isinstance(outputs, tuple) and len(outputs) == 2:
|
| segmentation_masks = outputs[0]
|
| feature_maps = outputs[1]
|
|
|
|
|
| mask = segmentation_masks[0]
|
|
|
|
|
| return mask
|
| else:
|
|
|
| return outputs
|
|
|
| @classmethod
|
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| """Load model from pretrained weights"""
|
| from transformers.utils import cached_file
|
|
|
| config = kwargs.pop("config", None)
|
| if config is None:
|
| config = ISNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
|
|
| model = cls(config)
|
|
|
|
|
| if "state_dict" in kwargs:
|
| state_dict = kwargs["state_dict"]
|
| else:
|
|
|
| try:
|
|
|
| model_file = cached_file(
|
| pretrained_model_name_or_path,
|
| "pytorch_model.bin",
|
| **kwargs
|
| )
|
| state_dict = torch.load(model_file, map_location="cpu")
|
| except:
|
| try:
|
|
|
| model_file = cached_file(
|
| pretrained_model_name_or_path,
|
| "model.safetensors",
|
| **kwargs
|
| )
|
| from safetensors import safe_open
|
| with safe_open(model_file, framework="pt", device="cpu") as f:
|
| state_dict = {key: f.get_tensor(key) for key in f.keys()}
|
| except:
|
|
|
| model_file = cached_file(
|
| pretrained_model_name_or_path,
|
| "supplyswap_isnet.pth",
|
| **kwargs
|
| )
|
| state_dict = torch.load(model_file, map_location="cpu")
|
|
|
|
|
| if isinstance(state_dict, dict):
|
|
|
| if any(key.startswith('isnet.') for key in state_dict.keys()):
|
|
|
| pass
|
| elif any(key.startswith('conv_in.') or key.startswith('stage') for key in state_dict.keys()):
|
|
|
| wrapped_state_dict = {}
|
| for key, value in state_dict.items():
|
| wrapped_state_dict[f"isnet.{key}"] = value
|
| state_dict = wrapped_state_dict
|
| else:
|
|
|
| pass
|
|
|
|
|
| try:
|
| model.isnet.load_state_dict(state_dict)
|
| except Exception as e:
|
| print(f"Warning: Could not load state dict directly: {e}")
|
| print("Attempting to load with strict=False...")
|
| model.isnet.load_state_dict(state_dict, strict=False)
|
|
|
| model.eval()
|
|
|
| return model |