mazesmazes commited on
Commit
8512337
·
verified ·
1 Parent(s): fdd2a3d

Update custom model files, README, and requirements

Browse files
Files changed (1) hide show
  1. asr_modeling.py +34 -3
asr_modeling.py CHANGED
@@ -164,15 +164,46 @@ class ASRModel(PreTrainedModel):
164
  cls._pretrained_model_path = pretrained_model_name_or_path
165
 
166
  try:
167
- # Let parent class handle loading config and model.safetensors
168
- model = super().from_pretrained(
169
- pretrained_model_name_or_path, *args, config=config, **kwargs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  )
171
 
 
 
 
 
 
 
 
 
 
 
172
  # Convert projector to target dtype after loading weights
173
  target_dtype = getattr(torch, config.model_dtype)
174
  model.projector = model.projector.to(dtype=target_dtype)
175
 
 
 
 
 
176
  return model
177
  finally:
178
  cls._is_loading_from_pretrained = False
 
164
  cls._pretrained_model_path = pretrained_model_name_or_path
165
 
166
  try:
167
+ from safetensors.torch import load_file
168
+ from transformers.utils.hub import cached_file
169
+
170
+ # Create model instance (loads encoder/decoder fresh from HF)
171
+ model = cls(config, **kwargs)
172
+
173
+ # Manually load model.safetensors to avoid corrupted generation_config.json
174
+ subfolder = kwargs.get("subfolder")
175
+ revision = kwargs.get("revision")
176
+ cache_kwargs = {}
177
+ if subfolder:
178
+ cache_kwargs["subfolder"] = subfolder
179
+ if revision:
180
+ cache_kwargs["revision"] = revision
181
+
182
+ model_file = cached_file(
183
+ pretrained_model_name_or_path,
184
+ "model.safetensors",
185
+ _raise_exceptions_for_missing_entries=False,
186
+ **cache_kwargs,
187
  )
188
 
189
+ if not model_file:
190
+ raise FileNotFoundError(
191
+ f"model.safetensors not found in {pretrained_model_name_or_path}. "
192
+ "The repository may not have been trained yet."
193
+ )
194
+
195
+ # Load trainable state (projector weights with "projector." prefix)
196
+ state_dict = load_file(model_file)
197
+ model.load_state_dict(state_dict, strict=False, assign=True)
198
+
199
  # Convert projector to target dtype after loading weights
200
  target_dtype = getattr(torch, config.model_dtype)
201
  model.projector = model.projector.to(dtype=target_dtype)
202
 
203
+ device = kwargs.get("device")
204
+ if device is not None:
205
+ model = model.to(device)
206
+
207
  return model
208
  finally:
209
  cls._is_loading_from_pretrained = False