yubo0306 commited on
Commit
425850c
·
verified ·
1 Parent(s): 3ca7108

Upload AVHubertForConditionalGeneration

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
The diff for this file is too large to render. See raw diff
 
configuration_avhubert.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import HubertConfig, PretrainedConfig
2
+
3
+
4
+ class AVHubertConfig(PretrainedConfig):
5
+ model_type: str = "avhubert"
6
+
7
+ def __init__(
8
+ self,
9
+ label_rate: int = 100,
10
+ encoder_layers: int = 12,
11
+ encoder_embed_dim: int = 768,
12
+ encoder_ffn_embed_dim: int = 3072,
13
+ encoder_attention_heads: int = 12,
14
+ activation_fn: str = "gelu",
15
+ dropout: float = 0.1,
16
+ attention_dropout: float = 0.1,
17
+ activation_dropout: float = 0.0,
18
+ encoder_layerdrop: float = 0.0,
19
+ dropout_input: float = 0.0,
20
+ conv_dim: tuple[int, ...] = (512, 512, 512, 512, 512, 512, 512),
21
+ conv_stride: tuple[int, ...] = (5, 2, 2, 2, 2, 2, 2),
22
+ conv_kernel: tuple[int, ...] = (10, 3, 3, 3, 3, 2, 2),
23
+ conv_bias: bool = False,
24
+ conv_pos: int = 128,
25
+ conv_pos_groups: int = 16,
26
+ resnet_relu_type: str = "prelu",
27
+ audio_feat_dim: int = 104,
28
+ modality_fuse: str = "concat",
29
+ decoder_embed_dim: int = 768,
30
+ decoder_ffn_embed_dim: int = 3072,
31
+ decoder_layers: int = 6,
32
+ decoder_layerdrop: float = 0.0,
33
+ decoder_attention_heads: int = 4,
34
+ decoder_learned_pos: bool = False,
35
+ decoder_normalize_before: bool = False,
36
+ no_token_positional_embeddings: bool = False,
37
+ decoder_dropout: float = 0.1,
38
+ decoder_attention_dropout: float = 0.1,
39
+ decoder_activation_dropout: float = 0.0,
40
+ max_target_positions: int = 2048,
41
+ share_decoder_input_output_embed: bool = False,
42
+ no_scale_embedding: bool = True,
43
+ sample_rate: int = 25,
44
+ num_labels: int = 100,
45
+ initializer_range: float = 0.02,
46
+ do_stable_layer_norm: bool = False,
47
+ vocab_size: int | None = None,
48
+ freeze_feature_encoder: bool = False,
49
+ freeze_base_model: bool = False,
50
+ ctc_loss_reduction: str = "mean",
51
+ ctc_zero_infinity: bool = False,
52
+ ctc_loss_weight: float = 0.3,
53
+ special_ids: list[int] | None = None,
54
+ **kwargs,
55
+ ):
56
+ super().__init__(**kwargs)
57
+ self.label_rate = label_rate
58
+ self.encoder_layers = encoder_layers
59
+ self.encoder_embed_dim = encoder_embed_dim
60
+ self.encoder_ffn_embed_dim = encoder_ffn_embed_dim
61
+ self.encoder_attention_heads = encoder_attention_heads
62
+ self.activation_fn = activation_fn
63
+ self.dropout = dropout
64
+ self.attention_dropout = attention_dropout
65
+ self.activation_dropout = activation_dropout
66
+ self.encoder_layerdrop = encoder_layerdrop
67
+ self.dropout_input = dropout_input
68
+ self.conv_dim = conv_dim
69
+ self.conv_kernel = conv_kernel
70
+ self.conv_stride = conv_stride
71
+ self.conv_bias = conv_bias
72
+ self.conv_pos = conv_pos
73
+ self.conv_pos_groups = conv_pos_groups
74
+ self.resnet_relu_type = resnet_relu_type
75
+ self.audio_feat_dim = audio_feat_dim
76
+ self.modality_fuse = modality_fuse
77
+ self.decoder_embed_dim = decoder_embed_dim
78
+ self.decoder_ffn_embed_dim = decoder_ffn_embed_dim
79
+ self.decoder_layers = decoder_layers
80
+ self.decoder_layerdrop = decoder_layerdrop
81
+ self.decoder_attention_heads = decoder_attention_heads
82
+ self.decoder_learned_pos = decoder_learned_pos
83
+ self.decoder_normalize_before = decoder_normalize_before
84
+ self.no_token_positional_embeddings = no_token_positional_embeddings
85
+ self.decoder_dropout = decoder_dropout
86
+ self.decoder_attention_dropout = decoder_attention_dropout
87
+ self.decoder_activation_dropout = decoder_activation_dropout
88
+ self.max_target_positions = max_target_positions
89
+ self.share_decoder_input_output_embed = share_decoder_input_output_embed
90
+ self.no_scale_embedding = no_scale_embedding
91
+ self.sample_rate = sample_rate
92
+ self.num_labels = num_labels
93
+ self.initializer_range = initializer_range
94
+ self.do_stable_layer_norm = do_stable_layer_norm
95
+ self.vocab_size = vocab_size
96
+ self.freeze_feature_encoder = freeze_feature_encoder
97
+ self.freeze_base_model = freeze_base_model
98
+ self.ctc_loss_reduction = ctc_loss_reduction
99
+ self.ctc_zero_infinity = ctc_zero_infinity
100
+ self.ctc_loss_weight = ctc_loss_weight
101
+ self.special_ids = special_ids
102
+
103
+ @property
104
+ def encoder_config(self) -> HubertConfig:
105
+ return HubertConfig(
106
+ hidden_size=self.encoder_embed_dim,
107
+ num_hidden_layers=self.encoder_layers,
108
+ num_attention_heads=self.encoder_attention_heads,
109
+ intermediate_size=self.encoder_ffn_embed_dim,
110
+ hidden_act=self.activation_fn,
111
+ hidden_dropout=self.dropout,
112
+ activation_dropout=self.activation_dropout,
113
+ attention_dropout=self.attention_dropout,
114
+ layerdrop=self.encoder_layerdrop,
115
+ conv_dim=self.conv_dim,
116
+ conv_kernel=self.conv_kernel,
117
+ conv_stride=self.conv_stride,
118
+ conv_bias=self.conv_bias,
119
+ num_conv_pos_embeddings=self.conv_pos,
120
+ num_conv_pos_embedding_groups=self.conv_pos_groups,
121
+ feat_extract_activation="gelu",
122
+ do_stable_layer_norm=self.do_stable_layer_norm,
123
+ max_position_embeddings=self.max_target_positions,
124
+ learned_pos=self.decoder_learned_pos,
125
+ share_input_output_embed=self.share_decoder_input_output_embed,
126
+ )
127
+
128
+ @property
129
+ def decoder_config(self) -> HubertConfig:
130
+ return HubertConfig(
131
+ hidden_size=self.decoder_embed_dim,
132
+ num_hidden_layers=self.decoder_layers,
133
+ num_attention_heads=self.decoder_attention_heads,
134
+ intermediate_size=self.decoder_ffn_embed_dim,
135
+ hidden_act=self.activation_fn,
136
+ hidden_dropout=self.decoder_dropout,
137
+ activation_dropout=self.decoder_activation_dropout,
138
+ attention_dropout=self.decoder_attention_dropout,
139
+ layerdrop=self.decoder_layerdrop,
140
+ conv_dim=self.conv_dim,
141
+ conv_kernel=self.conv_kernel,
142
+ conv_stride=self.conv_stride,
143
+ conv_bias=self.conv_bias,
144
+ num_conv_pos_embeddings=self.conv_pos,
145
+ num_conv_pos_embedding_groups=self.conv_pos_groups,
146
+ feat_extract_activation="gelu",
147
+ do_stable_layer_norm=self.do_stable_layer_norm,
148
+ max_position_embeddings=self.max_target_positions,
149
+ learned_pos=self.decoder_learned_pos,
150
+ share_input_output_embed=self.share_decoder_input_output_embed,
151
+ )
configuration_resnet.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class ResEncoderConfig(PretrainedConfig):
5
+ model_type = "modified_resnet"
6
+
7
+ def __init__(
8
+ self,
9
+ relu_type="prelu",
10
+ frontend_nout=64,
11
+ backend_out=512,
12
+ **kwargs,
13
+ ):
14
+ self.relu_type = relu_type
15
+ self.frontend_nout = frontend_nout
16
+ self.backend_out = backend_out
17
+ super().__init__(**kwargs)
decoder.py ADDED
@@ -0,0 +1,1097 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Tuple, TypedDict, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers.cache_utils import Cache, EncoderDecoderCache, StaticCache
7
+ from transformers.modeling_attn_mask_utils import (
8
+ AttentionMaskConverter,
9
+ _prepare_4d_attention_mask,
10
+ _prepare_4d_attention_mask_for_sdpa,
11
+ )
12
+ from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
13
+ from transformers.models.hubert.configuration_hubert import HubertConfig
14
+ from transformers.models.hubert.modeling_hubert import (
15
+ HubertAttnAdapterLayer,
16
+ HubertFeedForward,
17
+ is_deepspeed_zero3_enabled,
18
+ )
19
+ from transformers.utils import is_torchdynamo_compiling, logging
20
+ from typing_extensions import Unpack
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ # Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flash_attention_utils.py#L428
26
+ class FlashAttentionKwargs(TypedDict, total=False):
27
+ """
28
+ Keyword arguments for Flash Attention with Compile.
29
+
30
+ Attributes:
31
+ cumulative_seqlens_q (`torch.LongTensor`, *optional*)
32
+ Gets cumulative sequence length for query state.
33
+ cumulative_seqlens_k (`torch.LongTensor`, *optional*)
34
+ Gets cumulative sequence length for key state.
35
+ max_length_q (`int`, *optional*):
36
+ Maximum sequence length for query state.
37
+ max_length_k (`int`, *optional*):
38
+ Maximum sequence length for key state.
39
+ """
40
+
41
+ cumulative_seqlens_q: Optional[torch.LongTensor]
42
+ cumulative_seqlens_k: Optional[torch.LongTensor]
43
+ max_length_q: Optional[int]
44
+ max_length_k: Optional[int]
45
+
46
+
47
+ class SinusoidalPositionalEmbedding(nn.Module):
48
+ def __init__(self, config) -> None:
49
+ super().__init__()
50
+ weight = torch.empty(
51
+ (
52
+ config.max_position_embeddings,
53
+ config.hidden_size,
54
+ ),
55
+ requires_grad=False,
56
+ )
57
+ self._init_sinusoidal_embedding(weight)
58
+ self.register_buffer("position_embeddings", weight)
59
+
60
+ def _init_sinusoidal_embedding(self, embeddings: torch.Tensor) -> None:
61
+ T, D = embeddings.size()
62
+ position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / D) for j in range(D)] for pos in range(T)])
63
+ embeddings[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
64
+ embeddings[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
65
+
66
+ def forward(
67
+ self,
68
+ inputs: torch.Tensor,
69
+ past_key_values_length: int = 0, # Offset
70
+ position_ids: torch.LongTensor | None = None,
71
+ ) -> torch.Tensor:
72
+ if position_ids is None:
73
+ bsz, seq_len = inputs.shape[:2]
74
+ position_ids = torch.arange(
75
+ past_key_values_length,
76
+ past_key_values_length + seq_len,
77
+ dtype=torch.long,
78
+ device=self.position_embeddings.device,
79
+ ).expand(bsz, -1)
80
+ else:
81
+ position_ids = position_ids.unsqueeze(0)
82
+ return self.position_embeddings[position_ids]
83
+
84
+
85
+ # Copied from https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/bart/modeling_bart.py#L116
86
+ class LearnedPositionalEmbedding(nn.Embedding):
87
+ """
88
+ This module learns positional embeddings up to a fixed maximum size.
89
+ """
90
+
91
+ def __init__(self, num_embeddings: int, embedding_dim: int):
92
+ # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
93
+ # and adjust num_embeddings appropriately. Other models don't have this hack
94
+ # self.offset = 2
95
+ # super().__init__(num_embeddings + self.offset, embedding_dim)
96
+ super().__init__(num_embeddings, embedding_dim)
97
+
98
+ def forward(
99
+ self,
100
+ input_ids: torch.Tensor,
101
+ past_key_values_length: int = 0,
102
+ position_ids: torch.LongTensor = None,
103
+ ):
104
+ """`input_ids' shape is expected to be [bsz x seqlen]."""
105
+
106
+ if position_ids is None:
107
+ bsz, seq_len = input_ids.shape[:2]
108
+ position_ids = torch.arange(
109
+ past_key_values_length,
110
+ past_key_values_length + seq_len,
111
+ dtype=torch.long,
112
+ device=self.weight.device,
113
+ ).expand(bsz, -1)
114
+ else:
115
+ position_ids = position_ids.unsqueeze(0)
116
+
117
+ # return super().forward(positions + self.offset)
118
+ return super().forward(position_ids)
119
+
120
+
121
+ def eager_attention_forward(
122
+ module: nn.Module,
123
+ query: torch.Tensor,
124
+ key: torch.Tensor,
125
+ value: torch.Tensor,
126
+ attention_mask: Optional[torch.Tensor],
127
+ scaling: Optional[float] = None,
128
+ dropout: float = 0.0,
129
+ head_mask: Optional[torch.Tensor] = None,
130
+ **kwargs,
131
+ ):
132
+ if scaling is None:
133
+ scaling = query.size(-1) ** -0.5
134
+
135
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
136
+ if attention_mask is not None:
137
+ attn_weights = attn_weights + attention_mask
138
+
139
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
140
+
141
+ if head_mask is not None:
142
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
143
+
144
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
145
+ attn_output = torch.matmul(attn_weights, value)
146
+ attn_output = attn_output.transpose(1, 2).contiguous()
147
+
148
+ return attn_output, attn_weights
149
+
150
+
151
+ class AVHubertAttention(nn.Module):
152
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
153
+
154
+ def __init__(
155
+ self,
156
+ embed_dim: int,
157
+ num_heads: int,
158
+ dropout: float = 0.0,
159
+ is_decoder: bool = False,
160
+ bias: bool = True,
161
+ is_causal: bool = False,
162
+ config: Optional[HubertConfig] = None,
163
+ layer_idx: Optional[int] = None,
164
+ ):
165
+ super().__init__()
166
+ self.embed_dim = embed_dim
167
+ self.num_heads = num_heads
168
+ self.dropout = dropout
169
+ self.head_dim = embed_dim // num_heads
170
+ self.config = config
171
+
172
+ if (self.head_dim * num_heads) != self.embed_dim:
173
+ raise ValueError(
174
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
175
+ f" and `num_heads`: {num_heads})."
176
+ )
177
+ self.scaling = self.head_dim**-0.5
178
+ self.is_decoder = is_decoder
179
+ self.is_causal = is_causal
180
+ self.layer_idx = layer_idx
181
+ if layer_idx is None and self.is_decoder:
182
+ logger.warning_once(
183
+ f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
184
+ "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
185
+ "when creating this class."
186
+ )
187
+
188
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
189
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
190
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
191
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
192
+
193
+ def forward(
194
+ self,
195
+ hidden_states: torch.Tensor,
196
+ key_value_states: Optional[torch.Tensor] = None,
197
+ past_key_value: Optional[Cache] = None,
198
+ attention_mask: Optional[torch.Tensor] = None,
199
+ layer_head_mask: Optional[torch.Tensor] = None,
200
+ output_attentions: bool = False,
201
+ cache_position: Optional[torch.Tensor] = None,
202
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
203
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
204
+ **kwargs: Unpack[FlashAttentionKwargs],
205
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
206
+ """Input shape: Batch x Time x Channel"""
207
+
208
+ # if key_value_states are provided this layer is used as a cross-attention layer
209
+ # for the decoder
210
+ is_cross_attention = key_value_states is not None
211
+
212
+ # determine input shapes
213
+ bsz, tgt_len = hidden_states.shape[:-1]
214
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
215
+
216
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
217
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
218
+
219
+ # get query proj
220
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
221
+
222
+ if past_key_value is not None:
223
+ if isinstance(past_key_value, EncoderDecoderCache):
224
+ is_updated = past_key_value.is_updated.get(self.layer_idx)
225
+ if is_cross_attention:
226
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
227
+ curr_past_key_value = past_key_value.cross_attention_cache
228
+ else:
229
+ curr_past_key_value = past_key_value.self_attention_cache
230
+ else:
231
+ curr_past_key_value = past_key_value
232
+
233
+ current_states = key_value_states if is_cross_attention else hidden_states
234
+ if is_cross_attention and past_key_value is not None and is_updated:
235
+ # reuse k,v, cross_attentions
236
+ key_states = curr_past_key_value.key_cache[self.layer_idx]
237
+ value_states = curr_past_key_value.value_cache[self.layer_idx]
238
+ else:
239
+ key_states = self.k_proj(current_states)
240
+ value_states = self.v_proj(current_states)
241
+ key_states = key_states.view(*kv_input_shape).transpose(1, 2)
242
+ value_states = value_states.view(*kv_input_shape).transpose(1, 2)
243
+
244
+ if past_key_value is not None:
245
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
246
+ cache_position = cache_position if not is_cross_attention else None
247
+ key_states, value_states = curr_past_key_value.update(
248
+ key_states,
249
+ value_states,
250
+ self.layer_idx,
251
+ {"cache_position": cache_position},
252
+ )
253
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
254
+ if is_cross_attention:
255
+ past_key_value.is_updated[self.layer_idx] = True
256
+
257
+ attention_interface: Callable = eager_attention_forward
258
+ # TODO: attn implementation other than eager attention
259
+ # if self.config._attn_implementation != "eager":
260
+ # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
261
+
262
+ attn_output, attn_weights = attention_interface(
263
+ self,
264
+ query_states,
265
+ key_states,
266
+ value_states,
267
+ attention_mask,
268
+ dropout=0.0 if not self.training else self.dropout,
269
+ scaling=self.scaling,
270
+ output_attentions=output_attentions,
271
+ head_mask=layer_head_mask,
272
+ **kwargs,
273
+ )
274
+
275
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
276
+ attn_output = self.out_proj(attn_output)
277
+
278
+ return attn_output, attn_weights, past_key_value
279
+
280
+
281
+ class AVHubertDecoderLayer(nn.Module):
282
+ def __init__(self, config: HubertConfig, layer_idx: Optional[int] = None):
283
+ super().__init__()
284
+ self.attention = AVHubertAttention(
285
+ embed_dim=config.hidden_size,
286
+ num_heads=config.num_attention_heads,
287
+ dropout=config.attention_dropout,
288
+ is_decoder=True,
289
+ is_causal=True,
290
+ config=config,
291
+ layer_idx=layer_idx,
292
+ )
293
+ self.dropout = nn.Dropout(config.hidden_dropout)
294
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
295
+
296
+ self.encoder_attn = AVHubertAttention(
297
+ embed_dim=config.hidden_size,
298
+ num_heads=config.num_attention_heads,
299
+ dropout=config.attention_dropout,
300
+ is_decoder=True,
301
+ config=config,
302
+ layer_idx=layer_idx,
303
+ )
304
+
305
+ self.encoder_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
306
+ self.feed_forward = HubertFeedForward(config)
307
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
308
+
309
+ if getattr(config, "adapter_attn_dim", None) is not None:
310
+ self.adapter_layer = HubertAttnAdapterLayer(config)
311
+ else:
312
+ self.adapter_layer = None
313
+
314
+ def forward(
315
+ self,
316
+ hidden_states: torch.Tensor,
317
+ attention_mask: torch.Tensor | None = None,
318
+ encoder_hidden_states: torch.Tensor | None = None,
319
+ encoder_attention_mask: torch.Tensor | None = None,
320
+ layer_head_mask: Optional[torch.Tensor] = None,
321
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
322
+ past_key_value: Optional[Cache] = None,
323
+ output_attentions: Optional[bool] = False,
324
+ use_cache: Optional[bool] = True,
325
+ cache_position: Optional[torch.Tensor] = None,
326
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
327
+ residual = hidden_states
328
+ hidden_states, self_attn_weights, past_key_value = self.attention(
329
+ hidden_states=hidden_states,
330
+ past_key_value=past_key_value,
331
+ attention_mask=attention_mask,
332
+ layer_head_mask=layer_head_mask,
333
+ output_attentions=output_attentions,
334
+ cache_position=cache_position,
335
+ )
336
+ hidden_states = self.dropout(hidden_states)
337
+ hidden_states = residual + hidden_states
338
+ hidden_states = self.layer_norm(hidden_states)
339
+
340
+ # Cross-Attention Block
341
+ cross_attn_weights = None
342
+ if encoder_hidden_states is not None:
343
+ residual = hidden_states
344
+ hidden_states, cross_attn_weights, _ = self.encoder_attn(
345
+ hidden_states=hidden_states,
346
+ key_value_states=encoder_hidden_states,
347
+ attention_mask=encoder_attention_mask,
348
+ layer_head_mask=cross_attn_layer_head_mask,
349
+ past_key_value=past_key_value,
350
+ output_attentions=output_attentions,
351
+ cache_position=cache_position,
352
+ )
353
+ hidden_states = self.dropout(hidden_states)
354
+ hidden_states = hidden_states + residual
355
+ hidden_states = self.encoder_layer_norm(hidden_states)
356
+
357
+ hidden_states = hidden_states + self.feed_forward(hidden_states)
358
+ hidden_states = self.final_layer_norm(hidden_states)
359
+
360
+ if self.adapter_layer is not None:
361
+ hidden_states = hidden_states + self.adapter_layer(hidden_states)
362
+
363
+ outputs = (hidden_states,)
364
+
365
+ if output_attentions:
366
+ outputs += (self_attn_weights, cross_attn_weights)
367
+
368
+ if use_cache:
369
+ outputs += (past_key_value,)
370
+
371
+ return outputs
372
+
373
+
374
+ class AVHubertDecoderLayerStableLayerNorm(nn.Module):
375
+ def __init__(self, config: HubertConfig, layer_idx: Optional[int] = None):
376
+ super().__init__()
377
+ self.attention = AVHubertAttention(
378
+ embed_dim=config.hidden_size,
379
+ num_heads=config.num_attention_heads,
380
+ dropout=config.attention_dropout,
381
+ is_decoder=True,
382
+ is_causal=True,
383
+ config=config,
384
+ layer_idx=layer_idx,
385
+ )
386
+ self.dropout = nn.Dropout(config.hidden_dropout)
387
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
388
+
389
+ self.encoder_attn = AVHubertAttention(
390
+ embed_dim=config.hidden_size,
391
+ num_heads=config.num_attention_heads,
392
+ dropout=config.attention_dropout,
393
+ is_decoder=True,
394
+ config=config,
395
+ layer_idx=layer_idx,
396
+ )
397
+ self.encoder_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
398
+ self.feed_forward = HubertFeedForward(config)
399
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
400
+
401
+ if getattr(config, "adapter_attn_dim", None) is not None:
402
+ self.adapter_layer = HubertAttnAdapterLayer(config)
403
+ else:
404
+ self.adapter_layer = None
405
+
406
+ def forward(
407
+ self,
408
+ hidden_states: torch.Tensor,
409
+ attention_mask: torch.Tensor | None = None,
410
+ encoder_hidden_states: torch.Tensor | None = None,
411
+ encoder_attention_mask: torch.Tensor | None = None,
412
+ layer_head_mask: Optional[torch.Tensor] = None,
413
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
414
+ past_key_value: Optional[Cache] = None,
415
+ output_attentions: Optional[bool] = False,
416
+ use_cache: Optional[bool] = True,
417
+ cache_position: Optional[torch.Tensor] = None,
418
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
419
+ residual = hidden_states
420
+ hidden_states = self.layer_norm(hidden_states)
421
+
422
+ hidden_states, self_attn_weights, past_key_value = self.attention(
423
+ hidden_states=hidden_states,
424
+ past_key_value=past_key_value,
425
+ attention_mask=attention_mask,
426
+ layer_head_mask=layer_head_mask,
427
+ output_attentions=output_attentions,
428
+ cache_position=cache_position,
429
+ )
430
+ hidden_states = self.dropout(hidden_states)
431
+ hidden_states = residual + hidden_states
432
+
433
+ # Cross-Attention Block
434
+ cross_attn_weights = None
435
+ if encoder_hidden_states is not None:
436
+ residual = hidden_states
437
+ hidden_states = self.encoder_layer_norm(hidden_states)
438
+
439
+ hidden_states, cross_attn_weights, _ = self.encoder_attn(
440
+ hidden_states=hidden_states,
441
+ key_value_states=encoder_hidden_states,
442
+ attention_mask=encoder_attention_mask,
443
+ layer_head_mask=cross_attn_layer_head_mask,
444
+ past_key_value=past_key_value,
445
+ output_attentions=output_attentions,
446
+ cache_position=cache_position,
447
+ )
448
+ hidden_states = self.dropout(hidden_states)
449
+ hidden_states = hidden_states + residual
450
+
451
+ hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
452
+
453
+ if self.adapter_layer is not None:
454
+ hidden_states = hidden_states + self.adapter_layer(hidden_states)
455
+
456
+ outputs = (hidden_states,)
457
+
458
+ if output_attentions:
459
+ outputs += (self_attn_weights, cross_attn_weights)
460
+
461
+ if use_cache:
462
+ outputs += (past_key_value,)
463
+
464
+ return outputs
465
+
466
+
467
+ class AVHubertDecoder(nn.Module):
468
+ def __init__(self, config):
469
+ super().__init__()
470
+ self.config = config
471
+ if config.learned_pos:
472
+ self.pos_embed = LearnedPositionalEmbedding(
473
+ num_embeddings=config.max_position_embeddings,
474
+ embedding_dim=config.hidden_size,
475
+ )
476
+ else:
477
+ self.pos_embed = SinusoidalPositionalEmbedding(config=config)
478
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
479
+ self.dropout = nn.Dropout(config.hidden_dropout)
480
+ self.layers = nn.ModuleList([AVHubertDecoderLayer(config) for _ in range(config.num_hidden_layers)])
481
+ self.gradient_checkpointing = False
482
+
483
+ def forward(
484
+ self,
485
+ inputs_embeds: torch.Tensor | None = None,
486
+ attention_mask: torch.Tensor | None = None,
487
+ encoder_hidden_states: torch.Tensor | None = None,
488
+ encoder_attention_mask: torch.Tensor | None = None,
489
+ head_mask: torch.Tensor | None = None,
490
+ cross_attn_head_mask: torch.Tensor | None = None,
491
+ past_key_values: EncoderDecoderCache | None = None,
492
+ use_cache: bool | None = None,
493
+ output_attentions: bool = False,
494
+ output_hidden_states: bool = False,
495
+ return_dict: bool = True,
496
+ cache_position: torch.LongTensor | None = None,
497
+ ):
498
+ input_shape = inputs_embeds.shape[:-1]
499
+ if use_cache and past_key_values is None:
500
+ past_key_values = EncoderDecoderCache.from_legacy_cache()
501
+ past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
502
+ batch_size, seq_length = inputs_embeds.size()[:-1]
503
+ if cache_position is None:
504
+ cache_position = torch.arange(
505
+ past_key_values_length,
506
+ past_key_values_length + seq_length,
507
+ device=inputs_embeds.device,
508
+ )
509
+
510
+ if attention_mask is None and not is_torchdynamo_compiling():
511
+ # required mask seq length can be calculated via length of past cache
512
+ mask_seq_length = past_key_values_length + seq_length
513
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
514
+
515
+ self_attn_cache = (
516
+ past_key_values.self_attention_cache
517
+ if isinstance(past_key_values, EncoderDecoderCache)
518
+ else past_key_values
519
+ )
520
+
521
+ attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, self_attn_cache)
522
+ encoder_attention_mask = self._update_cross_attn_mask(
523
+ encoder_hidden_states, encoder_attention_mask, input_shape, inputs_embeds
524
+ )
525
+
526
+ # embed positions
527
+ position_embeddings = self.pos_embed(inputs_embeds, past_key_values_length, position_ids=cache_position)
528
+ hidden_states = inputs_embeds + position_embeddings
529
+ hidden_states = self.dropout(hidden_states)
530
+
531
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
532
+
533
+ # decoder layers
534
+ all_hidden_states = () if output_hidden_states else None
535
+ all_self_attns = () if output_attentions else None
536
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
537
+ next_decoder_cache = None
538
+
539
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
540
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
541
+ if attn_mask is not None:
542
+ if attn_mask.size()[0] != (len(self.layers)):
543
+ raise ValueError(
544
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
545
+ f" {head_mask.size()[0]}."
546
+ )
547
+
548
+ for idx, layer in enumerate(self.layers):
549
+ if output_hidden_states:
550
+ all_hidden_states += (hidden_states,)
551
+
552
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
553
+ dropout_probability = torch.rand([])
554
+
555
+ skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
556
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
557
+ # under deepspeed zero3 all gpus must run in sync
558
+ # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
559
+ if self.gradient_checkpointing and self.training:
560
+ layer_outputs = self._gradient_checkpointing_func(
561
+ layer.__call__,
562
+ hidden_states,
563
+ attention_mask,
564
+ encoder_hidden_states,
565
+ encoder_attention_mask,
566
+ output_attentions,
567
+ )
568
+ raise NotImplementedError("Currently, gradient checkpointing is not supported.")
569
+ else:
570
+ layer_outputs = layer(
571
+ hidden_states,
572
+ attention_mask=attention_mask,
573
+ encoder_hidden_states=encoder_hidden_states,
574
+ encoder_attention_mask=encoder_attention_mask,
575
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
576
+ cross_attn_layer_head_mask=(
577
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
578
+ ),
579
+ past_key_value=past_key_values,
580
+ output_attentions=output_attentions,
581
+ use_cache=use_cache,
582
+ cache_position=cache_position,
583
+ )
584
+ hidden_states = layer_outputs[0]
585
+
586
+ if skip_the_layer:
587
+ layer_outputs = (None, None, None, None)
588
+
589
+ if use_cache:
590
+ next_decoder_cache = layer_outputs[3 if output_attentions else 1]
591
+
592
+ if output_attentions:
593
+ all_self_attns += (layer_outputs[1],)
594
+
595
+ if encoder_hidden_states is not None:
596
+ all_cross_attentions += (layer_outputs[2],)
597
+
598
+ hidden_states = self.layer_norm(hidden_states)
599
+
600
+ # add hidden states from the last decoder layer
601
+ if output_hidden_states:
602
+ all_hidden_states += (hidden_states,)
603
+
604
+ next_cache = next_decoder_cache if use_cache else None
605
+
606
+ if output_hidden_states:
607
+ all_hidden_states = all_hidden_states + (hidden_states,)
608
+
609
+ if not return_dict:
610
+ return tuple(
611
+ v
612
+ for v in [
613
+ hidden_states,
614
+ next_cache,
615
+ all_hidden_states,
616
+ all_self_attns,
617
+ all_cross_attentions,
618
+ ]
619
+ if v is not None
620
+ )
621
+ return BaseModelOutputWithPastAndCrossAttentions(
622
+ last_hidden_state=hidden_states,
623
+ past_key_values=next_cache,
624
+ hidden_states=all_hidden_states,
625
+ attentions=all_self_attns,
626
+ cross_attentions=all_cross_attentions,
627
+ )
628
+
629
+ def _update_causal_mask(
630
+ self,
631
+ attention_mask: Optional[torch.Tensor],
632
+ input_tensor: torch.Tensor,
633
+ cache_position: torch.Tensor,
634
+ past_key_values: Cache,
635
+ ):
636
+ if self.config._attn_implementation == "flex_attention":
637
+ raise NotImplementedError
638
+
639
+ if self.config._attn_implementation == "flash_attention_2":
640
+ raise NotImplementedError
641
+
642
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
643
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
644
+ # to infer the attention mask.
645
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
646
+ using_compilable_cache = True if isinstance(past_key_values, StaticCache) else False
647
+
648
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
649
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
650
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
651
+ attention_mask,
652
+ inputs_embeds=input_tensor,
653
+ past_key_values_length=past_seen_tokens,
654
+ is_training=self.training,
655
+ ):
656
+ return None
657
+
658
+ dtype = input_tensor.dtype
659
+ sequence_length = input_tensor.shape[1]
660
+ if using_compilable_cache:
661
+ target_length = past_key_values.get_max_cache_shape()
662
+ else:
663
+ target_length = (
664
+ attention_mask.shape[-1]
665
+ if isinstance(attention_mask, torch.Tensor)
666
+ else past_seen_tokens + sequence_length + 1
667
+ )
668
+
669
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
670
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
671
+ attention_mask,
672
+ sequence_length=sequence_length,
673
+ target_length=target_length,
674
+ dtype=dtype,
675
+ cache_position=cache_position,
676
+ batch_size=input_tensor.shape[0],
677
+ )
678
+
679
+ if (
680
+ self.config._attn_implementation == "sdpa"
681
+ and attention_mask is not None
682
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
683
+ ):
684
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
685
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
686
+ # Details: https://github.com/pytorch/pytorch/issues/110213
687
+ min_dtype = torch.finfo(dtype).min
688
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
689
+
690
+ return causal_mask
691
+
692
+ @staticmethod
693
+ def _prepare_4d_causal_attention_mask_with_cache_position(
694
+ attention_mask: torch.Tensor,
695
+ sequence_length: int,
696
+ target_length: int,
697
+ dtype: torch.dtype,
698
+ cache_position: torch.Tensor,
699
+ batch_size: int,
700
+ **kwargs,
701
+ ):
702
+ """
703
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
704
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
705
+
706
+ Args:
707
+ attention_mask (`torch.Tensor`):
708
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
709
+ `(batch_size, 1, query_length, key_value_length)`.
710
+ sequence_length (`int`):
711
+ The sequence length being processed.
712
+ target_length (`int`):
713
+ The target length: when generating with static cache, the mask should be as long as the static cache,
714
+ to account for the 0 padding, the part of the cache that is not filled yet.
715
+ dtype (`torch.dtype`):
716
+ The dtype to use for the 4D attention mask.
717
+ cache_position (`torch.Tensor`):
718
+ Indices depicting the position of the input sequence tokens in the sequence.
719
+ batch_size (`torch.Tensor`):
720
+ Batch size.
721
+ """
722
+ if attention_mask is not None and attention_mask.dim() == 4:
723
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
724
+ causal_mask = attention_mask
725
+ else:
726
+ min_dtype = torch.finfo(dtype).min
727
+ causal_mask = torch.full(
728
+ (sequence_length, target_length),
729
+ fill_value=min_dtype,
730
+ dtype=dtype,
731
+ device=cache_position.device,
732
+ )
733
+ if sequence_length != 1:
734
+ causal_mask = torch.triu(causal_mask, diagonal=1)
735
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
736
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
737
+ if attention_mask is not None:
738
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
739
+ mask_length = attention_mask.shape[-1]
740
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
741
+ causal_mask.device
742
+ )
743
+ padding_mask = padding_mask == 0
744
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
745
+ padding_mask, min_dtype
746
+ )
747
+
748
+ return causal_mask
749
+
750
+ def _update_cross_attn_mask(
751
+ self,
752
+ encoder_hidden_states: Union[torch.Tensor, None],
753
+ encoder_attention_mask: Union[torch.Tensor, None],
754
+ input_shape: torch.Size,
755
+ inputs_embeds: torch.Tensor,
756
+ ):
757
+ # expand encoder attention mask
758
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
759
+ if self.config._attn_implementation == "flash_attention_2":
760
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
761
+ elif self.config._attn_implementation == "sdpa":
762
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
763
+ # the manual implementation that requires a 4D causal mask in all cases.
764
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
765
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
766
+ encoder_attention_mask,
767
+ inputs_embeds.dtype,
768
+ tgt_len=input_shape[-1],
769
+ )
770
+ elif self.config._attn_implementation == "flex_attention":
771
+ raise NotImplementedError
772
+ else:
773
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
774
+ encoder_attention_mask = _prepare_4d_attention_mask(
775
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
776
+ )
777
+
778
+ return encoder_attention_mask
779
+
780
+
781
+ class AVHubertDecoderStableLayerNorm(nn.Module):
782
+ def __init__(self, config):
783
+ super().__init__()
784
+ self.config = config
785
+ if config.learned_pos:
786
+ self.pos_embed = LearnedPositionalEmbedding(
787
+ num_embeddings=config.max_position_embeddings,
788
+ embedding_dim=config.hidden_size,
789
+ )
790
+ else:
791
+ self.pos_embed = SinusoidalPositionalEmbedding(config=config)
792
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
793
+ self.dropout = nn.Dropout(config.hidden_dropout)
794
+ self.layers = nn.ModuleList(
795
+ [
796
+ AVHubertDecoderLayerStableLayerNorm(config, layer_idx=layer_idx)
797
+ for layer_idx in range(config.num_hidden_layers)
798
+ ]
799
+ )
800
+ self.gradient_checkpointing = False
801
+
802
+ def forward(
803
+ self,
804
+ inputs_embeds: torch.Tensor | None = None,
805
+ attention_mask: torch.Tensor | None = None,
806
+ encoder_hidden_states: torch.Tensor | None = None,
807
+ encoder_attention_mask: torch.Tensor | None = None,
808
+ head_mask: torch.Tensor | None = None,
809
+ cross_attn_head_mask: torch.Tensor | None = None,
810
+ past_key_values: EncoderDecoderCache | None = None,
811
+ use_cache: bool | None = None,
812
+ output_attentions: bool = False,
813
+ output_hidden_states: bool = False,
814
+ return_dict: bool = True,
815
+ cache_position: torch.LongTensor | None = None,
816
+ ):
817
+ input_shape = inputs_embeds.shape[:-1]
818
+ if use_cache and past_key_values is None:
819
+ past_key_values = EncoderDecoderCache.from_legacy_cache()
820
+ past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
821
+ batch_size, seq_length = inputs_embeds.size()[:-1]
822
+ if cache_position is None:
823
+ cache_position = torch.arange(
824
+ past_key_values_length,
825
+ past_key_values_length + seq_length,
826
+ device=inputs_embeds.device,
827
+ )
828
+
829
+ if attention_mask is None and not is_torchdynamo_compiling():
830
+ # required mask seq length can be calculated via length of past cache
831
+ mask_seq_length = past_key_values_length + seq_length
832
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
833
+
834
+ self_attn_cache = (
835
+ past_key_values.self_attention_cache
836
+ if isinstance(past_key_values, EncoderDecoderCache)
837
+ else past_key_values
838
+ )
839
+
840
+ attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, self_attn_cache)
841
+ encoder_attention_mask = self._update_cross_attn_mask(
842
+ encoder_hidden_states, encoder_attention_mask, input_shape, inputs_embeds
843
+ )
844
+
845
+ # embed positions
846
+ position_embeddings = self.pos_embed(inputs_embeds, past_key_values_length, position_ids=cache_position)
847
+ hidden_states = inputs_embeds + position_embeddings
848
+ hidden_states = self.dropout(hidden_states)
849
+
850
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
851
+
852
+ # decoder layers
853
+ all_hidden_states = () if output_hidden_states else None
854
+ all_self_attns = () if output_attentions else None
855
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
856
+ next_decoder_cache = None
857
+
858
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
859
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
860
+ if attn_mask is not None:
861
+ if attn_mask.size()[0] != (len(self.layers)):
862
+ raise ValueError(
863
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
864
+ f" {head_mask.size()[0]}."
865
+ )
866
+
867
+ for idx, layer in enumerate(self.layers):
868
+ if output_hidden_states:
869
+ all_hidden_states += (hidden_states,)
870
+
871
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
872
+ dropout_probability = torch.rand([])
873
+
874
+ skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
875
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
876
+ # under deepspeed zero3 all gpus must run in sync
877
+ # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
878
+ if self.gradient_checkpointing and self.training:
879
+ layer_outputs = self._gradient_checkpointing_func(
880
+ layer.__call__,
881
+ hidden_states,
882
+ attention_mask,
883
+ encoder_hidden_states,
884
+ encoder_attention_mask,
885
+ output_attentions,
886
+ )
887
+ raise NotImplementedError("Currently, gradient checkpointing is not supported.")
888
+ else:
889
+ layer_outputs = layer(
890
+ hidden_states,
891
+ attention_mask=attention_mask,
892
+ encoder_hidden_states=encoder_hidden_states,
893
+ encoder_attention_mask=encoder_attention_mask,
894
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
895
+ cross_attn_layer_head_mask=(
896
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
897
+ ),
898
+ past_key_value=past_key_values,
899
+ output_attentions=output_attentions,
900
+ use_cache=use_cache,
901
+ cache_position=cache_position,
902
+ )
903
+ hidden_states = layer_outputs[0]
904
+
905
+ if skip_the_layer:
906
+ layer_outputs = (None, None, None, None)
907
+
908
+ if use_cache:
909
+ next_decoder_cache = layer_outputs[3 if output_attentions else 1]
910
+
911
+ if output_attentions:
912
+ all_self_attns += (layer_outputs[1],)
913
+
914
+ if encoder_hidden_states is not None:
915
+ all_cross_attentions += (layer_outputs[2],)
916
+
917
+ hidden_states = self.layer_norm(hidden_states)
918
+
919
+ # add hidden states from the last decoder layer
920
+ if output_hidden_states:
921
+ all_hidden_states += (hidden_states,)
922
+
923
+ next_cache = next_decoder_cache if use_cache else None
924
+
925
+ if output_hidden_states:
926
+ all_hidden_states = all_hidden_states + (hidden_states,)
927
+
928
+ if not return_dict:
929
+ return tuple(
930
+ v
931
+ for v in [
932
+ hidden_states,
933
+ next_cache,
934
+ all_hidden_states,
935
+ all_self_attns,
936
+ all_cross_attentions,
937
+ ]
938
+ if v is not None
939
+ )
940
+ return BaseModelOutputWithPastAndCrossAttentions(
941
+ last_hidden_state=hidden_states,
942
+ past_key_values=next_cache,
943
+ hidden_states=all_hidden_states,
944
+ attentions=all_self_attns,
945
+ cross_attentions=all_cross_attentions,
946
+ )
947
+
948
+ def _update_causal_mask(
949
+ self,
950
+ attention_mask: Optional[torch.Tensor],
951
+ input_tensor: torch.Tensor,
952
+ cache_position: torch.Tensor,
953
+ past_key_values: Cache,
954
+ ):
955
+ if self.config._attn_implementation == "flex_attention":
956
+ raise NotImplementedError
957
+
958
+ if self.config._attn_implementation == "flash_attention_2":
959
+ raise NotImplementedError
960
+
961
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
962
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
963
+ # to infer the attention mask.
964
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
965
+ using_compilable_cache = True if isinstance(past_key_values, StaticCache) else False
966
+
967
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
968
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
969
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
970
+ attention_mask,
971
+ inputs_embeds=input_tensor,
972
+ past_key_values_length=past_seen_tokens,
973
+ is_training=self.training,
974
+ ):
975
+ return None
976
+
977
+ dtype = input_tensor.dtype
978
+ sequence_length = input_tensor.shape[1]
979
+ if using_compilable_cache:
980
+ target_length = past_key_values.get_max_cache_shape()
981
+ else:
982
+ target_length = (
983
+ attention_mask.shape[-1]
984
+ if isinstance(attention_mask, torch.Tensor)
985
+ else past_seen_tokens + sequence_length + 1
986
+ )
987
+
988
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
989
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
990
+ attention_mask,
991
+ sequence_length=sequence_length,
992
+ target_length=target_length,
993
+ dtype=dtype,
994
+ cache_position=cache_position,
995
+ batch_size=input_tensor.shape[0],
996
+ )
997
+
998
+ if (
999
+ self.config._attn_implementation == "sdpa"
1000
+ and attention_mask is not None
1001
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
1002
+ ):
1003
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1004
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1005
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1006
+ min_dtype = torch.finfo(dtype).min
1007
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1008
+
1009
+ return causal_mask
1010
+
1011
+ @staticmethod
1012
+ def _prepare_4d_causal_attention_mask_with_cache_position(
1013
+ attention_mask: torch.Tensor,
1014
+ sequence_length: int,
1015
+ target_length: int,
1016
+ dtype: torch.dtype,
1017
+ cache_position: torch.Tensor,
1018
+ batch_size: int,
1019
+ **kwargs,
1020
+ ):
1021
+ """
1022
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
1023
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
1024
+
1025
+ Args:
1026
+ attention_mask (`torch.Tensor`):
1027
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
1028
+ `(batch_size, 1, query_length, key_value_length)`.
1029
+ sequence_length (`int`):
1030
+ The sequence length being processed.
1031
+ target_length (`int`):
1032
+ The target length: when generating with static cache, the mask should be as long as the static cache,
1033
+ to account for the 0 padding, the part of the cache that is not filled yet.
1034
+ dtype (`torch.dtype`):
1035
+ The dtype to use for the 4D attention mask.
1036
+ cache_position (`torch.Tensor`):
1037
+ Indices depicting the position of the input sequence tokens in the sequence.
1038
+ batch_size (`torch.Tensor`):
1039
+ Batch size.
1040
+ """
1041
+ if attention_mask is not None and attention_mask.dim() == 4:
1042
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
1043
+ causal_mask = attention_mask
1044
+ else:
1045
+ min_dtype = torch.finfo(dtype).min
1046
+ causal_mask = torch.full(
1047
+ (sequence_length, target_length),
1048
+ fill_value=min_dtype,
1049
+ dtype=dtype,
1050
+ device=cache_position.device,
1051
+ )
1052
+ if sequence_length != 1:
1053
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1054
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
1055
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
1056
+ if attention_mask is not None:
1057
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1058
+ mask_length = attention_mask.shape[-1]
1059
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
1060
+ causal_mask.device
1061
+ )
1062
+ padding_mask = padding_mask == 0
1063
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1064
+ padding_mask, min_dtype
1065
+ )
1066
+
1067
+ return causal_mask
1068
+
1069
+ def _update_cross_attn_mask(
1070
+ self,
1071
+ encoder_hidden_states: Union[torch.Tensor, None],
1072
+ encoder_attention_mask: Union[torch.Tensor, None],
1073
+ input_shape: torch.Size,
1074
+ inputs_embeds: torch.Tensor,
1075
+ ):
1076
+ # expand encoder attention mask
1077
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
1078
+ if self.config._attn_implementation == "flash_attention_2":
1079
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
1080
+ elif self.config._attn_implementation == "sdpa":
1081
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
1082
+ # the manual implementation that requires a 4D causal mask in all cases.
1083
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1084
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
1085
+ encoder_attention_mask,
1086
+ inputs_embeds.dtype,
1087
+ tgt_len=input_shape[-1],
1088
+ )
1089
+ elif self.config._attn_implementation == "flex_attention":
1090
+ raise NotImplementedError
1091
+ else:
1092
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1093
+ encoder_attention_mask = _prepare_4d_attention_mask(
1094
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1095
+ )
1096
+
1097
+ return encoder_attention_mask
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 3000,
4
+ "eos_token_id": 3001,
5
+ "pad_token_id": 3002,
6
+ "transformers_version": "4.53.3"
7
+ }
modeling_avhubert.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import dataclass
3
+ from typing import Optional
4
+
5
+ import einops
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from transformers import PreTrainedModel
10
+ from transformers.cache_utils import StaticCache
11
+ from transformers.generation import GenerationMixin
12
+ from transformers.generation.utils import GenerationConfig, GenerationMode
13
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
14
+ from transformers.modeling_outputs import Seq2SeqLMOutput
15
+ from transformers.models.hubert.modeling_hubert import (
16
+ HubertEncoder,
17
+ HubertEncoderStableLayerNorm,
18
+ )
19
+ from transformers.utils import ModelOutput
20
+
21
+ from .configuration_avhubert import AVHubertConfig
22
+ from .configuration_resnet import ResEncoderConfig
23
+ from .decoder import AVHubertDecoder, AVHubertDecoderStableLayerNorm
24
+ from .modeling_resnet import ResEncoder
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ NEED_SETUP_CACHE_CLASSES_MAPPING = {
29
+ "static": StaticCache,
30
+ }
31
+
32
+
33
+ @dataclass
34
+ class AVHubertOutput:
35
+ last_hidden_state: Optional[torch.Tensor] = None
36
+ hidden_states: Optional[torch.Tensor] = None
37
+ attentions: Optional[torch.Tensor] = None
38
+
39
+
40
+ class AudioFeatureExtractor(nn.Module):
41
+ def __init__(self, input_dim: int, output_dim: int) -> None:
42
+ super(AudioFeatureExtractor, self).__init__()
43
+ self.proj = nn.Linear(in_features=input_dim, out_features=output_dim)
44
+
45
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
46
+ x = self.proj(x) # [B, T, F]
47
+ return einops.rearrange(x, "b t f -> b f t") # [B, F, T]
48
+
49
+
50
+ class VideoFeatureExtractor(nn.Module):
51
+ def __init__(self, config: ResEncoderConfig, output_dim: int) -> None:
52
+ super(VideoFeatureExtractor, self).__init__()
53
+ self.resnet = ResEncoder(config=config)
54
+ self.proj = nn.Linear(
55
+ in_features=self.resnet.backend_out,
56
+ out_features=output_dim,
57
+ )
58
+
59
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
60
+ x = self.resnet(einops.rearrange(x, "b t c h w -> b c t h w")) # [B, F, T]
61
+ x = self.proj(einops.rearrange(x, "b f t -> b t f")) # [B, T, F]
62
+ return einops.rearrange(x, "b t f -> b f t") # [B, F, T]
63
+
64
+
65
+ class AVHubertPreTrainedModel(PreTrainedModel):
66
+ """
67
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
68
+ models.
69
+ """
70
+
71
+ config_class = AVHubertConfig
72
+ base_model_prefix = "avhubert"
73
+ supports_gradient_checkpointing = False
74
+
75
+ def _init_weights(self, module):
76
+ """Initialize the weights"""
77
+ if isinstance(module, (nn.Linear, nn.Embedding)):
78
+ # Slightly different from the TF version which uses truncated_normal for initialization
79
+ # cf https://github.com/pytorch/pytorch/pull/5617
80
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
81
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
82
+ module.bias.data.zero_()
83
+ module.weight.data.fill_(1.0)
84
+ elif isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
85
+ if is_deepspeed_zero3_enabled():
86
+ import deepspeed
87
+
88
+ if hasattr(module, "weight_v") and hasattr(module, "weight_g"):
89
+ with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0):
90
+ nn.init.kaiming_normal_(module.weight.data)
91
+ else:
92
+ with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0):
93
+ nn.init.kaiming_normal_(module.weight.data)
94
+ else:
95
+ if hasattr(module, "parametrizations"):
96
+ nn.init.kaiming_normal_(module.parametrizations.weight.original0.data)
97
+ nn.init.kaiming_normal_(module.parametrizations.weight.original1.data)
98
+ nn.init.kaiming_normal_(module.weight.data)
99
+
100
+ if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)) and module.bias is not None:
101
+ module.bias.data.zero_()
102
+
103
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor | int):
104
+ """
105
+ Computes the output length of the convolutional layers
106
+ """
107
+
108
+ def _conv_out_length(input_length, kernel_size, stride):
109
+ # 1D convolutional layer output length formula taken
110
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
111
+ return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
112
+
113
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
114
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
115
+
116
+ return input_lengths
117
+
118
+
119
+ class AVHubertModel(AVHubertPreTrainedModel):
120
+ def __init__(self, config: AVHubertConfig, **kwargs):
121
+ super().__init__(config, **kwargs)
122
+ self.config = config
123
+ self.feat2tar_ratio = config.label_rate / config.sample_rate
124
+
125
+ # feature extractor
126
+ resnet_config = ResEncoderConfig(relu_type=config.resnet_relu_type)
127
+ self.feature_extractor_audio = AudioFeatureExtractor(
128
+ input_dim=config.audio_feat_dim,
129
+ output_dim=config.encoder_embed_dim,
130
+ )
131
+ self.feature_extractor_video = VideoFeatureExtractor(config=resnet_config, output_dim=config.encoder_embed_dim)
132
+
133
+ self.encoder_embed_dim = config.encoder_embed_dim
134
+ if config.modality_fuse == "concat":
135
+ embed = config.encoder_embed_dim * 2
136
+ elif config.modality_fuse == "add":
137
+ embed = config.encoder_embed_dim
138
+ self.post_extract_proj = (
139
+ nn.Linear(embed, config.encoder_embed_dim) if embed != config.encoder_embed_dim else None
140
+ )
141
+
142
+ # dropout
143
+ self.dropout_input = nn.Dropout(config.dropout_input)
144
+
145
+ # transformer encoder
146
+ transformer_config = config.encoder_config
147
+ if transformer_config.do_stable_layer_norm:
148
+ self.encoder = HubertEncoderStableLayerNorm(config=transformer_config)
149
+ else:
150
+ self.encoder = HubertEncoder(config=transformer_config)
151
+ self.layer_norm = nn.LayerNorm(embed)
152
+
153
+ def forward_mask(self, features: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
154
+ extra = attention_mask.size(1) % features.size(1)
155
+ if extra > 0:
156
+ attention_mask = attention_mask[:, :-extra]
157
+ attention_mask = attention_mask.view(attention_mask.size(0), features.size(1), -1)
158
+ attention_mask = attention_mask.all(-1)
159
+ return attention_mask
160
+
161
+ def forward(
162
+ self,
163
+ input_values: Optional[torch.Tensor] = None,
164
+ pixel_values: Optional[torch.Tensor] = None,
165
+ padding_mask: Optional[torch.Tensor] = None,
166
+ output_attentions: bool = False,
167
+ output_hidden_states: bool = False,
168
+ **kwargs,
169
+ ) -> ModelOutput:
170
+ if input_values is not None and pixel_values is None:
171
+ features_audio = self.feature_extractor_audio(input_values) # [B, F, T]
172
+ features_video = torch.zeros_like(features_audio) # [B, F, T]
173
+ elif input_values is None and pixel_values is not None:
174
+ features_video = self.feature_extractor_video(pixel_values) # [B, F, T]
175
+ features_audio = torch.zeros_like(features_video) # [B, F, T]
176
+ elif input_values is not None and pixel_values is not None:
177
+ features_audio = self.feature_extractor_audio(input_values) # [B, F, T]
178
+ features_video = self.feature_extractor_video(pixel_values) # [B, F, T]
179
+ else:
180
+ raise ValueError("Either `input_values` or `pixel_values` must be passed")
181
+
182
+ # fuse modality
183
+ if self.config.modality_fuse == "concat":
184
+ features = torch.cat([features_audio, features_video], dim=1)
185
+ elif self.config.modality_fuse == "add":
186
+ features = features_audio + features_video
187
+
188
+ features = features.transpose(1, 2)
189
+ features = self.layer_norm(features)
190
+
191
+ if padding_mask is not None:
192
+ padding_mask = self.forward_mask(features, padding_mask)
193
+ else:
194
+ padding_mask = torch.zeros(features.size()[:2], dtype=torch.bool, device=features.device)
195
+
196
+ if self.post_extract_proj is not None:
197
+ features = self.post_extract_proj(features)
198
+
199
+ features = self.dropout_input(features)
200
+
201
+ # transformer encoder
202
+ encoder_out = self.encoder(
203
+ hidden_states=features,
204
+ attention_mask=~padding_mask.bool(),
205
+ output_attentions=output_attentions,
206
+ output_hidden_states=output_hidden_states,
207
+ )
208
+
209
+ return AVHubertOutput(
210
+ last_hidden_state=encoder_out.last_hidden_state,
211
+ hidden_states=encoder_out.hidden_states,
212
+ attentions=encoder_out.attentions,
213
+ )
214
+
215
+
216
+ class AVHubertForConditionalGeneration(AVHubertPreTrainedModel, GenerationMixin):
217
+ def __init__(
218
+ self,
219
+ config: AVHubertConfig,
220
+ **kwargs,
221
+ ) -> None:
222
+ super().__init__(config=config, **kwargs)
223
+ self.config = config
224
+
225
+ self.avhubert = AVHubertModel(config=config)
226
+ if config.freeze_base_model:
227
+ self.freeze_base_model()
228
+ if config.freeze_feature_encoder:
229
+ self.freeze_feature_encoder()
230
+
231
+ if config.vocab_size is None:
232
+ raise ValueError(
233
+ f"You are trying to instantiate {self.__class__} with a configuration that "
234
+ "does not define the vocabulary size of the language model head. Please "
235
+ "instantiate the model as follows: `AVHubertForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
236
+ "or define `vocab_size` of your model's configuration."
237
+ )
238
+
239
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.decoder_embed_dim, padding_idx=config.pad_token_id)
240
+ transformer_config = config.decoder_config
241
+ if transformer_config.do_stable_layer_norm:
242
+ self.decoder = AVHubertDecoderStableLayerNorm(config=transformer_config)
243
+ else:
244
+ self.decoder = AVHubertDecoder(config=transformer_config)
245
+
246
+ self.lm_head = nn.Linear(config.decoder_embed_dim, config.vocab_size, bias=False)
247
+ if config.share_decoder_input_output_embed:
248
+ # If this model shares lm head weights with the token embeddings,
249
+ # you can access lm head weights that is the same as the token embeddings but
250
+ # the token embeddings are directly referred to instead of lm heads when training!
251
+ self.lm_head.weight = self.embed_tokens.weight
252
+ else:
253
+ nn.init.normal_(self.lm_head.weight, mean=0, std=config.decoder_embed_dim**-0.5)
254
+
255
+ self.post_init()
256
+
257
+ def freeze_feature_encoder(self):
258
+ """
259
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
260
+ not be updated during training.
261
+ """
262
+ for param in self.avhubert.feature_extractor_audio.parameters():
263
+ param.requires_grad = False
264
+ for param in self.avhubert.feature_extractor_video.parameters():
265
+ param.requires_grad = False
266
+
267
+ def freeze_base_model(self):
268
+ """
269
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
270
+ be updated during training. Only the classification head will be updated.
271
+ """
272
+ for param in self.avhubert.parameters():
273
+ param.requires_grad = False
274
+
275
+ def get_encoder(self):
276
+ return self.avhubert
277
+
278
+ def forward(
279
+ self,
280
+ input_values: Optional[torch.Tensor] = None,
281
+ pixel_values: Optional[torch.Tensor] = None,
282
+ padding_mask: Optional[torch.Tensor] = None,
283
+ decoder_input_ids: Optional[torch.Tensor] = None,
284
+ decoder_attention_mask: Optional[torch.Tensor] = None,
285
+ labels: Optional[torch.Tensor] = None,
286
+ output_attentions: bool = False,
287
+ output_hidden_states: bool = False,
288
+ return_dict: bool = True,
289
+ ) -> ModelOutput:
290
+ encoder_outs = self.avhubert(
291
+ input_values=input_values,
292
+ pixel_values=pixel_values,
293
+ padding_mask=padding_mask,
294
+ output_attentions=output_attentions,
295
+ output_hidden_states=output_hidden_states,
296
+ )
297
+
298
+ embed_tokens = self.embed_tokens(decoder_input_ids)
299
+ hidden_states = self.decoder(
300
+ inputs_embeds=embed_tokens,
301
+ attention_mask=decoder_attention_mask,
302
+ encoder_hidden_states=encoder_outs.last_hidden_state,
303
+ encoder_attention_mask=~padding_mask.bool(),
304
+ output_attentions=output_attentions,
305
+ output_hidden_states=output_hidden_states,
306
+ )
307
+
308
+ if self.config.share_decoder_input_output_embed:
309
+ logits = F.linear(hidden_states.last_hidden_state, weight=self.embed_tokens.weight)
310
+ else:
311
+ logits = self.lm_head(hidden_states.last_hidden_state)
312
+
313
+ loss = None
314
+ if labels is not None:
315
+ loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
316
+ loss = loss_fn(logits.view(-1, self.config.vocab_size), labels.reshape(-1))
317
+
318
+ return Seq2SeqLMOutput(
319
+ loss=loss,
320
+ logits=logits,
321
+ past_key_values=None,
322
+ decoder_hidden_states=hidden_states.hidden_states,
323
+ decoder_attentions=hidden_states.attentions,
324
+ cross_attentions=None,
325
+ encoder_last_hidden_state=encoder_outs.last_hidden_state,
326
+ encoder_hidden_states=encoder_outs.hidden_states,
327
+ encoder_attentions=encoder_outs.attentions,
328
+ )
329
+
330
+ def _get_generation_mode(
331
+ self,
332
+ generation_config: GenerationConfig,
333
+ assistant_model: PreTrainedModel | None,
334
+ ) -> GenerationMode:
335
+ """
336
+ Returns the generation mode triggered by a [`GenerationConfig`] instance.
337
+ """
338
+ if generation_config.constraints is not None or generation_config.force_words_ids is not None:
339
+ generation_mode = GenerationMode.CONSTRAINED_BEAM_SEARCH
340
+ elif generation_config.num_beams == 1:
341
+ if generation_config.do_sample is False:
342
+ if (
343
+ generation_config.top_k is not None
344
+ and generation_config.top_k > 1
345
+ and generation_config.penalty_alpha is not None
346
+ and generation_config.penalty_alpha > 0
347
+ ):
348
+ generation_mode = GenerationMode.CONTRASTIVE_SEARCH
349
+ else:
350
+ generation_mode = GenerationMode.GREEDY_SEARCH
351
+ else:
352
+ generation_mode = GenerationMode.SAMPLE
353
+ else:
354
+ if generation_config.num_beam_groups > 1:
355
+ generation_mode = GenerationMode.GROUP_BEAM_SEARCH
356
+ elif generation_config.do_sample is True:
357
+ generation_mode = GenerationMode.BEAM_SAMPLE
358
+ else:
359
+ generation_mode = GenerationMode.BEAM_SEARCH
360
+
361
+ # Assisted generation may extend some generation modes
362
+ if assistant_model is not None or generation_config.prompt_lookup_num_tokens is not None:
363
+ if generation_mode in ("greedy_search", "sample"):
364
+ generation_mode = GenerationMode.ASSISTED_GENERATION
365
+ else:
366
+ raise ValueError(
367
+ "You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
368
+ "is only supported with Greedy Search and Sample."
369
+ )
370
+ return generation_mode
371
+
372
+ def prepare_inputs_for_generation(
373
+ self,
374
+ input_ids: torch.Tensor = None,
375
+ input_values: Optional[torch.Tensor] = None,
376
+ pixel_values: Optional[torch.Tensor] = None,
377
+ decoder_input_ids: Optional[torch.Tensor] = None,
378
+ decoder_attention_mask: Optional[torch.Tensor] = None,
379
+ padding_mask: Optional[torch.Tensor] = None,
380
+ **kwargs,
381
+ ):
382
+ if decoder_input_ids is None:
383
+ decoder_input_ids = input_ids
384
+ decoder_attention_mask = torch.ones_like(input_ids)
385
+ return {
386
+ "input_values": input_values,
387
+ "pixel_values": pixel_values,
388
+ "decoder_input_ids": decoder_input_ids,
389
+ "decoder_attention_mask": decoder_attention_mask,
390
+ "padding_mask": padding_mask,
391
+ }
modeling_resnet.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch.nn as nn
4
+ from transformers import PreTrainedModel
5
+
6
+ from .configuration_resnet import ResEncoderConfig
7
+
8
+
9
+ def conv3x3(in_planes, out_planes, stride=1):
10
+ return nn.Conv2d(
11
+ in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
12
+ )
13
+
14
+
15
+ def downsample_basic_block(inplanes, outplanes, stride):
16
+ return nn.Sequential(
17
+ nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=stride, bias=False),
18
+ nn.BatchNorm2d(outplanes),
19
+ )
20
+
21
+
22
+ def downsample_basic_block_v2(inplanes, outplanes, stride):
23
+ return nn.Sequential(
24
+ nn.AvgPool2d(
25
+ kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False
26
+ ),
27
+ nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, bias=False),
28
+ nn.BatchNorm2d(outplanes),
29
+ )
30
+
31
+
32
+ class BasicBlock(nn.Module):
33
+ expansion = 1
34
+
35
+ def __init__(self, inplanes, planes, stride=1, downsample=None, relu_type="relu"):
36
+ super(BasicBlock, self).__init__()
37
+
38
+ assert relu_type in ["relu", "prelu"]
39
+
40
+ self.conv1 = conv3x3(inplanes, planes, stride)
41
+ self.bn1 = nn.BatchNorm2d(planes)
42
+
43
+ if relu_type == "relu":
44
+ self.relu1 = nn.ReLU(inplace=True)
45
+ self.relu2 = nn.ReLU(inplace=True)
46
+ elif relu_type == "prelu":
47
+ self.relu1 = nn.PReLU(num_parameters=planes)
48
+ self.relu2 = nn.PReLU(num_parameters=planes)
49
+ else:
50
+ raise Exception("relu type not implemented")
51
+
52
+ self.conv2 = conv3x3(planes, planes)
53
+ self.bn2 = nn.BatchNorm2d(planes)
54
+
55
+ self.downsample = downsample
56
+ self.stride = stride
57
+
58
+ def forward(self, x):
59
+ residual = x
60
+ out = self.conv1(x)
61
+ out = self.bn1(out)
62
+ out = self.relu1(out)
63
+ out = self.conv2(out)
64
+ out = self.bn2(out)
65
+ if self.downsample is not None:
66
+ residual = self.downsample(x)
67
+
68
+ out += residual
69
+ out = self.relu2(out)
70
+
71
+ return out
72
+
73
+
74
+ class ResNet(nn.Module):
75
+ def __init__(
76
+ self,
77
+ block,
78
+ layers,
79
+ num_classes=1000,
80
+ relu_type="relu",
81
+ gamma_zero=False,
82
+ avg_pool_downsample=False,
83
+ ):
84
+ self.inplanes = 64
85
+ self.relu_type = relu_type
86
+ self.gamma_zero = gamma_zero
87
+ self.downsample_block = (
88
+ downsample_basic_block_v2 if avg_pool_downsample else downsample_basic_block
89
+ )
90
+
91
+ super(ResNet, self).__init__()
92
+ self.layer1 = self._make_layer(block, 64, layers[0])
93
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
94
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
95
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
96
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
97
+
98
+ for m in self.modules():
99
+ if isinstance(m, nn.Conv2d):
100
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
101
+ m.weight.data.normal_(0, math.sqrt(2.0 / n))
102
+ elif isinstance(m, nn.BatchNorm2d):
103
+ m.weight.data.fill_(1)
104
+ m.bias.data.zero_()
105
+
106
+ if self.gamma_zero:
107
+ for m in self.modules():
108
+ if isinstance(m, BasicBlock):
109
+ m.bn2.weight.data.zero_()
110
+
111
+ def _make_layer(self, block, planes, blocks, stride=1):
112
+ downsample = None
113
+ if stride != 1 or self.inplanes != planes * block.expansion:
114
+ downsample = self.downsample_block(
115
+ inplanes=self.inplanes,
116
+ outplanes=planes * block.expansion,
117
+ stride=stride,
118
+ )
119
+
120
+ layers = []
121
+ layers.append(
122
+ block(self.inplanes, planes, stride, downsample, relu_type=self.relu_type)
123
+ )
124
+ self.inplanes = planes * block.expansion
125
+ for i in range(1, blocks):
126
+ layers.append(block(self.inplanes, planes, relu_type=self.relu_type))
127
+
128
+ return nn.Sequential(*layers)
129
+
130
+ def forward(self, x):
131
+ x = self.layer1(x)
132
+ x = self.layer2(x)
133
+ x = self.layer3(x)
134
+ x = self.layer4(x)
135
+ x = self.avgpool(x)
136
+ x = x.view(x.size(0), -1)
137
+ return x
138
+
139
+
140
+ class ResEncoder(PreTrainedModel):
141
+ def __init__(self, config: ResEncoderConfig):
142
+ super(ResEncoder, self).__init__(config=config)
143
+ self.frontend_nout = config.frontend_nout
144
+ self.backend_out = config.backend_out
145
+ frontend_relu = (
146
+ nn.PReLU(num_parameters=self.frontend_nout)
147
+ if config.relu_type == "prelu"
148
+ else nn.ReLU()
149
+ )
150
+ self.frontend3D = nn.Sequential(
151
+ nn.Conv3d(
152
+ 1,
153
+ self.frontend_nout,
154
+ kernel_size=(5, 7, 7),
155
+ stride=(1, 2, 2),
156
+ padding=(2, 3, 3),
157
+ bias=False,
158
+ ),
159
+ nn.BatchNorm3d(self.frontend_nout),
160
+ frontend_relu,
161
+ nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),
162
+ )
163
+ self.trunk = ResNet(BasicBlock, [2, 2, 2, 2], relu_type=config.relu_type)
164
+
165
+ def forward(self, x):
166
+ B, C, T, H, W = x.size()
167
+ x = self.frontend3D(x)
168
+ Tnew = x.shape[2]
169
+ x = self.threeD_to_2D_tensor(x)
170
+ x = self.trunk(x)
171
+ x = x.view(B, Tnew, x.size(1))
172
+ x = x.transpose(1, 2).contiguous()
173
+ return x
174
+
175
+ def threeD_to_2D_tensor(self, x):
176
+ n_batch, n_channels, s_time, sx, sy = x.shape
177
+ x = x.transpose(1, 2).contiguous()
178
+ return x.reshape(n_batch * s_time, n_channels, sx, sy)
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:35ee1a95844cd8f2f45822d0c8c5f167727337bc5a616e95a02b4b0a4341ca2b
3
+ size 653053499