PicoAudio2 / model.py
rookie9's picture
Update model.py
32aa2ea verified
import torch
from transformers import PreTrainedModel, PretrainedConfig
import inspect, importlib
from safetensors.torch import load_file
from models.diffusion import AudioDiffusion
class PicoAudio2Config(PretrainedConfig):
model_type = "PicoAudio2"
def __init__(
self,
autoencoder=None,
content_encoder=None,
backbone=None,
frame_resolution: float = 0.005,
noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1",
snr_gamma: float = 5.0,
classifier_free_guidance: bool = True,
cfg_drop_ratio: float = 0.2,
num_steps: int = 50,
guidance_scale: float = 7.5,
**kwargs
):
super().__init__(**kwargs)
self.autoencoder = autoencoder
self.content_encoder = content_encoder
self.backbone = backbone
self.frame_resolution = frame_resolution
self.noise_scheduler_name = noise_scheduler_name
self.snr_gamma = snr_gamma
self.classifier_free_guidance = classifier_free_guidance
self.cfg_drop_ratio = cfg_drop_ratio
self.num_steps = num_steps
self.guidance_scale = guidance_scale
class PicoAudio2HF(PreTrainedModel):
config_class = PicoAudio2Config
def __init__(self, config: PicoAudio2Config):
super().__init__(config)
autoencoder = self._build_submodule(config.autoencoder)
content_encoder = self.build_content_encoder_from_config(config.content_encoder)
backbone = self._build_submodule(config.backbone)
self.inner_model = AudioDiffusion(
autoencoder=autoencoder,
content_encoder=content_encoder,
backbone=backbone,
frame_resolution=config.frame_resolution,
noise_scheduler_name=config.noise_scheduler_name,
snr_gamma=config.snr_gamma,
classifier_free_guidance=config.classifier_free_guidance,
cfg_drop_ratio=config.cfg_drop_ratio,
)
def build_content_encoder_from_config(self, content_encoder_cfg):
te_cfg = content_encoder_cfg['text_encoder']
te_mod_path, te_cls_name = te_cfg['_target_'].rsplit('.', 1)
te_mod = importlib.import_module(te_mod_path)
TextEncoderClass = getattr(te_mod, te_cls_name)
text_encoder = TextEncoderClass(model_name=te_cfg['model_name'])
ce_mod_path, ce_cls_name = content_encoder_cfg['_target_'].rsplit('.', 1)
ce_mod = importlib.import_module(ce_mod_path)
ContentEncoderClass = getattr(ce_mod, ce_cls_name)
content_encoder = ContentEncoderClass(text_encoder=text_encoder)
return content_encoder
def _build_submodule(self, sub_config):
import inspect
if sub_config is None:
return None
if isinstance(sub_config, dict) and "_target_" in sub_config:
kwargs = {}
for k, v in sub_config.items():
if k == "_target_":
continue
if isinstance(v, dict) and "_target_" in v:
kwargs[k] = self._build_submodule(v)
else:
kwargs[k] = v
module_path, class_name = sub_config["_target_"].rsplit(".", 1)
module = __import__(module_path, fromlist=[class_name])
cls = getattr(module, class_name)
obj = cls(**kwargs)
return obj
else:
return sub_config
def forward(
self,
content,
num_steps=None,
guidance_scale=None,
guidance_rescale=0.0,
disable_progress=True,
num_samples_per_content=1,
**kwargs
):
num_steps = num_steps if num_steps is not None else self.config.num_steps
guidance_scale = guidance_scale if guidance_scale is not None else self.config.guidance_scale
return self.inner_model.inference(
content=[content],
num_steps=num_steps,
guidance_scale=guidance_scale,
guidance_rescale=guidance_rescale,
disable_progress=disable_progress,
num_samples_per_content=num_samples_per_content,
**kwargs
)