othsueh commited on
Commit
7db9424
·
verified ·
1 Parent(s): f0aa0a3

Create modeling_upstream_finetune.py

Browse files
Files changed (1) hide show
  1. modeling_upstream_finetune.py +179 -0
modeling_upstream_finetune.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import os
5
+ from transformers import PretrainedConfig, PreTrainedModel, AutoProcessor, AutoModel
6
+ from safetensors.torch import load_file
7
+
8
+ class UpstreamFinetuneConfig(PretrainedConfig):
9
+ model_type = "wav2vec2-emodualhead"
10
+ def __init__(
11
+ self,
12
+ origin_upstream_url = "facebook/wav2vec2-base",
13
+ upstream_model="wav2vec2-base", # Reference to base model
14
+ finetune_layers = 0 , # Prevent overhead gpu usage
15
+ hidden_dim = 64,
16
+ dropout=0.2,
17
+ num_layers=2,
18
+ classifier_output_dim=8,
19
+ regressor_output_dim=2,
20
+ **kwargs
21
+ ):
22
+ self.origin_upstream_url = origin_upstream_url
23
+ self.upstream_model = upstream_model
24
+ self.dropout = dropout
25
+ self.finetune_layers = finetune_layers
26
+ self.num_layers = num_layers
27
+ self.hidden_dim = hidden_dim
28
+ self.classifier_output_dim = classifier_output_dim
29
+ self.regressor_output_dim = regressor_output_dim
30
+ super().__init__(**kwargs)
31
+
32
+
33
+ class ClassificationHead(nn.Module):
34
+ def __init__(self, first_dim, hidden_dim, dropout, num_layers, num_labels):
35
+ super().__init__()
36
+ self.hidden_layers = nn.Sequential(*[
37
+ layer for i in range(num_layers)
38
+ for layer in (nn.Linear(first_dim if i == 0 else hidden_dim, hidden_dim), nn.Tanh(), nn.Dropout(dropout))
39
+ ])
40
+ self.out_proj = nn.Linear(hidden_dim, num_labels)
41
+ self.embedding_dim = hidden_dim
42
+
43
+ def forward(self, x, return_embedding=False):
44
+ embedding = self.hidden_layers(x)
45
+ output = self.out_proj(embedding)
46
+ return (output, embedding) if return_embedding else output
47
+
48
+ class HierarchicalDCRegressionHead(nn.Module):
49
+ def __init__(self, classifier_embed_dim, cont_embed_dim, dropout, min_score=0.0, max_score=1.0):
50
+ super().__init__()
51
+ self.min_score = min_score
52
+ self.max_score = max_score
53
+ self.fusion_layer = nn.Sequential(
54
+ nn.Linear(classifier_embed_dim + cont_embed_dim, cont_embed_dim),
55
+ nn.Tanh(),
56
+ nn.Dropout(dropout),
57
+ nn.Linear(cont_embed_dim, 2)
58
+ )
59
+
60
+ def forward(self, ed, ec):
61
+ x = torch.cat([ed, ec], dim=-1)
62
+ out = self.fusion_layer(x)
63
+ return torch.sigmoid(out) * (self.max_score - self.min_score) + self.min_score
64
+
65
+
66
+ class UpstreamFinetune(PreTrainedModel):
67
+ config_class = UpstreamFinetuneConfig
68
+ def __init__(self, config, pretrained_path = None,device = None):
69
+ super().__init__(config)
70
+ if pretrained_path is None:
71
+ upstream_path = config.origin_upstream_url
72
+ else:
73
+ upstream_path = os.path.join(pretrained_path, config.upstream_model)
74
+ self.feature_extractor = AutoProcessor.from_pretrained(upstream_path,use_fast=False)
75
+ self.upstream = AutoModel.from_pretrained(upstream_path)
76
+ self.finetune_layers = config.finetune_layers
77
+
78
+ # Comment out for wav2vec2 base
79
+ # Explicitly initialize the masked_spec_embed parameter if it's causing issues
80
+ # if hasattr(self.upstream, 'masked_spec_embed'):
81
+ # self.upstream.masked_spec_embed = nn.Parameter(torch.zeros(self.upstream.config.hidden_size))
82
+
83
+ for param in self.upstream.parameters():
84
+ param.requires_grad = False
85
+
86
+ for i in range(1, self.finetune_layers + 1):
87
+ for param in self.upstream.encoder.layers[-i].parameters():
88
+ param.requires_grad = True
89
+
90
+ input_dim = self.upstream.config.hidden_size
91
+ self.classifier = ClassificationHead(input_dim, config.hidden_dim, config.dropout, config.num_layers, config.classifier_output_dim)
92
+ self.cont_proj = nn.Sequential(
93
+ nn.Linear(input_dim, config.hidden_dim),
94
+ nn.Tanh(),
95
+ nn.Dropout(config.dropout)
96
+ )
97
+ self.regressor = HierarchicalDCRegressionHead(
98
+ classifier_embed_dim=config.hidden_dim,
99
+ cont_embed_dim=config.hidden_dim,
100
+ dropout=config.dropout
101
+ )
102
+ self.to(device)
103
+
104
+ def forward(self, x, sr):
105
+ with torch.no_grad():
106
+ # Extract features from upstream model
107
+ features = self.feature_extractor(x, sampling_rate=sr, return_tensors='pt', padding=True).input_values
108
+ features = features.squeeze(0).squeeze(1)
109
+ features = features.cuda()
110
+
111
+ if torch.isnan(features).any():
112
+ print("Warning: NaN detected in features")
113
+ features = torch.nan_to_num(features, nan=0.0)
114
+
115
+ # Process through upstream model
116
+ outputs = self.upstream(features)
117
+ hidden_states = outputs.last_hidden_state
118
+
119
+ # For using multiple hidden states
120
+ # upstream_hidden_state = self.upstream(features,output_hidden_states=True).hidden_states
121
+ # upstream_hidden_state = torch.stack(upstream_hidden_state[-1:])
122
+ # upstream_hidden_state = torch.mean(upstream_hidden_state, dim=0)
123
+
124
+ # DEBUG field
125
+ if torch.isnan(hidden_states).any():
126
+ print("Warning: NaN detected in hidden state")
127
+ hidden_states = torch.nan_to_num(hidden_states, nan=0.0)
128
+
129
+ # Global average pooling over the sequence length
130
+ pooled_features = torch.mean(hidden_states, dim=1)
131
+
132
+ # DEBUG field
133
+ if torch.isnan(pooled_features).any():
134
+ print("Warning: NaN detected in pooled features")
135
+
136
+ # Pass through classifier
137
+ # Get discrete output and embedding
138
+ category, ed = self.classifier(pooled_features, return_embedding=True)
139
+
140
+ # Get continuous embedding
141
+ ec = self.cont_proj(pooled_features)
142
+
143
+ # Use ED and EC to predict continuous values
144
+ dim = self.regressor(ed, ec)
145
+
146
+ return category, dim
147
+
148
+ @classmethod
149
+ def from_pretrained(cls, model_path, pretrained_model_name_or_path = None, *model_args, **kwargs):
150
+ # Extract config and device from kwargs if provided
151
+ device = kwargs.pop('device', None)
152
+ pretrained_path = kwargs.pop('pretrained_path', None)
153
+
154
+ # Load the configuration
155
+ config = kwargs.pop('config', None)
156
+ if config is None:
157
+ config = cls.config_class.from_pretrained(model_path, **kwargs)
158
+
159
+ # Create model instance with the config
160
+ model = cls(config=config, pretrained_path=pretrained_model_name_or_path, device=device, *model_args, **kwargs)
161
+
162
+ model_bin_path = os.path.join(model_path, "pytorch_model.bin")
163
+ model_safetensors_path = os.path.join(model_path, "model.safetensors")
164
+
165
+ if os.path.exists(model_safetensors_path):
166
+ print(f"Loading model weights from {model_safetensors_path}...")
167
+ state_dict = load_file(model_safetensors_path)
168
+ model.load_state_dict(state_dict)
169
+ elif os.path.exists(model_bin_path):
170
+ print(f"Loading model weights from {model_bin_path}...")
171
+ state_dict = torch.load(model_bin_path, map_location="cpu")
172
+ model.load_state_dict(state_dict)
173
+ else:
174
+ raise FileNotFoundError(f"No model weights found at {model_path}. Expected either 'pytorch_model.bin' or 'model.safetensors'")
175
+
176
+ # Set model to eval mode by default
177
+ model.eval()
178
+
179
+ return model