singularitys0 commited on
Commit
0383930
·
verified ·
1 Parent(s): 1a05ac7

fix encoder load

Browse files
Files changed (1) hide show
  1. modeling_moss_speech_codec.py +65 -19
modeling_moss_speech_codec.py CHANGED
@@ -475,33 +475,79 @@ class MossSpeechCodec(PreTrainedModel):
475
  @classmethod
476
  def from_pretrained(
477
  cls,
478
- model_dir: Union[str, os.PathLike],
479
- *args,
 
 
 
 
 
 
 
480
  **kwargs,
481
  ):
482
- """Instantiate codec from a directory containing encoder and decoder assets.
483
-
484
- Expected layout:
485
- - `model.safetensors` (Whisper VQ encoder weights)
486
- - `config.json` (Whisper VQ config)
487
- - `preprocessor_config.json` (WhisperFeatureExtractor params)
488
- - `flow/{config.yaml, flow.pt, hift.pt, campplus.onnx}`
 
 
489
  """
490
- base = Path(str(model_dir))
491
- # Support both layouts:
492
- # 1) <base>/{model.safetensors, config.json, preprocessor_config.json, flow/}
493
- # 2) <base>/speech_tokenizer/{model.safetensors, ...} and <base>/flow/
494
- if (base / "model.safetensors").exists():
495
- tokenizer_dir = base
496
- flow_dir = base / "flow"
497
  else:
498
- tokenizer_dir = base / "speech_tokenizer"
499
- flow_dir = base / "flow"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
  encoder_weight_path = str(tokenizer_dir / "model.safetensors")
501
  encoder_config_path = str(tokenizer_dir / "config.json")
502
  encoder_feature_extractor_path = str(tokenizer_dir)
503
  flow_path = str(flow_dir)
504
-
505
  return cls(
506
  encoder_weight_path=encoder_weight_path,
507
  encoder_config_path=encoder_config_path,
 
475
  @classmethod
476
  def from_pretrained(
477
  cls,
478
+ pretrained_model_name_or_path: Union[str, os.PathLike],
479
+ *,
480
+ revision: Optional[str] = None,
481
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
482
+ force_download: bool = False,
483
+ local_files_only: bool = False,
484
+ token: Optional[Union[str, bool]] = None,
485
+ use_auth_token: Optional[Union[str, bool]] = None, # back-compat with HF Transformers kwarg
486
+ subfolder: Optional[str] = None,
487
  **kwargs,
488
  ):
489
+ """Instantiate codec from a local directory or a Hugging Face Hub repo.
490
+ This mirrors the typical Hugging Face ``from_pretrained`` behavior:
491
+ - If ``pretrained_model_name_or_path`` is a local folder, files are loaded from it.
492
+ - Otherwise, it is treated as a Hub repo ID and downloaded with ``snapshot_download``.
493
+ Expected layout inside the resolved base folder:
494
+ - ``model.safetensors`` (Whisper VQ encoder weights)
495
+ - ``config.json`` (Whisper VQ config)
496
+ - ``preprocessor_config.json`` (WhisperFeatureExtractor params)
497
+ - ``flow/{config.yaml, flow.pt, hift.pt, campplus.onnx}``
498
  """
499
+ # Resolve local directory vs HF Hub repo.
500
+ base: Path
501
+ path_str = str(pretrained_model_name_or_path)
502
+ if os.path.isdir(path_str):
503
+ base = Path(path_str)
 
 
504
  else:
505
+ try:
506
+ from huggingface_hub import snapshot_download # lazy import to avoid hard dependency at import time
507
+ except Exception as exc: # pragma: no cover
508
+ raise RuntimeError(
509
+ "huggingface_hub is required to load from a repo id; please `pip install huggingface_hub`."
510
+ ) from exc
511
+ # HF Transformers historically supports both `token` and deprecated `use_auth_token`.
512
+ if token is None and use_auth_token is not None:
513
+ token = use_auth_token
514
+ snapshot_path = snapshot_download(
515
+ repo_id=path_str,
516
+ revision=revision,
517
+ cache_dir=str(cache_dir) if cache_dir is not None else None,
518
+ force_download=force_download,
519
+ local_files_only=local_files_only,
520
+ token=token,
521
+ )
522
+ base = Path(snapshot_path)
523
+ if subfolder:
524
+ base = base / subfolder
525
+ tokenizer_dir = base
526
+ flow_dir = base / "flow"
527
+ # Validate expected files and provide actionable error messages, similar to HF patterns.
528
+ missing: List[str] = []
529
+ if not (tokenizer_dir / "model.safetensors").exists():
530
+ missing.append(str(tokenizer_dir / "model.safetensors"))
531
+ if not (tokenizer_dir / "config.json").exists():
532
+ missing.append(str(tokenizer_dir / "config.json"))
533
+ if not (tokenizer_dir / "preprocessor_config.json").exists():
534
+ missing.append(str(tokenizer_dir / "preprocessor_config.json"))
535
+ for fname in ("config.yaml", "flow.pt", "hift.pt"):
536
+ if not (flow_dir / fname).exists():
537
+ missing.append(str(flow_dir / fname))
538
+ # `campplus.onnx` may be named differently in some drops; only warn if absent.
539
+ has_campplus = (flow_dir / "campplus.onnx").exists()
540
+ if missing:
541
+ raise FileNotFoundError(
542
+ "Missing required codec assets under resolved path. The following files were not found: "
543
+ + ", ".join(missing)
544
+ )
545
+ if not has_campplus:
546
+ logger.warning("campplus.onnx not found under %s; decoding speaker embedding may fail.", flow_dir)
547
  encoder_weight_path = str(tokenizer_dir / "model.safetensors")
548
  encoder_config_path = str(tokenizer_dir / "config.json")
549
  encoder_feature_extractor_path = str(tokenizer_dir)
550
  flow_path = str(flow_dir)
 
551
  return cls(
552
  encoder_weight_path=encoder_weight_path,
553
  encoder_config_path=encoder_config_path,