| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | from patch_utils import MindSpeedPatchesManager as aspm |
| | import torchaudio |
| | import torch |
| | import logging |
| |
|
| | def get_vocos_mel_spectrogram_cpu( |
| | waveform, |
| | n_fft=1024, |
| | n_mel_channels=100, |
| | target_sample_rate=24000, |
| | hop_length=256, |
| | win_length=1024, |
| | ): |
| | wave_device = waveform.device |
| | waveform = waveform.cpu() |
| | mel_stft = torchaudio.transforms.MelSpectrogram( |
| | sample_rate=target_sample_rate, |
| | n_fft=n_fft, |
| | win_length=win_length, |
| | hop_length=hop_length, |
| | n_mels=n_mel_channels, |
| | power=1, |
| | center=True, |
| | normalized=False, |
| | norm=None, |
| | ).to(waveform.device) |
| | if len(waveform.shape) == 3: |
| | waveform = waveform.squeeze(1) |
| |
|
| | assert len(waveform.shape) == 2 |
| |
|
| | mel = mel_stft(waveform) |
| | mel = mel.clamp(min=1e-5).log() |
| | waveform = waveform.to(wave_device) |
| | mel = mel.to(wave_device) |
| | return mel |
| |
|
| |
|
| | def load_checkpoint_npu(model, ckpt_path, device: str, dtype=None, use_ema=True): |
| | logging.info(f"Load checkpoint {ckpt_path}") |
| | if dtype is None: |
| | dtype = ( |
| | torch.float16 |
| | if "cuda" in device or "npu" in device |
| | and torch.cuda.get_device_properties(device).major >= 6 |
| | and not torch.cuda.get_device_name().endswith("[ZLUDA]") |
| | else torch.float32 |
| | ) |
| | model = model.to(dtype) |
| |
|
| | ckpt_type = ckpt_path.split(".")[-1] |
| | if ckpt_type == "safetensors": |
| | from safetensors.torch import load_file |
| |
|
| | checkpoint = load_file(ckpt_path, device=device) |
| | else: |
| | checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True) |
| |
|
| | if use_ema: |
| | if ckpt_type == "safetensors": |
| | checkpoint = {"ema_model_state_dict": checkpoint} |
| | checkpoint["model_state_dict"] = { |
| | k.replace("ema_model.", ""): v |
| | for k, v in checkpoint["ema_model_state_dict"].items() |
| | if k not in ["initted", "step"] |
| | } |
| |
|
| | |
| | for key in [ |
| | "mel_spec.mel_stft.mel_scale.fb", |
| | "mel_spec.mel_stft.spectrogram.window", |
| | ]: |
| | if key in checkpoint["model_state_dict"]: |
| | del checkpoint["model_state_dict"][key] |
| |
|
| | model.load_state_dict(checkpoint["model_state_dict"]) |
| | else: |
| | if ckpt_type == "safetensors": |
| | checkpoint = {"model_state_dict": checkpoint} |
| | model.load_state_dict(checkpoint["model_state_dict"]) |
| |
|
| | del checkpoint |
| | torch.cuda.empty_cache() |
| |
|
| | return model.to(device) |
| |
|
| | def patch_for_npu(): |
| | |
| | aspm.register_patch('f5_tts.infer.utils_infer.load_checkpoint', load_checkpoint_npu) |
| | aspm.register_patch('f5_tts.model.modules.get_vocos_mel_spectrogram', get_vocos_mel_spectrogram_cpu) |
| | aspm.apply_patches() |