hemantn commited on
Commit
9b241ab
·
1 Parent(s): 0cdf9af

weights type issue fixed

Browse files
Files changed (1) hide show
  1. modeling_ablang2paired.py +41 -26
modeling_ablang2paired.py CHANGED
@@ -69,34 +69,49 @@ class AbLang2PairedHFModel(PreTrainedModel):
69
 
70
  @classmethod
71
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
72
- # Check if we have custom weights
73
- model_path = pretrained_model_name_or_path
74
- custom_weights_path = os.path.join(model_path, "model.pt")
 
 
75
 
76
- if os.path.exists(custom_weights_path):
77
- # Load config
78
- config = kwargs.get("config")
79
- if config is None:
80
- from transformers import AutoConfig
81
- config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
82
-
83
- # Create model with only the config argument
84
- model = cls(config)
85
-
86
- # Load custom weights
87
- state_dict = torch.load(custom_weights_path, map_location="cpu", weights_only=True)
88
- model.model.load_state_dict(state_dict)
89
-
90
- # Move model to appropriate device (GPU if available, otherwise CPU)
91
- device = kwargs.get("device", None)
92
- if device is None:
93
- device = "cuda" if torch.cuda.is_available() else "cpu"
94
- model = model.to(device)
95
 
96
- return model
97
- else:
98
- # Fall back to standard Hugging Face loading
99
- return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  def save_pretrained(self, save_directory, **kwargs):
102
  os.makedirs(save_directory, exist_ok=True)
 
69
 
70
  @classmethod
71
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
72
+ # Load config first
73
+ config = kwargs.get("config")
74
+ if config is None:
75
+ from transformers import AutoConfig
76
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
77
 
78
+ # Create model with config
79
+ model = cls(config)
80
+
81
+ # Try to load custom weights
82
+ try:
83
+ from transformers.utils import cached_file
84
+ custom_weights_path = cached_file(
85
+ pretrained_model_name_or_path,
86
+ "model.pt",
87
+ cache_dir=kwargs.get("cache_dir"),
88
+ force_download=kwargs.get("force_download", False),
89
+ resume_download=kwargs.get("resume_download", False),
90
+ proxies=kwargs.get("proxies"),
91
+ token=kwargs.get("token"),
92
+ revision=kwargs.get("revision"),
93
+ local_files_only=kwargs.get("local_files_only", False),
94
+ )
 
 
95
 
96
+ if custom_weights_path is not None and os.path.exists(custom_weights_path):
97
+ # Load custom weights
98
+ state_dict = torch.load(custom_weights_path, map_location="cpu", weights_only=True)
99
+ model.model.load_state_dict(state_dict)
100
+ print(f"✅ Loaded custom weights from: {custom_weights_path}")
101
+ else:
102
+ print("⚠️ No custom weights found, using initialized model")
103
+
104
+ except Exception as e:
105
+ print(f"⚠️ Could not load custom weights: {e}")
106
+ print("Using initialized model")
107
+
108
+ # Move model to appropriate device (GPU if available, otherwise CPU)
109
+ device = kwargs.get("device", None)
110
+ if device is None:
111
+ device = "cuda" if torch.cuda.is_available() else "cpu"
112
+ model = model.to(device)
113
+
114
+ return model
115
 
116
  def save_pretrained(self, save_directory, **kwargs):
117
  os.makedirs(save_directory, exist_ok=True)