flaubert commited on
Commit
8b96d1e
·
verified ·
1 Parent(s): 811127f

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. model.safetensors +2 -2
  2. modeling_data2vec2.py +47 -8
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d1f3a5b7e501b52e0d9d4ad11a9396cb2956e01e2972e4062c2dbb844c1419b7
3
- size 496547472
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:273806a81277195257bb969cda50cc595f645d2528869b869f1a280dfc6f8368
3
+ size 498910032
modeling_data2vec2.py CHANGED
@@ -26,6 +26,7 @@ import math
26
  import warnings
27
  from typing import Optional, Tuple, Dict, List, Callable, Any
28
  from functools import partial
 
29
 
30
  import numpy as np
31
 
@@ -35,9 +36,7 @@ from torch import nn
35
  from torch import Tensor
36
 
37
  from transformers import PreTrainedModel
38
- from transformers.modeling_outputs import (
39
- Wav2Vec2BaseModelOutput,
40
- )
41
  from .configuration_data2vec2 import (
42
  Data2Vec2MultiConfig,
43
  D2v2ModalityConfig,
@@ -59,6 +58,15 @@ from .utils_data2vec2 import (
59
  )
60
 
61
 
 
 
 
 
 
 
 
 
 
62
  #################################################
63
  ### modeling_data2vec2_base.py
64
  # copied from fairseq.modules.grad_multiply
@@ -1221,6 +1229,22 @@ class TextEncoder(ModalitySpecificEncoder):
1221
  #################################################
1222
 
1223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1224
  class Data2Vec2MultiPreTrainedModel(PreTrainedModel):
1225
  # use init_bert_params from fairseq
1226
  # copied from fairseq.modules.transformer_sentence_encoder.py
@@ -1281,7 +1305,9 @@ class Data2Vec2MultiModel(Data2Vec2MultiPreTrainedModel):
1281
  config_class = Data2Vec2MultiConfig
1282
  base_model_prefix = "data2vec2"
1283
 
1284
- def __init__(self, config: Data2Vec2MultiConfig):
 
 
1285
  super().__init__(config)
1286
  self.config = config
1287
  modalities_cfg = config.modalities
@@ -1327,6 +1353,10 @@ class Data2Vec2MultiModel(Data2Vec2MultiPreTrainedModel):
1327
 
1328
  self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(config.depth)])
1329
 
 
 
 
 
1330
  self.norm = None
1331
  if config.layer_norm_first:
1332
  self.norm = make_layer_norm(config.embed_dim)
@@ -1355,6 +1385,9 @@ class Data2Vec2MultiModel(Data2Vec2MultiPreTrainedModel):
1355
  """
1356
  for mod in self.modalities:
1357
  self.modality_encoders[mod]._freeze_parameters()
 
 
 
1358
 
1359
  def make_modality_encoder(
1360
  self,
@@ -1405,7 +1438,7 @@ class Data2Vec2MultiModel(Data2Vec2MultiPreTrainedModel):
1405
  precomputed_mask=None,
1406
  )
1407
  x = extractor_out["x"]
1408
- extract_features = x
1409
 
1410
  # encoder_mask = extractor_out["encoder_mask"]
1411
  masked_padding_mask = extractor_out["padding_mask"]
@@ -1447,20 +1480,26 @@ class Data2Vec2MultiModel(Data2Vec2MultiPreTrainedModel):
1447
  :, feature_extractor.modality_cfg.num_extra_tokens :
1448
  ]
1449
 
 
 
 
 
1450
  if not return_dict:
1451
  return tuple(
1452
  v
1453
  for v in [
1454
  x,
1455
- extract_features,
 
1456
  layer_results,
1457
  ]
1458
  if v is not None
1459
  )
1460
 
1461
- return Wav2Vec2BaseModelOutput(
1462
  last_hidden_state=x,
1463
- extract_features=extract_features,
 
1464
  hidden_states=layer_results if output_hidden_states else None,
1465
  attentions=None, # switch to manual implementation with fast=False in forward pass of AltAttention as pytorch's dspa does not output attention weights
1466
  )
 
26
  import warnings
27
  from typing import Optional, Tuple, Dict, List, Callable, Any
28
  from functools import partial
29
+ from dataclasses import dataclass
30
 
31
  import numpy as np
32
 
 
36
  from torch import Tensor
37
 
38
  from transformers import PreTrainedModel
39
+ from transformers.utils import ModelOutput
 
 
40
  from .configuration_data2vec2 import (
41
  Data2Vec2MultiConfig,
42
  D2v2ModalityConfig,
 
58
  )
59
 
60
 
61
+ @dataclass
62
+ class Data2vec2BaseModelOutput(ModelOutput):
63
+ last_hidden_state: Optional[torch.FloatTensor] = None # output of the encoder-only model
64
+ pooler_output: Optional[torch.FloatTensor] = None # pooled output for text tasks, which is the first token representation followed by a dense layer and activation function
65
+ local_features: Optional[torch.FloatTensor] = None # features before the Transformer encoder
66
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
67
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None # TODO: only support manual implementation with fast=False in the forward pass of AltAttention as pytorch's dspa does not output attention weights
68
+
69
+
70
  #################################################
71
  ### modeling_data2vec2_base.py
72
  # copied from fairseq.modules.grad_multiply
 
1229
  #################################################
1230
 
1231
 
1232
+ # copied from transformers.models.data2vec.modeling_data2vec.Data2VecTextPooler
1233
+ class Data2VecTextPooler(nn.Module):
1234
+ def __init__(self, config):
1235
+ super().__init__()
1236
+ self.dense = nn.Linear(config.embed_dim, config.embed_dim)
1237
+ self.activation = nn.Tanh()
1238
+
1239
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1240
+ # We "pool" the model by simply taking the hidden state corresponding
1241
+ # to the first token.
1242
+ first_token_tensor = hidden_states[:, 0]
1243
+ pooled_output = self.dense(first_token_tensor)
1244
+ pooled_output = self.activation(pooled_output)
1245
+ return pooled_output
1246
+
1247
+
1248
  class Data2Vec2MultiPreTrainedModel(PreTrainedModel):
1249
  # use init_bert_params from fairseq
1250
  # copied from fairseq.modules.transformer_sentence_encoder.py
 
1305
  config_class = Data2Vec2MultiConfig
1306
  base_model_prefix = "data2vec2"
1307
 
1308
+ def __init__(
1309
+ self, config: Data2Vec2MultiConfig, add_pooling_layer: bool = True
1310
+ ):
1311
  super().__init__(config)
1312
  self.config = config
1313
  modalities_cfg = config.modalities
 
1353
 
1354
  self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(config.depth)])
1355
 
1356
+ self.text_pooler = None
1357
+ if add_pooling_layer and config.supported_modality == "TEXT":
1358
+ self.text_pooler = Data2VecTextPooler(config)
1359
+
1360
  self.norm = None
1361
  if config.layer_norm_first:
1362
  self.norm = make_layer_norm(config.embed_dim)
 
1385
  """
1386
  for mod in self.modalities:
1387
  self.modality_encoders[mod]._freeze_parameters()
1388
+ for block in self.blocks:
1389
+ for p in block.parameters():
1390
+ p.requires_grad = False
1391
 
1392
  def make_modality_encoder(
1393
  self,
 
1438
  precomputed_mask=None,
1439
  )
1440
  x = extractor_out["x"]
1441
+ local_features = x
1442
 
1443
  # encoder_mask = extractor_out["encoder_mask"]
1444
  masked_padding_mask = extractor_out["padding_mask"]
 
1480
  :, feature_extractor.modality_cfg.num_extra_tokens :
1481
  ]
1482
 
1483
+ txt_pooled_output = (
1484
+ self.text_pooler(x) if self.text_pooler is not None else None
1485
+ )
1486
+
1487
  if not return_dict:
1488
  return tuple(
1489
  v
1490
  for v in [
1491
  x,
1492
+ txt_pooled_output,
1493
+ local_features,
1494
  layer_results,
1495
  ]
1496
  if v is not None
1497
  )
1498
 
1499
+ return Data2vec2BaseModelOutput(
1500
  last_hidden_state=x,
1501
+ pooler_output=txt_pooled_output,
1502
+ local_features=local_features,
1503
  hidden_states=layer_results if output_hidden_states else None,
1504
  attentions=None, # switch to manual implementation with fast=False in forward pass of AltAttention as pytorch's dspa does not output attention weights
1505
  )