Update modeling_upstream_finetune.py
Browse files- modeling_upstream_finetune.py +36 -25
modeling_upstream_finetune.py
CHANGED
|
@@ -3,11 +3,13 @@ import torch.nn as nn
|
|
| 3 |
import torch.nn.functional as F
|
| 4 |
import os
|
| 5 |
from transformers import PretrainedConfig, PreTrainedModel, AutoProcessor, AutoModel
|
|
|
|
| 6 |
|
| 7 |
class UpstreamFinetuneConfig(PretrainedConfig):
|
| 8 |
model_type = "wav2vec2-emodualhead"
|
| 9 |
def __init__(
|
| 10 |
self,
|
|
|
|
| 11 |
upstream_model="wav2vec2-base-960h", # Reference to base model
|
| 12 |
finetune_layers = 0 , # Prevent overhead gpu usage
|
| 13 |
hidden_dim = 64,
|
|
@@ -17,6 +19,7 @@ class UpstreamFinetuneConfig(PretrainedConfig):
|
|
| 17 |
regressor_output_dim=2,
|
| 18 |
**kwargs
|
| 19 |
):
|
|
|
|
| 20 |
self.upstream_model = upstream_model
|
| 21 |
self.dropout = dropout
|
| 22 |
self.finetune_layers = finetune_layers
|
|
@@ -96,7 +99,10 @@ class UpstreamFinetune(PreTrainedModel):
|
|
| 96 |
config_class = UpstreamFinetuneConfig
|
| 97 |
def __init__(self, config, pretrained_path = None,device = None):
|
| 98 |
super().__init__(config)
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
| 100 |
self.feature_extractor = AutoProcessor.from_pretrained(upstream_path,use_fast=False)
|
| 101 |
self.upstream = AutoModel.from_pretrained(upstream_path)
|
| 102 |
self.finetune_layers = config.finetune_layers
|
|
@@ -160,31 +166,36 @@ class UpstreamFinetune(PreTrainedModel):
|
|
| 160 |
|
| 161 |
return category, dim
|
| 162 |
@classmethod
|
| 163 |
-
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 164 |
-
# Extract
|
| 165 |
device = kwargs.pop('device', None)
|
| 166 |
pretrained_path = kwargs.pop('pretrained_path', None)
|
| 167 |
|
| 168 |
-
#
|
| 169 |
-
|
|
|
|
|
|
|
| 170 |
|
| 171 |
-
#
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import torch.nn.functional as F
|
| 4 |
import os
|
| 5 |
from transformers import PretrainedConfig, PreTrainedModel, AutoProcessor, AutoModel
|
| 6 |
+
from safetensors.torch import load_file
|
| 7 |
|
| 8 |
class UpstreamFinetuneConfig(PretrainedConfig):
|
| 9 |
model_type = "wav2vec2-emodualhead"
|
| 10 |
def __init__(
|
| 11 |
self,
|
| 12 |
+
origin_upstream_url = "facebook/wav2vec2-base-960h",
|
| 13 |
upstream_model="wav2vec2-base-960h", # Reference to base model
|
| 14 |
finetune_layers = 0 , # Prevent overhead gpu usage
|
| 15 |
hidden_dim = 64,
|
|
|
|
| 19 |
regressor_output_dim=2,
|
| 20 |
**kwargs
|
| 21 |
):
|
| 22 |
+
self.origin_upstream_url = origin_upstream_url
|
| 23 |
self.upstream_model = upstream_model
|
| 24 |
self.dropout = dropout
|
| 25 |
self.finetune_layers = finetune_layers
|
|
|
|
| 99 |
config_class = UpstreamFinetuneConfig
|
| 100 |
def __init__(self, config, pretrained_path = None,device = None):
|
| 101 |
super().__init__(config)
|
| 102 |
+
if pretrained_path is None:
|
| 103 |
+
upstream_path = config.origin_upstream_url
|
| 104 |
+
else:
|
| 105 |
+
upstream_path = os.path.join(pretrained_path, config.upstream_model)
|
| 106 |
self.feature_extractor = AutoProcessor.from_pretrained(upstream_path,use_fast=False)
|
| 107 |
self.upstream = AutoModel.from_pretrained(upstream_path)
|
| 108 |
self.finetune_layers = config.finetune_layers
|
|
|
|
| 166 |
|
| 167 |
return category, dim
|
| 168 |
@classmethod
|
| 169 |
+
def from_pretrained(cls, model_path, pretrained_model_name_or_path = None, *model_args, **kwargs):
|
| 170 |
+
# Extract config and device from kwargs if provided
|
| 171 |
device = kwargs.pop('device', None)
|
| 172 |
pretrained_path = kwargs.pop('pretrained_path', None)
|
| 173 |
|
| 174 |
+
# Load the configuration
|
| 175 |
+
config = kwargs.pop('config', None)
|
| 176 |
+
if config is None:
|
| 177 |
+
config = cls.config_class.from_pretrained(model_path, **kwargs)
|
| 178 |
|
| 179 |
+
# Create model instance with the config
|
| 180 |
+
model = cls(config=config, pretrained_path=pretrained_model_name_or_path, device=device, *model_args, **kwargs)
|
| 181 |
+
|
| 182 |
+
model_bin_path = os.path.join(model_path, "pytorch_model.bin")
|
| 183 |
+
model_safetensors_path = os.path.join(model_path, "model.safetensors")
|
| 184 |
+
|
| 185 |
+
if os.path.exists(model_safetensors_path):
|
| 186 |
+
print(f"Loading model weights from {model_safetensors_path}...")
|
| 187 |
+
state_dict = load_file(model_safetensors_path)
|
| 188 |
+
model.load_state_dict(state_dict)
|
| 189 |
+
elif os.path.exists(model_bin_path):
|
| 190 |
+
print(f"Loading model weights from {model_bin_path}...")
|
| 191 |
+
state_dict = torch.load(model_bin_path, map_location="cpu")
|
| 192 |
+
model.load_state_dict(state_dict)
|
| 193 |
+
else:
|
| 194 |
+
raise FileNotFoundError(f"No model weights found at {model_path}. Expected either 'pytorch_model.bin' or 'model.safetensors'")
|
| 195 |
+
|
| 196 |
+
# Set model to eval mode by default
|
| 197 |
+
model.eval()
|
| 198 |
+
|
| 199 |
+
return model
|
| 200 |
+
|
| 201 |
+
|