|
|
import os |
|
|
import gc |
|
|
import math |
|
|
import torch |
|
|
import types |
|
|
import torchaudio |
|
|
import torchvision.transforms as transforms |
|
|
|
|
|
from tqdm import tqdm |
|
|
from pathlib import Path |
|
|
from loguru import logger |
|
|
from diffusers.utils.torch_utils import randn_tensor |
|
|
from typing import Dict, Optional, List, Union, Type |
|
|
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps |
|
|
|
|
|
import nncf |
|
|
import openvino as ov |
|
|
from openvino.tools.ovc import convert_model |
|
|
from openvino_tokenizers import convert_tokenizer |
|
|
from openvino.frontend.pytorch.patch_model import __make_16bit_traceable |
|
|
|
|
|
from acestep.language_segmentation import LangSegment, language_filters |
|
|
from acestep.models.lyrics_utils.lyric_tokenizer import VoiceBpeTokenizer |
|
|
|
|
|
from acestep.pipeline_ace_step import ACEStepPipeline |
|
|
from acestep.models.ace_step_transformer import Transformer2DModelOutput |
|
|
from acestep.music_dcae.music_dcae_pipeline import MusicDCAE |
|
|
from acestep.schedulers.scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler |
|
|
from acestep.schedulers.scheduling_flow_match_pingpong import FlowMatchPingPongScheduler |
|
|
from acestep.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler |
|
|
from acestep.apg_guidance import ( |
|
|
apg_forward, |
|
|
MomentumBuffer, |
|
|
cfg_forward, |
|
|
cfg_zero_star, |
|
|
cfg_double_condition_forward, |
|
|
) |
|
|
|
|
|
torch.set_float32_matmul_precision("high") |
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
TOKENIZER_MODEL_NAME = "openvino_tokenizer.xml" |
|
|
TEXT_ENCODER_MODEL_NAME = "ov_text_encoder_model.xml" |
|
|
DCAE_ENCODER_MODEL_NAME = "ov_dcae_encoder_model.xml" |
|
|
DCAE_DECODER_MODEL_NAME = "ov_dcae_decoder_model.xml" |
|
|
VOCODER_DECODE_MODEL_NAME = "ov_vocoder_decode_model.xml" |
|
|
VOCODER_MEL_TRANSFORM_MODEL_NAME = "ov_vocoder_mel_transform_model.xml" |
|
|
TRANSFORMER_DECODER_MODEL_NAME = "ov_transformer_decoder_model.xml" |
|
|
TRANSFORMER_ENCODER_MODEL_NAME = "ov_transformer_encoder_model.xml" |
|
|
|
|
|
|
|
|
def cleanup_torchscript_cache(): |
|
|
""" |
|
|
Helper for removing cached model representation |
|
|
""" |
|
|
torch._C._jit_clear_class_registry() |
|
|
torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore() |
|
|
torch.jit._state._clear_class_state() |
|
|
|
|
|
|
|
|
def ov_convert( |
|
|
model_dir_path: str, |
|
|
ov_model_name: str, |
|
|
inputs: Dict, |
|
|
orig_model: Type[torch.nn.Module], |
|
|
model_name: str, |
|
|
quantization_config: Dict = None, |
|
|
force_convertion: bool = False, |
|
|
): |
|
|
try: |
|
|
ov_model_path = Path(model_dir_path, ov_model_name) |
|
|
if not ov_model_path.exists() or force_convertion: |
|
|
print(f"⌛ Convert {model_name} model") |
|
|
orig_model.eval() |
|
|
__make_16bit_traceable(orig_model) |
|
|
ov_model = convert_model(orig_model, example_input=inputs) |
|
|
if quantization_config is not None: |
|
|
print(f"⌛ Weights compression with {quantization_config['mode']} mode started") |
|
|
ov_model = nncf.compress_weights(ov_model, **quantization_config) |
|
|
print("✅ Weights compression finished") |
|
|
ov.save_model(ov_model, ov_model_path) |
|
|
|
|
|
del ov_model |
|
|
cleanup_torchscript_cache() |
|
|
gc.collect() |
|
|
print(f"✅ {model_name} model converted") |
|
|
except Exception as e: |
|
|
print(f"❌{model_name} model is not converted. Error: {e}") |
|
|
|
|
|
|
|
|
def convert_transformer_models(pipeline: ACEStepPipeline, model_dir: str = "ov_converted", orig_checkpoint_path: str = "", quantization_config: Dict = None): |
|
|
|
|
|
def encode_with_temperature_wrap( |
|
|
self, |
|
|
encoder_text_hidden_states: torch.Tensor = None, |
|
|
text_attention_mask: torch.LongTensor = None, |
|
|
speaker_embeds: torch.FloatTensor = None, |
|
|
lyric_token_idx: torch.LongTensor = None, |
|
|
lyric_mask: torch.LongTensor = None, |
|
|
tau: torch.FloatTensor = torch.Tensor([0.01]), |
|
|
): |
|
|
handlers = [] |
|
|
|
|
|
def hook(module, input, output): |
|
|
output[:] *= tau[0] |
|
|
return output |
|
|
|
|
|
l_min = 4 |
|
|
l_max = 6 |
|
|
for i in range(l_min, l_max): |
|
|
handler = self.lyric_encoder.encoders[i].self_attn.linear_q.register_forward_hook(hook) |
|
|
handlers.append(handler) |
|
|
|
|
|
encoder_hidden_states, encoder_hidden_mask = self.encode( |
|
|
encoder_text_hidden_states=encoder_text_hidden_states, |
|
|
text_attention_mask=text_attention_mask, |
|
|
speaker_embeds=speaker_embeds, |
|
|
lyric_token_idx=lyric_token_idx, |
|
|
lyric_mask=lyric_mask, |
|
|
) |
|
|
|
|
|
for hook in handlers: |
|
|
hook.remove() |
|
|
|
|
|
return encoder_hidden_states, encoder_hidden_mask |
|
|
|
|
|
inputs = { |
|
|
"encoder_text_hidden_states": torch.randn(size=(1, 15, 768), dtype=torch.float), |
|
|
"text_attention_mask": torch.ones([1, 15], dtype=torch.int64), |
|
|
"speaker_embeds": torch.zeros(size=(1, 512), dtype=torch.float), |
|
|
"lyric_token_idx": torch.randint(10000, [1, 543], dtype=torch.int64), |
|
|
"lyric_mask": torch.ones([1, 543], dtype=torch.int64), |
|
|
"tau": torch.Tensor([0.01]), |
|
|
} |
|
|
transformer_encoder_model = pipeline.ace_step_transformer |
|
|
transformer_encoder_erg_model = pipeline.ace_step_transformer |
|
|
transformer_encoder_erg_model.forward = types.MethodType(encode_with_temperature_wrap, transformer_encoder_model) |
|
|
ov_convert( |
|
|
model_dir, |
|
|
TRANSFORMER_ENCODER_MODEL_NAME, |
|
|
inputs, |
|
|
transformer_encoder_erg_model, |
|
|
"Transformer Encoder with Entropy Rectifying Guidance", |
|
|
quantization_config=quantization_config, |
|
|
) |
|
|
|
|
|
|
|
|
def decode_with_temperature_wrap( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: torch.Tensor, |
|
|
encoder_hidden_states: torch.Tensor, |
|
|
encoder_hidden_mask: torch.Tensor, |
|
|
timestep: torch.Tensor = None, |
|
|
|
|
|
output_length: int = 0, |
|
|
|
|
|
|
|
|
tau: torch.FloatTensor = torch.Tensor([0.01]), |
|
|
): |
|
|
handlers = [] |
|
|
|
|
|
def hook(module, input, output): |
|
|
output[:] *= tau[0] |
|
|
return output |
|
|
|
|
|
l_min = 5 |
|
|
l_max = 10 |
|
|
for i in range(l_min, l_max): |
|
|
handler = self.transformer_blocks[i].attn.to_q.register_forward_hook(hook) |
|
|
handlers.append(handler) |
|
|
handler = self.transformer_blocks[i].cross_attn.to_q.register_forward_hook(hook) |
|
|
handlers.append(handler) |
|
|
|
|
|
sample = self.decode( |
|
|
hidden_states=hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
encoder_hidden_mask=encoder_hidden_mask, |
|
|
output_length=output_length, |
|
|
timestep=timestep, |
|
|
).sample |
|
|
|
|
|
for hook in handlers: |
|
|
hook.remove() |
|
|
|
|
|
return sample |
|
|
|
|
|
inputs = { |
|
|
"hidden_states": torch.randn(size=(1, 8, 16, 151), dtype=torch.float), |
|
|
"attention_mask": torch.ones([1, 151], dtype=torch.int64), |
|
|
"encoder_hidden_states": torch.randn(size=(1, 559, 2560), dtype=torch.float), |
|
|
"encoder_hidden_mask": torch.ones([1, 559], dtype=torch.float), |
|
|
"output_length": torch.tensor(151), |
|
|
"timestep": torch.randn([1], dtype=torch.float), |
|
|
"tau": torch.Tensor([0.01]), |
|
|
} |
|
|
transformer_decoder_erg_model = pipeline.ace_step_transformer |
|
|
transformer_decoder_erg_model.forward = types.MethodType(decode_with_temperature_wrap, transformer_decoder_erg_model) |
|
|
ov_convert( |
|
|
model_dir, |
|
|
TRANSFORMER_DECODER_MODEL_NAME, |
|
|
inputs, |
|
|
transformer_decoder_erg_model, |
|
|
"Transformer Decoder with Entropy Rectifying Guidance", |
|
|
quantization_config=quantization_config, |
|
|
) |
|
|
|
|
|
|
|
|
def convert_models(pipeline: ACEStepPipeline, model_dir: str = "ov_converted_new", orig_checkpoint_path: str = "", quantization_config: Dict = None): |
|
|
print(f"⌛ Conversion started. Be patient, it may takes some time.") |
|
|
|
|
|
if not pipeline.loaded or (orig_checkpoint_path and not Path(orig_checkpoint_path).exists()): |
|
|
print("⌛ Load Original model checkpoints") |
|
|
pipeline.load_checkpoint(orig_checkpoint_path) |
|
|
print("✅ Original model checkpoints successfully loaded") |
|
|
|
|
|
|
|
|
ov_tokenizer_path = Path(model_dir, TOKENIZER_MODEL_NAME) |
|
|
if not ov_tokenizer_path.exists(): |
|
|
print(f"⌛ Convert Tokenizer") |
|
|
if not ov_tokenizer_path.exists(): |
|
|
ov_tokenizer = convert_tokenizer(pipeline.text_tokenizer, with_detokenizer=False) |
|
|
ov.save_model(ov_tokenizer, Path(model_dir, TOKENIZER_MODEL_NAME)) |
|
|
print(f"✅ Tokenizer is converted") |
|
|
|
|
|
|
|
|
inputs = { |
|
|
"input_ids": torch.randint(1000, size=(1, 15), dtype=torch.int64), |
|
|
"attention_mask": torch.ones([1, 15], dtype=torch.int64), |
|
|
} |
|
|
ov_convert(model_dir, TEXT_ENCODER_MODEL_NAME, inputs, pipeline.text_encoder_model, "UMT5 Encoder") |
|
|
|
|
|
|
|
|
inputs = {"hidden_states": torch.randn([1, 2, 128, 1208], dtype=torch.float)} |
|
|
ov_convert(model_dir, DCAE_ENCODER_MODEL_NAME, inputs, pipeline.music_dcae.dcae.encoder, "Sana's Deep Compression AutoEncoder") |
|
|
|
|
|
|
|
|
inputs = {"hidden_states": torch.randn([1, 8, 16, 151], dtype=torch.float)} |
|
|
ov_convert(model_dir, DCAE_DECODER_MODEL_NAME, inputs, pipeline.music_dcae.dcae.decoder, "Sana's Deep Compression AutoEncoder Decoder") |
|
|
|
|
|
|
|
|
inputs = {"x": torch.randn([2, 618496], dtype=torch.float)} |
|
|
ov_convert(model_dir, VOCODER_MEL_TRANSFORM_MODEL_NAME, inputs, pipeline.music_dcae.vocoder.mel_transform, "Vocoder Mel Transform") |
|
|
|
|
|
|
|
|
inputs = {"mel": torch.randn([1, 128, 856], dtype=torch.float)} |
|
|
ov_convert(model_dir, VOCODER_DECODE_MODEL_NAME, inputs, pipeline.music_dcae.vocoder, "Vocoder Decoder") |
|
|
|
|
|
|
|
|
convert_transformer_models(pipeline, model_dir, orig_checkpoint_path, quantization_config) |
|
|
|
|
|
|
|
|
class MusicDCAEWrapper(MusicDCAE): |
|
|
def __init__(self, source_sample_rate=None): |
|
|
torch.nn.Module.__init__(self) |
|
|
self.dcae = None |
|
|
self.vocoder = None |
|
|
|
|
|
if source_sample_rate is None: |
|
|
source_sample_rate = 48000 |
|
|
|
|
|
self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100) |
|
|
|
|
|
self.transform = transforms.Compose( |
|
|
[ |
|
|
transforms.Normalize(0.5, 0.5), |
|
|
] |
|
|
) |
|
|
self.min_mel_value = -11.0 |
|
|
self.max_mel_value = 3.0 |
|
|
self.audio_chunk_size = int(round((1024 * 512 / 44100 * 48000))) |
|
|
self.mel_chunk_size = 1024 |
|
|
self.time_dimention_multiple = 8 |
|
|
self.latent_chunk_size = self.mel_chunk_size // self.time_dimention_multiple |
|
|
self.scale_factor = 0.1786 |
|
|
self.shift_factor = -1.9091 |
|
|
|
|
|
|
|
|
class OVDCAECompiledModels(torch.nn.Module): |
|
|
def __init__(self, compiled_model): |
|
|
self.compiled_model = compiled_model |
|
|
|
|
|
def __call__(self, inputs): |
|
|
if not self.compiled_model: |
|
|
logger.error("OVDCAECompiledModels: compiled model is not defined") |
|
|
|
|
|
output = self.compiled_model({"hidden_states": inputs.to(dtype=torch.float32)}) |
|
|
return torch.from_numpy(output[0]) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, ov_model_path, device, ov_core): |
|
|
ov_dcae_model = ov_core.read_model(ov_model_path) |
|
|
compiled_model = ov_core.compile_model(ov_dcae_model, device) |
|
|
return cls(compiled_model) |
|
|
|
|
|
|
|
|
class OVWrapperAutoencoderDC(torch.nn.Module): |
|
|
def __init__(self, encoder, decoder): |
|
|
super().__init__() |
|
|
self.encoder = encoder |
|
|
self.decoder = decoder |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, ov_core, ov_models_path, device="CPU"): |
|
|
encoder = OVDCAECompiledModels.from_pretrained(Path(ov_models_path, DCAE_ENCODER_MODEL_NAME), device, ov_core) |
|
|
decoder = OVDCAECompiledModels.from_pretrained(Path(ov_models_path, DCAE_DECODER_MODEL_NAME), device, ov_core) |
|
|
return cls(encoder, decoder) |
|
|
|
|
|
|
|
|
class OVWrapperADaMoSHiFiGANV1(torch.nn.Module): |
|
|
def __init__(self, encoder_compiled_model, mel_trnasform_compiled_model): |
|
|
super().__init__() |
|
|
self.decoder = encoder_compiled_model |
|
|
self.mel_trnasform = mel_trnasform_compiled_model |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, ov_core, ov_models_path, device="CPU"): |
|
|
ov_vocoder_decoder_model = ov_core.read_model(Path(ov_models_path, VOCODER_DECODE_MODEL_NAME)) |
|
|
decoder = ov_core.compile_model(ov_vocoder_decoder_model, device) |
|
|
ov_vocoder_mel_transform_model = ov_core.read_model(Path(ov_models_path, VOCODER_MEL_TRANSFORM_MODEL_NAME)) |
|
|
mel_trnasform = ov_core.compile_model(ov_vocoder_mel_transform_model, device) |
|
|
return cls(decoder, mel_trnasform) |
|
|
|
|
|
def decode(self, inputs): |
|
|
output = self.decoder({"mel": inputs.to(dtype=torch.float32)}) |
|
|
return torch.from_numpy(output[0]) |
|
|
|
|
|
def mel_transform(self, inputs): |
|
|
output = self.mel_trnasform({"x": inputs.to(dtype=torch.float32)}) |
|
|
return torch.from_numpy(output[0]) |
|
|
|
|
|
def forward(self, inputs): |
|
|
return self.decode(inputs) |
|
|
|
|
|
|
|
|
class OvWrapperACEStepTransformer2DModel(torch.nn.Module): |
|
|
def __init__(self, encoder_model, decoder_model): |
|
|
super().__init__() |
|
|
self.ov_lyric_encoder_compiled = encoder_model |
|
|
self.ov_decoder_compiled_model = decoder_model |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, ov_core, ov_models_path, device="CPU"): |
|
|
ov_model_encoder = ov_core.read_model(Path(ov_models_path, TRANSFORMER_ENCODER_MODEL_NAME)) |
|
|
compiled_model_encoder = ov_core.compile_model(ov_model_encoder, device) |
|
|
|
|
|
ov_model_decoder = ov_core.read_model(Path(ov_models_path, TRANSFORMER_DECODER_MODEL_NAME)) |
|
|
compiled_model_decoder = ov_core.compile_model(ov_model_decoder, device) |
|
|
return cls(compiled_model_encoder, compiled_model_decoder) |
|
|
|
|
|
def encode_with_temperature( |
|
|
self, |
|
|
encoder_text_hidden_states: Optional[torch.Tensor] = None, |
|
|
text_attention_mask: Optional[torch.LongTensor] = None, |
|
|
speaker_embeds: Optional[torch.FloatTensor] = None, |
|
|
lyric_token_idx: Optional[torch.LongTensor] = None, |
|
|
lyric_mask: Optional[torch.LongTensor] = None, |
|
|
tau: Optional[torch.FloatTensor] = torch.Tensor([0.01]), |
|
|
): |
|
|
output = None |
|
|
if self.ov_lyric_encoder_compiled: |
|
|
output = self.ov_lyric_encoder_compiled( |
|
|
{ |
|
|
"encoder_text_hidden_states": encoder_text_hidden_states, |
|
|
"text_attention_mask": text_attention_mask, |
|
|
"speaker_embeds": speaker_embeds, |
|
|
"lyric_token_idx": lyric_token_idx, |
|
|
"lyric_mask": lyric_mask, |
|
|
"tau": tau, |
|
|
} |
|
|
) |
|
|
return torch.from_numpy(output[0]), torch.from_numpy(output[1]) |
|
|
|
|
|
def decode_with_temperature( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: torch.Tensor, |
|
|
encoder_hidden_states: torch.Tensor, |
|
|
encoder_hidden_mask: torch.Tensor, |
|
|
timestep: Optional[torch.Tensor], |
|
|
ssl_hidden_states: Optional[List[torch.Tensor]] = None, |
|
|
output_length: int = 0, |
|
|
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, |
|
|
controlnet_scale: Union[float, torch.Tensor] = 1.0, |
|
|
return_dict: bool = True, |
|
|
tau: Optional[torch.FloatTensor] = torch.Tensor([0.01]), |
|
|
): |
|
|
output = None |
|
|
if self.ov_decoder_compiled_model: |
|
|
output = self.ov_decoder_compiled_model( |
|
|
{ |
|
|
"hidden_states": hidden_states, |
|
|
"attention_mask": attention_mask, |
|
|
"encoder_hidden_states": encoder_hidden_states, |
|
|
"encoder_hidden_mask": encoder_hidden_mask, |
|
|
"output_length": output_length, |
|
|
"timestep": timestep, |
|
|
"tau": tau, |
|
|
} |
|
|
) |
|
|
|
|
|
sample = torch.from_numpy(output[0]) if output is not None else None |
|
|
return sample |
|
|
|
|
|
def encode( |
|
|
self, |
|
|
encoder_text_hidden_states: Optional[torch.Tensor] = None, |
|
|
text_attention_mask: Optional[torch.LongTensor] = None, |
|
|
speaker_embeds: Optional[torch.FloatTensor] = None, |
|
|
lyric_token_idx: Optional[torch.LongTensor] = None, |
|
|
lyric_mask: Optional[torch.LongTensor] = None, |
|
|
): |
|
|
return self.encode_with_temperature( |
|
|
encoder_text_hidden_states=encoder_text_hidden_states, |
|
|
text_attention_mask=text_attention_mask, |
|
|
speaker_embeds=speaker_embeds, |
|
|
lyric_token_idx=lyric_token_idx, |
|
|
lyric_mask=lyric_mask, |
|
|
tau=torch.Tensor([1]), |
|
|
) |
|
|
|
|
|
def decode( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: torch.Tensor, |
|
|
encoder_hidden_states: torch.Tensor, |
|
|
encoder_hidden_mask: torch.Tensor, |
|
|
timestep: Optional[torch.Tensor], |
|
|
ssl_hidden_states: Optional[List[torch.Tensor]] = None, |
|
|
output_length: int = 0, |
|
|
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, |
|
|
controlnet_scale: Union[float, torch.Tensor] = 1.0, |
|
|
return_dict: bool = True, |
|
|
): |
|
|
sample = self.decode_with_temperature( |
|
|
hidden_states=hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
encoder_hidden_mask=encoder_hidden_mask, |
|
|
timestep=timestep, |
|
|
ssl_hidden_states=ssl_hidden_states, |
|
|
output_length=output_length, |
|
|
block_controlnet_hidden_states=block_controlnet_hidden_states, |
|
|
controlnet_scale=controlnet_scale, |
|
|
return_dict=return_dict, |
|
|
tau=torch.Tensor([1]), |
|
|
) |
|
|
|
|
|
return Transformer2DModelOutput(sample, None) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: torch.Tensor, |
|
|
encoder_text_hidden_states: Optional[torch.Tensor] = None, |
|
|
text_attention_mask: Optional[torch.LongTensor] = None, |
|
|
speaker_embeds: Optional[torch.FloatTensor] = None, |
|
|
lyric_token_idx: Optional[torch.LongTensor] = None, |
|
|
lyric_mask: Optional[torch.LongTensor] = None, |
|
|
timestep: Optional[torch.Tensor] = None, |
|
|
ssl_hidden_states: Optional[List[torch.Tensor]] = None, |
|
|
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, |
|
|
controlnet_scale: Union[float, torch.Tensor] = 1.0, |
|
|
return_dict: bool = True, |
|
|
): |
|
|
encoder_hidden_states, encoder_hidden_mask = self.encode( |
|
|
encoder_text_hidden_states=encoder_text_hidden_states, |
|
|
text_attention_mask=text_attention_mask, |
|
|
speaker_embeds=speaker_embeds, |
|
|
lyric_token_idx=lyric_token_idx, |
|
|
lyric_mask=lyric_mask, |
|
|
) |
|
|
|
|
|
output_length = hidden_states.shape[-1] |
|
|
|
|
|
output = self.decode( |
|
|
hidden_states=hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
encoder_hidden_mask=encoder_hidden_mask, |
|
|
timestep=timestep, |
|
|
ssl_hidden_states=ssl_hidden_states, |
|
|
output_length=output_length, |
|
|
block_controlnet_hidden_states=block_controlnet_hidden_states, |
|
|
controlnet_scale=controlnet_scale, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class OVACEStepPipeline(ACEStepPipeline): |
|
|
def __init__(self): |
|
|
super().__init__(checkpoint_dir="", dtype="float32") |
|
|
self.core = ov.Core() |
|
|
|
|
|
self.dcae_decoder = None |
|
|
self.vocoder_encode = None |
|
|
self.vocoder_decoder = None |
|
|
self.transformer_encode = None |
|
|
self.transformer_encode_with_temperature = None |
|
|
self.transformer_decode = None |
|
|
self.transformer_decode_with_temperature = None |
|
|
|
|
|
self.ace_step_transformer_origin = None |
|
|
self.ace_step_transformer = None |
|
|
self.music_dcae = None |
|
|
self.text_tokenizer = None |
|
|
self.text_encoder_model = None |
|
|
|
|
|
def get_checkpoint_path(self, checkpoint_dir, repo): |
|
|
pass |
|
|
|
|
|
def load_checkpoint(self, checkpoint_dir=None, export_quantized_weights=False): |
|
|
pass |
|
|
|
|
|
def load_models(self, ov_models_path: str = None, device: str = "CPU"): |
|
|
self.loaded = True |
|
|
if ov_models_path and Path(ov_models_path).exists: |
|
|
ov_text_encoder_model = self.core.read_model(Path(ov_models_path, TEXT_ENCODER_MODEL_NAME)) |
|
|
self.text_encoder_model = self.core.compile_model(ov_text_encoder_model, device) |
|
|
|
|
|
ov_text_tokenizer_path = self.core.read_model(Path(ov_models_path, TOKENIZER_MODEL_NAME)) |
|
|
self.text_tokenizer = self.core.compile_model(ov_text_tokenizer_path, "CPU") |
|
|
|
|
|
self.music_dcae = MusicDCAEWrapper() |
|
|
self.music_dcae.dcae = OVWrapperAutoencoderDC.from_pretrained(self.core, ov_models_path, device) |
|
|
self.music_dcae.vocoder = OVWrapperADaMoSHiFiGANV1.from_pretrained(self.core, ov_models_path, device) |
|
|
|
|
|
self.ace_step_transformer = OvWrapperACEStepTransformer2DModel.from_pretrained(self.core, ov_models_path, device) |
|
|
else: |
|
|
logger.error(f"Path is not exists: {ov_models_path}") |
|
|
|
|
|
lang_segment = LangSegment() |
|
|
lang_segment.setfilters(language_filters.default) |
|
|
self.lang_segment = lang_segment |
|
|
self.lyric_tokenizer = VoiceBpeTokenizer() |
|
|
|
|
|
def load_quantized_checkpoint(self, checkpoint_dir=None): |
|
|
pass |
|
|
|
|
|
def get_text_embeddings(self, texts, text_max_length=256): |
|
|
inputs = self.text_tokenizer(texts) |
|
|
inputs = {"attention_mask": inputs["attention_mask"], "input_ids": inputs["input_ids"]} |
|
|
|
|
|
last_hidden_states = self.text_encoder_model(inputs) |
|
|
attention_mask = inputs["attention_mask"] |
|
|
return torch.from_numpy(last_hidden_states[0]), torch.from_numpy(attention_mask) |
|
|
|
|
|
def get_text_embeddings_null(self, texts, text_max_length=256, tau=0.01, l_min=8, l_max=10): |
|
|
inputs = self.text_tokenizer(texts) |
|
|
inputs = {"attention_mask": inputs["attention_mask"], "input_ids": inputs["input_ids"]} |
|
|
last_hidden_states = self.text_encoder_model(inputs) |
|
|
return torch.from_numpy(last_hidden_states[0]) |
|
|
|
|
|
def text2music_diffusion_process( |
|
|
self, |
|
|
duration, |
|
|
encoder_text_hidden_states, |
|
|
text_attention_mask, |
|
|
speaker_embds, |
|
|
lyric_token_ids, |
|
|
lyric_mask, |
|
|
random_generators=None, |
|
|
infer_steps=60, |
|
|
guidance_scale=15.0, |
|
|
omega_scale=10.0, |
|
|
scheduler_type="euler", |
|
|
cfg_type="apg", |
|
|
zero_steps=1, |
|
|
use_zero_init=True, |
|
|
guidance_interval=0.5, |
|
|
guidance_interval_decay=1.0, |
|
|
min_guidance_scale=3.0, |
|
|
oss_steps=[], |
|
|
encoder_text_hidden_states_null=None, |
|
|
use_erg_lyric=False, |
|
|
use_erg_diffusion=False, |
|
|
retake_random_generators=None, |
|
|
retake_variance=0.5, |
|
|
add_retake_noise=False, |
|
|
guidance_scale_text=0.0, |
|
|
guidance_scale_lyric=0.0, |
|
|
repaint_start=0, |
|
|
repaint_end=0, |
|
|
src_latents=None, |
|
|
audio2audio_enable=False, |
|
|
ref_audio_strength=0.5, |
|
|
ref_latents=None, |
|
|
): |
|
|
logger.info("cfg_type: {}, guidance_scale: {}, omega_scale: {}".format(cfg_type, guidance_scale, omega_scale)) |
|
|
do_classifier_free_guidance = True |
|
|
if guidance_scale == 0.0 or guidance_scale == 1.0: |
|
|
do_classifier_free_guidance = False |
|
|
|
|
|
do_double_condition_guidance = False |
|
|
if guidance_scale_text is not None and guidance_scale_text > 1.0 and guidance_scale_lyric is not None and guidance_scale_lyric > 1.0: |
|
|
do_double_condition_guidance = True |
|
|
logger.info( |
|
|
"do_double_condition_guidance: {}, guidance_scale_text: {}, guidance_scale_lyric: {}".format( |
|
|
do_double_condition_guidance, |
|
|
guidance_scale_text, |
|
|
guidance_scale_lyric, |
|
|
) |
|
|
) |
|
|
|
|
|
bsz = encoder_text_hidden_states.shape[0] |
|
|
|
|
|
if scheduler_type == "euler": |
|
|
scheduler = FlowMatchEulerDiscreteScheduler( |
|
|
num_train_timesteps=1000, |
|
|
shift=3.0, |
|
|
) |
|
|
elif scheduler_type == "heun": |
|
|
scheduler = FlowMatchHeunDiscreteScheduler( |
|
|
num_train_timesteps=1000, |
|
|
shift=3.0, |
|
|
) |
|
|
elif scheduler_type == "pingpong": |
|
|
scheduler = FlowMatchPingPongScheduler( |
|
|
num_train_timesteps=1000, |
|
|
shift=3.0, |
|
|
) |
|
|
|
|
|
frame_length = int(duration * 44100 / 512 / 8) |
|
|
if src_latents is not None: |
|
|
frame_length = src_latents.shape[-1] |
|
|
|
|
|
if ref_latents is not None: |
|
|
frame_length = ref_latents.shape[-1] |
|
|
|
|
|
if len(oss_steps) > 0: |
|
|
infer_steps = max(oss_steps) |
|
|
scheduler.set_timesteps |
|
|
timesteps, num_inference_steps = retrieve_timesteps( |
|
|
scheduler, |
|
|
num_inference_steps=infer_steps, |
|
|
device=self.device, |
|
|
timesteps=None, |
|
|
) |
|
|
new_timesteps = torch.zeros(len(oss_steps), dtype=self.dtype, device=self.device) |
|
|
for idx in range(len(oss_steps)): |
|
|
new_timesteps[idx] = timesteps[oss_steps[idx] - 1] |
|
|
num_inference_steps = len(oss_steps) |
|
|
sigmas = (new_timesteps / 1000).float().cpu().numpy() |
|
|
timesteps, num_inference_steps = retrieve_timesteps( |
|
|
scheduler, |
|
|
num_inference_steps=num_inference_steps, |
|
|
device=self.device, |
|
|
sigmas=sigmas, |
|
|
) |
|
|
logger.info(f"oss_steps: {oss_steps}, num_inference_steps: {num_inference_steps} after remapping to timesteps {timesteps}") |
|
|
else: |
|
|
timesteps, num_inference_steps = retrieve_timesteps( |
|
|
scheduler, |
|
|
num_inference_steps=infer_steps, |
|
|
device=self.device, |
|
|
timesteps=None, |
|
|
) |
|
|
|
|
|
target_latents = randn_tensor( |
|
|
shape=(bsz, 8, 16, frame_length), |
|
|
generator=random_generators, |
|
|
device=self.device, |
|
|
dtype=self.dtype, |
|
|
) |
|
|
|
|
|
is_repaint = False |
|
|
is_extend = False |
|
|
|
|
|
if add_retake_noise: |
|
|
n_min = int(infer_steps * (1 - retake_variance)) |
|
|
retake_variance = torch.tensor(retake_variance * math.pi / 2).to(self.device).to(self.dtype) |
|
|
retake_latents = randn_tensor( |
|
|
shape=(bsz, 8, 16, frame_length), |
|
|
generator=retake_random_generators, |
|
|
device=self.device, |
|
|
dtype=self.dtype, |
|
|
) |
|
|
repaint_start_frame = int(repaint_start * 44100 / 512 / 8) |
|
|
repaint_end_frame = int(repaint_end * 44100 / 512 / 8) |
|
|
x0 = src_latents |
|
|
|
|
|
is_repaint = repaint_end_frame - repaint_start_frame != frame_length |
|
|
|
|
|
is_extend = (repaint_start_frame < 0) or (repaint_end_frame > frame_length) |
|
|
if is_extend: |
|
|
is_repaint = True |
|
|
|
|
|
|
|
|
|
|
|
if not is_repaint: |
|
|
target_latents = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents |
|
|
elif not is_extend: |
|
|
|
|
|
repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=self.device, dtype=self.dtype) |
|
|
repaint_mask[:, :, :, repaint_start_frame:repaint_end_frame] = 1.0 |
|
|
repaint_noise = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents |
|
|
repaint_noise = torch.where(repaint_mask == 1.0, repaint_noise, target_latents) |
|
|
zt_edit = x0.clone() |
|
|
z0 = repaint_noise |
|
|
elif is_extend: |
|
|
to_right_pad_gt_latents = None |
|
|
to_left_pad_gt_latents = None |
|
|
gt_latents = src_latents |
|
|
src_latents_length = gt_latents.shape[-1] |
|
|
max_infer_fame_length = int(240 * 44100 / 512 / 8) |
|
|
left_pad_frame_length = 0 |
|
|
right_pad_frame_length = 0 |
|
|
right_trim_length = 0 |
|
|
left_trim_length = 0 |
|
|
if repaint_start_frame < 0: |
|
|
left_pad_frame_length = abs(repaint_start_frame) |
|
|
frame_length = left_pad_frame_length + gt_latents.shape[-1] |
|
|
extend_gt_latents = torch.nn.functional.pad(gt_latents, (left_pad_frame_length, 0), "constant", 0) |
|
|
if frame_length > max_infer_fame_length: |
|
|
right_trim_length = frame_length - max_infer_fame_length |
|
|
extend_gt_latents = extend_gt_latents[:, :, :, :max_infer_fame_length] |
|
|
to_right_pad_gt_latents = extend_gt_latents[:, :, :, -right_trim_length:] |
|
|
frame_length = max_infer_fame_length |
|
|
repaint_start_frame = 0 |
|
|
gt_latents = extend_gt_latents |
|
|
|
|
|
if repaint_end_frame > src_latents_length: |
|
|
right_pad_frame_length = repaint_end_frame - gt_latents.shape[-1] |
|
|
frame_length = gt_latents.shape[-1] + right_pad_frame_length |
|
|
extend_gt_latents = torch.nn.functional.pad(gt_latents, (0, right_pad_frame_length), "constant", 0) |
|
|
if frame_length > max_infer_fame_length: |
|
|
left_trim_length = frame_length - max_infer_fame_length |
|
|
extend_gt_latents = extend_gt_latents[:, :, :, -max_infer_fame_length:] |
|
|
to_left_pad_gt_latents = extend_gt_latents[:, :, :, :left_trim_length] |
|
|
frame_length = max_infer_fame_length |
|
|
repaint_end_frame = frame_length |
|
|
gt_latents = extend_gt_latents |
|
|
|
|
|
repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=self.device, dtype=self.dtype) |
|
|
if left_pad_frame_length > 0: |
|
|
repaint_mask[:, :, :, :left_pad_frame_length] = 1.0 |
|
|
if right_pad_frame_length > 0: |
|
|
repaint_mask[:, :, :, -right_pad_frame_length:] = 1.0 |
|
|
x0 = gt_latents |
|
|
padd_list = [] |
|
|
if left_pad_frame_length > 0: |
|
|
padd_list.append(retake_latents[:, :, :, :left_pad_frame_length]) |
|
|
padd_list.append( |
|
|
target_latents[ |
|
|
:, |
|
|
:, |
|
|
:, |
|
|
left_trim_length : target_latents.shape[-1] - right_trim_length, |
|
|
] |
|
|
) |
|
|
if right_pad_frame_length > 0: |
|
|
padd_list.append(retake_latents[:, :, :, -right_pad_frame_length:]) |
|
|
target_latents = torch.cat(padd_list, dim=-1) |
|
|
assert target_latents.shape[-1] == x0.shape[-1], f"{target_latents.shape=} {x0.shape=}" |
|
|
zt_edit = x0.clone() |
|
|
z0 = target_latents |
|
|
|
|
|
if audio2audio_enable and ref_latents is not None: |
|
|
logger.info(f"audio2audio_enable: {audio2audio_enable}, ref_latents: {ref_latents.shape}") |
|
|
target_latents, timesteps, scheduler, num_inference_steps = self.add_latents_noise( |
|
|
gt_latents=ref_latents, |
|
|
sigma_max=(1 - ref_audio_strength), |
|
|
noise=target_latents, |
|
|
scheduler_type=scheduler_type, |
|
|
infer_steps=infer_steps, |
|
|
) |
|
|
|
|
|
attention_mask = torch.ones(bsz, frame_length, device=self.device, dtype=self.dtype) |
|
|
|
|
|
|
|
|
start_idx = int(num_inference_steps * ((1 - guidance_interval) / 2)) |
|
|
end_idx = int(num_inference_steps * (guidance_interval / 2 + 0.5)) |
|
|
logger.info(f"start_idx: {start_idx}, end_idx: {end_idx}, num_inference_steps: {num_inference_steps}") |
|
|
|
|
|
momentum_buffer = MomentumBuffer() |
|
|
|
|
|
|
|
|
encoder_hidden_states, encoder_hidden_mask = self.ace_step_transformer.encode( |
|
|
encoder_text_hidden_states, |
|
|
text_attention_mask, |
|
|
speaker_embds, |
|
|
lyric_token_ids, |
|
|
lyric_mask, |
|
|
) |
|
|
|
|
|
if use_erg_lyric: |
|
|
|
|
|
encoder_hidden_states_null, _ = self.ace_step_transformer.encode_with_temperature( |
|
|
encoder_text_hidden_states=( |
|
|
encoder_text_hidden_states_null if encoder_text_hidden_states_null is not None else torch.zeros_like(encoder_text_hidden_states) |
|
|
), |
|
|
text_attention_mask=text_attention_mask, |
|
|
speaker_embeds=torch.zeros_like(speaker_embds), |
|
|
lyric_token_idx=lyric_token_ids, |
|
|
lyric_mask=lyric_mask, |
|
|
) |
|
|
else: |
|
|
|
|
|
encoder_hidden_states_null, _ = self.ace_step_transformer.encode( |
|
|
torch.zeros_like(encoder_text_hidden_states), |
|
|
text_attention_mask, |
|
|
torch.zeros_like(speaker_embds), |
|
|
torch.zeros_like(lyric_token_ids), |
|
|
lyric_mask, |
|
|
) |
|
|
|
|
|
encoder_hidden_states_no_lyric = None |
|
|
if do_double_condition_guidance: |
|
|
|
|
|
if use_erg_lyric: |
|
|
encoder_hidden_states_no_lyric, _ = self.ace_step_transformer.encode_with_temperature( |
|
|
encoder_text_hidden_states=encoder_text_hidden_states, |
|
|
text_attention_mask=text_attention_mask, |
|
|
speaker_embeds=torch.zeros_like(speaker_embds), |
|
|
lyric_token_idx=lyric_token_ids, |
|
|
lyric_mask=lyric_mask, |
|
|
) |
|
|
|
|
|
else: |
|
|
encoder_hidden_states_no_lyric, _ = self.ace_step_transformer.encode( |
|
|
encoder_text_hidden_states, |
|
|
text_attention_mask, |
|
|
torch.zeros_like(speaker_embds), |
|
|
torch.zeros_like(lyric_token_ids), |
|
|
lyric_mask, |
|
|
) |
|
|
|
|
|
for i, t in tqdm(enumerate(timesteps), total=num_inference_steps): |
|
|
if is_repaint: |
|
|
if i < n_min: |
|
|
continue |
|
|
elif i == n_min: |
|
|
t_i = t / 1000 |
|
|
zt_src = (1 - t_i) * x0 + (t_i) * z0 |
|
|
target_latents = zt_edit + zt_src - x0 |
|
|
logger.info(f"repaint start from {n_min} add {t_i} level of noise") |
|
|
|
|
|
|
|
|
latents = target_latents |
|
|
|
|
|
is_in_guidance_interval = start_idx <= i < end_idx |
|
|
if is_in_guidance_interval and do_classifier_free_guidance: |
|
|
|
|
|
if guidance_interval_decay > 0: |
|
|
|
|
|
progress = (i - start_idx) / (end_idx - start_idx - 1) |
|
|
current_guidance_scale = guidance_scale - (guidance_scale - min_guidance_scale) * progress * guidance_interval_decay |
|
|
else: |
|
|
current_guidance_scale = guidance_scale |
|
|
|
|
|
latent_model_input = latents |
|
|
timestep = t.expand(latent_model_input.shape[0]) |
|
|
output_length = latent_model_input.shape[-1] |
|
|
|
|
|
noise_pred_with_cond = self.ace_step_transformer.decode( |
|
|
hidden_states=latent_model_input, |
|
|
attention_mask=attention_mask, |
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
encoder_hidden_mask=encoder_hidden_mask, |
|
|
output_length=output_length, |
|
|
timestep=timestep, |
|
|
).sample |
|
|
|
|
|
noise_pred_with_only_text_cond = None |
|
|
if do_double_condition_guidance and encoder_hidden_states_no_lyric is not None: |
|
|
noise_pred_with_only_text_cond = self.ace_step_transformer.decode( |
|
|
hidden_states=latent_model_input, |
|
|
attention_mask=attention_mask, |
|
|
encoder_hidden_states=encoder_hidden_states_no_lyric, |
|
|
encoder_hidden_mask=encoder_hidden_mask, |
|
|
output_length=output_length, |
|
|
timestep=timestep, |
|
|
).sample |
|
|
|
|
|
if use_erg_diffusion: |
|
|
noise_pred_uncond = self.ace_step_transformer.decode_with_temperature( |
|
|
hidden_states=latent_model_input, |
|
|
timestep=timestep, |
|
|
encoder_hidden_states=encoder_hidden_states_null, |
|
|
encoder_hidden_mask=encoder_hidden_mask, |
|
|
output_length=output_length, |
|
|
attention_mask=attention_mask, |
|
|
) |
|
|
else: |
|
|
noise_pred_uncond = self.ace_step_transformer.decode( |
|
|
hidden_states=latent_model_input, |
|
|
attention_mask=attention_mask, |
|
|
encoder_hidden_states=encoder_hidden_states_null, |
|
|
encoder_hidden_mask=encoder_hidden_mask, |
|
|
output_length=output_length, |
|
|
timestep=timestep, |
|
|
).sample |
|
|
|
|
|
if do_double_condition_guidance and noise_pred_with_only_text_cond is not None: |
|
|
noise_pred = cfg_double_condition_forward( |
|
|
cond_output=noise_pred_with_cond, |
|
|
uncond_output=noise_pred_uncond, |
|
|
only_text_cond_output=noise_pred_with_only_text_cond, |
|
|
guidance_scale_text=guidance_scale_text, |
|
|
guidance_scale_lyric=guidance_scale_lyric, |
|
|
) |
|
|
|
|
|
elif cfg_type == "apg": |
|
|
noise_pred = apg_forward( |
|
|
pred_cond=noise_pred_with_cond, |
|
|
pred_uncond=noise_pred_uncond, |
|
|
guidance_scale=current_guidance_scale, |
|
|
momentum_buffer=momentum_buffer, |
|
|
) |
|
|
elif cfg_type == "cfg": |
|
|
noise_pred = cfg_forward( |
|
|
cond_output=noise_pred_with_cond, |
|
|
uncond_output=noise_pred_uncond, |
|
|
cfg_strength=current_guidance_scale, |
|
|
) |
|
|
elif cfg_type == "cfg_star": |
|
|
noise_pred = cfg_zero_star( |
|
|
noise_pred_with_cond=noise_pred_with_cond, |
|
|
noise_pred_uncond=noise_pred_uncond, |
|
|
guidance_scale=current_guidance_scale, |
|
|
i=i, |
|
|
zero_steps=zero_steps, |
|
|
use_zero_init=use_zero_init, |
|
|
) |
|
|
else: |
|
|
latent_model_input = latents |
|
|
timestep = t.expand(latent_model_input.shape[0]) |
|
|
noise_pred = self.ace_step_transformer.decode( |
|
|
hidden_states=latent_model_input, |
|
|
attention_mask=attention_mask, |
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
encoder_hidden_mask=encoder_hidden_mask, |
|
|
output_length=latent_model_input.shape[-1], |
|
|
timestep=timestep, |
|
|
).sample |
|
|
|
|
|
if is_repaint and i >= n_min: |
|
|
t_i = t / 1000 |
|
|
if i + 1 < len(timesteps): |
|
|
t_im1 = (timesteps[i + 1]) / 1000 |
|
|
else: |
|
|
t_im1 = torch.zeros_like(t_i).to(self.device) |
|
|
target_latents = target_latents.to(torch.float32) |
|
|
prev_sample = target_latents + (t_im1 - t_i) * noise_pred |
|
|
prev_sample = prev_sample.to(self.dtype) |
|
|
target_latents = prev_sample |
|
|
zt_src = (1 - t_im1) * x0 + (t_im1) * z0 |
|
|
target_latents = torch.where(repaint_mask == 1.0, target_latents, zt_src) |
|
|
else: |
|
|
target_latents = scheduler.step( |
|
|
model_output=noise_pred, |
|
|
timestep=t, |
|
|
sample=target_latents, |
|
|
return_dict=False, |
|
|
omega=omega_scale, |
|
|
generator=random_generators[0], |
|
|
)[0] |
|
|
|
|
|
if is_extend: |
|
|
if to_right_pad_gt_latents is not None: |
|
|
target_latents = torch.cat([target_latents, to_right_pad_gt_latents], dim=-1) |
|
|
if to_left_pad_gt_latents is not None: |
|
|
target_latents = torch.cat([to_right_pad_gt_latents, target_latents], dim=0) |
|
|
return target_latents |
|
|
|
|
|
def load_lora(self, model_with_lora_path, device="CPU"): |
|
|
if model_with_lora_path == "none": |
|
|
if self.ace_step_transformer_origin: |
|
|
self.ace_step_transformer = self.ace_step_transformer_origin |
|
|
else: |
|
|
self.ace_step_transformer_origin = self.ace_step_transformer |
|
|
self.update_transformer_model(model_with_lora_path, device) |
|
|
|
|
|
def update_transformer_model(self, new_transformer_path, device="CPU"): |
|
|
self.ace_step_transformer = OvWrapperACEStepTransformer2DModel.from_pretrained(self.core, new_transformer_path, device) |
|
|
|