rookie9 commited on
Commit
32aa2ea
·
verified ·
1 Parent(s): 131e947

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +2 -41
model.py CHANGED
@@ -43,10 +43,6 @@ class PicoAudio2HF(PreTrainedModel):
43
  content_encoder = self.build_content_encoder_from_config(config.content_encoder)
44
  backbone = self._build_submodule(config.backbone)
45
 
46
- state_dict = load_file("model.safetensors")
47
- new_state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
48
- backbone.load_state_dict(new_state_dict, strict=False, assign=True)
49
-
50
  self.inner_model = AudioDiffusion(
51
  autoencoder=autoencoder,
52
  content_encoder=content_encoder,
@@ -57,6 +53,7 @@ class PicoAudio2HF(PreTrainedModel):
57
  classifier_free_guidance=config.classifier_free_guidance,
58
  cfg_drop_ratio=config.cfg_drop_ratio,
59
  )
 
60
  def build_content_encoder_from_config(self, content_encoder_cfg):
61
  te_cfg = content_encoder_cfg['text_encoder']
62
  te_mod_path, te_cls_name = te_cfg['_target_'].rsplit('.', 1)
@@ -88,36 +85,9 @@ class PicoAudio2HF(PreTrainedModel):
88
  module = __import__(module_path, fromlist=[class_name])
89
  cls = getattr(module, class_name)
90
  obj = cls(**kwargs)
91
- if "pretrained_ckpt" in sub_config:
92
- state_dict = torch.load(sub_config["pretrained_ckpt"])
93
- if "state_dict" in state_dict:
94
- new_state_dict = state_dict["state_dict"]
95
- state_dict = {k.replace("autoencoder.", ""): v for k, v in new_state_dict.items()}
96
-
97
- sig = inspect.signature(obj.load_state_dict)
98
- if "assign" in sig.parameters:
99
- result = obj.load_state_dict(state_dict, strict=False, assign=True)
100
- else:
101
- result = obj.load_state_dict(state_dict, strict=False)
102
-
103
- self._check_param_stats(obj, class_name)
104
  return obj
105
  else:
106
  return sub_config
107
- def _check_weights(self, module, name):
108
- if hasattr(module, "load_state_dict") and hasattr(module, "state_dict"):
109
- print(f"[{name}] parameter keys:", list(module.state_dict().keys())[:5], "...")
110
- for idx, (k, v) in enumerate(module.state_dict().items()):
111
- print(f"[{name}] {k}: mean={v.float().mean():.5f}, std={v.float().std():.5f}")
112
- if idx >= 2:
113
- break
114
-
115
- def _check_param_stats(self, module, name):
116
- if hasattr(module, "named_parameters"):
117
- for idx, (k, v) in enumerate(module.named_parameters()):
118
- print(f"[{name}] {k}: mean={v.data.float().mean():.5f}, std={v.data.float().std():.5f}")
119
- if idx >= 2:
120
- break
121
 
122
  def forward(
123
  self,
@@ -139,13 +109,4 @@ class PicoAudio2HF(PreTrainedModel):
139
  disable_progress=disable_progress,
140
  num_samples_per_content=num_samples_per_content,
141
  **kwargs
142
- )
143
-
144
- @classmethod
145
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
146
- config = PicoAudio2Config.from_pretrained(pretrained_model_name_or_path, **kwargs)
147
- model = cls(config)
148
- return model
149
-
150
- def load_state_dict(self, state_dict, *args, **kwargs):
151
- pass
 
43
  content_encoder = self.build_content_encoder_from_config(config.content_encoder)
44
  backbone = self._build_submodule(config.backbone)
45
 
 
 
 
 
46
  self.inner_model = AudioDiffusion(
47
  autoencoder=autoencoder,
48
  content_encoder=content_encoder,
 
53
  classifier_free_guidance=config.classifier_free_guidance,
54
  cfg_drop_ratio=config.cfg_drop_ratio,
55
  )
56
+
57
  def build_content_encoder_from_config(self, content_encoder_cfg):
58
  te_cfg = content_encoder_cfg['text_encoder']
59
  te_mod_path, te_cls_name = te_cfg['_target_'].rsplit('.', 1)
 
85
  module = __import__(module_path, fromlist=[class_name])
86
  cls = getattr(module, class_name)
87
  obj = cls(**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  return obj
89
  else:
90
  return sub_config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  def forward(
93
  self,
 
109
  disable_progress=disable_progress,
110
  num_samples_per_content=num_samples_per_content,
111
  **kwargs
112
+ )