mateenahmed commited on
Commit
7e6a010
·
verified ·
1 Parent(s): 28da8c6

Upload modeling_isnet.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_isnet.py +43 -13
modeling_isnet.py CHANGED
@@ -7,7 +7,6 @@ import torch.nn as nn
7
  import torch.nn.functional as F
8
  import numpy as np
9
  from transformers import PreTrainedModel, PretrainedConfig
10
- from transformers.models.auto.modeling_auto import AutoModelForImageSegmentation
11
 
12
  # Import the ISNet model
13
  import sys
@@ -24,6 +23,7 @@ class ISNetConfig(PretrainedConfig):
24
  self.in_ch = in_ch
25
  self.out_ch = out_ch
26
  self.num_labels = out_ch # Required for AutoModelForImageSegmentation
 
27
 
28
  class ISNetForImageSegmentation(PreTrainedModel):
29
  """Transformers-compatible ISNet model for image segmentation"""
@@ -78,19 +78,49 @@ class ISNetForImageSegmentation(PreTrainedModel):
78
  )
79
  state_dict = torch.load(model_file, map_location="cpu")
80
  except:
81
- # Fallback to the original model file
82
- model_file = cached_file(
83
- pretrained_model_name_or_path,
84
- "supplyswap_isnet.pth",
85
- **kwargs
86
- )
87
- state_dict = torch.load(model_file, map_location="cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  # Load the weights into the ISNet model
90
- model.isnet.load_state_dict(state_dict)
 
 
 
 
 
 
91
  model.eval()
92
 
93
- return model
94
-
95
- # Register the model with transformers
96
- AutoModelForImageSegmentation.register(ISNetConfig, ISNetForImageSegmentation)
 
7
  import torch.nn.functional as F
8
  import numpy as np
9
  from transformers import PreTrainedModel, PretrainedConfig
 
10
 
11
  # Import the ISNet model
12
  import sys
 
23
  self.in_ch = in_ch
24
  self.out_ch = out_ch
25
  self.num_labels = out_ch # Required for AutoModelForImageSegmentation
26
+ self.architectures = ["ISNetForImageSegmentation"]
27
 
28
  class ISNetForImageSegmentation(PreTrainedModel):
29
  """Transformers-compatible ISNet model for image segmentation"""
 
78
  )
79
  state_dict = torch.load(model_file, map_location="cpu")
80
  except:
81
+ try:
82
+ # Try model.safetensors
83
+ model_file = cached_file(
84
+ pretrained_model_name_or_path,
85
+ "model.safetensors",
86
+ **kwargs
87
+ )
88
+ from safetensors import safe_open
89
+ with safe_open(model_file, framework="pt", device="cpu") as f:
90
+ state_dict = {key: f.get_tensor(key) for key in f.keys()}
91
+ except:
92
+ # Fallback to the original model file
93
+ model_file = cached_file(
94
+ pretrained_model_name_or_path,
95
+ "supplyswap_isnet.pth",
96
+ **kwargs
97
+ )
98
+ state_dict = torch.load(model_file, map_location="cpu")
99
+
100
+ # Handle different state dict formats
101
+ if isinstance(state_dict, dict):
102
+ # Check if the state dict has the expected keys
103
+ if any(key.startswith('isnet.') for key in state_dict.keys()):
104
+ # State dict already has the correct prefix
105
+ pass
106
+ elif any(key.startswith('conv_in.') or key.startswith('stage') for key in state_dict.keys()):
107
+ # State dict is from the original ISNet model, needs to be wrapped
108
+ wrapped_state_dict = {}
109
+ for key, value in state_dict.items():
110
+ wrapped_state_dict[f"isnet.{key}"] = value
111
+ state_dict = wrapped_state_dict
112
+ else:
113
+ # Try to load directly
114
+ pass
115
 
116
  # Load the weights into the ISNet model
117
+ try:
118
+ model.isnet.load_state_dict(state_dict)
119
+ except Exception as e:
120
+ print(f"Warning: Could not load state dict directly: {e}")
121
+ print("Attempting to load with strict=False...")
122
+ model.isnet.load_state_dict(state_dict, strict=False)
123
+
124
  model.eval()
125
 
126
+ return model