fix encoder load
Browse files- modeling_moss_speech_codec.py +65 -19
modeling_moss_speech_codec.py
CHANGED
|
@@ -475,33 +475,79 @@ class MossSpeechCodec(PreTrainedModel):
|
|
| 475 |
@classmethod
|
| 476 |
def from_pretrained(
|
| 477 |
cls,
|
| 478 |
-
|
| 479 |
-
*
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 480 |
**kwargs,
|
| 481 |
):
|
| 482 |
-
"""Instantiate codec from a directory
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
-
|
| 486 |
-
|
| 487 |
-
- `
|
| 488 |
-
- `
|
|
|
|
|
|
|
| 489 |
"""
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
tokenizer_dir = base
|
| 496 |
-
flow_dir = base / "flow"
|
| 497 |
else:
|
| 498 |
-
|
| 499 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 500 |
encoder_weight_path = str(tokenizer_dir / "model.safetensors")
|
| 501 |
encoder_config_path = str(tokenizer_dir / "config.json")
|
| 502 |
encoder_feature_extractor_path = str(tokenizer_dir)
|
| 503 |
flow_path = str(flow_dir)
|
| 504 |
-
|
| 505 |
return cls(
|
| 506 |
encoder_weight_path=encoder_weight_path,
|
| 507 |
encoder_config_path=encoder_config_path,
|
|
|
|
| 475 |
@classmethod
|
| 476 |
def from_pretrained(
|
| 477 |
cls,
|
| 478 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
| 479 |
+
*,
|
| 480 |
+
revision: Optional[str] = None,
|
| 481 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
| 482 |
+
force_download: bool = False,
|
| 483 |
+
local_files_only: bool = False,
|
| 484 |
+
token: Optional[Union[str, bool]] = None,
|
| 485 |
+
use_auth_token: Optional[Union[str, bool]] = None, # back-compat with HF Transformers kwarg
|
| 486 |
+
subfolder: Optional[str] = None,
|
| 487 |
**kwargs,
|
| 488 |
):
|
| 489 |
+
"""Instantiate codec from a local directory or a Hugging Face Hub repo.
|
| 490 |
+
This mirrors the typical Hugging Face ``from_pretrained`` behavior:
|
| 491 |
+
- If ``pretrained_model_name_or_path`` is a local folder, files are loaded from it.
|
| 492 |
+
- Otherwise, it is treated as a Hub repo ID and downloaded with ``snapshot_download``.
|
| 493 |
+
Expected layout inside the resolved base folder:
|
| 494 |
+
- ``model.safetensors`` (Whisper VQ encoder weights)
|
| 495 |
+
- ``config.json`` (Whisper VQ config)
|
| 496 |
+
- ``preprocessor_config.json`` (WhisperFeatureExtractor params)
|
| 497 |
+
- ``flow/{config.yaml, flow.pt, hift.pt, campplus.onnx}``
|
| 498 |
"""
|
| 499 |
+
# Resolve local directory vs HF Hub repo.
|
| 500 |
+
base: Path
|
| 501 |
+
path_str = str(pretrained_model_name_or_path)
|
| 502 |
+
if os.path.isdir(path_str):
|
| 503 |
+
base = Path(path_str)
|
|
|
|
|
|
|
| 504 |
else:
|
| 505 |
+
try:
|
| 506 |
+
from huggingface_hub import snapshot_download # lazy import to avoid hard dependency at import time
|
| 507 |
+
except Exception as exc: # pragma: no cover
|
| 508 |
+
raise RuntimeError(
|
| 509 |
+
"huggingface_hub is required to load from a repo id; please `pip install huggingface_hub`."
|
| 510 |
+
) from exc
|
| 511 |
+
# HF Transformers historically supports both `token` and deprecated `use_auth_token`.
|
| 512 |
+
if token is None and use_auth_token is not None:
|
| 513 |
+
token = use_auth_token
|
| 514 |
+
snapshot_path = snapshot_download(
|
| 515 |
+
repo_id=path_str,
|
| 516 |
+
revision=revision,
|
| 517 |
+
cache_dir=str(cache_dir) if cache_dir is not None else None,
|
| 518 |
+
force_download=force_download,
|
| 519 |
+
local_files_only=local_files_only,
|
| 520 |
+
token=token,
|
| 521 |
+
)
|
| 522 |
+
base = Path(snapshot_path)
|
| 523 |
+
if subfolder:
|
| 524 |
+
base = base / subfolder
|
| 525 |
+
tokenizer_dir = base
|
| 526 |
+
flow_dir = base / "flow"
|
| 527 |
+
# Validate expected files and provide actionable error messages, similar to HF patterns.
|
| 528 |
+
missing: List[str] = []
|
| 529 |
+
if not (tokenizer_dir / "model.safetensors").exists():
|
| 530 |
+
missing.append(str(tokenizer_dir / "model.safetensors"))
|
| 531 |
+
if not (tokenizer_dir / "config.json").exists():
|
| 532 |
+
missing.append(str(tokenizer_dir / "config.json"))
|
| 533 |
+
if not (tokenizer_dir / "preprocessor_config.json").exists():
|
| 534 |
+
missing.append(str(tokenizer_dir / "preprocessor_config.json"))
|
| 535 |
+
for fname in ("config.yaml", "flow.pt", "hift.pt"):
|
| 536 |
+
if not (flow_dir / fname).exists():
|
| 537 |
+
missing.append(str(flow_dir / fname))
|
| 538 |
+
# `campplus.onnx` may be named differently in some drops; only warn if absent.
|
| 539 |
+
has_campplus = (flow_dir / "campplus.onnx").exists()
|
| 540 |
+
if missing:
|
| 541 |
+
raise FileNotFoundError(
|
| 542 |
+
"Missing required codec assets under resolved path. The following files were not found: "
|
| 543 |
+
+ ", ".join(missing)
|
| 544 |
+
)
|
| 545 |
+
if not has_campplus:
|
| 546 |
+
logger.warning("campplus.onnx not found under %s; decoding speaker embedding may fail.", flow_dir)
|
| 547 |
encoder_weight_path = str(tokenizer_dir / "model.safetensors")
|
| 548 |
encoder_config_path = str(tokenizer_dir / "config.json")
|
| 549 |
encoder_feature_extractor_path = str(tokenizer_dir)
|
| 550 |
flow_path = str(flow_dir)
|
|
|
|
| 551 |
return cls(
|
| 552 |
encoder_weight_path=encoder_weight_path,
|
| 553 |
encoder_config_path=encoder_config_path,
|