kyboface commited on
Commit
7af065d
·
verified ·
1 Parent(s): fdee621

Upload 5 files

Browse files
src/audio_analysis/torch_utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def get_mask_from_lengths(lengths, max_len=None):
6
+ lengths = lengths.to(torch.long)
7
+ if max_len is None:
8
+ max_len = torch.max(lengths).item()
9
+
10
+ ids = torch.arange(0, max_len).unsqueeze(0).expand(lengths.shape[0], -1).to(lengths.device)
11
+ mask = ids < lengths.unsqueeze(1).expand(-1, max_len)
12
+
13
+ return mask
14
+
15
+
16
+ def linear_interpolation(features, seq_len):
17
+ features = features.transpose(1, 2)
18
+ output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
19
+ return output_features.transpose(1, 2)
20
+
src/audio_analysis/wav2vec2.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Wav2Vec2Config, Wav2Vec2Model
2
+ from transformers.modeling_outputs import BaseModelOutput
3
+
4
+ from src.audio_analysis.torch_utils import linear_interpolation
5
+
6
+ # the implementation of Wav2Vec2Model is borrowed from
7
+ # https://github.com/huggingface/transformers/blob/HEAD/src/transformers/models/wav2vec2/modeling_wav2vec2.py
8
+ # initialize our encoder with the pre-trained wav2vec 2.0 weights.
9
+ class Wav2Vec2Model(Wav2Vec2Model):
10
+ def __init__(self, config: Wav2Vec2Config):
11
+ super().__init__(config)
12
+
13
+ def forward(
14
+ self,
15
+ input_values,
16
+ seq_len,
17
+ attention_mask=None,
18
+ mask_time_indices=None,
19
+ output_attentions=None,
20
+ output_hidden_states=None,
21
+ return_dict=None,
22
+ ):
23
+ self.config.output_attentions = True
24
+
25
+ output_hidden_states = (
26
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
27
+ )
28
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
29
+
30
+ extract_features = self.feature_extractor(input_values)
31
+ extract_features = extract_features.transpose(1, 2)
32
+ extract_features = linear_interpolation(extract_features, seq_len=seq_len)
33
+
34
+ if attention_mask is not None:
35
+ # compute reduced attention_mask corresponding to feature vectors
36
+ attention_mask = self._get_feature_vector_attention_mask(
37
+ extract_features.shape[1], attention_mask, add_adapter=False
38
+ )
39
+
40
+ hidden_states, extract_features = self.feature_projection(extract_features)
41
+ hidden_states = self._mask_hidden_states(
42
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
43
+ )
44
+
45
+ encoder_outputs = self.encoder(
46
+ hidden_states,
47
+ attention_mask=attention_mask,
48
+ output_attentions=output_attentions,
49
+ output_hidden_states=output_hidden_states,
50
+ return_dict=return_dict,
51
+ )
52
+
53
+ hidden_states = encoder_outputs[0]
54
+
55
+ if self.adapter is not None:
56
+ hidden_states = self.adapter(hidden_states)
57
+
58
+ if not return_dict:
59
+ return (hidden_states, ) + encoder_outputs[1:]
60
+ return BaseModelOutput(
61
+ last_hidden_state=hidden_states,
62
+ hidden_states=encoder_outputs.hidden_states,
63
+ attentions=encoder_outputs.attentions,
64
+ )
65
+
66
+
67
+ def feature_extract(
68
+ self,
69
+ input_values,
70
+ seq_len,
71
+ ):
72
+ extract_features = self.feature_extractor(input_values)
73
+ extract_features = extract_features.transpose(1, 2)
74
+ extract_features = linear_interpolation(extract_features, seq_len=seq_len)
75
+
76
+ return extract_features
77
+
78
+ def encode(
79
+ self,
80
+ extract_features,
81
+ attention_mask=None,
82
+ mask_time_indices=None,
83
+ output_attentions=None,
84
+ output_hidden_states=None,
85
+ return_dict=None,
86
+ ):
87
+ self.config.output_attentions = True
88
+
89
+ output_hidden_states = (
90
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
91
+ )
92
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
93
+
94
+ if attention_mask is not None:
95
+ # compute reduced attention_mask corresponding to feature vectors
96
+ attention_mask = self._get_feature_vector_attention_mask(
97
+ extract_features.shape[1], attention_mask, add_adapter=False
98
+ )
99
+
100
+
101
+ hidden_states, extract_features = self.feature_projection(extract_features)
102
+ hidden_states = self._mask_hidden_states(
103
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
104
+ )
105
+
106
+ encoder_outputs = self.encoder(
107
+ hidden_states,
108
+ attention_mask=attention_mask,
109
+ output_attentions=output_attentions,
110
+ output_hidden_states=output_hidden_states,
111
+ return_dict=return_dict,
112
+ )
113
+
114
+ hidden_states = encoder_outputs[0]
115
+
116
+ if self.adapter is not None:
117
+ hidden_states = self.adapter(hidden_states)
118
+
119
+ if not return_dict:
120
+ return (hidden_states, ) + encoder_outputs[1:]
121
+ return BaseModelOutput(
122
+ last_hidden_state=hidden_states,
123
+ hidden_states=encoder_outputs.hidden_states,
124
+ attentions=encoder_outputs.attentions,
125
+ )
src/utils.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+
3
+ import torch
4
+
5
+ @contextmanager
6
+ def init_weights_on_device(device=torch.device("meta"), include_buffers: bool = False):
7
+ old_register_parameter = torch.nn.Module.register_parameter
8
+ if include_buffers:
9
+ old_register_buffer = torch.nn.Module.register_buffer
10
+
11
+ def register_empty_parameter(module, name, param):
12
+ old_register_parameter(module, name, param)
13
+ if param is not None:
14
+ param_cls = type(module._parameters[name])
15
+ kwargs = module._parameters[name].__dict__
16
+ kwargs["requires_grad"] = param.requires_grad
17
+ module._parameters[name] = param_cls(
18
+ module._parameters[name].to(device), **kwargs
19
+ )
20
+
21
+ def register_empty_buffer(module, name, buffer, persistent=True):
22
+ old_register_buffer(module, name, buffer, persistent=persistent)
23
+ if buffer is not None:
24
+ module._buffers[name] = module._buffers[name].to(device)
25
+
26
+ def patch_tensor_constructor(fn):
27
+ def wrapper(*args, **kwargs):
28
+ kwargs["device"] = device
29
+ return fn(*args, **kwargs)
30
+
31
+ return wrapper
32
+
33
+ if include_buffers:
34
+ tensor_constructors_to_patch = {
35
+ torch_function_name: getattr(torch, torch_function_name)
36
+ for torch_function_name in ["empty", "zeros", "ones", "full"]
37
+ }
38
+ else:
39
+ tensor_constructors_to_patch = {}
40
+
41
+ try:
42
+ torch.nn.Module.register_parameter = register_empty_parameter
43
+ if include_buffers:
44
+ torch.nn.Module.register_buffer = register_empty_buffer
45
+ for torch_function_name in tensor_constructors_to_patch.keys():
46
+ setattr(
47
+ torch,
48
+ torch_function_name,
49
+ patch_tensor_constructor(getattr(torch, torch_function_name)),
50
+ )
51
+ yield
52
+ finally:
53
+ torch.nn.Module.register_parameter = old_register_parameter
54
+ if include_buffers:
55
+ torch.nn.Module.register_buffer = old_register_buffer
56
+ for (
57
+ torch_function_name,
58
+ old_torch_function,
59
+ ) in tensor_constructors_to_patch.items():
60
+ setattr(torch, torch_function_name, old_torch_function)
src/vram_management/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .layers import *
src/vram_management/layers.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ import torch
4
+
5
+ from src.utils import init_weights_on_device
6
+ import optimum.quanto.nn.qlinear as qlinear
7
+
8
+ def cast_to(weight, dtype, device):
9
+ r = torch.empty_like(weight, dtype=dtype, device=device)
10
+ r.copy_(weight)
11
+ return r
12
+
13
+ def cast_to_device(weight, device):
14
+ if hasattr(weight, '__class__') and 'optimum.quanto' in str(weight.__class__):
15
+ return weight.to(device)
16
+ else:
17
+ r = torch.empty_like(weight, device=device)
18
+ r.copy_(weight)
19
+ return r
20
+
21
+ class AutoWrappedModule(torch.nn.Module):
22
+ def __init__(
23
+ self,
24
+ module: torch.nn.Module,
25
+ offload_dtype,
26
+ offload_device,
27
+ onload_dtype,
28
+ onload_device,
29
+ computation_dtype,
30
+ computation_device,
31
+ ):
32
+ super().__init__()
33
+ self.module = module.to(dtype=offload_dtype, device=offload_device)
34
+ self.offload_dtype = offload_dtype
35
+ self.offload_device = offload_device
36
+ self.onload_dtype = onload_dtype
37
+ self.onload_device = onload_device
38
+ self.computation_dtype = computation_dtype
39
+ self.computation_device = computation_device
40
+ self.state = 0
41
+
42
+ def offload(self):
43
+ if self.state == 1 and (
44
+ self.offload_dtype != self.onload_dtype
45
+ or self.offload_device != self.onload_device
46
+ ):
47
+ self.module.to(dtype=self.offload_dtype, device=self.offload_device)
48
+ self.state = 0
49
+
50
+ def onload(self):
51
+ if self.state == 0 and (
52
+ self.offload_dtype != self.onload_dtype
53
+ or self.offload_device != self.onload_device
54
+ ):
55
+ self.module.to(dtype=self.onload_dtype, device=self.onload_device)
56
+ self.state = 1
57
+
58
+ def forward(self, *args, **kwargs):
59
+ if (
60
+ self.onload_dtype == self.computation_dtype
61
+ and self.onload_device == self.computation_device
62
+ ):
63
+ module = self.module
64
+ else:
65
+ module = copy.deepcopy(self.module).to(
66
+ dtype=self.computation_dtype, device=self.computation_device
67
+ )
68
+ return module(*args, **kwargs)
69
+
70
+
71
+
72
+ class AutoWrappedQLinear(qlinear.QLinear):
73
+ def __init__(
74
+ self,
75
+ module: qlinear.QLinear,
76
+ offload_dtype,
77
+ offload_device,
78
+ onload_dtype,
79
+ onload_device,
80
+ computation_dtype,
81
+ computation_device,
82
+ ):
83
+ with init_weights_on_device(device=torch.device("meta")):
84
+ super().__init__(
85
+ in_features=module.in_features,
86
+ out_features=module.out_features,
87
+ bias=module.bias is not None,
88
+ device=offload_device,
89
+ )
90
+ self.weight = module.weight
91
+ self.bias = module.bias
92
+ self.offload_device = offload_device
93
+
94
+ self.onload_device = onload_device
95
+ self.computation_device = computation_device
96
+ self.state = 0
97
+
98
+ def offload(self):
99
+ if self.state == 1 and (
100
+ self.offload_device != self.onload_device
101
+ ):
102
+ self.to(device=self.offload_device)
103
+ self.state = 0
104
+
105
+ def onload(self):
106
+ if self.state == 0 and (
107
+ self.offload_device != self.onload_device
108
+ ):
109
+ self.to(device=self.onload_device)
110
+ self.state = 1
111
+
112
+ def forward(self, x, *args, **kwargs):
113
+ if (
114
+ self.onload_device == self.computation_device
115
+ ):
116
+
117
+ return torch.nn.functional.linear(x, self.weight, bias=self.bias)
118
+ else:
119
+
120
+ qweight = cast_to_device(self.weight, self.computation_device)
121
+ bias = (
122
+ None
123
+ if self.bias is None
124
+ else cast_to_device(self.bias, self.computation_device)
125
+ )
126
+ return torch.nn.functional.linear(x, qweight, bias)
127
+
128
+ class AutoWrappedLinear(torch.nn.Linear):
129
+ def __init__(
130
+ self,
131
+ module: torch.nn.Linear,
132
+ offload_dtype,
133
+ offload_device,
134
+ onload_dtype,
135
+ onload_device,
136
+ computation_dtype,
137
+ computation_device,
138
+ ):
139
+ with init_weights_on_device(device=torch.device("meta")):
140
+ super().__init__(
141
+ in_features=module.in_features,
142
+ out_features=module.out_features,
143
+ bias=module.bias is not None,
144
+ dtype=offload_dtype,
145
+ device=offload_device,
146
+ )
147
+ self.weight = module.weight
148
+ self.bias = module.bias
149
+ self.offload_dtype = offload_dtype
150
+ self.offload_device = offload_device
151
+ self.onload_dtype = onload_dtype
152
+ self.onload_device = onload_device
153
+ self.computation_dtype = computation_dtype
154
+ self.computation_device = computation_device
155
+ self.state = 0
156
+
157
+ def offload(self):
158
+ if self.state == 1 and (
159
+ self.offload_dtype != self.onload_dtype
160
+ or self.offload_device != self.onload_device
161
+ ):
162
+ self.to(dtype=self.offload_dtype, device=self.offload_device)
163
+ self.state = 0
164
+
165
+ def onload(self):
166
+ if self.state == 0 and (
167
+ self.offload_dtype != self.onload_dtype
168
+ or self.offload_device != self.onload_device
169
+ ):
170
+ self.to(dtype=self.onload_dtype, device=self.onload_device)
171
+ self.state = 1
172
+
173
+ def forward(self, x, *args, **kwargs):
174
+ if (
175
+ self.onload_dtype == self.computation_dtype
176
+ and self.onload_device == self.computation_device
177
+ ):
178
+ weight, bias = self.weight, self.bias
179
+ else:
180
+ weight = cast_to(
181
+ self.weight, self.computation_dtype, self.computation_device
182
+ )
183
+ bias = (
184
+ None
185
+ if self.bias is None
186
+ else cast_to(self.bias, self.computation_dtype, self.computation_device)
187
+ )
188
+ return torch.nn.functional.linear(x, weight, bias)
189
+
190
+
191
+ def enable_vram_management_recursively(
192
+ model: torch.nn.Module,
193
+ module_map: dict,
194
+ module_config: dict,
195
+ max_num_param=None,
196
+ overflow_module_config: dict = None,
197
+ total_num_param=0,
198
+ ):
199
+ for name, module in model.named_children():
200
+ for source_module, target_module in module_map.items():
201
+ if isinstance(module, source_module):
202
+ num_param = sum(p.numel() for p in module.parameters())
203
+ # print(str(module) + ':' + str(num_param))
204
+ if (
205
+ max_num_param is not None
206
+ and total_num_param + num_param > max_num_param
207
+ ):
208
+ # print(str(module) + '-->\t\t num:' + str(num_param) + "\t total:" + str(total_num_param))
209
+ module_config_ = overflow_module_config
210
+ else:
211
+ module_config_ = module_config
212
+ module_ = target_module(module, **module_config_)
213
+ setattr(model, name, module_)
214
+ total_num_param += num_param
215
+ break
216
+ else:
217
+ total_num_param = enable_vram_management_recursively(
218
+ module,
219
+ module_map,
220
+ module_config,
221
+ max_num_param,
222
+ overflow_module_config,
223
+ total_num_param,
224
+ )
225
+ return total_num_param
226
+
227
+
228
+ def enable_vram_management(
229
+ model: torch.nn.Module,
230
+ module_map: dict,
231
+ module_config: dict,
232
+ max_num_param=None,
233
+ overflow_module_config: dict = None,
234
+ ):
235
+ enable_vram_management_recursively(
236
+ model,
237
+ module_map,
238
+ module_config,
239
+ max_num_param,
240
+ overflow_module_config,
241
+ total_num_param=0,
242
+ )
243
+ model.vram_management_enabled = True