ACE-Step-v1-ov-int4 / ov_ace_helper.py
rippertnt's picture
Upload 18 files
9dc8183 verified
import os
import gc
import math
import torch
import types
import torchaudio
import torchvision.transforms as transforms
from tqdm import tqdm
from pathlib import Path
from loguru import logger
from diffusers.utils.torch_utils import randn_tensor
from typing import Dict, Optional, List, Union, Type
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
import nncf
import openvino as ov
from openvino.tools.ovc import convert_model
from openvino_tokenizers import convert_tokenizer
from openvino.frontend.pytorch.patch_model import __make_16bit_traceable
from acestep.language_segmentation import LangSegment, language_filters
from acestep.models.lyrics_utils.lyric_tokenizer import VoiceBpeTokenizer
from acestep.pipeline_ace_step import ACEStepPipeline
from acestep.models.ace_step_transformer import Transformer2DModelOutput
from acestep.music_dcae.music_dcae_pipeline import MusicDCAE
from acestep.schedulers.scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler
from acestep.schedulers.scheduling_flow_match_pingpong import FlowMatchPingPongScheduler
from acestep.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from acestep.apg_guidance import (
apg_forward,
MomentumBuffer,
cfg_forward,
cfg_zero_star,
cfg_double_condition_forward,
)
torch.set_float32_matmul_precision("high")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
TOKENIZER_MODEL_NAME = "openvino_tokenizer.xml"
TEXT_ENCODER_MODEL_NAME = "ov_text_encoder_model.xml"
DCAE_ENCODER_MODEL_NAME = "ov_dcae_encoder_model.xml"
DCAE_DECODER_MODEL_NAME = "ov_dcae_decoder_model.xml"
VOCODER_DECODE_MODEL_NAME = "ov_vocoder_decode_model.xml"
VOCODER_MEL_TRANSFORM_MODEL_NAME = "ov_vocoder_mel_transform_model.xml"
TRANSFORMER_DECODER_MODEL_NAME = "ov_transformer_decoder_model.xml"
TRANSFORMER_ENCODER_MODEL_NAME = "ov_transformer_encoder_model.xml"
def cleanup_torchscript_cache():
"""
Helper for removing cached model representation
"""
torch._C._jit_clear_class_registry()
torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
torch.jit._state._clear_class_state()
def ov_convert(
model_dir_path: str,
ov_model_name: str,
inputs: Dict,
orig_model: Type[torch.nn.Module],
model_name: str,
quantization_config: Dict = None,
force_convertion: bool = False,
):
try:
ov_model_path = Path(model_dir_path, ov_model_name)
if not ov_model_path.exists() or force_convertion:
print(f"⌛ Convert {model_name} model")
orig_model.eval()
__make_16bit_traceable(orig_model)
ov_model = convert_model(orig_model, example_input=inputs)
if quantization_config is not None:
print(f"⌛ Weights compression with {quantization_config['mode']} mode started")
ov_model = nncf.compress_weights(ov_model, **quantization_config)
print("✅ Weights compression finished")
ov.save_model(ov_model, ov_model_path)
del ov_model
cleanup_torchscript_cache()
gc.collect()
print(f"✅ {model_name} model converted")
except Exception as e:
print(f"❌{model_name} model is not converted. Error: {e}")
def convert_transformer_models(pipeline: ACEStepPipeline, model_dir: str = "ov_converted", orig_checkpoint_path: str = "", quantization_config: Dict = None):
# Transformer Encoder model
def encode_with_temperature_wrap(
self,
encoder_text_hidden_states: torch.Tensor = None,
text_attention_mask: torch.LongTensor = None,
speaker_embeds: torch.FloatTensor = None,
lyric_token_idx: torch.LongTensor = None,
lyric_mask: torch.LongTensor = None,
tau: torch.FloatTensor = torch.Tensor([0.01]),
):
handlers = []
def hook(module, input, output):
output[:] *= tau[0]
return output
l_min = 4
l_max = 6
for i in range(l_min, l_max):
handler = self.lyric_encoder.encoders[i].self_attn.linear_q.register_forward_hook(hook)
handlers.append(handler)
encoder_hidden_states, encoder_hidden_mask = self.encode(
encoder_text_hidden_states=encoder_text_hidden_states,
text_attention_mask=text_attention_mask,
speaker_embeds=speaker_embeds,
lyric_token_idx=lyric_token_idx,
lyric_mask=lyric_mask,
)
for hook in handlers:
hook.remove()
return encoder_hidden_states, encoder_hidden_mask
inputs = {
"encoder_text_hidden_states": torch.randn(size=(1, 15, 768), dtype=torch.float),
"text_attention_mask": torch.ones([1, 15], dtype=torch.int64),
"speaker_embeds": torch.zeros(size=(1, 512), dtype=torch.float),
"lyric_token_idx": torch.randint(10000, [1, 543], dtype=torch.int64),
"lyric_mask": torch.ones([1, 543], dtype=torch.int64),
"tau": torch.Tensor([0.01]),
}
transformer_encoder_model = pipeline.ace_step_transformer
transformer_encoder_erg_model = pipeline.ace_step_transformer
transformer_encoder_erg_model.forward = types.MethodType(encode_with_temperature_wrap, transformer_encoder_model)
ov_convert(
model_dir,
TRANSFORMER_ENCODER_MODEL_NAME,
inputs,
transformer_encoder_erg_model,
"Transformer Encoder with Entropy Rectifying Guidance",
quantization_config=quantization_config,
)
# Transformer Decoder model
def decode_with_temperature_wrap(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_mask: torch.Tensor,
timestep: torch.Tensor = None,
# ssl_hidden_states: List[torch.Tensor] = None,
output_length: int = 0,
# block_controlnet_hidden_states: Union[List[torch.Tensor], torch.Tensor] = None,
# controlnet_scale: Union[float, torch.Tensor] = 1.0,
tau: torch.FloatTensor = torch.Tensor([0.01]),
):
handlers = []
def hook(module, input, output):
output[:] *= tau[0]
return output
l_min = 5
l_max = 10
for i in range(l_min, l_max):
handler = self.transformer_blocks[i].attn.to_q.register_forward_hook(hook)
handlers.append(handler)
handler = self.transformer_blocks[i].cross_attn.to_q.register_forward_hook(hook)
handlers.append(handler)
sample = self.decode(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_mask=encoder_hidden_mask,
output_length=output_length,
timestep=timestep,
).sample
for hook in handlers:
hook.remove()
return sample
inputs = {
"hidden_states": torch.randn(size=(1, 8, 16, 151), dtype=torch.float),
"attention_mask": torch.ones([1, 151], dtype=torch.int64),
"encoder_hidden_states": torch.randn(size=(1, 559, 2560), dtype=torch.float),
"encoder_hidden_mask": torch.ones([1, 559], dtype=torch.float),
"output_length": torch.tensor(151),
"timestep": torch.randn([1], dtype=torch.float),
"tau": torch.Tensor([0.01]),
}
transformer_decoder_erg_model = pipeline.ace_step_transformer
transformer_decoder_erg_model.forward = types.MethodType(decode_with_temperature_wrap, transformer_decoder_erg_model)
ov_convert(
model_dir,
TRANSFORMER_DECODER_MODEL_NAME,
inputs,
transformer_decoder_erg_model,
"Transformer Decoder with Entropy Rectifying Guidance",
quantization_config=quantization_config,
)
def convert_models(pipeline: ACEStepPipeline, model_dir: str = "ov_converted_new", orig_checkpoint_path: str = "", quantization_config: Dict = None):
print(f"⌛ Conversion started. Be patient, it may takes some time.")
if not pipeline.loaded or (orig_checkpoint_path and not Path(orig_checkpoint_path).exists()):
print("⌛ Load Original model checkpoints")
pipeline.load_checkpoint(orig_checkpoint_path)
print("✅ Original model checkpoints successfully loaded")
# Tokenizer
ov_tokenizer_path = Path(model_dir, TOKENIZER_MODEL_NAME)
if not ov_tokenizer_path.exists():
print(f"⌛ Convert Tokenizer")
if not ov_tokenizer_path.exists():
ov_tokenizer = convert_tokenizer(pipeline.text_tokenizer, with_detokenizer=False)
ov.save_model(ov_tokenizer, Path(model_dir, TOKENIZER_MODEL_NAME))
print(f"✅ Tokenizer is converted")
# Text Encoder Model
inputs = {
"input_ids": torch.randint(1000, size=(1, 15), dtype=torch.int64),
"attention_mask": torch.ones([1, 15], dtype=torch.int64),
}
ov_convert(model_dir, TEXT_ENCODER_MODEL_NAME, inputs, pipeline.text_encoder_model, "UMT5 Encoder")
# DCAE Encoder model
inputs = {"hidden_states": torch.randn([1, 2, 128, 1208], dtype=torch.float)}
ov_convert(model_dir, DCAE_ENCODER_MODEL_NAME, inputs, pipeline.music_dcae.dcae.encoder, "Sana's Deep Compression AutoEncoder")
# DCAE Decoder model
inputs = {"hidden_states": torch.randn([1, 8, 16, 151], dtype=torch.float)}
ov_convert(model_dir, DCAE_DECODER_MODEL_NAME, inputs, pipeline.music_dcae.dcae.decoder, "Sana's Deep Compression AutoEncoder Decoder")
# Vocoder Mel Transform model
inputs = {"x": torch.randn([2, 618496], dtype=torch.float)}
ov_convert(model_dir, VOCODER_MEL_TRANSFORM_MODEL_NAME, inputs, pipeline.music_dcae.vocoder.mel_transform, "Vocoder Mel Transform")
# Vocoder Decoder model
inputs = {"mel": torch.randn([1, 128, 856], dtype=torch.float)}
ov_convert(model_dir, VOCODER_DECODE_MODEL_NAME, inputs, pipeline.music_dcae.vocoder, "Vocoder Decoder")
# DiT
convert_transformer_models(pipeline, model_dir, orig_checkpoint_path, quantization_config)
class MusicDCAEWrapper(MusicDCAE):
def __init__(self, source_sample_rate=None):
torch.nn.Module.__init__(self)
self.dcae = None
self.vocoder = None
if source_sample_rate is None:
source_sample_rate = 48000
self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
self.transform = transforms.Compose(
[
transforms.Normalize(0.5, 0.5),
]
)
self.min_mel_value = -11.0
self.max_mel_value = 3.0
self.audio_chunk_size = int(round((1024 * 512 / 44100 * 48000)))
self.mel_chunk_size = 1024
self.time_dimention_multiple = 8
self.latent_chunk_size = self.mel_chunk_size // self.time_dimention_multiple
self.scale_factor = 0.1786
self.shift_factor = -1.9091
class OVDCAECompiledModels(torch.nn.Module):
def __init__(self, compiled_model):
self.compiled_model = compiled_model
def __call__(self, inputs):
if not self.compiled_model:
logger.error("OVDCAECompiledModels: compiled model is not defined")
output = self.compiled_model({"hidden_states": inputs.to(dtype=torch.float32)})
return torch.from_numpy(output[0])
@classmethod
def from_pretrained(cls, ov_model_path, device, ov_core):
ov_dcae_model = ov_core.read_model(ov_model_path)
compiled_model = ov_core.compile_model(ov_dcae_model, device)
return cls(compiled_model)
class OVWrapperAutoencoderDC(torch.nn.Module):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
@classmethod
def from_pretrained(cls, ov_core, ov_models_path, device="CPU"):
encoder = OVDCAECompiledModels.from_pretrained(Path(ov_models_path, DCAE_ENCODER_MODEL_NAME), device, ov_core)
decoder = OVDCAECompiledModels.from_pretrained(Path(ov_models_path, DCAE_DECODER_MODEL_NAME), device, ov_core)
return cls(encoder, decoder)
class OVWrapperADaMoSHiFiGANV1(torch.nn.Module):
def __init__(self, encoder_compiled_model, mel_trnasform_compiled_model):
super().__init__()
self.decoder = encoder_compiled_model
self.mel_trnasform = mel_trnasform_compiled_model
@classmethod
def from_pretrained(cls, ov_core, ov_models_path, device="CPU"):
ov_vocoder_decoder_model = ov_core.read_model(Path(ov_models_path, VOCODER_DECODE_MODEL_NAME))
decoder = ov_core.compile_model(ov_vocoder_decoder_model, device)
ov_vocoder_mel_transform_model = ov_core.read_model(Path(ov_models_path, VOCODER_MEL_TRANSFORM_MODEL_NAME))
mel_trnasform = ov_core.compile_model(ov_vocoder_mel_transform_model, device)
return cls(decoder, mel_trnasform)
def decode(self, inputs):
output = self.decoder({"mel": inputs.to(dtype=torch.float32)})
return torch.from_numpy(output[0])
def mel_transform(self, inputs):
output = self.mel_trnasform({"x": inputs.to(dtype=torch.float32)})
return torch.from_numpy(output[0])
def forward(self, inputs):
return self.decode(inputs)
class OvWrapperACEStepTransformer2DModel(torch.nn.Module):
def __init__(self, encoder_model, decoder_model):
super().__init__()
self.ov_lyric_encoder_compiled = encoder_model
self.ov_decoder_compiled_model = decoder_model
@classmethod
def from_pretrained(cls, ov_core, ov_models_path, device="CPU"):
ov_model_encoder = ov_core.read_model(Path(ov_models_path, TRANSFORMER_ENCODER_MODEL_NAME))
compiled_model_encoder = ov_core.compile_model(ov_model_encoder, device)
ov_model_decoder = ov_core.read_model(Path(ov_models_path, TRANSFORMER_DECODER_MODEL_NAME))
compiled_model_decoder = ov_core.compile_model(ov_model_decoder, device)
return cls(compiled_model_encoder, compiled_model_decoder)
def encode_with_temperature(
self,
encoder_text_hidden_states: Optional[torch.Tensor] = None,
text_attention_mask: Optional[torch.LongTensor] = None,
speaker_embeds: Optional[torch.FloatTensor] = None,
lyric_token_idx: Optional[torch.LongTensor] = None,
lyric_mask: Optional[torch.LongTensor] = None,
tau: Optional[torch.FloatTensor] = torch.Tensor([0.01]),
):
output = None
if self.ov_lyric_encoder_compiled:
output = self.ov_lyric_encoder_compiled(
{
"encoder_text_hidden_states": encoder_text_hidden_states,
"text_attention_mask": text_attention_mask,
"speaker_embeds": speaker_embeds,
"lyric_token_idx": lyric_token_idx,
"lyric_mask": lyric_mask,
"tau": tau,
}
)
return torch.from_numpy(output[0]), torch.from_numpy(output[1])
def decode_with_temperature(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_mask: torch.Tensor,
timestep: Optional[torch.Tensor],
ssl_hidden_states: Optional[List[torch.Tensor]] = None,
output_length: int = 0,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
return_dict: bool = True,
tau: Optional[torch.FloatTensor] = torch.Tensor([0.01]),
):
output = None
if self.ov_decoder_compiled_model:
output = self.ov_decoder_compiled_model(
{
"hidden_states": hidden_states,
"attention_mask": attention_mask,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_mask": encoder_hidden_mask,
"output_length": output_length,
"timestep": timestep,
"tau": tau,
}
)
sample = torch.from_numpy(output[0]) if output is not None else None
return sample
def encode(
self,
encoder_text_hidden_states: Optional[torch.Tensor] = None,
text_attention_mask: Optional[torch.LongTensor] = None,
speaker_embeds: Optional[torch.FloatTensor] = None,
lyric_token_idx: Optional[torch.LongTensor] = None,
lyric_mask: Optional[torch.LongTensor] = None,
):
return self.encode_with_temperature(
encoder_text_hidden_states=encoder_text_hidden_states,
text_attention_mask=text_attention_mask,
speaker_embeds=speaker_embeds,
lyric_token_idx=lyric_token_idx,
lyric_mask=lyric_mask,
tau=torch.Tensor([1]),
)
def decode(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_mask: torch.Tensor,
timestep: Optional[torch.Tensor],
ssl_hidden_states: Optional[List[torch.Tensor]] = None,
output_length: int = 0,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
return_dict: bool = True,
):
sample = self.decode_with_temperature(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_mask=encoder_hidden_mask,
timestep=timestep,
ssl_hidden_states=ssl_hidden_states,
output_length=output_length,
block_controlnet_hidden_states=block_controlnet_hidden_states,
controlnet_scale=controlnet_scale,
return_dict=return_dict,
tau=torch.Tensor([1]),
)
return Transformer2DModelOutput(sample, None)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
encoder_text_hidden_states: Optional[torch.Tensor] = None,
text_attention_mask: Optional[torch.LongTensor] = None,
speaker_embeds: Optional[torch.FloatTensor] = None,
lyric_token_idx: Optional[torch.LongTensor] = None,
lyric_mask: Optional[torch.LongTensor] = None,
timestep: Optional[torch.Tensor] = None,
ssl_hidden_states: Optional[List[torch.Tensor]] = None,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
return_dict: bool = True,
):
encoder_hidden_states, encoder_hidden_mask = self.encode(
encoder_text_hidden_states=encoder_text_hidden_states,
text_attention_mask=text_attention_mask,
speaker_embeds=speaker_embeds,
lyric_token_idx=lyric_token_idx,
lyric_mask=lyric_mask,
)
output_length = hidden_states.shape[-1]
output = self.decode(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_mask=encoder_hidden_mask,
timestep=timestep,
ssl_hidden_states=ssl_hidden_states,
output_length=output_length,
block_controlnet_hidden_states=block_controlnet_hidden_states,
controlnet_scale=controlnet_scale,
return_dict=return_dict,
)
return output
class OVACEStepPipeline(ACEStepPipeline):
def __init__(self):
super().__init__(checkpoint_dir="", dtype="float32")
self.core = ov.Core()
self.dcae_decoder = None
self.vocoder_encode = None
self.vocoder_decoder = None
self.transformer_encode = None
self.transformer_encode_with_temperature = None
self.transformer_decode = None
self.transformer_decode_with_temperature = None
self.ace_step_transformer_origin = None
self.ace_step_transformer = None
self.music_dcae = None
self.text_tokenizer = None
self.text_encoder_model = None
def get_checkpoint_path(self, checkpoint_dir, repo):
pass
def load_checkpoint(self, checkpoint_dir=None, export_quantized_weights=False):
pass
def load_models(self, ov_models_path: str = None, device: str = "CPU"):
self.loaded = True
if ov_models_path and Path(ov_models_path).exists:
ov_text_encoder_model = self.core.read_model(Path(ov_models_path, TEXT_ENCODER_MODEL_NAME))
self.text_encoder_model = self.core.compile_model(ov_text_encoder_model, device)
ov_text_tokenizer_path = self.core.read_model(Path(ov_models_path, TOKENIZER_MODEL_NAME))
self.text_tokenizer = self.core.compile_model(ov_text_tokenizer_path, "CPU") # tokenizer can only be inferred on CPU
self.music_dcae = MusicDCAEWrapper()
self.music_dcae.dcae = OVWrapperAutoencoderDC.from_pretrained(self.core, ov_models_path, device)
self.music_dcae.vocoder = OVWrapperADaMoSHiFiGANV1.from_pretrained(self.core, ov_models_path, device)
self.ace_step_transformer = OvWrapperACEStepTransformer2DModel.from_pretrained(self.core, ov_models_path, device)
else:
logger.error(f"Path is not exists: {ov_models_path}")
lang_segment = LangSegment()
lang_segment.setfilters(language_filters.default)
self.lang_segment = lang_segment
self.lyric_tokenizer = VoiceBpeTokenizer()
def load_quantized_checkpoint(self, checkpoint_dir=None):
pass
def get_text_embeddings(self, texts, text_max_length=256):
inputs = self.text_tokenizer(texts)
inputs = {"attention_mask": inputs["attention_mask"], "input_ids": inputs["input_ids"]}
last_hidden_states = self.text_encoder_model(inputs)
attention_mask = inputs["attention_mask"]
return torch.from_numpy(last_hidden_states[0]), torch.from_numpy(attention_mask)
def get_text_embeddings_null(self, texts, text_max_length=256, tau=0.01, l_min=8, l_max=10):
inputs = self.text_tokenizer(texts)
inputs = {"attention_mask": inputs["attention_mask"], "input_ids": inputs["input_ids"]}
last_hidden_states = self.text_encoder_model(inputs)
return torch.from_numpy(last_hidden_states[0])
def text2music_diffusion_process(
self,
duration,
encoder_text_hidden_states,
text_attention_mask,
speaker_embds,
lyric_token_ids,
lyric_mask,
random_generators=None,
infer_steps=60,
guidance_scale=15.0,
omega_scale=10.0,
scheduler_type="euler",
cfg_type="apg",
zero_steps=1,
use_zero_init=True,
guidance_interval=0.5,
guidance_interval_decay=1.0,
min_guidance_scale=3.0,
oss_steps=[],
encoder_text_hidden_states_null=None,
use_erg_lyric=False,
use_erg_diffusion=False,
retake_random_generators=None,
retake_variance=0.5,
add_retake_noise=False,
guidance_scale_text=0.0,
guidance_scale_lyric=0.0,
repaint_start=0,
repaint_end=0,
src_latents=None,
audio2audio_enable=False,
ref_audio_strength=0.5,
ref_latents=None,
):
logger.info("cfg_type: {}, guidance_scale: {}, omega_scale: {}".format(cfg_type, guidance_scale, omega_scale))
do_classifier_free_guidance = True
if guidance_scale == 0.0 or guidance_scale == 1.0:
do_classifier_free_guidance = False
do_double_condition_guidance = False
if guidance_scale_text is not None and guidance_scale_text > 1.0 and guidance_scale_lyric is not None and guidance_scale_lyric > 1.0:
do_double_condition_guidance = True
logger.info(
"do_double_condition_guidance: {}, guidance_scale_text: {}, guidance_scale_lyric: {}".format(
do_double_condition_guidance,
guidance_scale_text,
guidance_scale_lyric,
)
)
bsz = encoder_text_hidden_states.shape[0]
if scheduler_type == "euler":
scheduler = FlowMatchEulerDiscreteScheduler(
num_train_timesteps=1000,
shift=3.0,
)
elif scheduler_type == "heun":
scheduler = FlowMatchHeunDiscreteScheduler(
num_train_timesteps=1000,
shift=3.0,
)
elif scheduler_type == "pingpong":
scheduler = FlowMatchPingPongScheduler(
num_train_timesteps=1000,
shift=3.0,
)
frame_length = int(duration * 44100 / 512 / 8)
if src_latents is not None:
frame_length = src_latents.shape[-1]
if ref_latents is not None:
frame_length = ref_latents.shape[-1]
if len(oss_steps) > 0:
infer_steps = max(oss_steps)
scheduler.set_timesteps
timesteps, num_inference_steps = retrieve_timesteps(
scheduler,
num_inference_steps=infer_steps,
device=self.device,
timesteps=None,
)
new_timesteps = torch.zeros(len(oss_steps), dtype=self.dtype, device=self.device)
for idx in range(len(oss_steps)):
new_timesteps[idx] = timesteps[oss_steps[idx] - 1]
num_inference_steps = len(oss_steps)
sigmas = (new_timesteps / 1000).float().cpu().numpy()
timesteps, num_inference_steps = retrieve_timesteps(
scheduler,
num_inference_steps=num_inference_steps,
device=self.device,
sigmas=sigmas,
)
logger.info(f"oss_steps: {oss_steps}, num_inference_steps: {num_inference_steps} after remapping to timesteps {timesteps}")
else:
timesteps, num_inference_steps = retrieve_timesteps(
scheduler,
num_inference_steps=infer_steps,
device=self.device,
timesteps=None,
)
target_latents = randn_tensor(
shape=(bsz, 8, 16, frame_length),
generator=random_generators,
device=self.device,
dtype=self.dtype,
)
is_repaint = False
is_extend = False
if add_retake_noise:
n_min = int(infer_steps * (1 - retake_variance))
retake_variance = torch.tensor(retake_variance * math.pi / 2).to(self.device).to(self.dtype)
retake_latents = randn_tensor(
shape=(bsz, 8, 16, frame_length),
generator=retake_random_generators,
device=self.device,
dtype=self.dtype,
)
repaint_start_frame = int(repaint_start * 44100 / 512 / 8)
repaint_end_frame = int(repaint_end * 44100 / 512 / 8)
x0 = src_latents
# retake
is_repaint = repaint_end_frame - repaint_start_frame != frame_length
is_extend = (repaint_start_frame < 0) or (repaint_end_frame > frame_length)
if is_extend:
is_repaint = True
# TODO: train a mask aware repainting controlnet
# to make sure mean = 0, std = 1
if not is_repaint:
target_latents = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
elif not is_extend:
# if repaint_end_frame
repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=self.device, dtype=self.dtype)
repaint_mask[:, :, :, repaint_start_frame:repaint_end_frame] = 1.0
repaint_noise = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
repaint_noise = torch.where(repaint_mask == 1.0, repaint_noise, target_latents)
zt_edit = x0.clone()
z0 = repaint_noise
elif is_extend:
to_right_pad_gt_latents = None
to_left_pad_gt_latents = None
gt_latents = src_latents
src_latents_length = gt_latents.shape[-1]
max_infer_fame_length = int(240 * 44100 / 512 / 8)
left_pad_frame_length = 0
right_pad_frame_length = 0
right_trim_length = 0
left_trim_length = 0
if repaint_start_frame < 0:
left_pad_frame_length = abs(repaint_start_frame)
frame_length = left_pad_frame_length + gt_latents.shape[-1]
extend_gt_latents = torch.nn.functional.pad(gt_latents, (left_pad_frame_length, 0), "constant", 0)
if frame_length > max_infer_fame_length:
right_trim_length = frame_length - max_infer_fame_length
extend_gt_latents = extend_gt_latents[:, :, :, :max_infer_fame_length]
to_right_pad_gt_latents = extend_gt_latents[:, :, :, -right_trim_length:]
frame_length = max_infer_fame_length
repaint_start_frame = 0
gt_latents = extend_gt_latents
if repaint_end_frame > src_latents_length:
right_pad_frame_length = repaint_end_frame - gt_latents.shape[-1]
frame_length = gt_latents.shape[-1] + right_pad_frame_length
extend_gt_latents = torch.nn.functional.pad(gt_latents, (0, right_pad_frame_length), "constant", 0)
if frame_length > max_infer_fame_length:
left_trim_length = frame_length - max_infer_fame_length
extend_gt_latents = extend_gt_latents[:, :, :, -max_infer_fame_length:]
to_left_pad_gt_latents = extend_gt_latents[:, :, :, :left_trim_length]
frame_length = max_infer_fame_length
repaint_end_frame = frame_length
gt_latents = extend_gt_latents
repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=self.device, dtype=self.dtype)
if left_pad_frame_length > 0:
repaint_mask[:, :, :, :left_pad_frame_length] = 1.0
if right_pad_frame_length > 0:
repaint_mask[:, :, :, -right_pad_frame_length:] = 1.0
x0 = gt_latents
padd_list = []
if left_pad_frame_length > 0:
padd_list.append(retake_latents[:, :, :, :left_pad_frame_length])
padd_list.append(
target_latents[
:,
:,
:,
left_trim_length : target_latents.shape[-1] - right_trim_length,
]
)
if right_pad_frame_length > 0:
padd_list.append(retake_latents[:, :, :, -right_pad_frame_length:])
target_latents = torch.cat(padd_list, dim=-1)
assert target_latents.shape[-1] == x0.shape[-1], f"{target_latents.shape=} {x0.shape=}"
zt_edit = x0.clone()
z0 = target_latents
if audio2audio_enable and ref_latents is not None:
logger.info(f"audio2audio_enable: {audio2audio_enable}, ref_latents: {ref_latents.shape}")
target_latents, timesteps, scheduler, num_inference_steps = self.add_latents_noise(
gt_latents=ref_latents,
sigma_max=(1 - ref_audio_strength),
noise=target_latents,
scheduler_type=scheduler_type,
infer_steps=infer_steps,
)
attention_mask = torch.ones(bsz, frame_length, device=self.device, dtype=self.dtype)
# guidance interval
start_idx = int(num_inference_steps * ((1 - guidance_interval) / 2))
end_idx = int(num_inference_steps * (guidance_interval / 2 + 0.5))
logger.info(f"start_idx: {start_idx}, end_idx: {end_idx}, num_inference_steps: {num_inference_steps}")
momentum_buffer = MomentumBuffer()
# P(speaker, text, lyric)
encoder_hidden_states, encoder_hidden_mask = self.ace_step_transformer.encode(
encoder_text_hidden_states,
text_attention_mask,
speaker_embds,
lyric_token_ids,
lyric_mask,
)
if use_erg_lyric:
# P(null_speaker, text_weaker, lyric_weaker)
encoder_hidden_states_null, _ = self.ace_step_transformer.encode_with_temperature(
encoder_text_hidden_states=(
encoder_text_hidden_states_null if encoder_text_hidden_states_null is not None else torch.zeros_like(encoder_text_hidden_states)
),
text_attention_mask=text_attention_mask,
speaker_embeds=torch.zeros_like(speaker_embds),
lyric_token_idx=lyric_token_ids,
lyric_mask=lyric_mask,
)
else:
# P(null_speaker, null_text, null_lyric)
encoder_hidden_states_null, _ = self.ace_step_transformer.encode(
torch.zeros_like(encoder_text_hidden_states),
text_attention_mask,
torch.zeros_like(speaker_embds),
torch.zeros_like(lyric_token_ids),
lyric_mask,
)
encoder_hidden_states_no_lyric = None
if do_double_condition_guidance:
# P(null_speaker, text, lyric_weaker)
if use_erg_lyric:
encoder_hidden_states_no_lyric, _ = self.ace_step_transformer.encode_with_temperature(
encoder_text_hidden_states=encoder_text_hidden_states,
text_attention_mask=text_attention_mask,
speaker_embeds=torch.zeros_like(speaker_embds),
lyric_token_idx=lyric_token_ids,
lyric_mask=lyric_mask,
)
# P(null_speaker, text, no_lyric)
else:
encoder_hidden_states_no_lyric, _ = self.ace_step_transformer.encode(
encoder_text_hidden_states,
text_attention_mask,
torch.zeros_like(speaker_embds),
torch.zeros_like(lyric_token_ids),
lyric_mask,
)
for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
if is_repaint:
if i < n_min:
continue
elif i == n_min:
t_i = t / 1000
zt_src = (1 - t_i) * x0 + (t_i) * z0
target_latents = zt_edit + zt_src - x0
logger.info(f"repaint start from {n_min} add {t_i} level of noise")
# expand the latents if we are doing classifier free guidance
latents = target_latents
is_in_guidance_interval = start_idx <= i < end_idx
if is_in_guidance_interval and do_classifier_free_guidance:
# compute current guidance scale
if guidance_interval_decay > 0:
# Linearly interpolate to calculate the current guidance scale
progress = (i - start_idx) / (end_idx - start_idx - 1) # 归一化到[0,1]
current_guidance_scale = guidance_scale - (guidance_scale - min_guidance_scale) * progress * guidance_interval_decay
else:
current_guidance_scale = guidance_scale
latent_model_input = latents
timestep = t.expand(latent_model_input.shape[0])
output_length = latent_model_input.shape[-1]
# P(x|speaker, text, lyric)
noise_pred_with_cond = self.ace_step_transformer.decode(
hidden_states=latent_model_input,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_mask=encoder_hidden_mask,
output_length=output_length,
timestep=timestep,
).sample
noise_pred_with_only_text_cond = None
if do_double_condition_guidance and encoder_hidden_states_no_lyric is not None:
noise_pred_with_only_text_cond = self.ace_step_transformer.decode(
hidden_states=latent_model_input,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states_no_lyric,
encoder_hidden_mask=encoder_hidden_mask,
output_length=output_length,
timestep=timestep,
).sample
if use_erg_diffusion:
noise_pred_uncond = self.ace_step_transformer.decode_with_temperature(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states_null,
encoder_hidden_mask=encoder_hidden_mask,
output_length=output_length,
attention_mask=attention_mask,
)
else:
noise_pred_uncond = self.ace_step_transformer.decode(
hidden_states=latent_model_input,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states_null,
encoder_hidden_mask=encoder_hidden_mask,
output_length=output_length,
timestep=timestep,
).sample
if do_double_condition_guidance and noise_pred_with_only_text_cond is not None:
noise_pred = cfg_double_condition_forward(
cond_output=noise_pred_with_cond,
uncond_output=noise_pred_uncond,
only_text_cond_output=noise_pred_with_only_text_cond,
guidance_scale_text=guidance_scale_text,
guidance_scale_lyric=guidance_scale_lyric,
)
elif cfg_type == "apg":
noise_pred = apg_forward(
pred_cond=noise_pred_with_cond,
pred_uncond=noise_pred_uncond,
guidance_scale=current_guidance_scale,
momentum_buffer=momentum_buffer,
)
elif cfg_type == "cfg":
noise_pred = cfg_forward(
cond_output=noise_pred_with_cond,
uncond_output=noise_pred_uncond,
cfg_strength=current_guidance_scale,
)
elif cfg_type == "cfg_star":
noise_pred = cfg_zero_star(
noise_pred_with_cond=noise_pred_with_cond,
noise_pred_uncond=noise_pred_uncond,
guidance_scale=current_guidance_scale,
i=i,
zero_steps=zero_steps,
use_zero_init=use_zero_init,
)
else:
latent_model_input = latents
timestep = t.expand(latent_model_input.shape[0])
noise_pred = self.ace_step_transformer.decode(
hidden_states=latent_model_input,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_mask=encoder_hidden_mask,
output_length=latent_model_input.shape[-1],
timestep=timestep,
).sample
if is_repaint and i >= n_min:
t_i = t / 1000
if i + 1 < len(timesteps):
t_im1 = (timesteps[i + 1]) / 1000
else:
t_im1 = torch.zeros_like(t_i).to(self.device)
target_latents = target_latents.to(torch.float32)
prev_sample = target_latents + (t_im1 - t_i) * noise_pred
prev_sample = prev_sample.to(self.dtype)
target_latents = prev_sample
zt_src = (1 - t_im1) * x0 + (t_im1) * z0
target_latents = torch.where(repaint_mask == 1.0, target_latents, zt_src)
else:
target_latents = scheduler.step(
model_output=noise_pred,
timestep=t,
sample=target_latents,
return_dict=False,
omega=omega_scale,
generator=random_generators[0],
)[0]
if is_extend:
if to_right_pad_gt_latents is not None:
target_latents = torch.cat([target_latents, to_right_pad_gt_latents], dim=-1)
if to_left_pad_gt_latents is not None:
target_latents = torch.cat([to_right_pad_gt_latents, target_latents], dim=0)
return target_latents
def load_lora(self, model_with_lora_path, device="CPU"):
if model_with_lora_path == "none":
if self.ace_step_transformer_origin:
self.ace_step_transformer = self.ace_step_transformer_origin
else:
self.ace_step_transformer_origin = self.ace_step_transformer
self.update_transformer_model(model_with_lora_path, device)
def update_transformer_model(self, new_transformer_path, device="CPU"):
self.ace_step_transformer = OvWrapperACEStepTransformer2DModel.from_pretrained(self.core, new_transformer_path, device)