othsueh commited on
Commit
8a62a4b
·
verified ·
1 Parent(s): c7cfd89

Update modeling_upstream_finetune.py

Browse files
Files changed (1) hide show
  1. modeling_upstream_finetune.py +36 -25
modeling_upstream_finetune.py CHANGED
@@ -3,11 +3,13 @@ import torch.nn as nn
3
  import torch.nn.functional as F
4
  import os
5
  from transformers import PretrainedConfig, PreTrainedModel, AutoProcessor, AutoModel
 
6
 
7
  class UpstreamFinetuneConfig(PretrainedConfig):
8
  model_type = "wav2vec2-emodualhead"
9
  def __init__(
10
  self,
 
11
  upstream_model="wav2vec2-base-960h", # Reference to base model
12
  finetune_layers = 0 , # Prevent overhead gpu usage
13
  hidden_dim = 64,
@@ -17,6 +19,7 @@ class UpstreamFinetuneConfig(PretrainedConfig):
17
  regressor_output_dim=2,
18
  **kwargs
19
  ):
 
20
  self.upstream_model = upstream_model
21
  self.dropout = dropout
22
  self.finetune_layers = finetune_layers
@@ -96,7 +99,10 @@ class UpstreamFinetune(PreTrainedModel):
96
  config_class = UpstreamFinetuneConfig
97
  def __init__(self, config, pretrained_path = None,device = None):
98
  super().__init__(config)
99
- upstream_path = os.path.join(pretrained_path, config.upstream_model)
 
 
 
100
  self.feature_extractor = AutoProcessor.from_pretrained(upstream_path,use_fast=False)
101
  self.upstream = AutoModel.from_pretrained(upstream_path)
102
  self.finetune_layers = config.finetune_layers
@@ -160,31 +166,36 @@ class UpstreamFinetune(PreTrainedModel):
160
 
161
  return category, dim
162
  @classmethod
163
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
164
- # Extract device and pretrained_path before calling parent method
165
  device = kwargs.pop('device', None)
166
  pretrained_path = kwargs.pop('pretrained_path', None)
167
 
168
- # Call the parent class from_pretrained method
169
- model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
 
 
170
 
171
- # If pretrained_path was provided, ensure it's properly set
172
- if pretrained_path is not None and not hasattr(model, 'feature_extractor'):
173
- config = model.config
174
- upstream_path = os.path.join(pretrained_path, config.upstream_model)
175
- model.feature_extractor = AutoProcessor.from_pretrained(upstream_path, use_fast=False)
176
- model.upstream = AutoModel.from_pretrained(upstream_path)
177
-
178
- # Set up finetuning layers
179
- for param in model.upstream.parameters():
180
- param.requires_grad = False
181
-
182
- for i in range(1, model.finetune_layers + 1):
183
- for param in model.upstream.encoder.layers[-i].parameters():
184
- param.requires_grad = True
185
-
186
- # Handle the masked_spec_embed
187
- if hasattr(model.upstream, 'masked_spec_embed'):
188
- model.upstream.masked_spec_embed = nn.Parameter(torch.zeros(model.upstream.config.hidden_size))
189
-
190
- return model
 
 
 
 
3
  import torch.nn.functional as F
4
  import os
5
  from transformers import PretrainedConfig, PreTrainedModel, AutoProcessor, AutoModel
6
+ from safetensors.torch import load_file
7
 
8
  class UpstreamFinetuneConfig(PretrainedConfig):
9
  model_type = "wav2vec2-emodualhead"
10
  def __init__(
11
  self,
12
+ origin_upstream_url = "facebook/wav2vec2-base-960h",
13
  upstream_model="wav2vec2-base-960h", # Reference to base model
14
  finetune_layers = 0 , # Prevent overhead gpu usage
15
  hidden_dim = 64,
 
19
  regressor_output_dim=2,
20
  **kwargs
21
  ):
22
+ self.origin_upstream_url = origin_upstream_url
23
  self.upstream_model = upstream_model
24
  self.dropout = dropout
25
  self.finetune_layers = finetune_layers
 
99
  config_class = UpstreamFinetuneConfig
100
  def __init__(self, config, pretrained_path = None,device = None):
101
  super().__init__(config)
102
+ if pretrained_path is None:
103
+ upstream_path = config.origin_upstream_url
104
+ else:
105
+ upstream_path = os.path.join(pretrained_path, config.upstream_model)
106
  self.feature_extractor = AutoProcessor.from_pretrained(upstream_path,use_fast=False)
107
  self.upstream = AutoModel.from_pretrained(upstream_path)
108
  self.finetune_layers = config.finetune_layers
 
166
 
167
  return category, dim
168
  @classmethod
169
+ def from_pretrained(cls, model_path, pretrained_model_name_or_path = None, *model_args, **kwargs):
170
+ # Extract config and device from kwargs if provided
171
  device = kwargs.pop('device', None)
172
  pretrained_path = kwargs.pop('pretrained_path', None)
173
 
174
+ # Load the configuration
175
+ config = kwargs.pop('config', None)
176
+ if config is None:
177
+ config = cls.config_class.from_pretrained(model_path, **kwargs)
178
 
179
+ # Create model instance with the config
180
+ model = cls(config=config, pretrained_path=pretrained_model_name_or_path, device=device, *model_args, **kwargs)
181
+
182
+ model_bin_path = os.path.join(model_path, "pytorch_model.bin")
183
+ model_safetensors_path = os.path.join(model_path, "model.safetensors")
184
+
185
+ if os.path.exists(model_safetensors_path):
186
+ print(f"Loading model weights from {model_safetensors_path}...")
187
+ state_dict = load_file(model_safetensors_path)
188
+ model.load_state_dict(state_dict)
189
+ elif os.path.exists(model_bin_path):
190
+ print(f"Loading model weights from {model_bin_path}...")
191
+ state_dict = torch.load(model_bin_path, map_location="cpu")
192
+ model.load_state_dict(state_dict)
193
+ else:
194
+ raise FileNotFoundError(f"No model weights found at {model_path}. Expected either 'pytorch_model.bin' or 'model.safetensors'")
195
+
196
+ # Set model to eval mode by default
197
+ model.eval()
198
+
199
+ return model
200
+
201
+