# Copyright (C) 2025. Huawei Technologies Co., Ltd. All Rights Reserved. (authors: Xiao Chen) # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from patch_utils import MindSpeedPatchesManager as aspm import torchaudio import torch import logging def get_vocos_mel_spectrogram_cpu( waveform, n_fft=1024, n_mel_channels=100, target_sample_rate=24000, hop_length=256, win_length=1024, ): wave_device = waveform.device waveform = waveform.cpu() mel_stft = torchaudio.transforms.MelSpectrogram( sample_rate=target_sample_rate, n_fft=n_fft, win_length=win_length, hop_length=hop_length, n_mels=n_mel_channels, power=1, center=True, normalized=False, norm=None, ).to(waveform.device) if len(waveform.shape) == 3: waveform = waveform.squeeze(1) # 'b 1 nw -> b nw' assert len(waveform.shape) == 2 mel = mel_stft(waveform) mel = mel.clamp(min=1e-5).log() waveform = waveform.to(wave_device) mel = mel.to(wave_device) return mel def load_checkpoint_npu(model, ckpt_path, device: str, dtype=None, use_ema=True): logging.info(f"Load checkpoint {ckpt_path}") if dtype is None: dtype = ( torch.float16 if "cuda" in device or "npu" in device and torch.cuda.get_device_properties(device).major >= 6 and not torch.cuda.get_device_name().endswith("[ZLUDA]") else torch.float32 ) model = model.to(dtype) ckpt_type = ckpt_path.split(".")[-1] if ckpt_type == "safetensors": from safetensors.torch import load_file checkpoint = load_file(ckpt_path, device=device) else: checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True) if use_ema: if ckpt_type == "safetensors": checkpoint = {"ema_model_state_dict": checkpoint} checkpoint["model_state_dict"] = { k.replace("ema_model.", ""): v for k, v in checkpoint["ema_model_state_dict"].items() if k not in ["initted", "step"] } # patch for backward compatibility, 305e3ea for key in [ "mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window", ]: if key in checkpoint["model_state_dict"]: del checkpoint["model_state_dict"][key] model.load_state_dict(checkpoint["model_state_dict"]) else: if ckpt_type == "safetensors": checkpoint = {"model_state_dict": checkpoint} model.load_state_dict(checkpoint["model_state_dict"]) del checkpoint torch.cuda.empty_cache() return model.to(device) def patch_for_npu(): # replace torch.cuda.get_device_capability with implementation from MindSpeed aspm.register_patch('f5_tts.infer.utils_infer.load_checkpoint', load_checkpoint_npu) aspm.register_patch('f5_tts.model.modules.get_vocos_mel_spectrogram', get_vocos_mel_spectrogram_cpu) aspm.apply_patches()