Update model.py
Browse files
model.py
CHANGED
|
@@ -43,10 +43,6 @@ class PicoAudio2HF(PreTrainedModel):
|
|
| 43 |
content_encoder = self.build_content_encoder_from_config(config.content_encoder)
|
| 44 |
backbone = self._build_submodule(config.backbone)
|
| 45 |
|
| 46 |
-
state_dict = load_file("model.safetensors")
|
| 47 |
-
new_state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
| 48 |
-
backbone.load_state_dict(new_state_dict, strict=False, assign=True)
|
| 49 |
-
|
| 50 |
self.inner_model = AudioDiffusion(
|
| 51 |
autoencoder=autoencoder,
|
| 52 |
content_encoder=content_encoder,
|
|
@@ -57,6 +53,7 @@ class PicoAudio2HF(PreTrainedModel):
|
|
| 57 |
classifier_free_guidance=config.classifier_free_guidance,
|
| 58 |
cfg_drop_ratio=config.cfg_drop_ratio,
|
| 59 |
)
|
|
|
|
| 60 |
def build_content_encoder_from_config(self, content_encoder_cfg):
|
| 61 |
te_cfg = content_encoder_cfg['text_encoder']
|
| 62 |
te_mod_path, te_cls_name = te_cfg['_target_'].rsplit('.', 1)
|
|
@@ -88,36 +85,9 @@ class PicoAudio2HF(PreTrainedModel):
|
|
| 88 |
module = __import__(module_path, fromlist=[class_name])
|
| 89 |
cls = getattr(module, class_name)
|
| 90 |
obj = cls(**kwargs)
|
| 91 |
-
if "pretrained_ckpt" in sub_config:
|
| 92 |
-
state_dict = torch.load(sub_config["pretrained_ckpt"])
|
| 93 |
-
if "state_dict" in state_dict:
|
| 94 |
-
new_state_dict = state_dict["state_dict"]
|
| 95 |
-
state_dict = {k.replace("autoencoder.", ""): v for k, v in new_state_dict.items()}
|
| 96 |
-
|
| 97 |
-
sig = inspect.signature(obj.load_state_dict)
|
| 98 |
-
if "assign" in sig.parameters:
|
| 99 |
-
result = obj.load_state_dict(state_dict, strict=False, assign=True)
|
| 100 |
-
else:
|
| 101 |
-
result = obj.load_state_dict(state_dict, strict=False)
|
| 102 |
-
|
| 103 |
-
self._check_param_stats(obj, class_name)
|
| 104 |
return obj
|
| 105 |
else:
|
| 106 |
return sub_config
|
| 107 |
-
def _check_weights(self, module, name):
|
| 108 |
-
if hasattr(module, "load_state_dict") and hasattr(module, "state_dict"):
|
| 109 |
-
print(f"[{name}] parameter keys:", list(module.state_dict().keys())[:5], "...")
|
| 110 |
-
for idx, (k, v) in enumerate(module.state_dict().items()):
|
| 111 |
-
print(f"[{name}] {k}: mean={v.float().mean():.5f}, std={v.float().std():.5f}")
|
| 112 |
-
if idx >= 2:
|
| 113 |
-
break
|
| 114 |
-
|
| 115 |
-
def _check_param_stats(self, module, name):
|
| 116 |
-
if hasattr(module, "named_parameters"):
|
| 117 |
-
for idx, (k, v) in enumerate(module.named_parameters()):
|
| 118 |
-
print(f"[{name}] {k}: mean={v.data.float().mean():.5f}, std={v.data.float().std():.5f}")
|
| 119 |
-
if idx >= 2:
|
| 120 |
-
break
|
| 121 |
|
| 122 |
def forward(
|
| 123 |
self,
|
|
@@ -139,13 +109,4 @@ class PicoAudio2HF(PreTrainedModel):
|
|
| 139 |
disable_progress=disable_progress,
|
| 140 |
num_samples_per_content=num_samples_per_content,
|
| 141 |
**kwargs
|
| 142 |
-
)
|
| 143 |
-
|
| 144 |
-
@classmethod
|
| 145 |
-
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 146 |
-
config = PicoAudio2Config.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
| 147 |
-
model = cls(config)
|
| 148 |
-
return model
|
| 149 |
-
|
| 150 |
-
def load_state_dict(self, state_dict, *args, **kwargs):
|
| 151 |
-
pass
|
|
|
|
| 43 |
content_encoder = self.build_content_encoder_from_config(config.content_encoder)
|
| 44 |
backbone = self._build_submodule(config.backbone)
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
self.inner_model = AudioDiffusion(
|
| 47 |
autoencoder=autoencoder,
|
| 48 |
content_encoder=content_encoder,
|
|
|
|
| 53 |
classifier_free_guidance=config.classifier_free_guidance,
|
| 54 |
cfg_drop_ratio=config.cfg_drop_ratio,
|
| 55 |
)
|
| 56 |
+
|
| 57 |
def build_content_encoder_from_config(self, content_encoder_cfg):
|
| 58 |
te_cfg = content_encoder_cfg['text_encoder']
|
| 59 |
te_mod_path, te_cls_name = te_cfg['_target_'].rsplit('.', 1)
|
|
|
|
| 85 |
module = __import__(module_path, fromlist=[class_name])
|
| 86 |
cls = getattr(module, class_name)
|
| 87 |
obj = cls(**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
return obj
|
| 89 |
else:
|
| 90 |
return sub_config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
def forward(
|
| 93 |
self,
|
|
|
|
| 109 |
disable_progress=disable_progress,
|
| 110 |
num_samples_per_content=num_samples_per_content,
|
| 111 |
**kwargs
|
| 112 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|