Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- Utils/PLBERT/config.yml +30 -0
- Utils/PLBERT/step_1000000.t7 +3 -0
- Utils/PLBERT/util.py +42 -0
.gitattributes
CHANGED
|
@@ -2332,3 +2332,4 @@ data/wavs/wavs/021_-_NonVerbal_Skills_For_Great_Leaders_3d5ba0fc_part034_02.wav
|
|
| 2332 |
data/wavs/wavs/021_-_NonVerbal_Skills_For_Great_Leaders_3d5ba0fc_part032_02.wav filter=lfs diff=lfs merge=lfs -text
|
| 2333 |
data/wavs/wavs/021_-_NonVerbal_Skills_For_Great_Leaders_3d5ba0fc_part032_01.wav filter=lfs diff=lfs merge=lfs -text
|
| 2334 |
Utils/JDC/bst.t7 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 2332 |
data/wavs/wavs/021_-_NonVerbal_Skills_For_Great_Leaders_3d5ba0fc_part032_02.wav filter=lfs diff=lfs merge=lfs -text
|
| 2333 |
data/wavs/wavs/021_-_NonVerbal_Skills_For_Great_Leaders_3d5ba0fc_part032_01.wav filter=lfs diff=lfs merge=lfs -text
|
| 2334 |
Utils/JDC/bst.t7 filter=lfs diff=lfs merge=lfs -text
|
| 2335 |
+
Utils/PLBERT/step_1000000.t7 filter=lfs diff=lfs merge=lfs -text
|
Utils/PLBERT/config.yml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
log_dir: "Checkpoint"
|
| 2 |
+
mixed_precision: "fp16"
|
| 3 |
+
data_folder: "wikipedia_20220301.en.processed"
|
| 4 |
+
batch_size: 192
|
| 5 |
+
save_interval: 5000
|
| 6 |
+
log_interval: 10
|
| 7 |
+
num_process: 1 # number of GPUs
|
| 8 |
+
num_steps: 1000000
|
| 9 |
+
|
| 10 |
+
dataset_params:
|
| 11 |
+
tokenizer: "transfo-xl-wt103"
|
| 12 |
+
token_separator: " " # token used for phoneme separator (space)
|
| 13 |
+
token_mask: "M" # token used for phoneme mask (M)
|
| 14 |
+
word_separator: 3039 # token used for word separator (<formula>)
|
| 15 |
+
token_maps: "token_maps.pkl" # token map path
|
| 16 |
+
|
| 17 |
+
max_mel_length: 512 # max phoneme length
|
| 18 |
+
|
| 19 |
+
word_mask_prob: 0.15 # probability to mask the entire word
|
| 20 |
+
phoneme_mask_prob: 0.1 # probability to mask each phoneme
|
| 21 |
+
replace_prob: 0.2 # probablity to replace phonemes
|
| 22 |
+
|
| 23 |
+
model_params:
|
| 24 |
+
vocab_size: 178
|
| 25 |
+
hidden_size: 768
|
| 26 |
+
num_attention_heads: 12
|
| 27 |
+
intermediate_size: 2048
|
| 28 |
+
max_position_embeddings: 512
|
| 29 |
+
num_hidden_layers: 12
|
| 30 |
+
dropout: 0.1
|
Utils/PLBERT/step_1000000.t7
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0714ff85804db43e06b3b0ac5749bf90cf206257c6c5916e8a98c5933b4c21e0
|
| 3 |
+
size 25185187
|
Utils/PLBERT/util.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import yaml
|
| 3 |
+
import torch
|
| 4 |
+
from transformers import AlbertConfig, AlbertModel
|
| 5 |
+
|
| 6 |
+
class CustomAlbert(AlbertModel):
|
| 7 |
+
def forward(self, *args, **kwargs):
|
| 8 |
+
# Call the original forward method
|
| 9 |
+
outputs = super().forward(*args, **kwargs)
|
| 10 |
+
|
| 11 |
+
# Only return the last_hidden_state
|
| 12 |
+
return outputs.last_hidden_state
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def load_plbert(log_dir):
|
| 16 |
+
config_path = os.path.join(log_dir, "config.yml")
|
| 17 |
+
plbert_config = yaml.safe_load(open(config_path))
|
| 18 |
+
|
| 19 |
+
albert_base_configuration = AlbertConfig(**plbert_config['model_params'])
|
| 20 |
+
bert = CustomAlbert(albert_base_configuration)
|
| 21 |
+
|
| 22 |
+
files = os.listdir(log_dir)
|
| 23 |
+
ckpts = []
|
| 24 |
+
for f in os.listdir(log_dir):
|
| 25 |
+
if f.startswith("step_"): ckpts.append(f)
|
| 26 |
+
|
| 27 |
+
iters = [int(f.split('_')[-1].split('.')[0]) for f in ckpts if os.path.isfile(os.path.join(log_dir, f))]
|
| 28 |
+
iters = sorted(iters)[-1]
|
| 29 |
+
|
| 30 |
+
checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".t7", map_location='cpu')
|
| 31 |
+
state_dict = checkpoint['net']
|
| 32 |
+
from collections import OrderedDict
|
| 33 |
+
new_state_dict = OrderedDict()
|
| 34 |
+
for k, v in state_dict.items():
|
| 35 |
+
name = k[7:] # remove `module.`
|
| 36 |
+
if name.startswith('encoder.'):
|
| 37 |
+
name = name[8:] # remove `encoder.`
|
| 38 |
+
new_state_dict[name] = v
|
| 39 |
+
del new_state_dict["embeddings.position_ids"]
|
| 40 |
+
bert.load_state_dict(new_state_dict, strict=False)
|
| 41 |
+
|
| 42 |
+
return bert
|