File size: 4,210 Bytes
79f3e78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32aa2ea
79f3e78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32aa2ea
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
from transformers import PreTrainedModel, PretrainedConfig
import inspect, importlib
from safetensors.torch import load_file
from models.diffusion import AudioDiffusion

class PicoAudio2Config(PretrainedConfig):
    model_type = "PicoAudio2"
    def __init__(
        self,
        autoencoder=None,
        content_encoder=None,
        backbone=None,
        frame_resolution: float = 0.005,
        noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1",
        snr_gamma: float = 5.0,
        classifier_free_guidance: bool = True,
        cfg_drop_ratio: float = 0.2,
        num_steps: int = 50,
        guidance_scale: float = 7.5,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.autoencoder = autoencoder
        self.content_encoder = content_encoder
        self.backbone = backbone
        self.frame_resolution = frame_resolution
        self.noise_scheduler_name = noise_scheduler_name
        self.snr_gamma = snr_gamma
        self.classifier_free_guidance = classifier_free_guidance
        self.cfg_drop_ratio = cfg_drop_ratio
        self.num_steps = num_steps
        self.guidance_scale = guidance_scale


class PicoAudio2HF(PreTrainedModel):
    config_class = PicoAudio2Config

    def __init__(self, config: PicoAudio2Config):
        super().__init__(config)
        
        autoencoder = self._build_submodule(config.autoencoder)
        content_encoder = self.build_content_encoder_from_config(config.content_encoder)
        backbone = self._build_submodule(config.backbone)
        
        self.inner_model = AudioDiffusion(
            autoencoder=autoencoder,
            content_encoder=content_encoder,
            backbone=backbone,
            frame_resolution=config.frame_resolution,
            noise_scheduler_name=config.noise_scheduler_name,
            snr_gamma=config.snr_gamma,
            classifier_free_guidance=config.classifier_free_guidance,
            cfg_drop_ratio=config.cfg_drop_ratio,
        )
    
    def build_content_encoder_from_config(self, content_encoder_cfg):
        te_cfg = content_encoder_cfg['text_encoder']
        te_mod_path, te_cls_name = te_cfg['_target_'].rsplit('.', 1)
        te_mod = importlib.import_module(te_mod_path)
        TextEncoderClass = getattr(te_mod, te_cls_name)
        text_encoder = TextEncoderClass(model_name=te_cfg['model_name'])

        ce_mod_path, ce_cls_name = content_encoder_cfg['_target_'].rsplit('.', 1)
        ce_mod = importlib.import_module(ce_mod_path)
        ContentEncoderClass = getattr(ce_mod, ce_cls_name)
        content_encoder = ContentEncoderClass(text_encoder=text_encoder)

        return content_encoder

    def _build_submodule(self, sub_config):
        import inspect
        if sub_config is None:
            return None
        if isinstance(sub_config, dict) and "_target_" in sub_config:
            kwargs = {}
            for k, v in sub_config.items():
                if k == "_target_":
                    continue
                if isinstance(v, dict) and "_target_" in v:
                    kwargs[k] = self._build_submodule(v)
                else:
                    kwargs[k] = v
            module_path, class_name = sub_config["_target_"].rsplit(".", 1)
            module = __import__(module_path, fromlist=[class_name])
            cls = getattr(module, class_name)
            obj = cls(**kwargs)
            return obj
        else:
            return sub_config

    def forward(
        self,
        content,
        num_steps=None,
        guidance_scale=None,
        guidance_rescale=0.0,
        disable_progress=True,
        num_samples_per_content=1,
        **kwargs
    ):
        num_steps = num_steps if num_steps is not None else self.config.num_steps
        guidance_scale = guidance_scale if guidance_scale is not None else self.config.guidance_scale
        return self.inner_model.inference(
            content=[content],
            num_steps=num_steps,
            guidance_scale=guidance_scale,
            guidance_rescale=guidance_rescale,
            disable_progress=disable_progress,
            num_samples_per_content=num_samples_per_content,
            **kwargs
        )