| import torch | |
| from transformers import PreTrainedModel, PretrainedConfig | |
| import inspect, importlib | |
| from safetensors.torch import load_file | |
| from models.diffusion import AudioDiffusion | |
| class PicoAudio2Config(PretrainedConfig): | |
| model_type = "PicoAudio2" | |
| def __init__( | |
| self, | |
| autoencoder=None, | |
| content_encoder=None, | |
| backbone=None, | |
| frame_resolution: float = 0.005, | |
| noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1", | |
| snr_gamma: float = 5.0, | |
| classifier_free_guidance: bool = True, | |
| cfg_drop_ratio: float = 0.2, | |
| num_steps: int = 50, | |
| guidance_scale: float = 7.5, | |
| **kwargs | |
| ): | |
| super().__init__(**kwargs) | |
| self.autoencoder = autoencoder | |
| self.content_encoder = content_encoder | |
| self.backbone = backbone | |
| self.frame_resolution = frame_resolution | |
| self.noise_scheduler_name = noise_scheduler_name | |
| self.snr_gamma = snr_gamma | |
| self.classifier_free_guidance = classifier_free_guidance | |
| self.cfg_drop_ratio = cfg_drop_ratio | |
| self.num_steps = num_steps | |
| self.guidance_scale = guidance_scale | |
| class PicoAudio2HF(PreTrainedModel): | |
| config_class = PicoAudio2Config | |
| def __init__(self, config: PicoAudio2Config): | |
| super().__init__(config) | |
| autoencoder = self._build_submodule(config.autoencoder) | |
| content_encoder = self.build_content_encoder_from_config(config.content_encoder) | |
| backbone = self._build_submodule(config.backbone) | |
| self.inner_model = AudioDiffusion( | |
| autoencoder=autoencoder, | |
| content_encoder=content_encoder, | |
| backbone=backbone, | |
| frame_resolution=config.frame_resolution, | |
| noise_scheduler_name=config.noise_scheduler_name, | |
| snr_gamma=config.snr_gamma, | |
| classifier_free_guidance=config.classifier_free_guidance, | |
| cfg_drop_ratio=config.cfg_drop_ratio, | |
| ) | |
| def build_content_encoder_from_config(self, content_encoder_cfg): | |
| te_cfg = content_encoder_cfg['text_encoder'] | |
| te_mod_path, te_cls_name = te_cfg['_target_'].rsplit('.', 1) | |
| te_mod = importlib.import_module(te_mod_path) | |
| TextEncoderClass = getattr(te_mod, te_cls_name) | |
| text_encoder = TextEncoderClass(model_name=te_cfg['model_name']) | |
| ce_mod_path, ce_cls_name = content_encoder_cfg['_target_'].rsplit('.', 1) | |
| ce_mod = importlib.import_module(ce_mod_path) | |
| ContentEncoderClass = getattr(ce_mod, ce_cls_name) | |
| content_encoder = ContentEncoderClass(text_encoder=text_encoder) | |
| return content_encoder | |
| def _build_submodule(self, sub_config): | |
| import inspect | |
| if sub_config is None: | |
| return None | |
| if isinstance(sub_config, dict) and "_target_" in sub_config: | |
| kwargs = {} | |
| for k, v in sub_config.items(): | |
| if k == "_target_": | |
| continue | |
| if isinstance(v, dict) and "_target_" in v: | |
| kwargs[k] = self._build_submodule(v) | |
| else: | |
| kwargs[k] = v | |
| module_path, class_name = sub_config["_target_"].rsplit(".", 1) | |
| module = __import__(module_path, fromlist=[class_name]) | |
| cls = getattr(module, class_name) | |
| obj = cls(**kwargs) | |
| return obj | |
| else: | |
| return sub_config | |
| def forward( | |
| self, | |
| content, | |
| num_steps=None, | |
| guidance_scale=None, | |
| guidance_rescale=0.0, | |
| disable_progress=True, | |
| num_samples_per_content=1, | |
| **kwargs | |
| ): | |
| num_steps = num_steps if num_steps is not None else self.config.num_steps | |
| guidance_scale = guidance_scale if guidance_scale is not None else self.config.guidance_scale | |
| return self.inner_model.inference( | |
| content=[content], | |
| num_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| guidance_rescale=guidance_rescale, | |
| disable_progress=disable_progress, | |
| num_samples_per_content=num_samples_per_content, | |
| **kwargs | |
| ) |