Update custom model files, README, and requirements
Browse files- asr_modeling.py +34 -3
asr_modeling.py
CHANGED
|
@@ -164,15 +164,46 @@ class ASRModel(PreTrainedModel):
|
|
| 164 |
cls._pretrained_model_path = pretrained_model_name_or_path
|
| 165 |
|
| 166 |
try:
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
)
|
| 171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
# Convert projector to target dtype after loading weights
|
| 173 |
target_dtype = getattr(torch, config.model_dtype)
|
| 174 |
model.projector = model.projector.to(dtype=target_dtype)
|
| 175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
return model
|
| 177 |
finally:
|
| 178 |
cls._is_loading_from_pretrained = False
|
|
|
|
| 164 |
cls._pretrained_model_path = pretrained_model_name_or_path
|
| 165 |
|
| 166 |
try:
|
| 167 |
+
from safetensors.torch import load_file
|
| 168 |
+
from transformers.utils.hub import cached_file
|
| 169 |
+
|
| 170 |
+
# Create model instance (loads encoder/decoder fresh from HF)
|
| 171 |
+
model = cls(config, **kwargs)
|
| 172 |
+
|
| 173 |
+
# Manually load model.safetensors to avoid corrupted generation_config.json
|
| 174 |
+
subfolder = kwargs.get("subfolder")
|
| 175 |
+
revision = kwargs.get("revision")
|
| 176 |
+
cache_kwargs = {}
|
| 177 |
+
if subfolder:
|
| 178 |
+
cache_kwargs["subfolder"] = subfolder
|
| 179 |
+
if revision:
|
| 180 |
+
cache_kwargs["revision"] = revision
|
| 181 |
+
|
| 182 |
+
model_file = cached_file(
|
| 183 |
+
pretrained_model_name_or_path,
|
| 184 |
+
"model.safetensors",
|
| 185 |
+
_raise_exceptions_for_missing_entries=False,
|
| 186 |
+
**cache_kwargs,
|
| 187 |
)
|
| 188 |
|
| 189 |
+
if not model_file:
|
| 190 |
+
raise FileNotFoundError(
|
| 191 |
+
f"model.safetensors not found in {pretrained_model_name_or_path}. "
|
| 192 |
+
"The repository may not have been trained yet."
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# Load trainable state (projector weights with "projector." prefix)
|
| 196 |
+
state_dict = load_file(model_file)
|
| 197 |
+
model.load_state_dict(state_dict, strict=False, assign=True)
|
| 198 |
+
|
| 199 |
# Convert projector to target dtype after loading weights
|
| 200 |
target_dtype = getattr(torch, config.model_dtype)
|
| 201 |
model.projector = model.projector.to(dtype=target_dtype)
|
| 202 |
|
| 203 |
+
device = kwargs.get("device")
|
| 204 |
+
if device is not None:
|
| 205 |
+
model = model.to(device)
|
| 206 |
+
|
| 207 |
return model
|
| 208 |
finally:
|
| 209 |
cls._is_loading_from_pretrained = False
|