Hguimaraes commited on
Commit
8ff90c3
·
verified ·
1 Parent(s): 862e4ad

Upload model

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.1,
3
+ "activation_fn": "gelu",
4
+ "architectures": [
5
+ "RDDistillerModel"
6
+ ],
7
+ "attention_dropout": 0.1,
8
+ "attention_type": "original",
9
+ "auto_map": {
10
+ "AutoConfig": "configuration_distiller.DistillerConfig",
11
+ "AutoModel": "distiller_model.RDDistillerModel"
12
+ },
13
+ "conv_pos": 128,
14
+ "conv_pos_groups": 16,
15
+ "cosine_loss": 1.0,
16
+ "dropout": 0.1,
17
+ "dtype": "float32",
18
+ "encoder_attention_heads": 12,
19
+ "encoder_embed_dim": 768,
20
+ "encoder_ffn_embed_dim": 3072,
21
+ "encoder_layerdrop": 0.0,
22
+ "encoder_layers": 2,
23
+ "extractor_conv_feature_layers": "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
24
+ "extractor_dropout": 0.0,
25
+ "extractor_mode": "default",
26
+ "feat_pen_loss": 0.0,
27
+ "feature_grad_mult": 0.1,
28
+ "final_dim": 768,
29
+ "init_teacher_conv_layers": true,
30
+ "init_teacher_encoder_layers": true,
31
+ "layer_emb_size": 0,
32
+ "layer_norm_first": false,
33
+ "loss_type": "l1",
34
+ "model_type": "rd_distiller",
35
+ "n_tasks": 3,
36
+ "out_layer_inter_dim": -1,
37
+ "out_layer_type": "expand-last",
38
+ "pred_layer_id": [
39
+ 4,
40
+ 8,
41
+ 12
42
+ ],
43
+ "task_emb_size": 0,
44
+ "task_emb_type": "expand-last",
45
+ "transformers_version": "5.1.0"
46
+ }
configuration_distiller.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Distiller Model
3
+ Author: Heng-Jui Chang (https://github.com/vectominist)
4
+ """
5
+ from transformers import PreTrainedConfig
6
+
7
+ class DistillerConfig(PreTrainedConfig):
8
+ """
9
+ Configuration class
10
+ """
11
+ model_type = "rd_distiller"
12
+
13
+ def __init__(
14
+ self,
15
+ extractor_mode: str = "default",
16
+ extractor_conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
17
+ extractor_dropout: float = 0.0,
18
+ feature_grad_mult: float = 1.0,
19
+ conv_pos: int = 128,
20
+ conv_pos_groups: int = 16,
21
+ encoder_layers: int = 1,
22
+ encoder_embed_dim: int = 768,
23
+ encoder_ffn_embed_dim: int = 3072,
24
+ encoder_attention_heads: int = 12,
25
+ activation_fn: str = "gelu",
26
+ layer_norm_first: bool = False,
27
+ attention_type: str = "original",
28
+ dropout: float = 0.1,
29
+ attention_dropout: float = 0.1,
30
+ activation_dropout: float = 0.1,
31
+ encoder_layerdrop: float = 0.0,
32
+ final_dim: int = 768,
33
+ out_layer_type: str = "expand-last",
34
+ out_layer_inter_dim: int = -1,
35
+ n_tasks: int = 12,
36
+ task_emb_type: str = "expand-last",
37
+ task_emb_size: int = 0,
38
+ layer_emb_size: int = 0,
39
+ loss_type: str = "l1",
40
+ feat_pen_loss: float = 0.0,
41
+ cosine_loss: float = 0.0,
42
+ pred_layer_id: list = range(1, 12 + 1),
43
+ init_teacher_conv_layers: bool = False,
44
+ init_teacher_encoder_layers: bool = False,
45
+ **kwargs
46
+ ):
47
+ super().__init__(**kwargs)
48
+
49
+ # Feature extractor
50
+ self.extractor_mode = extractor_mode
51
+ self.extractor_conv_feature_layers = extractor_conv_feature_layers
52
+ self.extractor_dropout = extractor_dropout
53
+ self.feature_grad_mult = feature_grad_mult
54
+
55
+ # Convolutional relative positional encoding
56
+ self.conv_pos = conv_pos
57
+ self.conv_pos_groups = conv_pos_groups
58
+
59
+ # Transformer encoder
60
+ self.encoder_layers = encoder_layers
61
+ self.encoder_embed_dim = encoder_embed_dim
62
+ self.encoder_ffn_embed_dim = encoder_ffn_embed_dim
63
+ self.encoder_attention_heads = encoder_attention_heads
64
+ self.activation_fn = activation_fn
65
+ self.layer_norm_first = layer_norm_first
66
+ self.attention_type = attention_type
67
+
68
+ # Dropout
69
+ self.dropout = dropout
70
+ self.attention_dropout = attention_dropout
71
+ self.activation_dropout = activation_dropout
72
+ self.encoder_layerdrop = encoder_layerdrop
73
+
74
+ # Output
75
+ self.final_dim = final_dim
76
+ self.out_layer_type = out_layer_type
77
+ self.out_layer_inter_dim = out_layer_inter_dim
78
+
79
+ # Task & loss
80
+ self.n_tasks = n_tasks
81
+ self.task_emb_type = task_emb_type
82
+ self.task_emb_size = task_emb_size
83
+ self.layer_emb_size = layer_emb_size
84
+ self.loss_type = loss_type
85
+ self.feat_pen_loss = feat_pen_loss
86
+ self.cosine_loss = cosine_loss
87
+
88
+ # When task_emb_type == 'expand-last' only
89
+ self.pred_layer_id = pred_layer_id
90
+
91
+ # Initialization
92
+ self.init_teacher_conv_layers = init_teacher_conv_layers
93
+ self.init_teacher_encoder_layers = init_teacher_encoder_layers
distiller_model.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from transformers import PreTrainedModel
4
+
5
+ from .modeling_distiller import DistillerModel
6
+ from .configuration_distiller import DistillerConfig
7
+
8
+
9
+ class RDDistillerModel(PreTrainedModel):
10
+ config_class = DistillerConfig
11
+
12
+ def __init__(self, config: DistillerConfig):
13
+ super().__init__(config)
14
+ self.model = DistillerModel(config)
15
+ self.post_init()
16
+
17
+ def prepare_input_data(
18
+ self,
19
+ wavs: torch.Tensor,
20
+ wav_lens: torch.Tensor = None
21
+ ):
22
+ if type(wavs) == list:
23
+ wav_lens = [len(wave) for wave in wavs]
24
+ wavs = pad_sequence(wavs, batch_first=True)
25
+
26
+ elif type(wavs) == torch.Tensor and wav_lens is None:
27
+ wav_lens = [wav.shape[0] for wav in wavs]
28
+
29
+ # add arbitary batch axis B if input `wavs` has shape of T
30
+ if wavs.dim() == 1:
31
+ wavs = wavs.unsqueeze(0)
32
+ elif wavs.dim() > 2:
33
+ raise ValueError
34
+
35
+ batch_size = wavs.shape[0]
36
+ seq_len = wavs.shape[1]
37
+
38
+ pad_mask = np.ones((batch_size, seq_len)) # (batch_size, seq_len)
39
+
40
+ # zero vectors for padding dimension
41
+ for idx in range(batch_size):
42
+ pad_mask[idx, wav_lens[idx] :] = 0
43
+
44
+ wavs = wavs.to(dtype=torch.float32) # (batch_size, seq_len, 1)
45
+ pad_mask = torch.FloatTensor(pad_mask).to(
46
+ device=wavs.device, dtype=torch.float32
47
+ ) # (batch_size, seq_len)
48
+ return wavs, pad_mask # (x, pad_mask)
49
+
50
+ def forward(
51
+ self,
52
+ wavs: torch.Tensor,
53
+ wav_lens: torch.Tensor = None,
54
+ ):
55
+ wavs, pad_mask = self.prepare_input_data(wavs, wav_lens)
56
+ _, feat_final, pred, _, layer_hidden = self.model(
57
+ wavs, pad_mask, get_hidden=True, no_pred=False
58
+ )
59
+
60
+ hidden_feats = pred.transpose(0, 1).split(1, 0)
61
+ hidden_feats = [hid.squeeze(0) for hid in hidden_feats]
62
+ hidden_feats = [feat_final] + layer_hidden + hidden_feats
63
+
64
+ return {
65
+ "last_hidden_state": hidden_feats[-1],
66
+ "hidden_states": hidden_feats,
67
+ }
distiller_modules.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Distiller Modules
3
+ Author: Heng-Jui Chang (https://github.com/vectominist)
4
+ """
5
+
6
+ import math
7
+ import numpy as np
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch import Tensor
12
+ import torch.nn.functional as F
13
+
14
+ from .distiller_w2v2_modules import (
15
+ MultiheadAttention,
16
+ SamePad,
17
+ get_activation_fn,
18
+ )
19
+
20
+
21
+ def init_bert_params(module):
22
+ """
23
+ Initialize the weights specific to the BERT Model.
24
+ This overrides the default initializations depending on the specified arguments.
25
+ 1. If normal_init_linear_weights is set then weights of linear
26
+ layer will be initialized using the normal distribution and
27
+ bais will be set to the specified value.
28
+ 2. If normal_init_embed_weights is set then weights of embedding
29
+ layer will be initialized using the normal distribution.
30
+ 3. If normal_init_proj_weights is set then weights of
31
+ in_project_weight for MultiHeadAttention initialized using
32
+ the normal distribution (to be validated).
33
+ """
34
+
35
+ def normal_(data):
36
+ # FIX: Check if the tensor is on the meta device
37
+ if data.is_meta:
38
+ return # Skip initialization; real weights will be loaded later
39
+
40
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
41
+ # so that the RNG is consistent with and without FSDP
42
+ data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
43
+
44
+ if isinstance(module, nn.Linear):
45
+ normal_(module.weight.data)
46
+ if module.bias is not None:
47
+ module.bias.data.zero_()
48
+ if isinstance(module, nn.Embedding):
49
+ normal_(module.weight.data)
50
+ if module.padding_idx is not None:
51
+ module.weight.data[module.padding_idx].zero_()
52
+ if isinstance(module, MultiheadAttention):
53
+ normal_(module.q_proj.weight.data)
54
+ normal_(module.k_proj.weight.data)
55
+ normal_(module.v_proj.weight.data)
56
+
57
+
58
+ class SplitLinear(nn.Module):
59
+ """Split Linear Layer"""
60
+
61
+ def __init__(self, in_dim, in_split, out_dim):
62
+ super().__init__()
63
+
64
+ self.in_dim = in_dim # Din
65
+ self.in_split = in_split # N
66
+ self.out_dim = out_dim # Dout
67
+
68
+ if in_split > 1:
69
+ # weight = torch.zeros((1, 1, self.in_split, self.in_dim, self.out_dim))
70
+ weight = torch.zeros((self.in_split, self.in_dim, self.out_dim))
71
+ self.weight = nn.Parameter(weight, requires_grad=True)
72
+ nn.init.uniform_(self.weight, -(self.in_dim**-0.5), self.in_dim**-0.5)
73
+
74
+ bias = torch.zeros((1, 1, self.in_split, self.out_dim))
75
+ self.bias = nn.Parameter(bias, requires_grad=True)
76
+ nn.init.uniform_(self.bias, -(self.in_dim**-0.5), self.in_dim**-0.5)
77
+ else:
78
+ self.layer = nn.Linear(self.in_dim, self.out_dim)
79
+
80
+ def forward(self, x: torch.Tensor):
81
+ # x: shape = B x T x NDin
82
+
83
+ if self.in_split == 1:
84
+ return self.layer(x)
85
+ else:
86
+ x = x.reshape(x.shape[0], x.shape[1], self.in_split, 1, self.in_dim)
87
+ # x: B x T x N x 1 x Din
88
+
89
+ out = torch.einsum("...klm,kmn->...kln", x, self.weight).squeeze(3)
90
+ # out: B x T x N x Dout
91
+ out = out + self.bias
92
+
93
+ return out.reshape(x.shape[0], x.shape[1], -1)
94
+
95
+
96
+ class TransformerSentenceEncoderLayer(nn.Module):
97
+ """
98
+ Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
99
+ models.
100
+ """
101
+
102
+ def __init__(
103
+ self,
104
+ embedding_dim: float = 768,
105
+ ffn_embedding_dim: float = 3072,
106
+ num_attention_heads: float = 8,
107
+ dropout: float = 0.1,
108
+ attention_dropout: float = 0.1,
109
+ activation_dropout: float = 0.1,
110
+ activation_fn: str = "relu",
111
+ layer_norm_first: bool = False,
112
+ attention_type: str = "original",
113
+ ) -> None:
114
+ super().__init__()
115
+ # Initialize parameters
116
+ self.embedding_dim = embedding_dim
117
+ self.dropout = dropout
118
+ self.activation_dropout = activation_dropout
119
+
120
+ # Initialize blocks
121
+ self.activation_fn = get_activation_fn(activation_fn)
122
+ self.attention_type = attention_type
123
+ if attention_type == "original":
124
+ self.self_attn = MultiheadAttention(
125
+ self.embedding_dim,
126
+ num_attention_heads,
127
+ dropout=attention_dropout,
128
+ self_attention=True,
129
+ )
130
+ else:
131
+ raise NotImplementedError(f"Unknown attention type {attention_type}")
132
+
133
+ self.dropout1 = nn.Dropout(dropout)
134
+ self.dropout2 = nn.Dropout(self.activation_dropout)
135
+ self.dropout3 = nn.Dropout(dropout)
136
+
137
+ self.layer_norm_first = layer_norm_first
138
+
139
+ # layer norm associated with the self attention layer
140
+ self.self_attn_layer_norm = nn.LayerNorm(self.embedding_dim)
141
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
142
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
143
+
144
+ # layer norm associated with the position wise feed-forward NN
145
+ self.final_layer_norm = nn.LayerNorm(self.embedding_dim)
146
+
147
+ def forward_self_attn(
148
+ self,
149
+ x: torch.Tensor,
150
+ self_attn_mask: torch.Tensor = None,
151
+ self_attn_padding_mask: torch.Tensor = None,
152
+ need_weights: bool = False,
153
+ ):
154
+ if self.attention_type in ["original", "sparse"]:
155
+ x, attn = self.self_attn(
156
+ query=x,
157
+ key=x,
158
+ value=x,
159
+ key_padding_mask=self_attn_padding_mask,
160
+ need_weights=need_weights,
161
+ attn_mask=self_attn_mask,
162
+ )
163
+ elif self.attention_type == "dynamic":
164
+ x = self.self_attn(x)
165
+ attn = None
166
+
167
+ return x, attn
168
+
169
+ def forward(
170
+ self,
171
+ x: torch.Tensor,
172
+ self_attn_mask: torch.Tensor = None,
173
+ self_attn_padding_mask: torch.Tensor = None,
174
+ need_weights: bool = False,
175
+ att_args=None,
176
+ ):
177
+ """
178
+ LayerNorm is applied either before or after the self-attention/ffn
179
+ modules similar to the original Transformer imlementation.
180
+ """
181
+ residual = x
182
+
183
+ if self.layer_norm_first:
184
+ x = self.self_attn_layer_norm(x)
185
+ x, attn = self.forward_self_attn(
186
+ x,
187
+ self_attn_mask=self_attn_mask,
188
+ need_weights=False,
189
+ self_attn_padding_mask=self_attn_padding_mask,
190
+ )
191
+ x = self.dropout1(x)
192
+ x = residual + x
193
+
194
+ residual = x
195
+ x = self.final_layer_norm(x)
196
+ x = self.activation_fn(self.fc1(x))
197
+ x = self.dropout2(x)
198
+ x = self.fc2(x)
199
+ x = self.dropout3(x)
200
+ x = residual + x
201
+ else:
202
+ x, attn = self.forward_self_attn(
203
+ x,
204
+ self_attn_mask=self_attn_mask,
205
+ need_weights=need_weights,
206
+ self_attn_padding_mask=self_attn_padding_mask,
207
+ )
208
+
209
+ x = self.dropout1(x)
210
+ x = residual + x
211
+ x = self.self_attn_layer_norm(x)
212
+
213
+ residual = x
214
+ x = self.activation_fn(self.fc1(x))
215
+ x = self.dropout2(x)
216
+ x = self.fc2(x)
217
+ x = self.dropout3(x)
218
+ x = residual + x
219
+ x = self.final_layer_norm(x)
220
+
221
+ return x, attn
222
+
223
+
224
+ class TransformerEncoder(nn.Module):
225
+ def __init__(self, args):
226
+ super().__init__()
227
+
228
+ self.dropout = args.dropout
229
+ self.embedding_dim = args.encoder_embed_dim
230
+
231
+ self.pos_conv = nn.Conv1d(
232
+ self.embedding_dim,
233
+ self.embedding_dim,
234
+ kernel_size=args.conv_pos,
235
+ padding=args.conv_pos // 2,
236
+ groups=args.conv_pos_groups,
237
+ )
238
+ dropout = 0
239
+ std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
240
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
241
+ nn.init.constant_(self.pos_conv.bias, 0)
242
+
243
+ self.pos_conv = nn.utils.parametrizations.weight_norm(self.pos_conv, name="weight", dim=2)
244
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
245
+
246
+ print(f"[TransformerEncoder] - Attention type = {args.attention_type}")
247
+ self.layers = nn.ModuleList(
248
+ [
249
+ TransformerSentenceEncoderLayer(
250
+ embedding_dim=self.embedding_dim,
251
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
252
+ num_attention_heads=args.encoder_attention_heads,
253
+ dropout=self.dropout,
254
+ attention_dropout=args.attention_dropout,
255
+ activation_dropout=args.activation_dropout,
256
+ activation_fn=args.activation_fn,
257
+ layer_norm_first=args.layer_norm_first,
258
+ attention_type=args.attention_type,
259
+ )
260
+ for _ in range(args.encoder_layers)
261
+ ]
262
+ )
263
+
264
+ self.layer_norm_first = args.layer_norm_first
265
+ self.layer_norm = nn.LayerNorm(self.embedding_dim)
266
+ self.layerdrop = args.encoder_layerdrop
267
+
268
+ self.apply(init_bert_params)
269
+
270
+ def forward(self, x, padding_mask=None, attn_mask=None, get_hidden=False):
271
+ x, layer_results = self.extract_features(
272
+ x, padding_mask, attn_mask, get_hidden=get_hidden
273
+ )
274
+
275
+ if self.layer_norm_first:
276
+ x = self.layer_norm(x)
277
+
278
+ return x, layer_results
279
+
280
+ def extract_features(self, x, padding_mask=None, attn_mask=None, get_hidden=False):
281
+ if padding_mask is not None:
282
+ x[padding_mask] = 0
283
+
284
+ x_conv = self.pos_conv(x.transpose(1, 2))
285
+ x_conv = x_conv.transpose(1, 2)
286
+ x = x + x_conv
287
+
288
+ if not self.layer_norm_first:
289
+ x = self.layer_norm(x)
290
+
291
+ x = F.dropout(x, p=self.dropout, training=self.training)
292
+
293
+ # B x T x C -> T x B x C
294
+ x = x.transpose(0, 1)
295
+
296
+ layer_results = []
297
+ for i, layer in enumerate(self.layers):
298
+ dropout_probability = np.random.random()
299
+ if not self.training or (dropout_probability > self.layerdrop):
300
+ x, z = layer(
301
+ x,
302
+ self_attn_padding_mask=padding_mask,
303
+ need_weights=False,
304
+ self_attn_mask=attn_mask,
305
+ )
306
+ if get_hidden:
307
+ layer_results.append(x.transpose(0, 1))
308
+
309
+ # T x B x C -> B x T x C
310
+ x = x.transpose(0, 1)
311
+
312
+ return x, layer_results
distiller_w2v2_modules.py ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36d8fc11ed208055e4efdf2dcf36d97c30b18964be7b0d27330fb39604d45f4d
3
+ size 108144824
modeling_distiller.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Builder for Distiller
3
+ Author: Heng-Jui Chang (https://github.com/vectominist)
4
+ """
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+ from .configuration_distiller import DistillerConfig
10
+ from .distiller_w2v2_modules import (
11
+ ConvFeatureExtractionModel,
12
+ GradMultiply,
13
+ )
14
+ from .distiller_modules import (
15
+ TransformerEncoder,
16
+ SplitLinear,
17
+ )
18
+
19
+ class DistillerModel(nn.Module):
20
+ """
21
+ Distiller Model
22
+ """
23
+
24
+ def __init__(self, config: DistillerConfig):
25
+ super().__init__()
26
+
27
+ self.config = config
28
+
29
+ self.conv_layers = eval(config.extractor_conv_feature_layers)
30
+ feat_emb_dim = self.conv_layers[-1][0]
31
+ self.feature_extractor = ConvFeatureExtractionModel(
32
+ self.conv_layers,
33
+ dropout=config.extractor_dropout,
34
+ mode=config.extractor_mode,
35
+ conv_bias=False,
36
+ )
37
+ self.feature_grad_mult = config.feature_grad_mult
38
+
39
+ self.n_tasks = config.n_tasks
40
+ self.task_emb_type = config.task_emb_type
41
+
42
+ final_emb_size = config.encoder_embed_dim
43
+ if self.task_emb_type == "add":
44
+ self.task_embedding = nn.Embedding(config.n_tasks, config.encoder_embed_dim)
45
+ nn.init.normal_(self.task_embedding.weight, 0.0, 0.1)
46
+ elif self.task_emb_type == "concat":
47
+ assert config.task_emb_size > 0
48
+ feat_emb_dim += config.task_emb_size
49
+ self.task_embedding = nn.Embedding(config.n_tasks, config.task_emb_size)
50
+ elif self.task_emb_type == "concat-last":
51
+ assert config.task_emb_size > 0
52
+ self.task_embedding = nn.Embedding(config.n_tasks, config.task_emb_size)
53
+ final_emb_size += config.task_emb_size
54
+ elif self.task_emb_type == "expand-last":
55
+ self.pred_layer_id = config.pred_layer_id
56
+ assert self.n_tasks == len(self.pred_layer_id)
57
+ print(
58
+ f"[DistillerModel] - Expands the output dimension by {self.n_tasks} times"
59
+ )
60
+ print(f"[DistillerModel] - Pred layers: {self.pred_layer_id}")
61
+ elif self.task_emb_type == "self-hidden":
62
+ self.pred_layer_id = config.pred_layer_id
63
+ assert self.n_tasks == len(self.pred_layer_id)
64
+ assert self.n_tasks == config.encoder_layers + 1
65
+ print("[DistillerModel] - Predicting with self-hidden layers")
66
+ print(f"[DistillerModel] - Pred layers: {self.pred_layer_id}")
67
+ elif self.task_emb_type == "none":
68
+ print(
69
+ f"[DistillerModel] - Disabled task embedding (predicts only layer {self.n_tasks})"
70
+ )
71
+ else:
72
+ raise NotImplementedError(f"Unknown task emb type {self.task_emb_type}")
73
+
74
+ self.post_extract_proj = (
75
+ nn.Linear(feat_emb_dim, config.encoder_embed_dim)
76
+ if feat_emb_dim != config.encoder_embed_dim
77
+ else None
78
+ )
79
+
80
+ if config.encoder_layers > 0:
81
+ self.encoder = TransformerEncoder(config)
82
+ else:
83
+ self.encoder = nn.GELU()
84
+
85
+ final_dim = config.final_dim * (
86
+ 1 if self.task_emb_type != "expand-last" else self.n_tasks
87
+ )
88
+
89
+ inter_dim = config.out_layer_inter_dim
90
+ inter_dim = inter_dim if inter_dim > 0 else final_emb_size
91
+
92
+ print(f"[DistillerModel] - Out layer type: {config.out_layer_type}")
93
+ if config.out_layer_type == "expand-last":
94
+ assert self.task_emb_type == "expand-last"
95
+ print(f"[DistillerModel] - Inter dim = {inter_dim}")
96
+ self.output_layer = nn.Sequential(
97
+ nn.Linear(final_emb_size, inter_dim * self.n_tasks),
98
+ nn.GELU(),
99
+ SplitLinear(inter_dim, self.n_tasks, config.final_dim),
100
+ )
101
+ elif config.out_layer_type in {"none", "self-hidden"}:
102
+ self.output_layer = None
103
+ else:
104
+ raise NotImplementedError(f"Unknown out layer type {config.out_layer_type}")
105
+
106
+ def forward_feature(self, wave, pad_mask):
107
+ """Forward feature extractor"""
108
+
109
+ if self.feature_grad_mult > 0:
110
+ feat = self.feature_extractor(wave)
111
+ if self.feature_grad_mult != 1.0:
112
+ feat = GradMultiply.apply(feat, self.feature_grad_mult)
113
+ else:
114
+ with torch.no_grad():
115
+ feat = self.feature_extractor(wave)
116
+
117
+ feat = feat.transpose(1, 2) # B x T x D
118
+ pad_mask = self.cal_pad_mask(pad_mask, feat.shape[1])
119
+
120
+ return feat, pad_mask
121
+
122
+ def forward(self, wave, pad_mask, task_id=None, get_hidden=False, no_pred=False):
123
+ """
124
+ Forward function
125
+ Input:
126
+ wave (FloatTensor): B x T_wave
127
+ pad_mask (BoolTensor): B x T_wave
128
+ task_id (LongTensor): N >= 1
129
+ """
130
+
131
+ feat, pad_mask = self.forward_feature(wave, pad_mask)
132
+
133
+ if self.task_emb_type not in ["none", "expand-last", "self-hidden"]:
134
+ if task_id is None:
135
+ task_id = self.generate_task_id(feat.device)
136
+ elif isinstance(task_id, list):
137
+ task_id = torch.LongTensor(task_id).to(feat.device)
138
+ task_embs = self.task_embedding(task_id)
139
+ # N x D
140
+ n_sz = len(task_id)
141
+ else:
142
+ n_sz = 1
143
+ b_sz, t_sz, _ = feat.shape
144
+
145
+ if self.task_emb_type == "add":
146
+ # Add embs to feature
147
+ if self.post_extract_proj is not None:
148
+ feat_final = self.post_extract_proj(feat)
149
+ else:
150
+ feat_final = feat
151
+ feat_final = feat_final.unsqueeze(1) + task_embs.unsqueeze(0).unsqueeze(2)
152
+ elif self.task_emb_type == "concat":
153
+ # Concatenates embs to feature
154
+ feat_final = torch.cat(
155
+ [
156
+ feat.unsqueeze(1).expand(-1, n_sz, -1, -1),
157
+ task_embs.unsqueeze(0).unsqueeze(2).expand(b_sz, -1, t_sz, -1),
158
+ ],
159
+ dim=-1,
160
+ )
161
+ if self.post_extract_proj is not None:
162
+ feat_final = self.post_extract_proj(feat_final)
163
+ else:
164
+ if self.post_extract_proj is not None:
165
+ feat_final = self.post_extract_proj(feat)
166
+ else:
167
+ feat_final = feat
168
+ feat_final = feat_final.unsqueeze(1)
169
+ # feat_final: B x N x T x D or B x 1 x T x D
170
+
171
+ pad_mask = pad_mask.unsqueeze(1).expand(-1, n_sz, -1).reshape(b_sz * n_sz, t_sz)
172
+ # BN x T
173
+ feat_final = feat_final.reshape(b_sz * n_sz, t_sz, -1)
174
+ # BN x T x D
175
+
176
+ layer_hiddens = []
177
+ if self.config.encoder_layers > 0:
178
+ get_hidden_tmp = (
179
+ True if (self.task_emb_type == "self-hidden") else get_hidden
180
+ )
181
+ hidden, layer_hiddens = self.encoder(
182
+ feat_final, ~pad_mask.bool(), get_hidden=get_hidden_tmp
183
+ )
184
+ else:
185
+ hidden = self.encoder(feat_final)
186
+
187
+ if not no_pred:
188
+ if self.task_emb_type == "self-hidden":
189
+ pred = torch.stack([feat_final] + layer_hiddens, dim=1)
190
+ else:
191
+ pred = self.output_layer(hidden).reshape(b_sz, n_sz, t_sz, -1)
192
+ # B x N x T x D
193
+ else:
194
+ pred = None
195
+
196
+ if (not no_pred) and self.task_emb_type == "expand-last":
197
+ assert n_sz == 1, n_sz
198
+ pred = (
199
+ pred.squeeze(1)
200
+ .reshape(b_sz, t_sz, self.n_tasks, -1)
201
+ .permute(0, 2, 1, 3)
202
+ )
203
+ # B x N x T x D
204
+
205
+ if get_hidden:
206
+ return feat, feat_final, pred, pad_mask, layer_hiddens
207
+ else:
208
+ return feat, feat_final, pred, pad_mask
209
+
210
+ def cal_pad_mask(self, pad_mask, max_len):
211
+ """Calculates pad mask after conv."""
212
+ pad_len = (pad_mask > 0).sum(1).long()
213
+ for _, k_size, s_size in self.conv_layers:
214
+ pad_len = (pad_len - k_size) // s_size + 1
215
+
216
+ new_pad_mask = torch.ones(
217
+ (pad_mask.shape[0], max_len), dtype=pad_mask.dtype, device=pad_mask.device
218
+ )
219
+
220
+ for idx in range(pad_len.shape[0]):
221
+ new_pad_mask[idx, pad_len[idx] :] = 0
222
+
223
+ return new_pad_mask
224
+
225
+ def generate_task_id(self, device):
226
+ return torch.arange(self.n_tasks, device=device, dtype=torch.long)