Upload 21 files
Browse files- tools/bert-vocab.txt +0 -0
- tools/checkpoint_loader_megatron.py +273 -0
- tools/checkpoint_saver_megatron.py +322 -0
- tools/checkpoint_util.py +151 -0
- tools/linter.py +36 -0
- tools/merge_datasets.py +66 -0
- tools/merge_mp_partitions.py +352 -0
- tools/openwebtext/README.md +59 -0
- tools/openwebtext/add_id.py +67 -0
- tools/openwebtext/blacklist_urls.py +312 -0
- tools/openwebtext/cleanup_dataset.py +115 -0
- tools/openwebtext/cleanup_fix_dataset.py +191 -0
- tools/openwebtext/filter_ngrams.py +492 -0
- tools/openwebtext/find_duplicates.py +305 -0
- tools/openwebtext/group_duplicate_url.py +90 -0
- tools/openwebtext/merge_jsons.py +55 -0
- tools/openwebtext/remove_group_duplicates.py +69 -0
- tools/preprocess_data.py +205 -0
- tools/run_build_data.sh +16 -0
- tools/run_text_generation_server.py +90 -0
- tools/text_generation_cli.py +34 -0
tools/bert-vocab.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tools/checkpoint_loader_megatron.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import types
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
def add_arguments(parser):
|
| 9 |
+
group = parser.add_argument_group(title='Megatron loader')
|
| 10 |
+
|
| 11 |
+
group.add_argument('--true-vocab-size', type=int, default=None,
|
| 12 |
+
help='original size of vocab, if specified will trim padding from embedding table.')
|
| 13 |
+
group.add_argument('--vocab-file', type=str, default=None,
|
| 14 |
+
help='Path to the vocab file. If specified will use this to get vocab size and '
|
| 15 |
+
'trim padding from the embedding table.')
|
| 16 |
+
group.add_argument('--megatron-path', type=str, default=None,
|
| 17 |
+
help='Base directory of deepspeed repository')
|
| 18 |
+
|
| 19 |
+
def _load_checkpoint(queue, args):
|
| 20 |
+
|
| 21 |
+
# Search in directory above this
|
| 22 |
+
sys.path.append(os.path.abspath(
|
| 23 |
+
os.path.join(os.path.dirname(__file__),
|
| 24 |
+
os.path.pardir)))
|
| 25 |
+
if args.megatron_path is not None:
|
| 26 |
+
sys.path.insert(0, args.megatron_path)
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
from megatron.arguments import parse_args, validate_args
|
| 30 |
+
from megatron.global_vars import set_args, set_global_variables
|
| 31 |
+
from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint
|
| 32 |
+
from megatron.model import ModelType, module
|
| 33 |
+
from megatron import mpu, fused_kernels
|
| 34 |
+
except ModuleNotFoundError:
|
| 35 |
+
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
|
| 36 |
+
queue.put("exit")
|
| 37 |
+
exit(1)
|
| 38 |
+
|
| 39 |
+
# We want all arguments to come from us
|
| 40 |
+
sys.argv = ['script.py',
|
| 41 |
+
'--no-masked-softmax-fusion',
|
| 42 |
+
'--no-bias-gelu-fusion',
|
| 43 |
+
'--no-bias-dropout-fusion',
|
| 44 |
+
'--use-cpu-initialization',
|
| 45 |
+
'--micro-batch-size', '1',
|
| 46 |
+
'--no-load-optim',
|
| 47 |
+
'--no-load-rng',
|
| 48 |
+
'--no-save-optim',
|
| 49 |
+
'--no-save-rng',
|
| 50 |
+
'--no-initialization',
|
| 51 |
+
'--load', args.load_dir
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
margs = parse_args()
|
| 55 |
+
margs = load_args_from_checkpoint(margs)
|
| 56 |
+
|
| 57 |
+
# Arguments do sanity checks on the world size, but we don't care,
|
| 58 |
+
# so trick it into thinking we are plenty of processes
|
| 59 |
+
margs.world_size = margs.tensor_model_parallel_size * margs.pipeline_model_parallel_size
|
| 60 |
+
|
| 61 |
+
margs = validate_args(margs)
|
| 62 |
+
|
| 63 |
+
def check_for_arg(arg_name):
|
| 64 |
+
if getattr(margs, arg_name, None) is None:
|
| 65 |
+
print(f"Checkpoint does not specify the argument {arg_name}. Exiting.")
|
| 66 |
+
print(f"Arguments: {margs}")
|
| 67 |
+
queue.put("exit")
|
| 68 |
+
exit(1)
|
| 69 |
+
|
| 70 |
+
check_for_arg('tensor_model_parallel_size')
|
| 71 |
+
check_for_arg('pipeline_model_parallel_size')
|
| 72 |
+
check_for_arg('num_layers')
|
| 73 |
+
check_for_arg('hidden_size')
|
| 74 |
+
check_for_arg('seq_length')
|
| 75 |
+
check_for_arg('num_attention_heads')
|
| 76 |
+
check_for_arg('max_position_embeddings')
|
| 77 |
+
check_for_arg('tokenizer_type')
|
| 78 |
+
check_for_arg('iteration')
|
| 79 |
+
check_for_arg('bert_binary_head')
|
| 80 |
+
check_for_arg('params_dtype')
|
| 81 |
+
|
| 82 |
+
# Determine how to make our models
|
| 83 |
+
if args.model_type == 'GPT':
|
| 84 |
+
from pretrain_gpt import model_provider
|
| 85 |
+
margs.model_type = ModelType.encoder_or_decoder
|
| 86 |
+
elif args.model_type == 'BERT':
|
| 87 |
+
from pretrain_bert import model_provider
|
| 88 |
+
margs.model_type = ModelType.encoder_or_decoder
|
| 89 |
+
else:
|
| 90 |
+
raise Exception(f'unrecognized model type: {args.model_type}')
|
| 91 |
+
|
| 92 |
+
# supress warning about torch.distributed not being initialized
|
| 93 |
+
module.MegatronModule.embedding_warning_printed = True
|
| 94 |
+
|
| 95 |
+
consumed_train_samples = None
|
| 96 |
+
consumed_valid_samples = None
|
| 97 |
+
def get_models(count, dtype, pre_process, post_process):
|
| 98 |
+
nonlocal consumed_train_samples
|
| 99 |
+
nonlocal consumed_valid_samples
|
| 100 |
+
models = []
|
| 101 |
+
for rank in range(count):
|
| 102 |
+
mpu.initialize.set_tensor_model_parallel_rank(rank)
|
| 103 |
+
model_ = [model_provider(pre_process, post_process).to(dtype)]
|
| 104 |
+
margs.consumed_train_samples = 0
|
| 105 |
+
margs.consumed_valid_samples = 0
|
| 106 |
+
load_checkpoint(model_, None, None)
|
| 107 |
+
assert(len(model_) == 1)
|
| 108 |
+
model_ = model_[0]
|
| 109 |
+
if consumed_train_samples is not None:
|
| 110 |
+
assert(margs.consumed_train_samples == consumed_train_samples)
|
| 111 |
+
else:
|
| 112 |
+
consumed_train_samples = margs.consumed_train_samples
|
| 113 |
+
if consumed_valid_samples is not None:
|
| 114 |
+
assert(margs.consumed_valid_samples == consumed_valid_samples)
|
| 115 |
+
else:
|
| 116 |
+
consumed_valid_samples = margs.consumed_valid_samples
|
| 117 |
+
models.append(model_)
|
| 118 |
+
return models
|
| 119 |
+
|
| 120 |
+
if margs.num_layers_per_virtual_pipeline_stage is not None:
|
| 121 |
+
print("Model with an interleaved pipeline schedule are not yet supported.")
|
| 122 |
+
queue.put("exit")
|
| 123 |
+
exit(1)
|
| 124 |
+
|
| 125 |
+
set_global_variables(margs)
|
| 126 |
+
mpu.initialize.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size)
|
| 127 |
+
mpu.initialize.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size)
|
| 128 |
+
fused_kernels.load(margs)
|
| 129 |
+
|
| 130 |
+
# Get true (non-padded) vocab size
|
| 131 |
+
if args.true_vocab_size is not None:
|
| 132 |
+
true_vocab_size = args.true_vocab_size
|
| 133 |
+
elif args.vocab_file is not None:
|
| 134 |
+
vocab = json.load(open(args.vocab_file))
|
| 135 |
+
true_vocab_size = len(vocab)
|
| 136 |
+
if args.true_vocab_size is not None and true_vocab_size != args.true_vocab_size:
|
| 137 |
+
print("Both --true-vocab-size and --vocab-file specified and the vocab size does not match, aborting.")
|
| 138 |
+
queue.put("exit")
|
| 139 |
+
exit(1)
|
| 140 |
+
else:
|
| 141 |
+
true_vocab_size = None
|
| 142 |
+
|
| 143 |
+
# short aliases
|
| 144 |
+
tp_size = margs.tensor_model_parallel_size
|
| 145 |
+
pp_size = margs.pipeline_model_parallel_size
|
| 146 |
+
|
| 147 |
+
# metadata
|
| 148 |
+
md = types.SimpleNamespace()
|
| 149 |
+
md.model_type = args.model_type
|
| 150 |
+
md.num_layers = margs.num_layers
|
| 151 |
+
md.hidden_size = margs.hidden_size
|
| 152 |
+
md.seq_length = margs.seq_length
|
| 153 |
+
md.num_attention_heads = margs.num_attention_heads
|
| 154 |
+
md.max_position_embeddings = margs.max_position_embeddings
|
| 155 |
+
md.tokenizer_type = margs.tokenizer_type
|
| 156 |
+
md.iteration = margs.iteration
|
| 157 |
+
md.params_dtype = margs.params_dtype
|
| 158 |
+
md.bert_binary_head = margs.bert_binary_head
|
| 159 |
+
md.previous_tensor_parallel_size = margs.tensor_model_parallel_size
|
| 160 |
+
md.previous_pipeline_parallel_size = margs.pipeline_model_parallel_size
|
| 161 |
+
md.true_vocab_size = true_vocab_size
|
| 162 |
+
md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by
|
| 163 |
+
|
| 164 |
+
# Get first pipe stage
|
| 165 |
+
mpu.initialize.set_pipeline_model_parallel_rank(0)
|
| 166 |
+
post_process = pp_size == 1
|
| 167 |
+
models = get_models(tp_size, md.params_dtype, True, post_process)
|
| 168 |
+
|
| 169 |
+
md.consumed_train_samples = consumed_train_samples
|
| 170 |
+
md.consumed_valid_samples = consumed_valid_samples
|
| 171 |
+
queue.put(md)
|
| 172 |
+
|
| 173 |
+
def queue_put(name, msg):
|
| 174 |
+
print(f"sending {name}")
|
| 175 |
+
msg["name"] = name
|
| 176 |
+
queue.put(msg)
|
| 177 |
+
|
| 178 |
+
# Send embeddings
|
| 179 |
+
message = {
|
| 180 |
+
"position embeddings": models[0].language_model.embedding.position_embeddings.weight.data,
|
| 181 |
+
"word embeddings": torch.cat(
|
| 182 |
+
[models[tp_rank].language_model.embedding.word_embeddings.weight.data for tp_rank in range(tp_size)],
|
| 183 |
+
dim = 0)
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
queue_put("embeddings", message)
|
| 187 |
+
|
| 188 |
+
total_layer_num = 0
|
| 189 |
+
for pp_rank in range(pp_size):
|
| 190 |
+
if pp_rank > 0:
|
| 191 |
+
mpu.initialize.set_pipeline_model_parallel_rank(pp_rank)
|
| 192 |
+
post_process = pp_rank == pp_size - 1
|
| 193 |
+
models = get_models(tp_size, md.params_dtype, False, post_process)
|
| 194 |
+
for layer_num in range(len(models[0].language_model.encoder.layers)):
|
| 195 |
+
message = {}
|
| 196 |
+
|
| 197 |
+
# Get non-parallel tensors from tp_rank 0
|
| 198 |
+
layer = models[0].language_model.encoder.layers[layer_num]
|
| 199 |
+
message["input layernorm weight"] = layer.input_layernorm.weight.data
|
| 200 |
+
message["input layernorm bias"] = layer.input_layernorm.bias.data
|
| 201 |
+
message["dense bias"] = layer.self_attention.dense.bias.data
|
| 202 |
+
message["post layernorm weight"] = layer.post_attention_layernorm.weight.data
|
| 203 |
+
message["post layernorm bias"] = layer.post_attention_layernorm.bias.data
|
| 204 |
+
message["mlp l1 bias"] = layer.mlp.dense_4h_to_h.bias.data
|
| 205 |
+
|
| 206 |
+
# Grab all parallel tensors for this layer
|
| 207 |
+
qkv_weight = []
|
| 208 |
+
qkv_bias = []
|
| 209 |
+
dense_weight = []
|
| 210 |
+
mlp_l0_weight = []
|
| 211 |
+
mlp_l0_bias = []
|
| 212 |
+
mlp_l1_weight = []
|
| 213 |
+
for tp_rank, model in enumerate(models):
|
| 214 |
+
layer = model.language_model.encoder.layers[layer_num]
|
| 215 |
+
qkv_weight.append(layer.self_attention.query_key_value.weight.data)
|
| 216 |
+
qkv_bias.append(layer.self_attention.query_key_value.bias.data)
|
| 217 |
+
dense_weight.append(layer.self_attention.dense.weight.data)
|
| 218 |
+
mlp_l0_weight.append(layer.mlp.dense_h_to_4h.weight.data)
|
| 219 |
+
mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data)
|
| 220 |
+
mlp_l1_weight.append(layer.mlp.dense_4h_to_h.weight.data)
|
| 221 |
+
|
| 222 |
+
# concat them
|
| 223 |
+
message["qkv weight"] = torch.cat(qkv_weight, dim=0)
|
| 224 |
+
message["qkv bias"] = torch.cat(qkv_bias, dim=0)
|
| 225 |
+
message["dense weight"] = torch.cat(dense_weight, dim=1)
|
| 226 |
+
message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0)
|
| 227 |
+
message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0)
|
| 228 |
+
message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1)
|
| 229 |
+
|
| 230 |
+
queue_put(f"transformer layer {total_layer_num}", message)
|
| 231 |
+
|
| 232 |
+
total_layer_num = total_layer_num + 1
|
| 233 |
+
|
| 234 |
+
# Send final layernorm from tp_rank 0
|
| 235 |
+
message = {
|
| 236 |
+
"weight": models[0].language_model.encoder.final_layernorm.weight.data,
|
| 237 |
+
"bias": models[0].language_model.encoder.final_layernorm.bias.data
|
| 238 |
+
}
|
| 239 |
+
queue_put("final layernorm", message)
|
| 240 |
+
|
| 241 |
+
# Send BERT lm head and binary head if it exists
|
| 242 |
+
if md.model_type == 'BERT':
|
| 243 |
+
print("Sending LM Pooler")
|
| 244 |
+
message = {
|
| 245 |
+
"weight": models[0].language_model.pooler.dense.weight.data,
|
| 246 |
+
"bias": models[0].language_model.pooler.dense.bias.data
|
| 247 |
+
}
|
| 248 |
+
queue_put("pooler", message)
|
| 249 |
+
|
| 250 |
+
message = {
|
| 251 |
+
"dense weight": models[0].lm_head.dense.weight.data,
|
| 252 |
+
"dense bias": models[0].lm_head.dense.bias.data,
|
| 253 |
+
"layernorm weight": models[0].lm_head.layernorm.weight.data,
|
| 254 |
+
"layernorm bias": models[0].lm_head.layernorm.bias.data
|
| 255 |
+
}
|
| 256 |
+
queue_put("lm head", message)
|
| 257 |
+
|
| 258 |
+
if md.bert_binary_head:
|
| 259 |
+
print("Sending BERT Binary head")
|
| 260 |
+
queue.put("binary head")
|
| 261 |
+
message = {
|
| 262 |
+
"weight": models[0].binary_head.weight.data,
|
| 263 |
+
"bias": models[0].binary_head.bias.data
|
| 264 |
+
}
|
| 265 |
+
queue_put("binary head", message)
|
| 266 |
+
queue.put("done")
|
| 267 |
+
|
| 268 |
+
def load_checkpoint(queue, args):
|
| 269 |
+
try:
|
| 270 |
+
_load_checkpoint(queue, args)
|
| 271 |
+
except:
|
| 272 |
+
queue.put("exit")
|
| 273 |
+
raise
|
tools/checkpoint_saver_megatron.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from collections.abc import Mapping
|
| 3 |
+
import concurrent.futures
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
def add_arguments(parser):
|
| 10 |
+
group = parser.add_argument_group(title='Megatron saver')
|
| 11 |
+
|
| 12 |
+
group.add_argument('--megatron-path', type=str, default=None,
|
| 13 |
+
help='Base directory of Megatron repository')
|
| 14 |
+
|
| 15 |
+
group.add_argument('--target-tensor-parallel-size', type=int,
|
| 16 |
+
help='Target tensor model parallel size, defaults to the tensor parallel size '
|
| 17 |
+
'in the input checkpoint if provided by the loader, otherwise to 1')
|
| 18 |
+
group.add_argument('--target-pipeline-parallel-size', type=int,
|
| 19 |
+
help='Target tensor model parallel size, default to the pipeline parall size '
|
| 20 |
+
'in the input checkpoint if provided by the loader, otherwise to 1')
|
| 21 |
+
|
| 22 |
+
def save_checkpoint(queue, args):
|
| 23 |
+
|
| 24 |
+
# Search in directory above this
|
| 25 |
+
sys.path.append(os.path.abspath(
|
| 26 |
+
os.path.join(os.path.dirname(__file__),
|
| 27 |
+
os.path.pardir)))
|
| 28 |
+
if args.megatron_path is not None:
|
| 29 |
+
sys.path.insert(0, args.megatron_path)
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
from megatron.arguments import (parse_args, validate_args)
|
| 33 |
+
from megatron.checkpointing import save_checkpoint
|
| 34 |
+
from megatron.global_vars import set_global_variables, get_args
|
| 35 |
+
from megatron.model import ModelType
|
| 36 |
+
from megatron.tokenizer.tokenizer import _vocab_size_with_padding
|
| 37 |
+
from megatron import mpu, fused_kernels
|
| 38 |
+
except ModuleNotFoundError:
|
| 39 |
+
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
|
| 40 |
+
exit(1)
|
| 41 |
+
|
| 42 |
+
def queue_get(name=None):
|
| 43 |
+
val = queue.get()
|
| 44 |
+
if val == "exit":
|
| 45 |
+
print("Loader exited, exiting saver")
|
| 46 |
+
exit(1)
|
| 47 |
+
if name is not None and args.checking and val["name"] != name:
|
| 48 |
+
val_name = val["name"]
|
| 49 |
+
print(f'Unexpected message. Expecting "{name}" but got "{val_name}". Exiting saver.')
|
| 50 |
+
exit(1)
|
| 51 |
+
if name is not None:
|
| 52 |
+
print(f"received {name}")
|
| 53 |
+
return val
|
| 54 |
+
|
| 55 |
+
def check_message(msg):
|
| 56 |
+
if not args.checking:
|
| 57 |
+
return
|
| 58 |
+
msg_name = msg.pop("name")
|
| 59 |
+
if len(msg.keys()) > 0:
|
| 60 |
+
print(f"Unexpected values in {msg_name}:")
|
| 61 |
+
for key in msg.keys():
|
| 62 |
+
print(f" {key}")
|
| 63 |
+
print(f"Exiting. If you want to ignore this, use the argument --no-checking.")
|
| 64 |
+
exit(1)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
md = queue_get()
|
| 68 |
+
|
| 69 |
+
if args.target_tensor_parallel_size is None:
|
| 70 |
+
if hasattr(md, 'previous_tensor_parallel_size'):
|
| 71 |
+
args.target_tensor_parallel_size = md.previous_tensor_parallel_size
|
| 72 |
+
else:
|
| 73 |
+
print("loader did not provide a tensor parallel size and --target-tensor-parallel-size not provided on command line. "
|
| 74 |
+
"Default to 1.")
|
| 75 |
+
args.target_tensor_parallel_size = 1
|
| 76 |
+
|
| 77 |
+
if args.target_pipeline_parallel_size is None:
|
| 78 |
+
if hasattr(md, 'previous_pipeline_parallel_size'):
|
| 79 |
+
args.target_pipeline_parallel_size = md.previous_pipeline_parallel_size
|
| 80 |
+
else:
|
| 81 |
+
print("loader did not provide a pipeline parallel size and --target-pipeline-parallel-size not provided on command line. "
|
| 82 |
+
"Default to 1.")
|
| 83 |
+
args.target_pipeline_parallel_size = 1
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# Arguments do sanity checks on the world size, but we don't care,
|
| 87 |
+
# so trick it into thinking we are plenty of processes
|
| 88 |
+
if args.target_tensor_parallel_size is not None and args.target_pipeline_parallel_size is not None:
|
| 89 |
+
os.environ["WORLD_SIZE"] = f'{args.target_tensor_parallel_size * args.target_pipeline_parallel_size}'
|
| 90 |
+
|
| 91 |
+
# We want all arguments to come from us
|
| 92 |
+
sys.argv = ['script.py',
|
| 93 |
+
'--num-layers', str(md.num_layers),
|
| 94 |
+
'--hidden-size', str(md.hidden_size),
|
| 95 |
+
'--seq-length', str(md.seq_length),
|
| 96 |
+
'--num-attention-heads', str(md.num_attention_heads),
|
| 97 |
+
'--max-position-embeddings', str(md.max_position_embeddings),
|
| 98 |
+
'--tokenizer-type', str(md.tokenizer_type),
|
| 99 |
+
'--tensor-model-parallel-size', str(args.target_tensor_parallel_size),
|
| 100 |
+
'--pipeline-model-parallel-size', str(args.target_pipeline_parallel_size),
|
| 101 |
+
'--no-masked-softmax-fusion',
|
| 102 |
+
'--no-bias-gelu-fusion',
|
| 103 |
+
'--no-bias-dropout-fusion',
|
| 104 |
+
'--use-cpu-initialization',
|
| 105 |
+
'--micro-batch-size', '1',
|
| 106 |
+
'--no-load-optim',
|
| 107 |
+
'--no-load-rng',
|
| 108 |
+
'--no-save-optim',
|
| 109 |
+
'--no-save-rng',
|
| 110 |
+
'--no-initialization',
|
| 111 |
+
'--save-interval', '1',
|
| 112 |
+
'--save', args.save_dir
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
if md.make_vocab_size_divisible_by is not None:
|
| 116 |
+
sys.argv.extend(['--make-vocab-size-divisible-by', str(md.make_vocab_size_divisible_by)])
|
| 117 |
+
if md.params_dtype == torch.float16:
|
| 118 |
+
sys.argv.append('--fp16')
|
| 119 |
+
elif md.params_dtype == torch.bfloat16:
|
| 120 |
+
sys.argv.append('--bf16')
|
| 121 |
+
|
| 122 |
+
if md.model_type == 'BERT' and not md.bert_binary_head:
|
| 123 |
+
sys.argv.append('--bert-no-binary-head')
|
| 124 |
+
|
| 125 |
+
margs = parse_args()
|
| 126 |
+
validate_args(margs)
|
| 127 |
+
set_global_variables(margs)
|
| 128 |
+
|
| 129 |
+
# margs = megatron args
|
| 130 |
+
margs = get_args()
|
| 131 |
+
|
| 132 |
+
if hasattr(md, 'consumed_train_samples'):
|
| 133 |
+
margs.consumed_train_samples = md.consumed_train_samples
|
| 134 |
+
margs.consumed_valid_samples = md.consumed_valid_samples
|
| 135 |
+
print(f"Setting consumed_train_samples to {margs.consumed_train_samples}"
|
| 136 |
+
f" and consumed_valid_samples to {margs.consumed_valid_samples}")
|
| 137 |
+
else:
|
| 138 |
+
print("consumed_train_samples not provided.")
|
| 139 |
+
|
| 140 |
+
# Determine how to make our models
|
| 141 |
+
if md.model_type == 'GPT':
|
| 142 |
+
from pretrain_gpt import model_provider
|
| 143 |
+
margs.model_type = ModelType.encoder_or_decoder
|
| 144 |
+
elif md.model_type == 'BERT':
|
| 145 |
+
from pretrain_bert import model_provider
|
| 146 |
+
margs.model_type = ModelType.encoder_or_decoder
|
| 147 |
+
else:
|
| 148 |
+
raise Exception(f'unrecognized model type: {args.model_type}')
|
| 149 |
+
|
| 150 |
+
def get_models(count, dtype, pre_process, post_process):
|
| 151 |
+
models = [model_provider(pre_process, post_process).to(dtype) for _ in range(count)]
|
| 152 |
+
return models
|
| 153 |
+
|
| 154 |
+
# fake initializing distributed
|
| 155 |
+
mpu.initialize.set_tensor_model_parallel_world_size(args.target_tensor_parallel_size)
|
| 156 |
+
mpu.initialize.set_pipeline_model_parallel_world_size(args.target_pipeline_parallel_size)
|
| 157 |
+
mpu.initialize.set_tensor_model_parallel_rank(0)
|
| 158 |
+
mpu.initialize.set_pipeline_model_parallel_rank(0)
|
| 159 |
+
fused_kernels.load(margs)
|
| 160 |
+
|
| 161 |
+
# Embeddings
|
| 162 |
+
#-----------
|
| 163 |
+
embeddings_msg = queue_get("embeddings")
|
| 164 |
+
|
| 165 |
+
pos_embed = embeddings_msg.pop("position embeddings")
|
| 166 |
+
orig_word_embed = embeddings_msg.pop("word embeddings")
|
| 167 |
+
check_message(embeddings_msg)
|
| 168 |
+
|
| 169 |
+
# Deal with padding
|
| 170 |
+
if md.true_vocab_size is not None:
|
| 171 |
+
# figure out what our padded vocab size is
|
| 172 |
+
orig_vocab_size = orig_word_embed.shape[0]
|
| 173 |
+
margs.padded_vocab_size = _vocab_size_with_padding(md.true_vocab_size, margs)
|
| 174 |
+
|
| 175 |
+
# Cut out extra padding we don't need
|
| 176 |
+
if orig_vocab_size > margs.padded_vocab_size:
|
| 177 |
+
full_word_embed = orig_word_embed[0:margs.padded_vocab_size,:]
|
| 178 |
+
|
| 179 |
+
# Expanding embedding to larger size by replicating final entry
|
| 180 |
+
elif orig_vocab_size < margs.padded_vocab_size:
|
| 181 |
+
padding_size = margs.padded_vocab_size - orig_vocab_size
|
| 182 |
+
|
| 183 |
+
full_word_embed = torch.cat((
|
| 184 |
+
orig_word_embed,
|
| 185 |
+
orig_word_embed[-1].unsqueeze(0).expand(padding_size, -1)))
|
| 186 |
+
|
| 187 |
+
# Same size!
|
| 188 |
+
else:
|
| 189 |
+
full_word_embed = orig_word_embed
|
| 190 |
+
else:
|
| 191 |
+
print("Original vocab size not specified, leaving embedding table as-is. "
|
| 192 |
+
"If you've changed the tensor parallel size this could cause problems.")
|
| 193 |
+
margs.padded_vocab_size = orig_word_embed.shape[0]
|
| 194 |
+
full_word_embed = orig_word_embed
|
| 195 |
+
|
| 196 |
+
# Split into new tensor model parallel sizes
|
| 197 |
+
out_word_embed = torch.chunk(full_word_embed, args.target_tensor_parallel_size, dim=0)
|
| 198 |
+
|
| 199 |
+
# Make models for first pipeline stage and fill in embeddings
|
| 200 |
+
mpu.initialize.set_pipeline_model_parallel_rank(0)
|
| 201 |
+
post_process = args.target_pipeline_parallel_size == 1
|
| 202 |
+
models = get_models(args.target_tensor_parallel_size, md.params_dtype, True, post_process)
|
| 203 |
+
for tp_rank, model in enumerate(models):
|
| 204 |
+
print(f"word embeddings shape {model.language_model.embedding.word_embeddings.weight.shape}")
|
| 205 |
+
model.language_model.embedding.word_embeddings.weight.data.copy_(out_word_embed[tp_rank])
|
| 206 |
+
model.language_model.embedding.position_embeddings.weight.data.copy_(pos_embed)
|
| 207 |
+
|
| 208 |
+
# Transformer layers
|
| 209 |
+
#-------------------
|
| 210 |
+
total_layer_num = 0
|
| 211 |
+
for pp_rank in range(args.target_pipeline_parallel_size):
|
| 212 |
+
# For later pipeline parallel ranks, make the new models
|
| 213 |
+
if pp_rank > 0:
|
| 214 |
+
mpu.initialize.set_pipeline_model_parallel_rank(pp_rank)
|
| 215 |
+
post_process = pp_rank == args.target_pipeline_parallel_size - 1
|
| 216 |
+
models = get_models(args.target_tensor_parallel_size, md.params_dtype, False, post_process)
|
| 217 |
+
|
| 218 |
+
for layer in range(len(models[0].language_model.encoder.layers)):
|
| 219 |
+
msg = queue_get(f"transformer layer {total_layer_num}")
|
| 220 |
+
|
| 221 |
+
# duplicated tensors
|
| 222 |
+
input_layernorm_weight = msg.pop("input layernorm weight")
|
| 223 |
+
input_layernorm_bias = msg.pop("input layernorm bias")
|
| 224 |
+
dense_bias = msg.pop("dense bias")
|
| 225 |
+
post_layernorm_weight = msg.pop("post layernorm weight")
|
| 226 |
+
post_layernorm_bias = msg.pop("post layernorm bias")
|
| 227 |
+
mlp_l1_bias = msg.pop("mlp l1 bias")
|
| 228 |
+
|
| 229 |
+
# Split up the parallel tensors
|
| 230 |
+
qkv_weight = torch.chunk(msg.pop("qkv weight"), args.target_tensor_parallel_size, dim=0)
|
| 231 |
+
qkv_bias = torch.chunk(msg.pop("qkv bias"), args.target_tensor_parallel_size, dim=0)
|
| 232 |
+
dense_weight = torch.chunk(msg.pop("dense weight"), args.target_tensor_parallel_size, dim=1)
|
| 233 |
+
mlp_l0_weight = torch.chunk(msg.pop("mlp l0 weight"), args.target_tensor_parallel_size, dim=0)
|
| 234 |
+
mlp_l0_bias = torch.chunk(msg.pop("mlp l0 bias"), args.target_tensor_parallel_size, dim=0)
|
| 235 |
+
mlp_l1_weight = torch.chunk(msg.pop("mlp l1 weight"), args.target_tensor_parallel_size, dim=1)
|
| 236 |
+
|
| 237 |
+
# Save them to the model
|
| 238 |
+
for tp_rank in range(args.target_tensor_parallel_size):
|
| 239 |
+
l = models[tp_rank].language_model.encoder.layers[layer]
|
| 240 |
+
l.input_layernorm.weight.data.copy_(input_layernorm_weight)
|
| 241 |
+
l.input_layernorm.bias.data.copy_(input_layernorm_bias)
|
| 242 |
+
l.self_attention.query_key_value.weight.data.copy_(qkv_weight[tp_rank])
|
| 243 |
+
l.self_attention.query_key_value.bias.data.copy_(qkv_bias[tp_rank])
|
| 244 |
+
l.self_attention.dense.weight.data.copy_(dense_weight[tp_rank])
|
| 245 |
+
l.self_attention.dense.bias.data.copy_(dense_bias)
|
| 246 |
+
l.post_attention_layernorm.weight.data.copy_(post_layernorm_weight)
|
| 247 |
+
l.post_attention_layernorm.bias.data.copy_(post_layernorm_bias)
|
| 248 |
+
l.mlp.dense_h_to_4h.weight.data.copy_(mlp_l0_weight[tp_rank])
|
| 249 |
+
l.mlp.dense_h_to_4h.bias.data.copy_(mlp_l0_bias[tp_rank])
|
| 250 |
+
l.mlp.dense_4h_to_h.weight.data.copy_(mlp_l1_weight[tp_rank])
|
| 251 |
+
l.mlp.dense_4h_to_h.bias.data.copy_(mlp_l1_bias)
|
| 252 |
+
total_layer_num = total_layer_num + 1
|
| 253 |
+
check_message(msg)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
if post_process:
|
| 257 |
+
msg = queue_get("final layernorm")
|
| 258 |
+
final_layernorm_weight = msg.pop("weight")
|
| 259 |
+
final_layernorm_bias = msg.pop("bias")
|
| 260 |
+
for tp_rank in range(args.target_tensor_parallel_size):
|
| 261 |
+
models[tp_rank].language_model.encoder.final_layernorm.weight.data.copy_(final_layernorm_weight)
|
| 262 |
+
models[tp_rank].language_model.encoder.final_layernorm.bias.data.copy_(final_layernorm_bias)
|
| 263 |
+
if pp_rank != 0:
|
| 264 |
+
# Copy word embeddings to final pipeline rank
|
| 265 |
+
models[tp_rank].word_embeddings.weight.data.copy_(out_word_embed[tp_rank])
|
| 266 |
+
del final_layernorm_weight
|
| 267 |
+
del final_layernorm_bias
|
| 268 |
+
check_message(msg)
|
| 269 |
+
|
| 270 |
+
msg = queue_get()
|
| 271 |
+
if msg != "done" and msg["name"] == "pooler":
|
| 272 |
+
if not hasattr(models[0].language_model, 'pooler'):
|
| 273 |
+
print("ERROR: got a pooler, but model does not have one")
|
| 274 |
+
exit(1)
|
| 275 |
+
print("received pooler")
|
| 276 |
+
pooler_weight = msg.pop("weight")
|
| 277 |
+
pooler_bias = msg.pop("bias")
|
| 278 |
+
for tp_rank in range(args.target_tensor_parallel_size):
|
| 279 |
+
models[tp_rank].language_model.pooler.dense.weight.data.copy_(pooler_weight)
|
| 280 |
+
models[tp_rank].language_model.pooler.dense.bias.data.copy_(pooler_bias)
|
| 281 |
+
del pooler_weight
|
| 282 |
+
del pooler_bias
|
| 283 |
+
check_message(msg)
|
| 284 |
+
msg = queue_get()
|
| 285 |
+
|
| 286 |
+
if msg != "done" and msg["name"] == "lm head":
|
| 287 |
+
if not hasattr(models[0], 'lm_head'):
|
| 288 |
+
print("ERROR: got an lm head, but model does not have one")
|
| 289 |
+
exit(1)
|
| 290 |
+
print("received lm head")
|
| 291 |
+
lm_head_dense_weight = msg.pop("dense weight")
|
| 292 |
+
lm_head_dense_bias = msg.pop("dense bias")
|
| 293 |
+
lm_head_layernorm_weight = msg.pop("layernorm weight")
|
| 294 |
+
lm_head_layernorm_bias = msg.pop("layernorm bias")
|
| 295 |
+
for tp_rank in range(args.target_tensor_parallel_size):
|
| 296 |
+
models[tp_rank].lm_head.dense.weight.data.copy_(lm_head_dense_weight)
|
| 297 |
+
models[tp_rank].lm_head.dense.bias.data.copy_(lm_head_dense_bias)
|
| 298 |
+
models[tp_rank].lm_head.layernorm.weight.data.copy_(lm_head_layernorm_weight)
|
| 299 |
+
models[tp_rank].lm_head.layernorm.bias.data.copy_(lm_head_layernorm_bias)
|
| 300 |
+
check_message(msg)
|
| 301 |
+
msg = queue_get()
|
| 302 |
+
|
| 303 |
+
if msg != "done" and msg["name"] == "binary head":
|
| 304 |
+
if not hasattr(models[0], 'binary_head'):
|
| 305 |
+
print("ERROR: got a binary head, but model does not have one")
|
| 306 |
+
exit(1)
|
| 307 |
+
print("received binary head")
|
| 308 |
+
binary_head_weight = msg.pop("weight")
|
| 309 |
+
binary_head_bias = msg.pop("bias")
|
| 310 |
+
for tp_rank in range(args.target_tensor_parallel_size):
|
| 311 |
+
models[tp_rank].binary_head.weight.data.copy_(binary_head_weight)
|
| 312 |
+
models[tp_rank].binary_head.bias.data.copy_(binary_head_bias)
|
| 313 |
+
check_message(msg)
|
| 314 |
+
msg = queue_get()
|
| 315 |
+
|
| 316 |
+
if msg != "done":
|
| 317 |
+
print("ERROR: got some more data but was expecting to be done")
|
| 318 |
+
|
| 319 |
+
for tp_rank in range(args.target_tensor_parallel_size):
|
| 320 |
+
mpu.initialize.set_tensor_model_parallel_rank(tp_rank)
|
| 321 |
+
save_checkpoint(md.iteration, [models[tp_rank]], None, None)
|
| 322 |
+
print("Done!")
|
tools/checkpoint_util.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import importlib
|
| 3 |
+
import torch.multiprocessing as mp
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
|
| 7 |
+
# A loader is a python file with at least two functions
|
| 8 |
+
# - add_arguments - takes in a parser and adds any arguments needed
|
| 9 |
+
# - load_checkpoint - takes in the queue and parsed arguments
|
| 10 |
+
|
| 11 |
+
# A saver is similar but has save_checkpoint instead of
|
| 12 |
+
# load_checkpoint
|
| 13 |
+
|
| 14 |
+
# The loader and saver process are each given a queue, the loader
|
| 15 |
+
# should load the checkpoint and send the weights in messages in the
|
| 16 |
+
# following order, the saver should receive them in this order and
|
| 17 |
+
# save the checkpoints. A message consists of a python dictionary with
|
| 18 |
+
# a "name" for error checking and an entry for each tensor as
|
| 19 |
+
# indicated below. Note that the weight sent over the queue are the
|
| 20 |
+
# full model weights, nothing split.
|
| 21 |
+
|
| 22 |
+
# If the loader ever sends "exit" to the queue, that means something
|
| 23 |
+
# went wrong and it is exiting.
|
| 24 |
+
|
| 25 |
+
# - Metadata Namespace with the following attributes:
|
| 26 |
+
# model_type - GPT, BERT, T5, etc. (Part of protocol to allow this to be deduced later instead of given on command line)
|
| 27 |
+
# num_layers - Number of transformer layers
|
| 28 |
+
# hidden_size
|
| 29 |
+
# seq_length
|
| 30 |
+
# num_attention_heads
|
| 31 |
+
# max_position_embeddings
|
| 32 |
+
# tokenizer_type
|
| 33 |
+
# iteration
|
| 34 |
+
# params_dtype
|
| 35 |
+
# bert_binary_head - Used only if model_type is BERT
|
| 36 |
+
# previous_tensor_parallel_size - Optional
|
| 37 |
+
# previous_pipeline_parallel_size - Optional
|
| 38 |
+
# true_vocab_size
|
| 39 |
+
# make_vocab_size_divisble_by
|
| 40 |
+
# consumed_train_samples
|
| 41 |
+
# consumed_valid_samples
|
| 42 |
+
# messages
|
| 43 |
+
# {
|
| 44 |
+
# "name": "embeddings"
|
| 45 |
+
# "position embeddings"
|
| 46 |
+
# "word embeddings"
|
| 47 |
+
# }
|
| 48 |
+
# (for each transformer layer):
|
| 49 |
+
# {
|
| 50 |
+
# "name": "transformer layer N"
|
| 51 |
+
# "input layernorm weight"
|
| 52 |
+
# "input layernorm bias"
|
| 53 |
+
# "qkv weight"
|
| 54 |
+
# "qkv bias"
|
| 55 |
+
# "dense weight"
|
| 56 |
+
# "dense bias"
|
| 57 |
+
# "post layernorm weight"
|
| 58 |
+
# "post layernorm bias"
|
| 59 |
+
# "mlp l0 weight"
|
| 60 |
+
# "mlp l0 bias"
|
| 61 |
+
# "mlp l1 weight"
|
| 62 |
+
# "mlp l1 bias"
|
| 63 |
+
# }
|
| 64 |
+
# {
|
| 65 |
+
# "name": "final layer norm"
|
| 66 |
+
# "weight"
|
| 67 |
+
# "bias"
|
| 68 |
+
# }
|
| 69 |
+
# if present (i.e. for BERT):
|
| 70 |
+
# {
|
| 71 |
+
# "name": "pooler"
|
| 72 |
+
# "weight"
|
| 73 |
+
# "bias"
|
| 74 |
+
# }
|
| 75 |
+
# {
|
| 76 |
+
# "name": "lm head"
|
| 77 |
+
# "dense weight"
|
| 78 |
+
# "dense bias"
|
| 79 |
+
# "layernorm weight"
|
| 80 |
+
# "layernorm bias"
|
| 81 |
+
# }
|
| 82 |
+
# {
|
| 83 |
+
# "name": "binary head"
|
| 84 |
+
# "weight"
|
| 85 |
+
# "bias"
|
| 86 |
+
# }
|
| 87 |
+
# - "done"
|
| 88 |
+
|
| 89 |
+
def load_plugin(plugin_type, name):
|
| 90 |
+
module_name = f"checkpoint_{plugin_type}_{name}"
|
| 91 |
+
try:
|
| 92 |
+
plugin = importlib.import_module(module_name)
|
| 93 |
+
except ModuleNotFoundError:
|
| 94 |
+
module_name = name
|
| 95 |
+
try:
|
| 96 |
+
plugin = importlib.import_module(module_name)
|
| 97 |
+
except ModuleNotFoundError:
|
| 98 |
+
sys.exit(f"Unable to load {plugin_type} plugin {name}. Exiting.")
|
| 99 |
+
|
| 100 |
+
if not hasattr(plugin, 'add_arguments'):
|
| 101 |
+
sys.exit(f"{module_name} module is not a plugin. Exiting.")
|
| 102 |
+
|
| 103 |
+
print(f"Loaded {module_name} as the {plugin_type}.")
|
| 104 |
+
return plugin
|
| 105 |
+
|
| 106 |
+
def main():
|
| 107 |
+
import argparse
|
| 108 |
+
parser = argparse.ArgumentParser(description="Megatron Checkpoint Utility Arguments",
|
| 109 |
+
allow_abbrev=False, conflict_handler='resolve')
|
| 110 |
+
|
| 111 |
+
parser.add_argument('--model-type', type=str, required=True,
|
| 112 |
+
choices=['GPT', 'BERT'],
|
| 113 |
+
help='Type of the model')
|
| 114 |
+
parser.add_argument('--loader', type=str, default='megatron',
|
| 115 |
+
help='Module name to load checkpoint, should be on python path')
|
| 116 |
+
parser.add_argument('--saver', type=str, default='megatron',
|
| 117 |
+
help='Module name to save checkpoint, shdoul be on python path')
|
| 118 |
+
parser.add_argument('--load-dir', type=str, required=True,
|
| 119 |
+
help='Directory to load model checkpoint from')
|
| 120 |
+
parser.add_argument('--save-dir', type=str, required=True,
|
| 121 |
+
help='Directory to save model checkpoint to')
|
| 122 |
+
parser.add_argument('--max-queue-size', type=int, default=50,
|
| 123 |
+
help='Maximum number of tensors in the queue')
|
| 124 |
+
parser.add_argument('--no-checking', action='store_false',
|
| 125 |
+
help='Do not perform checking on the name and ordering of weights',
|
| 126 |
+
dest='checking')
|
| 127 |
+
|
| 128 |
+
known_args, _ = parser.parse_known_args()
|
| 129 |
+
loader = load_plugin('loader', known_args.loader)
|
| 130 |
+
saver = load_plugin('saver', known_args.saver)
|
| 131 |
+
|
| 132 |
+
loader.add_arguments(parser)
|
| 133 |
+
saver.add_arguments(parser)
|
| 134 |
+
|
| 135 |
+
args = parser.parse_args()
|
| 136 |
+
|
| 137 |
+
queue = mp.Queue(maxsize=args.max_queue_size)
|
| 138 |
+
|
| 139 |
+
print("Starting saver...")
|
| 140 |
+
saver_proc = mp.Process(target=saver.save_checkpoint, args=(queue, args))
|
| 141 |
+
saver_proc.start()
|
| 142 |
+
|
| 143 |
+
print("Starting loader...")
|
| 144 |
+
loader.load_checkpoint(queue, args)
|
| 145 |
+
|
| 146 |
+
print("Waiting for saver to complete...")
|
| 147 |
+
saver_proc.join()
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
if __name__ == '__main__':
|
| 151 |
+
main()
|
tools/linter.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import os.path as osp
|
| 3 |
+
import pathlib
|
| 4 |
+
import subprocess
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def recursively_lint_files():
|
| 8 |
+
"""Recursively lint all python files in chosen subdirectories of megatron-lm"""
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
import autopep8
|
| 12 |
+
except ModuleNotFoundError:
|
| 13 |
+
print("Please first install autopep8 via `pip install autopep8`")
|
| 14 |
+
return
|
| 15 |
+
|
| 16 |
+
# get all python file paths from top level directory
|
| 17 |
+
file_dir = str(pathlib.Path(__file__).parent.absolute())
|
| 18 |
+
working_dir = osp.join(file_dir, os.pardir)
|
| 19 |
+
all_py_paths = set(os.path.join(working_dir, fname)
|
| 20 |
+
for fname in os.listdir(working_dir) if ".py" in fname)
|
| 21 |
+
|
| 22 |
+
# get all python file paths from chosen subdirectories
|
| 23 |
+
check_dirs = ['docker', 'megatron', 'openwebtext', 'scripts', 'tasks']
|
| 24 |
+
for sub_dir in check_dirs:
|
| 25 |
+
for path, _, fnames in os.walk(osp.join(working_dir, sub_dir)):
|
| 26 |
+
all_py_paths.update(set(osp.join(path, fname) for fname in fnames if ".py" in fname))
|
| 27 |
+
|
| 28 |
+
print("Linting the following: ")
|
| 29 |
+
for py_path in all_py_paths:
|
| 30 |
+
print(py_path)
|
| 31 |
+
command = 'autopep8 --max-line-length 100 --aggressive --in-place {}'.format(py_path)
|
| 32 |
+
subprocess.check_call(command)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if __name__ == "__main__":
|
| 36 |
+
recursively_lint_files()
|
tools/merge_datasets.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import json
|
| 4 |
+
import argparse
|
| 5 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
|
| 6 |
+
os.path.pardir)))
|
| 7 |
+
|
| 8 |
+
from megatron.data import indexed_dataset
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def main(args):
|
| 12 |
+
|
| 13 |
+
prefixes = set()
|
| 14 |
+
for basename in os.listdir(args.input):
|
| 15 |
+
prefix, ext = os.path.splitext(basename)
|
| 16 |
+
|
| 17 |
+
if prefix in prefixes:
|
| 18 |
+
continue
|
| 19 |
+
|
| 20 |
+
if not os.path.isfile(os.path.join(args.input, basename)):
|
| 21 |
+
continue
|
| 22 |
+
|
| 23 |
+
ext_pair = '.bin' if ext == '.idx' else '.idx'
|
| 24 |
+
assert os.path.isfile(os.path.join(args.input, prefix) + ext_pair), \
|
| 25 |
+
f'ERROR: {ext_pair} file not provided for {os.path.join(args.input, prefix)}'
|
| 26 |
+
|
| 27 |
+
prefixes.add(prefix)
|
| 28 |
+
|
| 29 |
+
builder = None
|
| 30 |
+
for prefix in sorted(prefixes):
|
| 31 |
+
if builder is None:
|
| 32 |
+
dataset = indexed_dataset.make_dataset(os.path.join(args.input, prefix), 'infer')
|
| 33 |
+
|
| 34 |
+
if isinstance(dataset, indexed_dataset.MMapIndexedDataset):
|
| 35 |
+
builder = indexed_dataset.MMapIndexedDatasetBuilder(args.output_prefix + '.bin', dtype=dataset._index.dtype)
|
| 36 |
+
else:
|
| 37 |
+
builder = indexed_dataset.IndexedDatasetBuilder(args.output_prefix + '.bin')
|
| 38 |
+
|
| 39 |
+
del dataset
|
| 40 |
+
|
| 41 |
+
builder.merge_file_(os.path.join(args.input, prefix))
|
| 42 |
+
|
| 43 |
+
builder.finalize(args.output_prefix + '.idx')
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
if __name__ == '__main__':
|
| 47 |
+
parser = argparse.ArgumentParser()
|
| 48 |
+
|
| 49 |
+
group = parser.add_argument_group(title='input data')
|
| 50 |
+
group.add_argument('--input', type=str, required=True,
|
| 51 |
+
help='Path to directory containing all document files to merge')
|
| 52 |
+
|
| 53 |
+
group = parser.add_argument_group(title='output data')
|
| 54 |
+
group.add_argument('--output-prefix', type=str, required=True,
|
| 55 |
+
help='Path to binary output file without suffix')
|
| 56 |
+
|
| 57 |
+
args = parser.parse_args()
|
| 58 |
+
|
| 59 |
+
assert os.path.isdir(args.input), \
|
| 60 |
+
f'ERROR: {args.input} is not a directory or does not exist'
|
| 61 |
+
|
| 62 |
+
assert os.path.isdir(os.path.dirname(args.output_prefix)), \
|
| 63 |
+
f'ERROR: {os.path.dirname(args.output_prefix)} is not a directory or does not exist'
|
| 64 |
+
|
| 65 |
+
main(args)
|
| 66 |
+
|
tools/merge_mp_partitions.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Merge model parallel partitions."""
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import re
|
| 20 |
+
import sys
|
| 21 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
|
| 22 |
+
os.path.pardir)))
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
|
| 26 |
+
from megatron import mpu
|
| 27 |
+
from megatron.checkpointing import load_checkpoint, save_checkpoint
|
| 28 |
+
from megatron.checkpointing import ensure_directory_exists
|
| 29 |
+
from megatron.checkpointing import get_checkpoint_name
|
| 30 |
+
from megatron.checkpointing import get_checkpoint_version
|
| 31 |
+
from megatron.checkpointing import get_checkpoint_tracker_filename
|
| 32 |
+
from megatron.global_vars import set_global_variables, get_args
|
| 33 |
+
from megatron.global_vars import rebuild_tokenizer
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def split_into_partitions(tensor, num_partitions, partition_dim, stride):
|
| 37 |
+
|
| 38 |
+
per_partition_size = mpu.utils.divide(tensor.size(partition_dim),
|
| 39 |
+
num_partitions)
|
| 40 |
+
per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride)
|
| 41 |
+
|
| 42 |
+
partitions_list = torch.split(tensor,
|
| 43 |
+
per_partition_per_stride_size,
|
| 44 |
+
dim=partition_dim)
|
| 45 |
+
|
| 46 |
+
partitions = []
|
| 47 |
+
for i in range(num_partitions):
|
| 48 |
+
partition = torch.cat(partitions_list[i::num_partitions],
|
| 49 |
+
dim=partition_dim)
|
| 50 |
+
partitions.append(partition)
|
| 51 |
+
|
| 52 |
+
return partitions
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def merge_partitions(merged, partitions, partition_dim, stride):
|
| 56 |
+
|
| 57 |
+
# Number and size of each partition.
|
| 58 |
+
num_partitions = len(partitions)
|
| 59 |
+
per_partition_size = None
|
| 60 |
+
for partition in partitions:
|
| 61 |
+
if per_partition_size is None:
|
| 62 |
+
per_partition_size = partition.size(partition_dim)
|
| 63 |
+
else:
|
| 64 |
+
assert per_partition_size == partition.size(partition_dim)
|
| 65 |
+
|
| 66 |
+
def concat_partitions(partitions_):
|
| 67 |
+
with torch.no_grad():
|
| 68 |
+
if (per_partition_size * num_partitions) == merged.size(
|
| 69 |
+
partition_dim):
|
| 70 |
+
torch.cat(partitions_, dim=partition_dim, out=merged)
|
| 71 |
+
else:
|
| 72 |
+
print(' ***WARNING*** sizes do not match. Will cut '
|
| 73 |
+
'the merged partitions by {} along dimension {} '
|
| 74 |
+
'to reduce the size from {} to {} ...'.format(
|
| 75 |
+
(per_partition_size * num_partitions) - \
|
| 76 |
+
merged.size(partition_dim), partition_dim,
|
| 77 |
+
per_partition_size * num_partitions,
|
| 78 |
+
merged.size(partition_dim)))
|
| 79 |
+
merged_ = torch.cat(partitions_, dim=partition_dim)
|
| 80 |
+
merged_split = torch.split(merged_, merged.size(partition_dim),
|
| 81 |
+
dim=partition_dim)
|
| 82 |
+
merged_ = merged_split[0]
|
| 83 |
+
assert merged_.size(partition_dim) == merged.size(partition_dim)
|
| 84 |
+
merged.data.copy_(merged_.data)
|
| 85 |
+
|
| 86 |
+
# If stride is 1, then do simple concatination.
|
| 87 |
+
if stride == 1:
|
| 88 |
+
concat_partitions(partitions)
|
| 89 |
+
return
|
| 90 |
+
|
| 91 |
+
# For none unity strides, first split based on stride and then group.
|
| 92 |
+
per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride)
|
| 93 |
+
# Chunk and build a list.
|
| 94 |
+
chunks = None
|
| 95 |
+
for i, partition in enumerate(partitions):
|
| 96 |
+
chunk = torch.split(partition,
|
| 97 |
+
per_partition_per_stride_size,
|
| 98 |
+
dim=partition_dim)
|
| 99 |
+
|
| 100 |
+
if chunks is None:
|
| 101 |
+
chunks = [0]*(num_partitions*len(chunk))
|
| 102 |
+
chunks[i::num_partitions] = chunk
|
| 103 |
+
|
| 104 |
+
# Concatinate.
|
| 105 |
+
concat_partitions(chunks)
|
| 106 |
+
|
| 107 |
+
return
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def get_model(model_type):
|
| 111 |
+
|
| 112 |
+
if model_type == 'BERT':
|
| 113 |
+
from pretrain_bert import model_provider
|
| 114 |
+
elif model_type == 'GPT':
|
| 115 |
+
from pretrain_gpt import model_provider
|
| 116 |
+
elif model_type == 'RACE':
|
| 117 |
+
from tasks.race.finetune import model_provider
|
| 118 |
+
elif model_type == ['MNLI', 'QQP']:
|
| 119 |
+
num_classes = 2
|
| 120 |
+
if model_type == 'MNLI':
|
| 121 |
+
num_classes = 3
|
| 122 |
+
from megatron.model.classification import Classification
|
| 123 |
+
def model_provider():
|
| 124 |
+
return Classification(num_classes=num_classes, num_tokentypes=2)
|
| 125 |
+
else:
|
| 126 |
+
raise Exception('unrecognized model type: {}'.format(model_type))
|
| 127 |
+
|
| 128 |
+
model = model_provider()
|
| 129 |
+
model = model.half()
|
| 130 |
+
|
| 131 |
+
return model
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def get_parallel_checkpoint_name(path):
|
| 135 |
+
|
| 136 |
+
tracker_filename = get_checkpoint_tracker_filename(path)
|
| 137 |
+
iteration = 0
|
| 138 |
+
with open(tracker_filename, 'r') as f:
|
| 139 |
+
metastring = f.read().strip()
|
| 140 |
+
iteration = int(metastring)
|
| 141 |
+
assert iteration > 0
|
| 142 |
+
checkpoint_name = get_checkpoint_name(path, iteration)
|
| 143 |
+
|
| 144 |
+
return checkpoint_name, iteration
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def test_split_merge():
|
| 148 |
+
|
| 149 |
+
print('testing split and merge ...')
|
| 150 |
+
|
| 151 |
+
#[QKV.ROW-COL]
|
| 152 |
+
tensor = torch.FloatTensor([[1.11, 1.12, 1.13, 1.14, 1.15],
|
| 153 |
+
[1.21, 1.22, 1.23, 1.24, 1.25],
|
| 154 |
+
[1.31, 1.32, 1.33, 1.34, 1.35],
|
| 155 |
+
[1.41, 1.42, 1.43, 1.44, 1.45],
|
| 156 |
+
[2.11, 2.12, 2.13, 2.14, 2.15],
|
| 157 |
+
[2.21, 2.22, 2.23, 2.24, 2.25],
|
| 158 |
+
[2.31, 2.32, 2.33, 2.34, 2.35],
|
| 159 |
+
[2.41, 2.42, 2.43, 2.44, 2.45],
|
| 160 |
+
[3.11, 3.12, 3.13, 3.14, 3.15],
|
| 161 |
+
[3.21, 3.22, 3.23, 3.24, 3.25],
|
| 162 |
+
[3.31, 3.32, 3.33, 3.34, 3.35],
|
| 163 |
+
[3.41, 3.42, 3.43, 3.44, 3.45]])
|
| 164 |
+
|
| 165 |
+
num_partitions = 2
|
| 166 |
+
partition_dim = 0
|
| 167 |
+
stride = 3
|
| 168 |
+
partitions = split_into_partitions(tensor, num_partitions,
|
| 169 |
+
partition_dim, stride)
|
| 170 |
+
|
| 171 |
+
merged = torch.zeros_like(tensor)
|
| 172 |
+
merge_partitions(merged, partitions, partition_dim, stride)
|
| 173 |
+
|
| 174 |
+
max_error = (merged - tensor).abs().max()
|
| 175 |
+
print(' > max error (should be zero): {}'.format(max_error))
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def get_mp_merge_args(parser):
|
| 179 |
+
"""Provide extra arguments required for merging."""
|
| 180 |
+
group = parser.add_argument_group(title='mp merge')
|
| 181 |
+
|
| 182 |
+
group.add_argument('--model-type', type=str, required=True,
|
| 183 |
+
choices=['BERT', 'GPT', 'RACE', 'MNLI', 'QQP'],
|
| 184 |
+
help='Type of the mdoel.')
|
| 185 |
+
group.add_argument('--target-pipeline-model-parallel-size', type=int, default=1,
|
| 186 |
+
help='Degree of pipeline model parallelism in output model.')
|
| 187 |
+
|
| 188 |
+
return parser
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def main():
|
| 192 |
+
|
| 193 |
+
# Arguments do sanity checks on the world size, but we don't care,
|
| 194 |
+
# so trick it into thinking we are plenty of processes
|
| 195 |
+
os.environ["WORLD_SIZE"] = f'{2**31}'
|
| 196 |
+
|
| 197 |
+
# Args
|
| 198 |
+
set_global_variables(extra_args_provider=get_mp_merge_args,
|
| 199 |
+
args_defaults = {'use_cpu_initialization': True,
|
| 200 |
+
'micro_batch_size': 1,
|
| 201 |
+
'no_load_optim': True,
|
| 202 |
+
'no_load_rng': True,
|
| 203 |
+
'no_save_optim': True,
|
| 204 |
+
'no_save_rng': True,
|
| 205 |
+
'save_interval': 1})
|
| 206 |
+
args = get_args()
|
| 207 |
+
|
| 208 |
+
if args.pipeline_model_parallel_size > 1:
|
| 209 |
+
print("Checkpoints with pipeline model parallelism are not currently supported.")
|
| 210 |
+
exit()
|
| 211 |
+
|
| 212 |
+
model_type = args.model_type
|
| 213 |
+
orig_tensor_model_parallel_size = args.tensor_model_parallel_size
|
| 214 |
+
args.tensor_model_parallel_size = 1
|
| 215 |
+
tokenizer = rebuild_tokenizer(args)
|
| 216 |
+
|
| 217 |
+
print('\n merging model parallel partitions ...')
|
| 218 |
+
print(' > number of partitions: {}'.format(orig_tensor_model_parallel_size))
|
| 219 |
+
print(' > checkpoint path: {}'.format(args.load))
|
| 220 |
+
print(' > model parameters:')
|
| 221 |
+
print(' number of tokens ................ {} '.format(
|
| 222 |
+
tokenizer.vocab_size))
|
| 223 |
+
print(' number of layers ................ {}'.format(args.num_layers))
|
| 224 |
+
print(' hidden size ..................... {}'.format(args.hidden_size))
|
| 225 |
+
print(' number of attention heads ....... {}'.format(
|
| 226 |
+
args.num_attention_heads))
|
| 227 |
+
print(' maximum position embeddings ..... {}'.format(
|
| 228 |
+
args.max_position_embeddings))
|
| 229 |
+
|
| 230 |
+
# Full model.
|
| 231 |
+
print('> building the full model ...')
|
| 232 |
+
mpu.initialize.set_tensor_model_parallel_world_size(1)
|
| 233 |
+
mpu.initialize.set_tensor_model_parallel_rank(0)
|
| 234 |
+
mpu.initialize.set_pipeline_model_parallel_world_size(1)
|
| 235 |
+
mpu.initialize.set_pipeline_model_parallel_rank(0)
|
| 236 |
+
merged_model = get_model(model_type)
|
| 237 |
+
|
| 238 |
+
# Build and load partitions.
|
| 239 |
+
partitions = []
|
| 240 |
+
iteration = 0
|
| 241 |
+
args.tensor_model_parallel_size = orig_tensor_model_parallel_size
|
| 242 |
+
tokenizer = rebuild_tokenizer(args)
|
| 243 |
+
mpu.initialize.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
|
| 244 |
+
for rank in range(args.tensor_model_parallel_size):
|
| 245 |
+
# Reset these since load_checkpoint asserts they are 0, but we are loading
|
| 246 |
+
# multiple checkpoints in the same process and they get set each time
|
| 247 |
+
args.consumed_train_samples = 0
|
| 248 |
+
args.consumed_valid_samples = 0
|
| 249 |
+
|
| 250 |
+
mpu.initialize.set_tensor_model_parallel_rank(rank)
|
| 251 |
+
checkpoint_name, iteration = get_parallel_checkpoint_name(args.load)
|
| 252 |
+
model_ = get_model(model_type)
|
| 253 |
+
print(f'> loading {checkpoint_name} ...')
|
| 254 |
+
load_checkpoint(model_, None, None)
|
| 255 |
+
print(f'> checkpoint version {get_checkpoint_version()}')
|
| 256 |
+
partitions.append(model_)
|
| 257 |
+
|
| 258 |
+
# Parameter generators so we can loop through them semiltaneouly.
|
| 259 |
+
merged_params_gen = merged_model.named_parameters()
|
| 260 |
+
partitions_params_gen = [partition.named_parameters()
|
| 261 |
+
for partition in partitions]
|
| 262 |
+
while True:
|
| 263 |
+
try:
|
| 264 |
+
|
| 265 |
+
# Get the params and check names.
|
| 266 |
+
name, merged_param = next(merged_params_gen)
|
| 267 |
+
print(' > working on {} ...'.format(name))
|
| 268 |
+
print(' merged type: {}, size: {}'.format(
|
| 269 |
+
merged_param.dtype, list(merged_param.size())))
|
| 270 |
+
partitions_param = []
|
| 271 |
+
for rank, partition_params_gen in enumerate(partitions_params_gen):
|
| 272 |
+
partition_name, partition_param = next(partition_params_gen)
|
| 273 |
+
assert partition_name == name
|
| 274 |
+
partitions_param.append(partition_param)
|
| 275 |
+
print(' partition {} type: {}, size: {}'.format(
|
| 276 |
+
rank, partition_param.dtype, list(partition_param.size())))
|
| 277 |
+
|
| 278 |
+
# For the non-parallel parameters, simply copy the rank 0 values.
|
| 279 |
+
if not hasattr(merged_param, 'tensor_model_parallel'):
|
| 280 |
+
print(' none-parallel parameter, simple copy from rank 0')
|
| 281 |
+
with torch.no_grad():
|
| 282 |
+
merged_param.data.copy_(partitions_param[0].data)
|
| 283 |
+
# For parallel parameters, merge the values
|
| 284 |
+
else:
|
| 285 |
+
dim = merged_param.partition_dim
|
| 286 |
+
stride = merged_param.partition_stride
|
| 287 |
+
print(f' parallel parameter merge with stride {stride} along '
|
| 288 |
+
f'dimention {dim}')
|
| 289 |
+
merge_partitions(merged_param,
|
| 290 |
+
partitions_param,
|
| 291 |
+
dim,
|
| 292 |
+
stride)
|
| 293 |
+
|
| 294 |
+
except StopIteration:
|
| 295 |
+
break
|
| 296 |
+
|
| 297 |
+
partitions = []
|
| 298 |
+
args.tensor_model_parallel_size = 1
|
| 299 |
+
args.pipeline_model_parallel_size = args.target_pipeline_model_parallel_size
|
| 300 |
+
|
| 301 |
+
assert args.num_layers % args.pipeline_model_parallel_size == 0, \
|
| 302 |
+
'num_layers must be divisible by target pipeline model parallel size'
|
| 303 |
+
layers_per_part = args.num_layers // args.pipeline_model_parallel_size
|
| 304 |
+
|
| 305 |
+
tokenizer = rebuild_tokenizer(args)
|
| 306 |
+
mpu.initialize.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
|
| 307 |
+
mpu.initialize.set_tensor_model_parallel_rank(0)
|
| 308 |
+
mpu.initialize.set_pipeline_model_parallel_world_size(args.pipeline_model_parallel_size)
|
| 309 |
+
|
| 310 |
+
# regex to parse out layer number from param name
|
| 311 |
+
layer_re = re.compile('layers\.([0-9]+)')
|
| 312 |
+
|
| 313 |
+
if args.pipeline_model_parallel_size > 1:
|
| 314 |
+
merged_params = {}
|
| 315 |
+
for name, merged_param in merged_model.named_parameters():
|
| 316 |
+
merged_params[name] = merged_param
|
| 317 |
+
|
| 318 |
+
for rank in range(args.pipeline_model_parallel_size):
|
| 319 |
+
mpu.initialize.set_pipeline_model_parallel_rank(rank)
|
| 320 |
+
model = get_model(model_type)
|
| 321 |
+
def update_layer_num(m):
|
| 322 |
+
# TODO! This assumes no interleaved pipeline execution
|
| 323 |
+
layer = int(m.group(1))
|
| 324 |
+
layer += rank * layers_per_part
|
| 325 |
+
return f'layers.{layer}'
|
| 326 |
+
|
| 327 |
+
for dst_name, partition_param in model.named_parameters():
|
| 328 |
+
if dst_name == "word_embeddings.weight":
|
| 329 |
+
# See comment in MegatronModule.initialize_word_embeddings()
|
| 330 |
+
src_name = "language_model.embedding.word_embeddings.weight"
|
| 331 |
+
else:
|
| 332 |
+
# Translate destination layer number (0-N for each partition)
|
| 333 |
+
# to source layer number (single-model layer number)
|
| 334 |
+
src_name = re.sub(layer_re, update_layer_num, dst_name)
|
| 335 |
+
print(f" > copying {src_name} to {dst_name} in rank {rank}'s model")
|
| 336 |
+
partition_param.data.copy_(merged_params[src_name].data)
|
| 337 |
+
|
| 338 |
+
partitions.append(model)
|
| 339 |
+
else:
|
| 340 |
+
partitions = [merged_model]
|
| 341 |
+
|
| 342 |
+
for rank, model in enumerate(partitions):
|
| 343 |
+
mpu.initialize.set_pipeline_model_parallel_rank(rank)
|
| 344 |
+
print(f"> saving rank {rank}'s model")
|
| 345 |
+
save_checkpoint(iteration, model, None, None)
|
| 346 |
+
|
| 347 |
+
print('done :-)')
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
if __name__ == '__main__':
|
| 351 |
+
|
| 352 |
+
main()
|
tools/openwebtext/README.md
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
The following steps show how to prepare training dataset to train the mode.
|
| 2 |
+
|
| 3 |
+
# Libraries to install
|
| 4 |
+
|
| 5 |
+
```
|
| 6 |
+
pip install ftfy langdetect numpy torch pandas nltk sentencepiece boto3 tqdm regex bs4 newspaper3k htmlmin tldextract
|
| 7 |
+
git clone https://github.com/mattilyra/LSH
|
| 8 |
+
cd LSH
|
| 9 |
+
python setup.py install
|
| 10 |
+
```
|
| 11 |
+
|
| 12 |
+
# Download the dataset
|
| 13 |
+
|
| 14 |
+
1. Download the deduplicated URLs from [jcpeterson](https://mega.nz/#F!EZZD0YwJ!9_PlEQzdMVLaNdKv_ICNVQ!cc4RgQQZ)
|
| 15 |
+
2. Remove blacklisted URLs.
|
| 16 |
+
```
|
| 17 |
+
python blacklist_urls.py <path to the dowloaded deduplicated URLs> <filename for clean urls. e.g. clean_urls.txt>
|
| 18 |
+
```
|
| 19 |
+
3. Download the content from the clean urls with [openwebtext's utilities](https://github.com/eukaryote31/openwebtext/blob/master/download.py).
|
| 20 |
+
|
| 21 |
+
4. Merge the contents into one loose json file with 1 json per newline of the format `{'text': text, 'url': unique_url}`. It is important for the url to be unique.
|
| 22 |
+
|
| 23 |
+
# Prepare the data for GPT training:
|
| 24 |
+
|
| 25 |
+
1. Perform ftfy, english detection and remove documents with less than 128 tokens. This step can be sharded and run on shards.
|
| 26 |
+
```
|
| 27 |
+
python cleanup_dataset.py <input data file> <output cleaned data filename>
|
| 28 |
+
```
|
| 29 |
+
Additional cleanup (e.g. remove documents less than 512 characters or dataset specific cleaning like stories, realnews datasets) can be done using `cleanup_fix_dataset.py`. More details can be found by running `python cleanup_fix_dataset.py --help`.
|
| 30 |
+
2. Using LSH, find possible duplicates and store then in a file for later processing. The code supports saving and loading fingerprints for recurrent deduplications, and is also multithreaded for faster processing. More details are can be found by `python find_duplicate.py --help`.
|
| 31 |
+
```
|
| 32 |
+
python find_duplicates.py --inputs <pairlist list of input cleaned data files and keys, e.g. cc.json cc_id news.json news_id> --output <output possible duplicate urls filename>
|
| 33 |
+
```
|
| 34 |
+
3. Based on similarity measure defind inside function `is_similar` (default: 0.9), group urls that are similar. Basically, for each group, only one url we should keep and remove the rest.
|
| 35 |
+
```
|
| 36 |
+
python group_duplicate_urls.py <possible duplicate urls file> <output file containing similar urls>
|
| 37 |
+
```
|
| 38 |
+
4. Remove similar documents that were detected in the last step.
|
| 39 |
+
```
|
| 40 |
+
python remove_group_duplicates.py <file containing simialr documents> <cleaned data file> <outputfile containing deduplicate data>
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
5. Shuffle the dataset.
|
| 44 |
+
```
|
| 45 |
+
shuf <cleaned deduped data file> -o train_data.json
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
# Deduplicating ngrams
|
| 49 |
+
|
| 50 |
+
To deduplicate the downstream tasks (e.g. lambada, squad) from the training dataset, we run the following command.
|
| 51 |
+
|
| 52 |
+
```
|
| 53 |
+
python filter_ngrams.py --tasks <name of the task, e.g. lambada, squad> --dedup-dataset <training dataset to deduplicate> <json key> --output <output training dataset>
|
| 54 |
+
```
|
| 55 |
+
We use 13-grams by default for the deduplication. When we find a 13-gram match in a training document, we split the document into two pieces and remove the 13-gram along with 200 characters from the both side of the 13-gram. We also remove any splitted document with less than 200 characters or if a document got splitted more than 10 times. These parameters can be changed using corresponding arguments.
|
| 56 |
+
|
| 57 |
+
Only for the lambada task, we need to provide the path, `--lambada-path <path of the lambada test data>`.
|
| 58 |
+
|
| 59 |
+
Several other features (e.g. save and load dictionary) have been added, look at `python filter_ngrams.py --help` for details.
|
tools/openwebtext/add_id.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import time
|
| 20 |
+
|
| 21 |
+
"""
|
| 22 |
+
This code adds id to each json object in a json file. User can add prefix
|
| 23 |
+
to the ids.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
if __name__ == '__main__':
|
| 27 |
+
|
| 28 |
+
print('parsing the arguments ...')
|
| 29 |
+
|
| 30 |
+
parser = argparse.ArgumentParser()
|
| 31 |
+
parser.add_argument('--input-file', type=str, default=None, help='Input'\
|
| 32 |
+
' json file where id needs to be added')
|
| 33 |
+
parser.add_argument('--output-file', type=str, default=None, help=\
|
| 34 |
+
'Output file name with id')
|
| 35 |
+
parser.add_argument('--id-prefix', type=str, default=None, help=\
|
| 36 |
+
'Id prefix')
|
| 37 |
+
parser.add_argument('--log-interval', type=int, default=100,
|
| 38 |
+
help='Log interval')
|
| 39 |
+
args = parser.parse_args()
|
| 40 |
+
|
| 41 |
+
print('Adding ids to dataset ...')
|
| 42 |
+
|
| 43 |
+
f_input = open(args.input_file, 'r', encoding='utf-8')
|
| 44 |
+
f_output = open(args.output_file, 'wb')
|
| 45 |
+
|
| 46 |
+
unique_ids = 1
|
| 47 |
+
start_time = time.time()
|
| 48 |
+
for row in f_input:
|
| 49 |
+
each_row = json.loads(row)
|
| 50 |
+
adlr_id_string = args.id_prefix + '-{:010d}'.format(int(unique_ids))
|
| 51 |
+
each_row['adlr_id'] = adlr_id_string
|
| 52 |
+
myjson = json.dumps(each_row, ensure_ascii=False)
|
| 53 |
+
|
| 54 |
+
f_output.write(myjson.encode('utf-8'))
|
| 55 |
+
f_output.write('\n'.encode('utf-8'))
|
| 56 |
+
|
| 57 |
+
if unique_ids % args.log_interval == 0:
|
| 58 |
+
print(' processed {:9d} documents in {:.2f} seconds ...'.format( \
|
| 59 |
+
unique_ids, time.time() - start_time), flush=True)
|
| 60 |
+
|
| 61 |
+
unique_ids += 1
|
| 62 |
+
|
| 63 |
+
# Close the file.
|
| 64 |
+
f_input.close()
|
| 65 |
+
f_output.close()
|
| 66 |
+
|
| 67 |
+
print('done :-)', flush=True)
|
tools/openwebtext/blacklist_urls.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import glob
|
| 18 |
+
import re
|
| 19 |
+
import time
|
| 20 |
+
import tldextract
|
| 21 |
+
import sys
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# List of the domains to blacklist.
|
| 25 |
+
domain_blacklist = set([
|
| 26 |
+
'500px',
|
| 27 |
+
'aapks',
|
| 28 |
+
'akamaihd',
|
| 29 |
+
'amazon',
|
| 30 |
+
'apple',
|
| 31 |
+
'artifactfire',
|
| 32 |
+
'artstation',
|
| 33 |
+
'awwni',
|
| 34 |
+
'bandcamp',
|
| 35 |
+
'battleforthenet',
|
| 36 |
+
'coinscalendar',
|
| 37 |
+
'dailymotion',
|
| 38 |
+
'deviantart',
|
| 39 |
+
'discord',
|
| 40 |
+
'discordapp',
|
| 41 |
+
'dlapkandroid',
|
| 42 |
+
'dropbox',
|
| 43 |
+
'e621',
|
| 44 |
+
'ebay',
|
| 45 |
+
'edealinfo',
|
| 46 |
+
'erome',
|
| 47 |
+
'eroshare',
|
| 48 |
+
'explosm',
|
| 49 |
+
'facebook',
|
| 50 |
+
'fbcdn',
|
| 51 |
+
'flickr',
|
| 52 |
+
'furaffinity',
|
| 53 |
+
'futhead',
|
| 54 |
+
'gatopardo',
|
| 55 |
+
'gfycat',
|
| 56 |
+
'gifsound',
|
| 57 |
+
'gifsoup',
|
| 58 |
+
'giphy',
|
| 59 |
+
'github',
|
| 60 |
+
'google',
|
| 61 |
+
'gunprime',
|
| 62 |
+
'gyazo',
|
| 63 |
+
'hotdealstar',
|
| 64 |
+
'imagefap',
|
| 65 |
+
'imageshack',
|
| 66 |
+
'imgflip',
|
| 67 |
+
'imgur',
|
| 68 |
+
'instagram',
|
| 69 |
+
'karmadecay',
|
| 70 |
+
'kryptocal',
|
| 71 |
+
'kym-cdn',
|
| 72 |
+
'liveleak',
|
| 73 |
+
'livememe',
|
| 74 |
+
'lmgtfy',
|
| 75 |
+
'magaimg',
|
| 76 |
+
'memegenerator',
|
| 77 |
+
'minorplanetcenter',
|
| 78 |
+
'minus',
|
| 79 |
+
'mobafire',
|
| 80 |
+
'morejpeg',
|
| 81 |
+
'nocookie',
|
| 82 |
+
'pcpartpicker',
|
| 83 |
+
'photobucket',
|
| 84 |
+
'pinimg',
|
| 85 |
+
'pinterest',
|
| 86 |
+
'pixiv',
|
| 87 |
+
'pornhub',
|
| 88 |
+
'prntscr',
|
| 89 |
+
'puu',
|
| 90 |
+
'qkme',
|
| 91 |
+
'quickmeme',
|
| 92 |
+
'radd',
|
| 93 |
+
'redd',
|
| 94 |
+
'reddit',
|
| 95 |
+
'reddit-stream',
|
| 96 |
+
'redditlog',
|
| 97 |
+
'redditmedia',
|
| 98 |
+
'reddituploads',
|
| 99 |
+
'redtube',
|
| 100 |
+
'reupp',
|
| 101 |
+
'reverb',
|
| 102 |
+
'roanoke',
|
| 103 |
+
'rollingstone',
|
| 104 |
+
'sli',
|
| 105 |
+
'soundcloud',
|
| 106 |
+
'soundgasm',
|
| 107 |
+
'spankbang',
|
| 108 |
+
'spotify',
|
| 109 |
+
'strawpoll',
|
| 110 |
+
'streamable',
|
| 111 |
+
'timeanddate',
|
| 112 |
+
'tinypic',
|
| 113 |
+
'touhouradio',
|
| 114 |
+
'tumblr',
|
| 115 |
+
'twimg',
|
| 116 |
+
'twitch',
|
| 117 |
+
'twitter',
|
| 118 |
+
'vid',
|
| 119 |
+
'vimeo',
|
| 120 |
+
'vine',
|
| 121 |
+
'vkaao',
|
| 122 |
+
'vocaroo',
|
| 123 |
+
'voyagefusion',
|
| 124 |
+
'walmart',
|
| 125 |
+
'wciu',
|
| 126 |
+
'wikimedia',
|
| 127 |
+
'wikipedia',
|
| 128 |
+
'xhamster',
|
| 129 |
+
'xkcd',
|
| 130 |
+
'xvideos',
|
| 131 |
+
'youtu',
|
| 132 |
+
'youtube',
|
| 133 |
+
'youtubedoubler',
|
| 134 |
+
'ytimg',
|
| 135 |
+
'zillexplorer',
|
| 136 |
+
])
|
| 137 |
+
|
| 138 |
+
def domain_is_in_blacklist(url):
|
| 139 |
+
domain = tldextract.extract(url).domain
|
| 140 |
+
return domain in domain_blacklist
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# List of extentions to blacklist.
|
| 144 |
+
extentions_blacklist = (
|
| 145 |
+
'.3gp',
|
| 146 |
+
'.7z'
|
| 147 |
+
'.ai',
|
| 148 |
+
'.aif',
|
| 149 |
+
'.apk',
|
| 150 |
+
'.app',
|
| 151 |
+
'.avi',
|
| 152 |
+
'.bin',
|
| 153 |
+
'.bmp',
|
| 154 |
+
'.bz2',
|
| 155 |
+
'.css',
|
| 156 |
+
'.csv',
|
| 157 |
+
'.dat',
|
| 158 |
+
'.deb',
|
| 159 |
+
'.dmg',
|
| 160 |
+
'.doc',
|
| 161 |
+
'.docx',
|
| 162 |
+
'.exe',
|
| 163 |
+
'.gif',
|
| 164 |
+
'.gifv',
|
| 165 |
+
'.gz',
|
| 166 |
+
'.iso',
|
| 167 |
+
'.jar',
|
| 168 |
+
'.jpeg',
|
| 169 |
+
'.jpg',
|
| 170 |
+
'.js',
|
| 171 |
+
'.log',
|
| 172 |
+
'.mid',
|
| 173 |
+
'.midi',
|
| 174 |
+
'.mkv',
|
| 175 |
+
'.mov',
|
| 176 |
+
'.mp3',
|
| 177 |
+
'.mp4',
|
| 178 |
+
'.mpeg',
|
| 179 |
+
'.mpg',
|
| 180 |
+
'.ogg',
|
| 181 |
+
'.ogv',
|
| 182 |
+
'.otf',
|
| 183 |
+
'.pdf',
|
| 184 |
+
'.pkg',
|
| 185 |
+
'.png',
|
| 186 |
+
'.pps',
|
| 187 |
+
'.ppt',
|
| 188 |
+
'.pptx',
|
| 189 |
+
'.psd',
|
| 190 |
+
'.py',
|
| 191 |
+
'.qt',
|
| 192 |
+
'.ram',
|
| 193 |
+
'.rar',
|
| 194 |
+
'.sql',
|
| 195 |
+
'.svg',
|
| 196 |
+
'.swf',
|
| 197 |
+
'.tar.gz',
|
| 198 |
+
'.tar',
|
| 199 |
+
'.tgz',
|
| 200 |
+
'.tiff',
|
| 201 |
+
'.ttf',
|
| 202 |
+
'.txt',
|
| 203 |
+
'.wav',
|
| 204 |
+
'.webm',
|
| 205 |
+
'.wma',
|
| 206 |
+
'.wmv',
|
| 207 |
+
'.xls',
|
| 208 |
+
'.xlsx',
|
| 209 |
+
'.xml',
|
| 210 |
+
'.xz',
|
| 211 |
+
'.zip',
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
def extention_is_in_blacklist(url):
|
| 215 |
+
if url.split('?')[0].lower().endswith(extentions_blacklist):
|
| 216 |
+
return True
|
| 217 |
+
return False
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# Malformed urls.
|
| 221 |
+
# This function is adapted from:
|
| 222 |
+
# https://stackoverflow.com/questions/7160737/python-how-to-validate-a-url-in-python-malformed-or-not
|
| 223 |
+
url_regex = re.compile(
|
| 224 |
+
r'^(?:http)s?://' # http:// or https://
|
| 225 |
+
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' #domain...
|
| 226 |
+
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
|
| 227 |
+
r'(?::\d+)?' # optional port
|
| 228 |
+
r'(?:/?|[/?]\S+)$', re.IGNORECASE)
|
| 229 |
+
def url_is_malformed(url):
|
| 230 |
+
return re.match(url_regex, url) is None
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def print_progress(prefix, start_time, urls_counter,
|
| 234 |
+
domain_blacklist_counter,
|
| 235 |
+
extention_blacklist_counter,
|
| 236 |
+
short_url_counter, malformed_url_counter,
|
| 237 |
+
duplicate_url_counter):
|
| 238 |
+
string = prefix + ' | '
|
| 239 |
+
string += 'time elapsed (s): {:.2f} | '.format(time.time() - start_time)
|
| 240 |
+
string += 'number of urls: {} | '.format(urls_counter)
|
| 241 |
+
string += 'domain blacklisted: {} | '.format(domain_blacklist_counter)
|
| 242 |
+
string += 'extention blacklisted: {} | '.format(extention_blacklist_counter)
|
| 243 |
+
string += 'short urls (<=8): {} | '.format(short_url_counter)
|
| 244 |
+
string += 'malformed urls: {} | '.format(malformed_url_counter)
|
| 245 |
+
string += 'duplicate urls: {}'.format(duplicate_url_counter)
|
| 246 |
+
print(string, flush=True)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
if __name__ == '__main__':
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
print('remove blacklisted urls ..')
|
| 253 |
+
|
| 254 |
+
# Path to the url files.
|
| 255 |
+
path = sys.argv[1]
|
| 256 |
+
# Output url file.
|
| 257 |
+
output = sys.argv[2]
|
| 258 |
+
|
| 259 |
+
# Get the list of url files.
|
| 260 |
+
files = glob.glob(path + '/*.txt')
|
| 261 |
+
print('> found {} files'.format(len(files)))
|
| 262 |
+
|
| 263 |
+
urls = set()
|
| 264 |
+
urls_counter = 0
|
| 265 |
+
domain_blacklist_counter = 0
|
| 266 |
+
extention_blacklist_counter = 0
|
| 267 |
+
short_url_counter = 0
|
| 268 |
+
malformed_url_counter = 0
|
| 269 |
+
duplicate_url_counter = 0
|
| 270 |
+
start_time = time.time()
|
| 271 |
+
for filename in files:
|
| 272 |
+
with open(filename, 'r') as f:
|
| 273 |
+
for line in f:
|
| 274 |
+
url = line.strip()
|
| 275 |
+
urls_counter += 1
|
| 276 |
+
if domain_is_in_blacklist(url):
|
| 277 |
+
print('[DOMAIN BLACKLIST]: {}'.format(url), flush=True)
|
| 278 |
+
domain_blacklist_counter += 1
|
| 279 |
+
elif extention_is_in_blacklist(url):
|
| 280 |
+
print('[EXTENTION BLACKLIST]: {}'.format(url), flush=True)
|
| 281 |
+
extention_blacklist_counter += 1
|
| 282 |
+
elif len(url) <= 8:
|
| 283 |
+
print('[SHORT URL]: {}'.format(url), flush=True)
|
| 284 |
+
short_url_counter += 1
|
| 285 |
+
elif url_is_malformed(url):
|
| 286 |
+
print('[MALFORMED URL]: {}'.format(url), flush=True)
|
| 287 |
+
malformed_url_counter += 1
|
| 288 |
+
elif url in urls:
|
| 289 |
+
print('[DUPLICATE URL]: {}'.format(url), flush=True)
|
| 290 |
+
duplicate_url_counter += 1
|
| 291 |
+
else:
|
| 292 |
+
urls.add(url)
|
| 293 |
+
if urls_counter % 100000 == 0:
|
| 294 |
+
print_progress('PROGRESS', start_time, urls_counter,
|
| 295 |
+
domain_blacklist_counter,
|
| 296 |
+
extention_blacklist_counter,
|
| 297 |
+
short_url_counter, malformed_url_counter,
|
| 298 |
+
duplicate_url_counter)
|
| 299 |
+
|
| 300 |
+
print_progress('FINAL', start_time, urls_counter,
|
| 301 |
+
domain_blacklist_counter,
|
| 302 |
+
extention_blacklist_counter,
|
| 303 |
+
short_url_counter, malformed_url_counter,
|
| 304 |
+
duplicate_url_counter)
|
| 305 |
+
|
| 306 |
+
# Write the final set of urls.
|
| 307 |
+
print('> writing cleaned up url list to {}'.format(output))
|
| 308 |
+
with open(output, 'w') as f:
|
| 309 |
+
for url in urls:
|
| 310 |
+
f.write(url + '\n')
|
| 311 |
+
|
| 312 |
+
print('done :-)')
|
tools/openwebtext/cleanup_dataset.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import ftfy
|
| 18 |
+
import json
|
| 19 |
+
from langdetect import detect
|
| 20 |
+
import numpy as np
|
| 21 |
+
import time
|
| 22 |
+
import os
|
| 23 |
+
import sys
|
| 24 |
+
|
| 25 |
+
from tokenizer import Tokenizer
|
| 26 |
+
|
| 27 |
+
MIN_DOCUMENT_LENGHT = 128
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def print_progress(prefix, start_time, num_docs, num_fixed_text,
|
| 31 |
+
num_non_english_docs, chars_non_english_docs,
|
| 32 |
+
num_small_docs, chars_small_docs):
|
| 33 |
+
|
| 34 |
+
string = prefix + ' | '
|
| 35 |
+
string += 'elapsed time: {:.2f} | '.format(time.time() - start_time)
|
| 36 |
+
string += 'documents: {} | '.format(num_docs)
|
| 37 |
+
string += 'fixed text: {} | '.format(num_fixed_text)
|
| 38 |
+
string += 'non-english: {} | '.format(num_non_english_docs)
|
| 39 |
+
string += 'non-english chars: {} | '.format(chars_non_english_docs)
|
| 40 |
+
string += 'small docs: {} | '.format(num_small_docs)
|
| 41 |
+
string += 'small docs chars: {}'.format(chars_small_docs)
|
| 42 |
+
print(string, flush=True)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def filter_corpus(filename, out_filename, print_interval=10000):
|
| 46 |
+
|
| 47 |
+
print(' > filtering {}'.format(filename))
|
| 48 |
+
|
| 49 |
+
tokenizer = Tokenizer(cache_dir='./cache')
|
| 50 |
+
|
| 51 |
+
num_docs = 0
|
| 52 |
+
num_written_docs = 0
|
| 53 |
+
num_small_docs = 0
|
| 54 |
+
num_fixed_text = 0
|
| 55 |
+
num_non_english_docs = 0
|
| 56 |
+
chars_non_english_docs = 0
|
| 57 |
+
chars_small_docs = 0
|
| 58 |
+
start_time = time.time()
|
| 59 |
+
with open(out_filename, 'wb') as f:
|
| 60 |
+
with open(filename, 'r') as fin:
|
| 61 |
+
for line in fin:
|
| 62 |
+
try:
|
| 63 |
+
num_docs += 1
|
| 64 |
+
myjson = json.loads(line)
|
| 65 |
+
# Fix text
|
| 66 |
+
text = ftfy.fix_text(myjson['text'])
|
| 67 |
+
if text != myjson['text']:
|
| 68 |
+
num_fixed_text += 1
|
| 69 |
+
myjson['text'] = text
|
| 70 |
+
# Detect language.
|
| 71 |
+
if detect(text) != 'en':
|
| 72 |
+
print('[non-english text]', myjson)
|
| 73 |
+
num_non_english_docs += 1
|
| 74 |
+
chars_non_english_docs += len(text)
|
| 75 |
+
continue
|
| 76 |
+
# On average each token is 5 characters so 8 is an
|
| 77 |
+
# upper bound.
|
| 78 |
+
if len(text) < (8 * MIN_DOCUMENT_LENGHT):
|
| 79 |
+
tokens = tokenizer.tokenize_document(text)
|
| 80 |
+
if len(tokens) < MIN_DOCUMENT_LENGHT:
|
| 81 |
+
print('[small document, skipping]:', myjson)
|
| 82 |
+
num_small_docs += 1
|
| 83 |
+
chars_small_docs += len(text)
|
| 84 |
+
continue
|
| 85 |
+
myjson = json.dumps(myjson, ensure_ascii=False)
|
| 86 |
+
f.write(myjson.encode('utf-8'))
|
| 87 |
+
f.write('\n'.encode('utf-8'))
|
| 88 |
+
num_written_docs += 1
|
| 89 |
+
if num_docs % print_interval == 0:
|
| 90 |
+
print_progress('[PROGRESS]', start_time, num_docs,
|
| 91 |
+
num_fixed_text, num_non_english_docs,
|
| 92 |
+
chars_non_english_docs,
|
| 93 |
+
num_small_docs, chars_small_docs)
|
| 94 |
+
except Exception as e:
|
| 95 |
+
print(' skipping ', line, e)
|
| 96 |
+
|
| 97 |
+
print_progress('[FINAL]', start_time, num_docs,
|
| 98 |
+
num_fixed_text, num_non_english_docs,
|
| 99 |
+
chars_non_english_docs,
|
| 100 |
+
num_small_docs, chars_small_docs)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
if __name__ == '__main__':
|
| 104 |
+
|
| 105 |
+
print('building gpt2 dataset ...')
|
| 106 |
+
|
| 107 |
+
input_filename = sys.argv[1]
|
| 108 |
+
output_filename = sys.argv[2]
|
| 109 |
+
|
| 110 |
+
print('will be reading {}'.format(input_filename))
|
| 111 |
+
print('and will write the results to {}'.format(output_filename))
|
| 112 |
+
|
| 113 |
+
filter_corpus(input_filename, output_filename)
|
| 114 |
+
|
| 115 |
+
|
tools/openwebtext/cleanup_fix_dataset.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
Filter and clean documents:
|
| 18 |
+
Capable to clean docs with less than 512 characters, less than
|
| 19 |
+
256 characters and contains javascript, fix text and dataset specific
|
| 20 |
+
cleaning like stories and realnews datasets.
|
| 21 |
+
Program arguments have the details.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import argparse
|
| 25 |
+
from functools import partial
|
| 26 |
+
import glob
|
| 27 |
+
import ftfy
|
| 28 |
+
import json
|
| 29 |
+
from langdetect import detect
|
| 30 |
+
import multiprocessing
|
| 31 |
+
import os
|
| 32 |
+
from pathlib import Path
|
| 33 |
+
import re
|
| 34 |
+
import time
|
| 35 |
+
|
| 36 |
+
def process_doc(json_line, args):
|
| 37 |
+
|
| 38 |
+
# Read the line.
|
| 39 |
+
document = json.loads(json_line)
|
| 40 |
+
text = document['text']
|
| 41 |
+
|
| 42 |
+
output = {'remove_512': False, 'remove_256_javascript': False, \
|
| 43 |
+
'remove_512_non_english': False, 'ftfy_fix_text': False, \
|
| 44 |
+
'general_cleaning': False}
|
| 45 |
+
|
| 46 |
+
try:
|
| 47 |
+
# Reomove all docs with less than 512 characters
|
| 48 |
+
if "remove_512" in args.tasks:
|
| 49 |
+
if len(text) < 512:
|
| 50 |
+
output['remove_512'] = True
|
| 51 |
+
return output, text, document, True
|
| 52 |
+
|
| 53 |
+
# Remove docs if less than 256 character length and contains Javascript
|
| 54 |
+
if "remove_256_javascript" in args.tasks:
|
| 55 |
+
if len(text) < 256 and 'javascript' in text.lower():
|
| 56 |
+
output['remove_256_javascript'] = True
|
| 57 |
+
return output, text, document, True
|
| 58 |
+
|
| 59 |
+
# Remove docs < 512 and nonenglish
|
| 60 |
+
if "remove_512_non_english" in args.tasks:
|
| 61 |
+
if len(text) < 512 and detect(text) != 'en':
|
| 62 |
+
output['remove_512_non_english'] = True
|
| 63 |
+
return output, text, document, True
|
| 64 |
+
|
| 65 |
+
# Fix the text using ftfy, don't remove the text, hence return False
|
| 66 |
+
if "ftfy_fix_text" in args.tasks:
|
| 67 |
+
fixed_text = ftfy.fix_text(text)
|
| 68 |
+
output['ftfy_fix_text'] = True
|
| 69 |
+
return output, fixed_text, document, False
|
| 70 |
+
|
| 71 |
+
# Cleaning extra spaces and newlines
|
| 72 |
+
if "general_cleaning" in args.tasks:
|
| 73 |
+
cleaned_text = re.sub(r" +|\b\n+ |\b\n+", " ", text)
|
| 74 |
+
#cleaned_text = re.sub(r"\n\n+", "\n\n", text) # used this for Gutenberg dataset
|
| 75 |
+
#cleaned_text = re.sub(r"\n", "\n\n", text) # Used this for realnews
|
| 76 |
+
|
| 77 |
+
# stories datasets
|
| 78 |
+
#cleaned_text = re.sub(r" \'", "'", text)
|
| 79 |
+
#cleaned_text = re.sub(r" \!", "!", cleaned_text)
|
| 80 |
+
#cleaned_text = re.sub(r" \.", ".", cleaned_text)
|
| 81 |
+
#cleaned_text = re.sub(r" \?", "?", cleaned_text)
|
| 82 |
+
#cleaned_text = re.sub(r" - ", "-", cleaned_text)
|
| 83 |
+
##cleaned_text = re.sub(r"\" ", "\"", cleaned_text)
|
| 84 |
+
#cleaned_text = re.sub(r" @ ", "@", cleaned_text)
|
| 85 |
+
|
| 86 |
+
output['general_cleaning'] = True
|
| 87 |
+
return output, cleaned_text, document, False
|
| 88 |
+
|
| 89 |
+
except Exception as e:
|
| 90 |
+
print('Error: *************************\n{}\ntext: {}'.format(e, \
|
| 91 |
+
text), flush=True)
|
| 92 |
+
return output, text, document, True
|
| 93 |
+
|
| 94 |
+
# don't remove
|
| 95 |
+
return output, text, document, False
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def process_set(args, input_file, output_f_cleaned, output_f_filtered):
|
| 99 |
+
|
| 100 |
+
print(' > working on {} ...'.format(input_file), flush=True)
|
| 101 |
+
|
| 102 |
+
num_docs = num_remove_512 = num_remove_java = num_remove_512_non_english \
|
| 103 |
+
= num_ftfy_fix_text = num_general_cleaning = 0
|
| 104 |
+
|
| 105 |
+
# Output file and counters.
|
| 106 |
+
output_cleaned = open(output_f_cleaned, 'wb')
|
| 107 |
+
output_filtered = open(output_f_filtered, 'wb')
|
| 108 |
+
|
| 109 |
+
start_time = time.time()
|
| 110 |
+
|
| 111 |
+
# Setup multi-processing.
|
| 112 |
+
num_workers = 40
|
| 113 |
+
fin = open(input_file, 'r', encoding='utf-8')
|
| 114 |
+
pool = multiprocessing.Pool(num_workers)
|
| 115 |
+
process_doc_partial = partial(process_doc, args=args)
|
| 116 |
+
processed_docs = pool.imap(process_doc_partial, fin, 500)
|
| 117 |
+
|
| 118 |
+
# Process documents.
|
| 119 |
+
for output, text, document, to_filter in processed_docs:
|
| 120 |
+
num_docs += 1
|
| 121 |
+
|
| 122 |
+
num_remove_512 += 1 if output['remove_512'] else 0
|
| 123 |
+
num_remove_java += 1 if output['remove_256_javascript'] else 0
|
| 124 |
+
num_remove_512_non_english += 1 if output['remove_512_non_english'] \
|
| 125 |
+
else 0
|
| 126 |
+
num_ftfy_fix_text += 1 if output['ftfy_fix_text'] else 0
|
| 127 |
+
num_general_cleaning += 1 if output['general_cleaning'] else 0
|
| 128 |
+
|
| 129 |
+
document['text'] = text
|
| 130 |
+
myjson = json.dumps(document, ensure_ascii=False)
|
| 131 |
+
|
| 132 |
+
if to_filter:
|
| 133 |
+
output_filtered.write(myjson.encode('utf-8'))
|
| 134 |
+
output_filtered.write('\n'.encode('utf-8'))
|
| 135 |
+
else:
|
| 136 |
+
output_cleaned.write(myjson.encode('utf-8'))
|
| 137 |
+
output_cleaned.write('\n'.encode('utf-8'))
|
| 138 |
+
|
| 139 |
+
if num_docs % args.log_interval == 0:
|
| 140 |
+
print(' processed {:9d} documents in {:.2f} seconds ...'.format(
|
| 141 |
+
num_docs, time.time() - start_time), flush=True)
|
| 142 |
+
|
| 143 |
+
# Close the file.
|
| 144 |
+
output_cleaned.close()
|
| 145 |
+
output_filtered.close()
|
| 146 |
+
fin.close()
|
| 147 |
+
|
| 148 |
+
# Print stats.
|
| 149 |
+
print(' >> total docs: {} remove_512 {} remove_256_javascript {} '\
|
| 150 |
+
'remove_512_non_english {} ftfy_fix_text {} general_cleaning {}'.\
|
| 151 |
+
format(num_docs, num_remove_512, num_remove_java,\
|
| 152 |
+
num_remove_512_non_english, num_ftfy_fix_text, \
|
| 153 |
+
num_general_cleaning), flush=True)
|
| 154 |
+
|
| 155 |
+
if __name__ == '__main__':
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
print('parsing the arguments ...')
|
| 159 |
+
|
| 160 |
+
parser = argparse.ArgumentParser()
|
| 161 |
+
parser.add_argument('--input-files', nargs = '*', required=True, default=\
|
| 162 |
+
None, help = 'Input json files that needs to be'\
|
| 163 |
+
' cleaned')
|
| 164 |
+
parser.add_argument('--tasks', nargs = '*', required=True, default=None,\
|
| 165 |
+
help = 'Tasks to perform on the input files, ' \
|
| 166 |
+
'such as remove_512, remove_256_javascript, ' \
|
| 167 |
+
'remove_512_non_english, ftfy_fix_text, and ' \
|
| 168 |
+
'general_cleaning. 256 or 512 means the number' \
|
| 169 |
+
' of characters.')
|
| 170 |
+
|
| 171 |
+
parser.add_argument('--output-path', type=str, default=None,
|
| 172 |
+
help='Directory where the output should go')
|
| 173 |
+
parser.add_argument('--log-interval', type=int, default=100,
|
| 174 |
+
help='Log interval')
|
| 175 |
+
|
| 176 |
+
args = parser.parse_args()
|
| 177 |
+
|
| 178 |
+
print('cleanup dataset ...')
|
| 179 |
+
|
| 180 |
+
for input_file in args.input_files:
|
| 181 |
+
input_filename, input_filename_ext = os.path.splitext(Path(input_file)\
|
| 182 |
+
.name)
|
| 183 |
+
|
| 184 |
+
output_f_cleaned = os.path.join(args.output_path, input_filename + \
|
| 185 |
+
"_cleaned" + input_filename_ext)
|
| 186 |
+
output_f_filtered = os.path.join(args.output_path, input_filename + \
|
| 187 |
+
"_filtered" + input_filename_ext)
|
| 188 |
+
|
| 189 |
+
process_set(args, input_file, output_f_cleaned, output_f_filtered)
|
| 190 |
+
|
| 191 |
+
print('done :-)', flush=True)
|
tools/openwebtext/filter_ngrams.py
ADDED
|
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
Deduplicate downstream tasks from training dataset. 13-grams have been used.
|
| 18 |
+
All split documents with less than 200 characters got filtered. Any document
|
| 19 |
+
with more than 10 splits got filtered as well.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import argparse
|
| 23 |
+
from functools import partial
|
| 24 |
+
import json
|
| 25 |
+
import multiprocessing
|
| 26 |
+
import nltk
|
| 27 |
+
import pickle
|
| 28 |
+
import re
|
| 29 |
+
import string
|
| 30 |
+
import sys
|
| 31 |
+
import time
|
| 32 |
+
|
| 33 |
+
def get_words(text):
|
| 34 |
+
# get all the lowercase words from text
|
| 35 |
+
words, positions = [], []
|
| 36 |
+
for match in re.finditer(r'\w+', text.lower()):
|
| 37 |
+
words.append(match.group(0))
|
| 38 |
+
positions.append(match.start())
|
| 39 |
+
return words, positions
|
| 40 |
+
|
| 41 |
+
# splits the text
|
| 42 |
+
def split_text(text, start_position, remove_char_each_side, seq):
|
| 43 |
+
# first part of the text
|
| 44 |
+
punctuations = ".!?"
|
| 45 |
+
pos = start_position - remove_char_each_side
|
| 46 |
+
text_first = ""
|
| 47 |
+
while pos > 0 and not text[pos] in punctuations:
|
| 48 |
+
pos -= 1
|
| 49 |
+
if pos > 0:
|
| 50 |
+
text_first = text[0:pos+1]
|
| 51 |
+
|
| 52 |
+
# add length of seq and remove_char_each_side
|
| 53 |
+
pos = start_position + len(seq) + remove_char_each_side
|
| 54 |
+
|
| 55 |
+
# last part of the text
|
| 56 |
+
text_second = ""
|
| 57 |
+
while pos < len(text) and not text[pos] in punctuations:
|
| 58 |
+
pos += 1
|
| 59 |
+
if pos + 1 < len(text):
|
| 60 |
+
text_second = text[pos+1:len(text)]
|
| 61 |
+
|
| 62 |
+
return text_first, text_second
|
| 63 |
+
|
| 64 |
+
def check_and_clean_text(args, words, ngrams, text, start_position, \
|
| 65 |
+
text_buf_ngram_free, text_buf, local_ngram):
|
| 66 |
+
|
| 67 |
+
seq = " ".join(words)
|
| 68 |
+
if seq in ngrams:
|
| 69 |
+
print(" [matched]: {}".format(seq), flush=True)
|
| 70 |
+
|
| 71 |
+
if args.get_ngram_freq_only:
|
| 72 |
+
# increase freq of this seq and then only consider the later part
|
| 73 |
+
# of the text for further processing
|
| 74 |
+
if seq in local_ngram:
|
| 75 |
+
local_ngram[seq] += 1
|
| 76 |
+
else:
|
| 77 |
+
local_ngram[seq] = 1
|
| 78 |
+
#print(" [increased]: {} {}".format(seq, ngrams[seq]), flush=True)
|
| 79 |
+
if (start_position + len(seq) + 1) < len(text):
|
| 80 |
+
text_buf.append(text[start_position + len(seq) + 1:len(text)])
|
| 81 |
+
return False
|
| 82 |
+
|
| 83 |
+
# split the text
|
| 84 |
+
text_first, text_second = split_text(text, start_position, \
|
| 85 |
+
args.remove_char_each_side, seq)
|
| 86 |
+
|
| 87 |
+
# first part of ngrams free
|
| 88 |
+
if len(text_first) > args.filter_text_char_len:
|
| 89 |
+
text_buf_ngram_free.append(text_first)
|
| 90 |
+
|
| 91 |
+
# add second part for further processing
|
| 92 |
+
if len(text_second) > args.filter_text_char_len:
|
| 93 |
+
text_buf.append(text_second)
|
| 94 |
+
|
| 95 |
+
return False # not ngram free
|
| 96 |
+
|
| 97 |
+
# ngram free
|
| 98 |
+
return True
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
|
| 102 |
+
# remove all the ngrams
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
myjson = json.loads(line)
|
| 106 |
+
text_buf = [myjson[key]]
|
| 107 |
+
except Exception as e:
|
| 108 |
+
print("Error: {}".format(e), flush=True)
|
| 109 |
+
text_buf = []
|
| 110 |
+
|
| 111 |
+
text_buf_ngram_free = []
|
| 112 |
+
local_ngram = {}
|
| 113 |
+
while len(text_buf) > 0:
|
| 114 |
+
|
| 115 |
+
# get the first one from the buffer
|
| 116 |
+
text = text_buf.pop(0)
|
| 117 |
+
words, positions = get_words(text)
|
| 118 |
+
|
| 119 |
+
ngram_free = True
|
| 120 |
+
# find each max n-grams and check dictionary
|
| 121 |
+
for i in range(len(words) - args.max_ngram_size + 1):
|
| 122 |
+
check_ngram_free = check_and_clean_text(args, words[i:\
|
| 123 |
+
i+args.max_ngram_size], ngrams, text, positions[i], \
|
| 124 |
+
text_buf_ngram_free, text_buf, local_ngram)
|
| 125 |
+
|
| 126 |
+
# the seq is ngram free? if yes, break
|
| 127 |
+
if not check_ngram_free:
|
| 128 |
+
ngram_free = False
|
| 129 |
+
break
|
| 130 |
+
|
| 131 |
+
# if max ngrams doesn't match, check if any other lower n-grams
|
| 132 |
+
# within max ngram macthes
|
| 133 |
+
for ngram_len, _ in ngrams_freq_sorted:
|
| 134 |
+
check_ngram_free = check_and_clean_text(args, words[i:\
|
| 135 |
+
i+ngram_len], ngrams, text, positions[i], \
|
| 136 |
+
text_buf_ngram_free, text_buf, local_ngram)
|
| 137 |
+
|
| 138 |
+
# same check as above
|
| 139 |
+
if not check_ngram_free:
|
| 140 |
+
ngram_free = False
|
| 141 |
+
break
|
| 142 |
+
|
| 143 |
+
# check break from lower than max ngram loop above
|
| 144 |
+
if not ngram_free:
|
| 145 |
+
break
|
| 146 |
+
|
| 147 |
+
# for the last max n-gram, check all the lower ngrams in it
|
| 148 |
+
if ngram_free and len(words) - args.max_ngram_size > 0:
|
| 149 |
+
# get the last words of the lax max ngram
|
| 150 |
+
last_seq_words = words[(len(words)-args.max_ngram_size):len(words)]
|
| 151 |
+
last_seq_start_position = len(words) - args.max_ngram_size
|
| 152 |
+
|
| 153 |
+
# check all n-grams lower than the max
|
| 154 |
+
for pos, (ngram_len, _) in enumerate(ngrams_freq_sorted):
|
| 155 |
+
|
| 156 |
+
# ignore the max ngram as has been considered already
|
| 157 |
+
if ngram_len == args.max_ngram_size:
|
| 158 |
+
continue
|
| 159 |
+
|
| 160 |
+
# find each ngram of ngram_len in max n-grams and check
|
| 161 |
+
for i in range(len(last_seq_words) - ngram_len + 1):
|
| 162 |
+
check_ngram_free = check_and_clean_text(args, \
|
| 163 |
+
last_seq_words[i:i+ngram_len], ngrams, text,\
|
| 164 |
+
positions[last_seq_start_position+i], \
|
| 165 |
+
text_buf_ngram_free, text_buf, local_ngram)
|
| 166 |
+
|
| 167 |
+
if not check_ngram_free:
|
| 168 |
+
ngram_free = False
|
| 169 |
+
break
|
| 170 |
+
|
| 171 |
+
if not ngram_free:
|
| 172 |
+
break
|
| 173 |
+
|
| 174 |
+
# texts are ngram free
|
| 175 |
+
if ngram_free and not args.get_ngram_freq_only:
|
| 176 |
+
text_buf_ngram_free.append(text)
|
| 177 |
+
|
| 178 |
+
# check if the text has only been trimmed
|
| 179 |
+
trimmed = 0
|
| 180 |
+
if not args.get_ngram_freq_only and len(text_buf_ngram_free) == 1 and \
|
| 181 |
+
len(text_buf_ngram_free[0]) < len(myjson[key]):
|
| 182 |
+
trimmed = 1
|
| 183 |
+
|
| 184 |
+
return text_buf_ngram_free, trimmed, myjson, local_ngram
|
| 185 |
+
|
| 186 |
+
# insert word sequence into dictionary
|
| 187 |
+
def insert_dict(words, ngrams, pos):
|
| 188 |
+
seq = " ".join(words)
|
| 189 |
+
if seq not in ngrams:
|
| 190 |
+
ngrams[seq] = 0
|
| 191 |
+
#ngrams[seq] = pos
|
| 192 |
+
|
| 193 |
+
# insert each ngram from text into the ngrams dictionary
|
| 194 |
+
def compute_ngrams_insert_dict(args, text, ngrams):
|
| 195 |
+
words, positions = get_words(text)
|
| 196 |
+
if len(words) < args.min_ngram_size:
|
| 197 |
+
return
|
| 198 |
+
|
| 199 |
+
if len(words) < args.max_ngram_size:
|
| 200 |
+
insert_dict(words, ngrams, positions[0])
|
| 201 |
+
|
| 202 |
+
for i in range(len(words) - args.max_ngram_size+1):
|
| 203 |
+
insert_dict(words[i:i+args.max_ngram_size], ngrams, positions[i])
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
# Build ngrams for the lambada dataset
|
| 207 |
+
def process_task_lambda(args, task_file, ngrams):
|
| 208 |
+
print(' reading from {} and computing ngrams'.format(task_file))
|
| 209 |
+
with open(task_file, 'r') as f:
|
| 210 |
+
for line in f:
|
| 211 |
+
try:
|
| 212 |
+
myjson = json.loads(line)
|
| 213 |
+
text = myjson['text']
|
| 214 |
+
compute_ngrams_insert_dict(args, text, ngrams)
|
| 215 |
+
except Exception as e:
|
| 216 |
+
print('Error:', e)
|
| 217 |
+
print(" Entities in ngrams {}".format(len(ngrams)), flush=True)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# Build ngrams for the dataset of the given task
|
| 221 |
+
def process_task(args, task_name, ngrams):
|
| 222 |
+
|
| 223 |
+
print(' reading from {} and computing ngrams'.format('import datasets'))
|
| 224 |
+
print(" Current entities in ngrams {}".format(len(ngrams)), flush=True)
|
| 225 |
+
# using validation/test data from datasets
|
| 226 |
+
from datasets import load_dataset
|
| 227 |
+
|
| 228 |
+
entities_in_ngrams = len(ngrams)
|
| 229 |
+
|
| 230 |
+
# load the dataset
|
| 231 |
+
if task_name == 'squad':
|
| 232 |
+
dataset = load_dataset('squad_v2', split='validation')
|
| 233 |
+
elif task_name == 'natural_questions':
|
| 234 |
+
dataset = load_dataset('natural_questions', split='validation')
|
| 235 |
+
elif task_name == 'triviaqa':
|
| 236 |
+
dataset = load_dataset('trivia_qa', 'unfiltered', split='test')
|
| 237 |
+
elif task_name == 'webqa':
|
| 238 |
+
dataset = load_dataset('web_questions', split='test')
|
| 239 |
+
elif task_name == 'race':
|
| 240 |
+
dataset = load_dataset('race', 'all', split='test')
|
| 241 |
+
elif task_name == 'drop':
|
| 242 |
+
dataset = load_dataset('drop', split='validation')
|
| 243 |
+
elif task_name == 'coqa':
|
| 244 |
+
dataset = load_dataset('coqa', split='validation')
|
| 245 |
+
elif task_name == 'piqa':
|
| 246 |
+
dataset = load_dataset('piqa', split='test')
|
| 247 |
+
else:
|
| 248 |
+
print("Invalid task name: {}".format(task_name), flush=True)
|
| 249 |
+
return
|
| 250 |
+
|
| 251 |
+
# read the dataset and add to ngrams
|
| 252 |
+
for line in dataset:
|
| 253 |
+
try:
|
| 254 |
+
if task_name in ['squad', 'triviaqa', 'webqa', 'race', 'drop']:
|
| 255 |
+
text = line['question']
|
| 256 |
+
compute_ngrams_insert_dict(args, text, ngrams)
|
| 257 |
+
elif task_name == 'natural_questions':
|
| 258 |
+
text = line['question']['text']
|
| 259 |
+
compute_ngrams_insert_dict(args, text, ngrams)
|
| 260 |
+
elif task_name == 'coqa':
|
| 261 |
+
all_questions = line['questions']
|
| 262 |
+
for question in all_questions:
|
| 263 |
+
compute_ngrams_insert_dict(args, question, ngrams)
|
| 264 |
+
elif task_name == 'piqa':
|
| 265 |
+
text = line['goal']
|
| 266 |
+
compute_ngrams_insert_dict(args, text, ngrams)
|
| 267 |
+
except Exception as e:
|
| 268 |
+
print('Error:', e)
|
| 269 |
+
|
| 270 |
+
print(" After task {} entities in ngrams {}, added {}".format(task_name, \
|
| 271 |
+
len(ngrams), len(ngrams) - entities_in_ngrams), flush=True)
|
| 272 |
+
|
| 273 |
+
def compute_tasks_ngrams(args, ngrams):
|
| 274 |
+
start_time = time.time()
|
| 275 |
+
for _, task_name in enumerate(args.tasks):
|
| 276 |
+
print('Task: {}'.format(task_name), flush=True)
|
| 277 |
+
if task_name == 'lambada':
|
| 278 |
+
assert args.lambada_path is not None
|
| 279 |
+
process_task_lambda(args, args.lambada_path, ngrams)
|
| 280 |
+
else:
|
| 281 |
+
process_task(args, task_name, ngrams)
|
| 282 |
+
print(" Taken time to compute ngrams {:.2f}".format(time.time() - \
|
| 283 |
+
start_time), flush=True)
|
| 284 |
+
|
| 285 |
+
def compute_ngram_freq_sorted(args, ngrams):
|
| 286 |
+
ngrams_freq = {}
|
| 287 |
+
for ngram_key in ngrams.keys():
|
| 288 |
+
length = len(ngram_key.split())
|
| 289 |
+
ngrams_freq[length] = ngrams_freq[length] + 1 if length in \
|
| 290 |
+
ngrams_freq else 1
|
| 291 |
+
|
| 292 |
+
ngrams_freq_sorted = sorted(ngrams_freq.items(), key=lambda item: item[0])
|
| 293 |
+
print(" Ngram frequencies: {}".format(ngrams_freq_sorted), flush=True)
|
| 294 |
+
print(" Entities in ngrams {} min_ngram_size {} max_ngram_size {}".format(\
|
| 295 |
+
len(ngrams), ngrams_freq_sorted[0][0], ngrams_freq_sorted[len(\
|
| 296 |
+
ngrams_freq_sorted) -1 ][0]), flush=True)
|
| 297 |
+
return ngrams_freq_sorted
|
| 298 |
+
|
| 299 |
+
def get_ngrams_below_threshold(args, ngrams, ngrams_below_threshold, \
|
| 300 |
+
dedup_file, dedup_key, ngrams_freq_sorted):
|
| 301 |
+
|
| 302 |
+
start_time = time.time()
|
| 303 |
+
# get the ngrams frequency
|
| 304 |
+
args.get_ngram_freq_only = True
|
| 305 |
+
|
| 306 |
+
# Open the large file to process in parallel
|
| 307 |
+
num_workers = args.num_threads
|
| 308 |
+
pool = multiprocessing.Pool(num_workers)
|
| 309 |
+
fin = open(dedup_file, 'r', encoding='utf-8')
|
| 310 |
+
free_ngram_abt_partial=partial(free_ngram, args=args, key=dedup_key, \
|
| 311 |
+
ngrams=ngrams, ngrams_freq_sorted=ngrams_freq_sorted)
|
| 312 |
+
free_ngrams_abt = pool.imap(free_ngram_abt_partial, fin, 500)
|
| 313 |
+
|
| 314 |
+
counter = 0
|
| 315 |
+
for _, _, _, local_ngram in free_ngrams_abt:
|
| 316 |
+
counter += 1
|
| 317 |
+
if counter % 1000 == 0:
|
| 318 |
+
print(' [compute_stat]> processed {} documents in {:.2f} seconds ...'.
|
| 319 |
+
format(counter, time.time() - start_time), flush=True)
|
| 320 |
+
for local_key in local_ngram:
|
| 321 |
+
if local_key in ngrams:
|
| 322 |
+
ngrams[local_key] += 1
|
| 323 |
+
local_ngram = {}
|
| 324 |
+
|
| 325 |
+
print(' Time taken to compute statistics {:.2f} seconds'.format(time.time() - \
|
| 326 |
+
start_time), flush=True)
|
| 327 |
+
pool.close()
|
| 328 |
+
pool.join()
|
| 329 |
+
|
| 330 |
+
start_time = time.time()
|
| 331 |
+
counter_threshold = 0
|
| 332 |
+
# Get ngram below theadhold
|
| 333 |
+
for local_key, local_val in ngrams.items():
|
| 334 |
+
if ngrams[local_key] < args.key_threshold:
|
| 335 |
+
print(" [threshold] {} {}".format(local_key, local_val), flush=True)
|
| 336 |
+
counter_threshold += 1
|
| 337 |
+
ngrams_below_threshold[local_key] = 1
|
| 338 |
+
|
| 339 |
+
print(' Ngrams below threshold {}'.format(counter_threshold), flush=True)
|
| 340 |
+
fin.close()
|
| 341 |
+
|
| 342 |
+
def clean_ngrams_below_threshold(args, ngrams_below_threshold, dedup_file, \
|
| 343 |
+
dedup_key):
|
| 344 |
+
|
| 345 |
+
start_time = time.time()
|
| 346 |
+
# Now actually filter the dataset
|
| 347 |
+
args.get_ngram_freq_only = False
|
| 348 |
+
#id_prefix = '-'.join(args.tasks[::2])
|
| 349 |
+
id_prefix = '-'.join(args.tasks[::1])
|
| 350 |
+
|
| 351 |
+
# get the range of the size of the ngrams
|
| 352 |
+
ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams_below_threshold)
|
| 353 |
+
|
| 354 |
+
# Open the large file to process in parallel
|
| 355 |
+
counter = splitted = ignored = split_mt_thld = trimmed_count = 0
|
| 356 |
+
num_workers = args.num_threads
|
| 357 |
+
pool = multiprocessing.Pool(num_workers)
|
| 358 |
+
fin = open(dedup_file, 'r', encoding='utf-8')
|
| 359 |
+
free_ngram_clean_partial=partial(free_ngram, args=args, key=dedup_key, \
|
| 360 |
+
ngrams=ngrams_below_threshold, ngrams_freq_sorted=ngrams_freq_sorted)
|
| 361 |
+
free_ngrams_clean = pool.imap(free_ngram_clean_partial, fin, 500)
|
| 362 |
+
|
| 363 |
+
out_f = open(args.output, 'wb')
|
| 364 |
+
|
| 365 |
+
for text_buf_ngram_free, trimmed, myjson, _ in free_ngrams_clean:
|
| 366 |
+
counter += 1
|
| 367 |
+
try:
|
| 368 |
+
|
| 369 |
+
trimmed_count += trimmed
|
| 370 |
+
|
| 371 |
+
if len(text_buf_ngram_free) > 1:
|
| 372 |
+
splitted += 1
|
| 373 |
+
if len(text_buf_ngram_free) == 0:
|
| 374 |
+
ignored += 1
|
| 375 |
+
# more than 10 splits ignored
|
| 376 |
+
if len(text_buf_ngram_free) > args.splits_count:
|
| 377 |
+
text_buf_ngram_free = []
|
| 378 |
+
split_mt_thld += 1
|
| 379 |
+
|
| 380 |
+
if args.output is not None:
|
| 381 |
+
if "split_id" in myjson:
|
| 382 |
+
use_prefix = myjson["split_id"] + "-"
|
| 383 |
+
else:
|
| 384 |
+
use_prefix = ""
|
| 385 |
+
|
| 386 |
+
for i in range(len(text_buf_ngram_free)):
|
| 387 |
+
split_id_string = id_prefix + '-{:010d}'.format(int(\
|
| 388 |
+
counter)) + '-{:04d}'.format(int(i))
|
| 389 |
+
myjson[dedup_key] = text_buf_ngram_free[i]
|
| 390 |
+
myjson["split_id"] = use_prefix + split_id_string
|
| 391 |
+
outjson = json.dumps(myjson, ensure_ascii=False)
|
| 392 |
+
#outjson = json.dumps({"text":text_buf_ngram_free[i],
|
| 393 |
+
# id_prefix+"_split_id":split_id_string},
|
| 394 |
+
# ensure_ascii=False)
|
| 395 |
+
out_f.write(outjson.encode('utf-8'))
|
| 396 |
+
out_f.write('\n'.encode('utf-8'))
|
| 397 |
+
|
| 398 |
+
if counter % 1000 == 0:
|
| 399 |
+
print(' [final]> processed {} documents in {:.2f} seconds ...'.
|
| 400 |
+
format(counter, time.time() - start_time), flush=True)
|
| 401 |
+
except Exception as e:
|
| 402 |
+
print('Error:', e)
|
| 403 |
+
|
| 404 |
+
print(' [final]> processed {} documents in {:.2f} seconds ...'.
|
| 405 |
+
format(counter, time.time() - start_time), flush=True)
|
| 406 |
+
|
| 407 |
+
print(' Total docs {} splitted {} ignored {} splits > theshold {} trimmed'\
|
| 408 |
+
' {}'.format(counter, splitted, ignored, split_mt_thld, trimmed_count)\
|
| 409 |
+
, flush=True)
|
| 410 |
+
|
| 411 |
+
pool.close()
|
| 412 |
+
pool.join()
|
| 413 |
+
|
| 414 |
+
out_f.close()
|
| 415 |
+
fin.close()
|
| 416 |
+
|
| 417 |
+
if __name__ == '__main__':
|
| 418 |
+
|
| 419 |
+
# we use 13-grams, any text less than 200 characters got removed
|
| 420 |
+
# any text splitted more than 10 got removed as well
|
| 421 |
+
|
| 422 |
+
print('parsing the arguments ...')
|
| 423 |
+
|
| 424 |
+
parser = argparse.ArgumentParser()
|
| 425 |
+
parser.add_argument('--tasks', nargs = '*', required=True, default=None, \
|
| 426 |
+
help = 'Tasks to use for deduplication: currently '
|
| 427 |
+
' suuport [lambada, squad, natural_questions,'
|
| 428 |
+
' triviaqa, webqa, race, drop, coqa, and piqa]')
|
| 429 |
+
parser.add_argument('--lambada-path', type=str, default=None,
|
| 430 |
+
help='Only Lambada task needs the path')
|
| 431 |
+
parser.add_argument('--dedup-dataset', nargs = '*', default=None,
|
| 432 |
+
help='Dataset to deduplicate with the key to use'
|
| 433 |
+
' e.g. cc.json text')
|
| 434 |
+
parser.add_argument('--output', type=str, default=None,
|
| 435 |
+
help='Output file name to save dedup dataset')
|
| 436 |
+
parser.add_argument('--num-threads', type=int, default=40,
|
| 437 |
+
help='Number of threads to use')
|
| 438 |
+
# Default dedup values
|
| 439 |
+
parser.add_argument('--max-ngram-size', type=int, default=13,
|
| 440 |
+
help='Maximum size of ngram to use.')
|
| 441 |
+
parser.add_argument('--min-ngram-size', type=int, default=8,
|
| 442 |
+
help='Minimum size of ngram to use.')
|
| 443 |
+
parser.add_argument('--filter-text-char-len', type=int, default=200,
|
| 444 |
+
help='Remove any text below this length.')
|
| 445 |
+
parser.add_argument('--key-threshold', type=int, default=10,
|
| 446 |
+
help='Number of keys to consider as threshold')
|
| 447 |
+
parser.add_argument('--save-dictionary', type=str, default=None,
|
| 448 |
+
help='Save the dictionary')
|
| 449 |
+
parser.add_argument('--load-dictionary', type=str, default=None,
|
| 450 |
+
help='Load the dictionary')
|
| 451 |
+
parser.add_argument('--splits-count', type=int, default=10,
|
| 452 |
+
help='Remove any documents more than this many splits')
|
| 453 |
+
parser.add_argument('--remove-char-each-side', type=int, default=200,
|
| 454 |
+
help='Maximum size of ngram to use.')
|
| 455 |
+
|
| 456 |
+
args = parser.parse_args()
|
| 457 |
+
|
| 458 |
+
assert len(args.dedup_dataset) == 2
|
| 459 |
+
dedup_file = args.dedup_dataset[0]
|
| 460 |
+
dedup_key = args.dedup_dataset[1]
|
| 461 |
+
|
| 462 |
+
# Setup multi-processing
|
| 463 |
+
num_workers = args.num_threads
|
| 464 |
+
if args.load_dictionary is None:
|
| 465 |
+
|
| 466 |
+
# Build ngrams
|
| 467 |
+
ngrams = {}
|
| 468 |
+
compute_tasks_ngrams(args, ngrams)
|
| 469 |
+
|
| 470 |
+
# get the range of the size of the ngrams
|
| 471 |
+
ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams)
|
| 472 |
+
|
| 473 |
+
# get ngram freq from large file in parallel
|
| 474 |
+
# get ngrams below threshold
|
| 475 |
+
ngrams_below_threshold = {}
|
| 476 |
+
get_ngrams_below_threshold(args, ngrams, ngrams_below_threshold, \
|
| 477 |
+
dedup_file, dedup_key, ngrams_freq_sorted)
|
| 478 |
+
|
| 479 |
+
# save the dictionary if needed
|
| 480 |
+
if args.save_dictionary is not None:
|
| 481 |
+
with open(args.save_dictionary, 'wb') as save_dict_handle:
|
| 482 |
+
pickle.dump(ngrams_below_threshold, save_dict_handle)
|
| 483 |
+
else:
|
| 484 |
+
with open(args.load_dictionary, 'rb') as load_dict_handle:
|
| 485 |
+
ngrams_below_threshold = pickle.load(load_dict_handle)
|
| 486 |
+
|
| 487 |
+
# filter the large file
|
| 488 |
+
if args.output is not None:
|
| 489 |
+
clean_ngrams_below_threshold(args, ngrams_below_threshold, \
|
| 490 |
+
dedup_file, dedup_key)
|
| 491 |
+
|
| 492 |
+
print('done :-)')
|
tools/openwebtext/find_duplicates.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
from functools import partial
|
| 18 |
+
import itertools
|
| 19 |
+
import json
|
| 20 |
+
from lsh import cache, minhash
|
| 21 |
+
import multiprocessing
|
| 22 |
+
import numpy as np
|
| 23 |
+
import time
|
| 24 |
+
import pickle
|
| 25 |
+
import sys
|
| 26 |
+
import os
|
| 27 |
+
|
| 28 |
+
# This function is adapted from:
|
| 29 |
+
# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
|
| 30 |
+
def shingles(text, char_ngram=5):
|
| 31 |
+
return set(text[head:head + char_ngram]
|
| 32 |
+
for head in range(0, len(text) - char_ngram))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# This function is adapted from:
|
| 36 |
+
# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
|
| 37 |
+
def jaccard(set_a, set_b, args):
|
| 38 |
+
if len(set_a) < 1 or len(set_b) < 1:
|
| 39 |
+
return 0.0
|
| 40 |
+
|
| 41 |
+
intersection = set_a & set_b
|
| 42 |
+
union = set_a | set_b
|
| 43 |
+
|
| 44 |
+
if args.jaccard == 'min':
|
| 45 |
+
return len(intersection) / min(len(set_a), len(set_b))
|
| 46 |
+
elif args.jaccard == 'max':
|
| 47 |
+
return len(intersection) / max(len(set_a), len(set_b))
|
| 48 |
+
else:
|
| 49 |
+
return len(intersection) / len(union)
|
| 50 |
+
|
| 51 |
+
def compute_fingerprint(line, key):
|
| 52 |
+
try:
|
| 53 |
+
myjson = json.loads(line)
|
| 54 |
+
url = myjson[key]
|
| 55 |
+
text = myjson['text']
|
| 56 |
+
fingerprint = hasher.fingerprint(text)
|
| 57 |
+
except Exception as e:
|
| 58 |
+
print('Error:', e)
|
| 59 |
+
return None, None, None, False
|
| 60 |
+
|
| 61 |
+
return url, text, fingerprint, True
|
| 62 |
+
|
| 63 |
+
def url_pairs_to_remove(args, bucket_urls, url_doc):
|
| 64 |
+
remove_urls_list = []
|
| 65 |
+
deduped_local, counter_local = 0, 0
|
| 66 |
+
iteration = 0
|
| 67 |
+
while len(bucket_urls) > 1:
|
| 68 |
+
if args.heuristic_iter != -1 and \
|
| 69 |
+
iteration == args.heuristic_iter:
|
| 70 |
+
break
|
| 71 |
+
|
| 72 |
+
items = list(bucket_urls)
|
| 73 |
+
remove_urls = []
|
| 74 |
+
main_url = items[np.random.randint(0, len(items))]
|
| 75 |
+
main_dhingles = shingles(url_doc[main_url])
|
| 76 |
+
|
| 77 |
+
for i in range(0, len(items)):
|
| 78 |
+
counter_local += 1
|
| 79 |
+
other_url = items[i]
|
| 80 |
+
if other_url == main_url:
|
| 81 |
+
continue
|
| 82 |
+
other_shingles = shingles(url_doc[other_url])
|
| 83 |
+
try:
|
| 84 |
+
jaccard_sim = jaccard(main_dhingles, other_shingles, args)
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print('Error:', e)
|
| 87 |
+
jaccard_sim = 0.0
|
| 88 |
+
if jaccard_sim > 0.5:
|
| 89 |
+
remove_urls.append({other_url: jaccard_sim})
|
| 90 |
+
deduped_local += 1
|
| 91 |
+
bucket_urls.remove(other_url)
|
| 92 |
+
|
| 93 |
+
bucket_urls.remove(main_url)
|
| 94 |
+
if len(remove_urls) > 0:
|
| 95 |
+
remove_urls_list.append({main_url: remove_urls})
|
| 96 |
+
iteration += 1
|
| 97 |
+
return remove_urls_list, deduped_local, counter_local
|
| 98 |
+
|
| 99 |
+
def write_remove_urls_list(remove_urls_list, f_out):
|
| 100 |
+
if len(remove_urls_list) > 0:
|
| 101 |
+
for each_url_remove in remove_urls_list:
|
| 102 |
+
myjson = json.dumps(each_url_remove, ensure_ascii=False)
|
| 103 |
+
f_out.write(myjson.encode('utf-8'))
|
| 104 |
+
f_out.write('\n'.encode('utf-8'))
|
| 105 |
+
|
| 106 |
+
def compute_jaccard(each_bin, num_bins, start_time_local):
|
| 107 |
+
|
| 108 |
+
remove_urls_list = []
|
| 109 |
+
deduped_local, counter_local, bucket_local = 0, 0, 0
|
| 110 |
+
|
| 111 |
+
for bucket_id in each_bin:
|
| 112 |
+
bucket_local += 1
|
| 113 |
+
if os.getpid() % num_bins == 0 and bucket_local % 100000 == 0:
|
| 114 |
+
print("Counter {}, progress {:.2f} time {:.2f}".\
|
| 115 |
+
format(bucket_local, float(bucket_local)/float(len(each_bin)),\
|
| 116 |
+
time.time() - start_time_local), flush=True)
|
| 117 |
+
|
| 118 |
+
if len(each_bin[bucket_id]) <= 1:
|
| 119 |
+
continue
|
| 120 |
+
|
| 121 |
+
bucket_urls = each_bin[bucket_id].copy()
|
| 122 |
+
remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
|
| 123 |
+
url_pairs_to_remove(args, bucket_urls, url_doc)
|
| 124 |
+
|
| 125 |
+
deduped_local += deduped_local_sub
|
| 126 |
+
counter_local += counter_local_sub
|
| 127 |
+
if len(remove_urls_list_sub) > 0:
|
| 128 |
+
remove_urls_list.extend(remove_urls_list_sub)
|
| 129 |
+
|
| 130 |
+
return remove_urls_list, deduped_local, counter_local
|
| 131 |
+
|
| 132 |
+
def find_pair_urls_parallel(args, lshcache, url_doc):
|
| 133 |
+
start_time = time.time()
|
| 134 |
+
f_out = open(args.output, 'wb')
|
| 135 |
+
deduped, counter = 0, 0
|
| 136 |
+
|
| 137 |
+
# compute jaccards of buckets in bin in parallel (parallelism
|
| 138 |
+
# limited to # of bins)
|
| 139 |
+
num_bins = len(lshcache.bins)
|
| 140 |
+
pool = multiprocessing.Pool(num_bins)
|
| 141 |
+
compute_jaccard_partial = partial(compute_jaccard, num_bins=num_bins, \
|
| 142 |
+
start_time_local=start_time)
|
| 143 |
+
# don't need to pass args and url_doc as they are already shared
|
| 144 |
+
compute_jaccard_iter = pool.imap(compute_jaccard_partial, lshcache.bins)
|
| 145 |
+
|
| 146 |
+
print("multiprocessing init took {:.2f}".format(time.time() - start_time),\
|
| 147 |
+
flush=True)
|
| 148 |
+
for remove_urls_list, deduped_local, counter_local in compute_jaccard_iter:
|
| 149 |
+
deduped += deduped_local
|
| 150 |
+
counter += counter_local
|
| 151 |
+
write_remove_urls_list(remove_urls_list, f_out)
|
| 152 |
+
print(' [write]> processed {} documents in {:.2f} '
|
| 153 |
+
'seoncds and deduped {} documents ...'.format(counter, time.time()\
|
| 154 |
+
- start_time, deduped), flush=True)
|
| 155 |
+
|
| 156 |
+
pool.close()
|
| 157 |
+
pool.join()
|
| 158 |
+
f_out.close()
|
| 159 |
+
|
| 160 |
+
print(' Taken time for jaccard similariries {:.2f} seconds'.format(\
|
| 161 |
+
time.time() - start_time), flush=True)
|
| 162 |
+
|
| 163 |
+
def find_pair_urls_sequential(args, lshcache, url_doc):
|
| 164 |
+
start_time = time.time()
|
| 165 |
+
f_out = open(args.output, 'wb')
|
| 166 |
+
deduped, counter = 0, 0
|
| 167 |
+
for b in lshcache.bins:
|
| 168 |
+
for bucket_id in b:
|
| 169 |
+
if len(b[bucket_id]) <= 1:
|
| 170 |
+
continue
|
| 171 |
+
|
| 172 |
+
bucket_urls = b[bucket_id].copy()
|
| 173 |
+
remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
|
| 174 |
+
url_pairs_to_remove(args, bucket_urls, url_doc)
|
| 175 |
+
|
| 176 |
+
deduped += deduped_local_sub
|
| 177 |
+
counter += counter_local_sub
|
| 178 |
+
write_remove_urls_list(remove_urls_list_sub, f_out)
|
| 179 |
+
if counter % 10000 == 0:
|
| 180 |
+
print(' [write]> processed {} documents in {:.2f} '
|
| 181 |
+
'seoncds and deduped {} documents ...'.
|
| 182 |
+
format(counter, time.time() - start_time,
|
| 183 |
+
deduped), flush=True)
|
| 184 |
+
f_out.close()
|
| 185 |
+
print(' [write]> processed {} documents in {:.2f} '
|
| 186 |
+
'seoncds and deduped {} documents ...'.
|
| 187 |
+
format(counter, time.time() - start_time,
|
| 188 |
+
deduped), flush=True)
|
| 189 |
+
|
| 190 |
+
if __name__ == '__main__':
|
| 191 |
+
|
| 192 |
+
print('parsing the arguments ...')
|
| 193 |
+
|
| 194 |
+
parser = argparse.ArgumentParser()
|
| 195 |
+
parser.add_argument('--seed', type=int, default=1234,
|
| 196 |
+
help='Random seed used for python, numpy')
|
| 197 |
+
parser.add_argument('--inputs', nargs = '*', default=None, help = \
|
| 198 |
+
'Pairwise list of the input files and keys, '
|
| 199 |
+
'e.g. --inputs cc.json cc_id news.json news_id')
|
| 200 |
+
parser.add_argument('--load-fingerprints', nargs = '*', default=None,
|
| 201 |
+
help='Load fingerprints from a list of pickle files,'
|
| 202 |
+
' e.g. cc.pkl news.pkl')
|
| 203 |
+
parser.add_argument('--save-fingerprints', type=str, default=None,
|
| 204 |
+
help='Save the fingerprints of the inputs.')
|
| 205 |
+
parser.add_argument('--output', type=str, default=None,
|
| 206 |
+
help='Output file name that consists of all ids'
|
| 207 |
+
' with matching similarities')
|
| 208 |
+
parser.add_argument('--jaccard', type=str, default='union',
|
| 209 |
+
choices=['union', 'min', 'max'], help='Jaccard'\
|
| 210 |
+
' similarity computation')
|
| 211 |
+
parser.add_argument('--heuristic-iter', type=int, default=1,
|
| 212 |
+
help='Number of iterations to run the heuristics'
|
| 213 |
+
': use -1 for exact')
|
| 214 |
+
parser.add_argument('--num-bands', type=int, default=10,
|
| 215 |
+
help='Number of bands to use in cache')
|
| 216 |
+
parser.add_argument('--num-seeds', type=int, default=100,
|
| 217 |
+
help='Number of seeds to use for minhash. Note that'
|
| 218 |
+
' this value should be divisible by num-bands')
|
| 219 |
+
parser.add_argument('--jaccard-parallel', action='store_true',
|
| 220 |
+
help='Use this to process large number of documents.')
|
| 221 |
+
args = parser.parse_args()
|
| 222 |
+
|
| 223 |
+
print('finding possible duplicate content ...')
|
| 224 |
+
|
| 225 |
+
# set seed and get an array of seeds of 100 integers
|
| 226 |
+
np.random.seed(args.seed)
|
| 227 |
+
seeds = np.random.randint(0, 1e6, size=args.num_seeds)
|
| 228 |
+
|
| 229 |
+
# initialize minhash and lsh cache
|
| 230 |
+
hasher = minhash.MinHasher(seeds=seeds, char_ngram=5, hashbytes=4)
|
| 231 |
+
lshcache = cache.Cache(num_bands=args.num_bands, hasher=hasher)
|
| 232 |
+
|
| 233 |
+
url_doc = {}
|
| 234 |
+
|
| 235 |
+
# load fingerprints from pickle file if needed
|
| 236 |
+
if args.load_fingerprints is not None:
|
| 237 |
+
for count_fp, fp_file_name in enumerate(args.load_fingerprints):
|
| 238 |
+
print("Loading fingerprints from pickle file {}".format(
|
| 239 |
+
fp_file_name), flush=True)
|
| 240 |
+
fp = open(fp_file_name, "rb")
|
| 241 |
+
if count_fp == 0:
|
| 242 |
+
# assign directory for the first pkl
|
| 243 |
+
lshcache = pickle.load(fp)
|
| 244 |
+
url_doc = pickle.load(fp)
|
| 245 |
+
else:
|
| 246 |
+
# append these to lshcache and url_doc
|
| 247 |
+
local_lshcache = pickle.load(fp)
|
| 248 |
+
local_url_doc = pickle.load(fp)
|
| 249 |
+
for url in local_lshcache.fingerprints.keys():
|
| 250 |
+
url_doc[url] = local_url_doc[url]
|
| 251 |
+
lshcache.add_fingerprint(local_lshcache.fingerprints[url], url)
|
| 252 |
+
fp.close()
|
| 253 |
+
|
| 254 |
+
counter = 0
|
| 255 |
+
start_time = time.time()
|
| 256 |
+
|
| 257 |
+
# compute finger prints of the inputs if any
|
| 258 |
+
# input file and the key to use as id
|
| 259 |
+
if args.inputs is not None:
|
| 260 |
+
print("Computing fingerprints", flush=True)
|
| 261 |
+
assert len(args.inputs) % 2 == 0
|
| 262 |
+
for input_file, key in zip(args.inputs[::2], args.inputs[1::2]):
|
| 263 |
+
print(' document processing {} with key {}'.format(input_file, key),
|
| 264 |
+
flush=True)
|
| 265 |
+
|
| 266 |
+
# compute fingerprints in parallel
|
| 267 |
+
num_workers = 40
|
| 268 |
+
pool = multiprocessing.Pool(num_workers)
|
| 269 |
+
fin = open(input_file, 'r', encoding='utf-8')
|
| 270 |
+
compute_fingerprint_partial = partial(compute_fingerprint, key=key)
|
| 271 |
+
compute_fingerprint_iter = pool.imap(compute_fingerprint_partial,
|
| 272 |
+
fin, 512)
|
| 273 |
+
# traverse all the texts and add fingerprints
|
| 274 |
+
for url, text, fingerprint, flag in compute_fingerprint_iter:
|
| 275 |
+
counter += 1
|
| 276 |
+
if flag:
|
| 277 |
+
url_doc[url] = text
|
| 278 |
+
lshcache.add_fingerprint(fingerprint, url)
|
| 279 |
+
if counter % 10000 == 0:
|
| 280 |
+
print(' [read]> processed {} documents in {:.2f} '
|
| 281 |
+
'seconds ...'.format(counter, time.time() - \
|
| 282 |
+
start_time), flush=True)
|
| 283 |
+
|
| 284 |
+
fin.close()
|
| 285 |
+
pool.close()
|
| 286 |
+
pool.join()
|
| 287 |
+
|
| 288 |
+
# Save the fingerprints if needed
|
| 289 |
+
if args.save_fingerprints is not None:
|
| 290 |
+
print("Saving fingerprints to pickle file {}".format(
|
| 291 |
+
args.save_fingerprints), flush=True)
|
| 292 |
+
with open(args.save_fingerprints, 'wb') as f_save:
|
| 293 |
+
pickle.dump(lshcache, f_save)
|
| 294 |
+
pickle.dump(url_doc, f_save)
|
| 295 |
+
|
| 296 |
+
# compute jaccard index of the input texts and write to file if needed
|
| 297 |
+
if args.output is not None:
|
| 298 |
+
print("Compute jaccard similarity", flush=True)
|
| 299 |
+
if args.jaccard_parallel:
|
| 300 |
+
find_pair_urls_parallel(args, lshcache, url_doc)
|
| 301 |
+
else:
|
| 302 |
+
find_pair_urls_sequential(args, lshcache, url_doc)
|
| 303 |
+
|
| 304 |
+
print('done :-)')
|
| 305 |
+
|
tools/openwebtext/group_duplicate_url.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
import time
|
| 18 |
+
import sys
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
if __name__ == '__main__':
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
print('grouping duplicate urls ...')
|
| 25 |
+
|
| 26 |
+
input = sys.argv[1]
|
| 27 |
+
output = sys.argv[2]
|
| 28 |
+
if len(sys.argv) > 3:
|
| 29 |
+
jaccard_similarity_threshold = float(sys.argv[3])
|
| 30 |
+
else:
|
| 31 |
+
jaccard_similarity_threshold = 0.7
|
| 32 |
+
|
| 33 |
+
url_to_index = {}
|
| 34 |
+
index_to_urls = []
|
| 35 |
+
counter = 0
|
| 36 |
+
start_time = time.time()
|
| 37 |
+
with open(input, 'r') as f:
|
| 38 |
+
for line in f:
|
| 39 |
+
counter += 1
|
| 40 |
+
myjson = json.loads(line)
|
| 41 |
+
urls = []
|
| 42 |
+
for main_url in myjson.keys():
|
| 43 |
+
urls.append(main_url)
|
| 44 |
+
for value in myjson[main_url]:
|
| 45 |
+
for other_url, js in value.items():
|
| 46 |
+
if js >= jaccard_similarity_threshold:
|
| 47 |
+
urls.append(other_url)
|
| 48 |
+
current_index = -1
|
| 49 |
+
other_indices = set()
|
| 50 |
+
for url in urls:
|
| 51 |
+
if url in url_to_index:
|
| 52 |
+
if current_index == -1:
|
| 53 |
+
current_index = url_to_index[url]
|
| 54 |
+
elif current_index != url_to_index[url]:
|
| 55 |
+
other_indices.add(url_to_index[url])
|
| 56 |
+
if current_index == -1:
|
| 57 |
+
current_index = len(index_to_urls)
|
| 58 |
+
index_to_urls.append(set())
|
| 59 |
+
for url in urls:
|
| 60 |
+
url_to_index[url] = current_index
|
| 61 |
+
index_to_urls[current_index].add(url)
|
| 62 |
+
for index in other_indices:
|
| 63 |
+
for url in index_to_urls[index]:
|
| 64 |
+
index_to_urls[current_index].add(url)
|
| 65 |
+
url_to_index[url] = current_index
|
| 66 |
+
index_to_urls[index] = None
|
| 67 |
+
|
| 68 |
+
if counter % 100000 == 0:
|
| 69 |
+
print(' > processed {} lines in {} seconds ...'.format(
|
| 70 |
+
counter, time.time() - start_time))
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
total_remove = 0
|
| 74 |
+
total_remain = 0
|
| 75 |
+
for urls in index_to_urls:
|
| 76 |
+
if urls is not None:
|
| 77 |
+
if len(urls) > 1:
|
| 78 |
+
total_remove += (len(urls) - 1)
|
| 79 |
+
total_remain += 1
|
| 80 |
+
print('out of {} urls, only {} are unique and {} should be removed'.format(
|
| 81 |
+
total_remove+total_remain, total_remain, total_remove))
|
| 82 |
+
|
| 83 |
+
with open(output, 'wb') as f:
|
| 84 |
+
for i, urls in enumerate(index_to_urls):
|
| 85 |
+
if urls is not None:
|
| 86 |
+
if len(urls) > 1:
|
| 87 |
+
myjson = json.dumps({str(i): list(urls)},
|
| 88 |
+
ensure_ascii=False)
|
| 89 |
+
f.write(myjson.encode('utf-8'))
|
| 90 |
+
f.write('\n'.encode('utf-8'))
|
tools/openwebtext/merge_jsons.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import glob
|
| 18 |
+
import sys
|
| 19 |
+
import json
|
| 20 |
+
import argparse
|
| 21 |
+
|
| 22 |
+
if __name__ == '__main__':
|
| 23 |
+
|
| 24 |
+
parser = argparse.ArgumentParser()
|
| 25 |
+
parser.add_argument("--json_path", type=str, default=".",
|
| 26 |
+
help="path where all the json files are located")
|
| 27 |
+
|
| 28 |
+
parser.add_argument("--output_file", type=str, default="merged_output.json",
|
| 29 |
+
help="filename where the merged json should go")
|
| 30 |
+
|
| 31 |
+
args = parser.parse_args()
|
| 32 |
+
|
| 33 |
+
json_path = args.json_path
|
| 34 |
+
out_file = args.output_file
|
| 35 |
+
|
| 36 |
+
json_files = glob.glob(json_path + '/*.json')
|
| 37 |
+
|
| 38 |
+
counter = 0
|
| 39 |
+
|
| 40 |
+
with open(out_file, 'w') as outfile:
|
| 41 |
+
for fname in json_files:
|
| 42 |
+
counter += 1
|
| 43 |
+
|
| 44 |
+
if counter % 1024 == 0:
|
| 45 |
+
print("Merging at ", counter, flush=True)
|
| 46 |
+
|
| 47 |
+
with open(fname, 'r') as infile:
|
| 48 |
+
for row in infile:
|
| 49 |
+
each_row = json.loads(row)
|
| 50 |
+
outfile.write(row)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
print("Merged file", out_file, flush=True)
|
| 54 |
+
|
| 55 |
+
|
tools/openwebtext/remove_group_duplicates.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import time
|
| 19 |
+
import sys
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
if __name__ == '__main__':
|
| 23 |
+
|
| 24 |
+
url_filename = sys.argv[1]
|
| 25 |
+
data_filename = sys.argv[2]
|
| 26 |
+
output_filename = sys.argv[3]
|
| 27 |
+
|
| 28 |
+
urls = set()
|
| 29 |
+
with open(url_filename, 'r') as f:
|
| 30 |
+
for line in f:
|
| 31 |
+
myjson = json.loads(line)
|
| 32 |
+
for key in myjson:
|
| 33 |
+
this_urls = myjson[key]
|
| 34 |
+
for i in range(1, len(this_urls)):
|
| 35 |
+
urls.add(this_urls[i])
|
| 36 |
+
print('will be removing {} urls'.format(len(urls)), flush=True)
|
| 37 |
+
|
| 38 |
+
written_docs = 0
|
| 39 |
+
removed_docs = 0
|
| 40 |
+
removed_chars = 0
|
| 41 |
+
start_time = time.time()
|
| 42 |
+
with open(output_filename, 'wb') as fout:
|
| 43 |
+
with open(data_filename, 'r') as fin:
|
| 44 |
+
for line in fin:
|
| 45 |
+
try:
|
| 46 |
+
myjson = json.loads(line)
|
| 47 |
+
url = myjson['url']
|
| 48 |
+
if url in urls:
|
| 49 |
+
print('removing', myjson)
|
| 50 |
+
removed_docs += 1
|
| 51 |
+
removed_chars += len(myjson['text'])
|
| 52 |
+
continue
|
| 53 |
+
myjson = json.dumps(myjson, ensure_ascii=False)
|
| 54 |
+
fout.write(myjson.encode('utf-8'))
|
| 55 |
+
fout.write('\n'.encode('utf-8'))
|
| 56 |
+
written_docs += 1
|
| 57 |
+
if written_docs % 10000 == 0:
|
| 58 |
+
print(' [PROCESSED] time (s): {:.2f} | written: {} '
|
| 59 |
+
'| removed: {} (char: {})'.format(
|
| 60 |
+
time.time() - start_time,
|
| 61 |
+
written_docs, removed_docs, removed_chars))
|
| 62 |
+
except Exception as e:
|
| 63 |
+
print('[SKIPPING]', line, e)
|
| 64 |
+
|
| 65 |
+
print(' [PROCESSED] time (s): {:.2f} | written: {} '
|
| 66 |
+
'| removed: {} (char: {})'.format(
|
| 67 |
+
time.time() - start_time,
|
| 68 |
+
written_docs, removed_docs, removed_chars))
|
| 69 |
+
print('done :-)')
|
tools/preprocess_data.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Processing data for pretraining."""
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import json
|
| 20 |
+
import multiprocessing
|
| 21 |
+
import os
|
| 22 |
+
import sys
|
| 23 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
|
| 24 |
+
os.path.pardir)))
|
| 25 |
+
import time
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
try:
|
| 29 |
+
import nltk
|
| 30 |
+
nltk_available = True
|
| 31 |
+
except ImportError:
|
| 32 |
+
nltk_available = False
|
| 33 |
+
|
| 34 |
+
from megatron.tokenizer import build_tokenizer
|
| 35 |
+
from megatron.data import indexed_dataset
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer
|
| 39 |
+
class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars):
|
| 40 |
+
|
| 41 |
+
_period_context_fmt = r"""
|
| 42 |
+
\S* # some word material
|
| 43 |
+
%(SentEndChars)s # a potential sentence ending
|
| 44 |
+
\s* # <-- THIS is what I changed
|
| 45 |
+
(?=(?P<after_tok>
|
| 46 |
+
%(NonWord)s # either other punctuation
|
| 47 |
+
|
|
| 48 |
+
(?P<next_tok>\S+) # <-- Normally you would have \s+ here
|
| 49 |
+
))"""
|
| 50 |
+
|
| 51 |
+
class IdentitySplitter(object):
|
| 52 |
+
def tokenize(self, *text):
|
| 53 |
+
return text
|
| 54 |
+
|
| 55 |
+
class Encoder(object):
|
| 56 |
+
def __init__(self, args):
|
| 57 |
+
self.args = args
|
| 58 |
+
|
| 59 |
+
def initializer(self):
|
| 60 |
+
# Use Encoder class as a container for global data
|
| 61 |
+
Encoder.tokenizer = build_tokenizer(self.args)
|
| 62 |
+
if self.args.split_sentences:
|
| 63 |
+
if not nltk_available:
|
| 64 |
+
print("NLTK is not available to split sentences.")
|
| 65 |
+
exit()
|
| 66 |
+
splitter = nltk.load("tokenizers/punkt/english.pickle")
|
| 67 |
+
if self.args.keep_newlines:
|
| 68 |
+
# this prevents punkt from eating newlines after sentences
|
| 69 |
+
Encoder.splitter = nltk.tokenize.punkt.PunktSentenceTokenizer(
|
| 70 |
+
train_text = splitter._params,
|
| 71 |
+
lang_vars = CustomLanguageVars())
|
| 72 |
+
else:
|
| 73 |
+
Encoder.splitter = splitter
|
| 74 |
+
|
| 75 |
+
else:
|
| 76 |
+
Encoder.splitter = IdentitySplitter()
|
| 77 |
+
|
| 78 |
+
def encode(self, json_line):
|
| 79 |
+
data = json.loads(json_line)
|
| 80 |
+
ids = {}
|
| 81 |
+
for key in self.args.json_keys:
|
| 82 |
+
text = data[key]
|
| 83 |
+
doc_ids = []
|
| 84 |
+
for sentence in Encoder.splitter.tokenize(text):
|
| 85 |
+
sentence_ids = Encoder.tokenizer.tokenize(sentence)
|
| 86 |
+
if len(sentence_ids) > 0:
|
| 87 |
+
doc_ids.append(sentence_ids)
|
| 88 |
+
if len(doc_ids) > 0 and self.args.append_eod:
|
| 89 |
+
doc_ids[-1].append(Encoder.tokenizer.eod)
|
| 90 |
+
ids[key] = doc_ids
|
| 91 |
+
return ids, len(json_line)
|
| 92 |
+
|
| 93 |
+
def get_args():
|
| 94 |
+
parser = argparse.ArgumentParser()
|
| 95 |
+
group = parser.add_argument_group(title='input data')
|
| 96 |
+
group.add_argument('--input', type=str, required=True,
|
| 97 |
+
help='Path to input JSON')
|
| 98 |
+
group.add_argument('--json-keys', nargs='+', default=['text'],
|
| 99 |
+
help='space separate listed of keys to extract from json')
|
| 100 |
+
group.add_argument('--split-sentences', action='store_true',
|
| 101 |
+
help='Split documents into sentences.')
|
| 102 |
+
group.add_argument('--keep-newlines', action='store_true',
|
| 103 |
+
help='Keep newlines between sentences when splitting.')
|
| 104 |
+
|
| 105 |
+
group = parser.add_argument_group(title='tokenizer')
|
| 106 |
+
group.add_argument('--tokenizer-type', type=str, required=True,
|
| 107 |
+
choices=['BertWordPieceLowerCase','BertWordPieceCase',
|
| 108 |
+
'GPT2BPETokenizer'],
|
| 109 |
+
help='What type of tokenizer to use.')
|
| 110 |
+
group.add_argument('--vocab-file', type=str, default=None,
|
| 111 |
+
help='Path to the vocab file')
|
| 112 |
+
group.add_argument('--merge-file', type=str, default=None,
|
| 113 |
+
help='Path to the BPE merge file (if necessary).')
|
| 114 |
+
group.add_argument('--append-eod', action='store_true',
|
| 115 |
+
help='Append an <eod> token to the end of a document.')
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
group = parser.add_argument_group(title='output data')
|
| 119 |
+
group.add_argument('--output-prefix', type=str, required=True,
|
| 120 |
+
help='Path to binary output file without suffix')
|
| 121 |
+
group.add_argument('--dataset-impl', type=str, default='mmap',
|
| 122 |
+
choices=['lazy', 'cached', 'mmap'])
|
| 123 |
+
|
| 124 |
+
group = parser.add_argument_group(title='runtime')
|
| 125 |
+
group.add_argument('--workers', type=int, required=True,
|
| 126 |
+
help='Number of worker processes to launch')
|
| 127 |
+
group.add_argument('--chunk-size', type=int, required=True,
|
| 128 |
+
help='Chunk size assigned to each worker process')
|
| 129 |
+
group.add_argument('--log-interval', type=int, default=100,
|
| 130 |
+
help='Interval between progress updates')
|
| 131 |
+
args = parser.parse_args()
|
| 132 |
+
args.keep_empty = False
|
| 133 |
+
|
| 134 |
+
if args.tokenizer_type.lower().startswith('bert'):
|
| 135 |
+
if not args.split_sentences:
|
| 136 |
+
print("Bert tokenizer detected, are you sure you don't want to split sentences?")
|
| 137 |
+
|
| 138 |
+
# some default/dummy values for the tokenizer
|
| 139 |
+
args.rank = 0
|
| 140 |
+
args.make_vocab_size_divisible_by = 128
|
| 141 |
+
args.tensor_model_parallel_size = 1
|
| 142 |
+
args.vocab_extra_ids = 0
|
| 143 |
+
|
| 144 |
+
return args
|
| 145 |
+
|
| 146 |
+
def main():
|
| 147 |
+
args = get_args()
|
| 148 |
+
startup_start = time.time()
|
| 149 |
+
|
| 150 |
+
print("Opening", args.input)
|
| 151 |
+
fin = open(args.input, 'r', encoding='utf-8')
|
| 152 |
+
|
| 153 |
+
if nltk_available and args.split_sentences:
|
| 154 |
+
nltk.download("punkt", quiet=True)
|
| 155 |
+
|
| 156 |
+
encoder = Encoder(args)
|
| 157 |
+
tokenizer = build_tokenizer(args)
|
| 158 |
+
pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
|
| 159 |
+
encoded_docs = pool.imap(encoder.encode, fin, args.chunk_size)
|
| 160 |
+
#encoded_docs = map(encoder.encode, fin)
|
| 161 |
+
|
| 162 |
+
level = "document"
|
| 163 |
+
if args.split_sentences:
|
| 164 |
+
level = "sentence"
|
| 165 |
+
|
| 166 |
+
print(f"Vocab size: {tokenizer.vocab_size}")
|
| 167 |
+
print(f"Output prefix: {args.output_prefix}")
|
| 168 |
+
output_bin_files = {}
|
| 169 |
+
output_idx_files = {}
|
| 170 |
+
builders = {}
|
| 171 |
+
for key in args.json_keys:
|
| 172 |
+
output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix,
|
| 173 |
+
key, level)
|
| 174 |
+
output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix,
|
| 175 |
+
key, level)
|
| 176 |
+
builders[key] = indexed_dataset.make_builder(output_bin_files[key],
|
| 177 |
+
impl=args.dataset_impl,
|
| 178 |
+
vocab_size=tokenizer.vocab_size)
|
| 179 |
+
|
| 180 |
+
startup_end = time.time()
|
| 181 |
+
proc_start = time.time()
|
| 182 |
+
total_bytes_processed = 0
|
| 183 |
+
print("Time to startup:", startup_end - startup_start)
|
| 184 |
+
|
| 185 |
+
for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
|
| 186 |
+
total_bytes_processed += bytes_processed
|
| 187 |
+
for key, sentences in doc.items():
|
| 188 |
+
if len(sentences) == 0:
|
| 189 |
+
continue
|
| 190 |
+
for sentence in sentences:
|
| 191 |
+
builders[key].add_item(torch.IntTensor(sentence))
|
| 192 |
+
builders[key].end_document()
|
| 193 |
+
if i % args.log_interval == 0:
|
| 194 |
+
current = time.time()
|
| 195 |
+
elapsed = current - proc_start
|
| 196 |
+
mbs = total_bytes_processed/elapsed/1024/1024
|
| 197 |
+
print(f"Processed {i} documents",
|
| 198 |
+
f"({i/elapsed} docs/s, {mbs} MB/s).",
|
| 199 |
+
file=sys.stderr)
|
| 200 |
+
|
| 201 |
+
for key in args.json_keys:
|
| 202 |
+
builders[key].finalize(output_idx_files[key])
|
| 203 |
+
|
| 204 |
+
if __name__ == '__main__':
|
| 205 |
+
main()
|
tools/run_build_data.sh
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
#INPUT="roberta_train_data_raw/valid.json"
|
| 3 |
+
INPUT="/mnt/nvme0/ouyangxuan/project_pretrain/make_pretrain_data/roberta_train_data_raw/valid.json"
|
| 4 |
+
python preprocess_data.py \
|
| 5 |
+
--input ${INPUT} \
|
| 6 |
+
--output-prefix my-bert \
|
| 7 |
+
--vocab bert-vocab.txt \
|
| 8 |
+
--dataset-impl mmap \
|
| 9 |
+
--worker 1 \
|
| 10 |
+
--chunk-size 1 \
|
| 11 |
+
--tokenizer-type BertWordPieceLowerCase \
|
| 12 |
+
--split-sentences
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
#--input /mnt/nvme1/ouyangxuan/project_pretrain/find_framework/tmp_data/data.json \
|
| 16 |
+
#--input roberta_train_data_raw/train_1g.json \
|
tools/run_text_generation_server.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Sample Generate GPT"""
|
| 17 |
+
import os
|
| 18 |
+
import sys
|
| 19 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
|
| 20 |
+
os.path.pardir)))
|
| 21 |
+
import socket
|
| 22 |
+
from megatron import get_args
|
| 23 |
+
from megatron import print_rank_0
|
| 24 |
+
from megatron import mpu
|
| 25 |
+
from megatron.checkpointing import load_checkpoint
|
| 26 |
+
from megatron.initialize import initialize_megatron
|
| 27 |
+
from megatron.model import GPTModel
|
| 28 |
+
from megatron.training import get_model
|
| 29 |
+
from megatron.text_generation_server import MegatronServer
|
| 30 |
+
from megatron.text_generation import generate_and_post_process
|
| 31 |
+
from megatron.text_generation import beam_search_and_post_process
|
| 32 |
+
import torch
|
| 33 |
+
|
| 34 |
+
def model_provider(pre_process=True, post_process=True):
|
| 35 |
+
"""Build the model."""
|
| 36 |
+
|
| 37 |
+
print_rank_0('building GPT model ...')
|
| 38 |
+
model = GPTModel(num_tokentypes=0, parallel_output=False, pre_process=pre_process, post_process=post_process)
|
| 39 |
+
|
| 40 |
+
return model
|
| 41 |
+
|
| 42 |
+
def add_text_generate_args(parser):
|
| 43 |
+
group = parser.add_argument_group(title='text generation')
|
| 44 |
+
|
| 45 |
+
group.add_argument("--temperature", type=float, default=1.0,
|
| 46 |
+
help='Sampling temperature.')
|
| 47 |
+
group.add_argument("--top_p", type=float, default=0.0,
|
| 48 |
+
help='Top p sampling.')
|
| 49 |
+
group.add_argument("--top_k", type=int, default=0,
|
| 50 |
+
help='Top k sampling.')
|
| 51 |
+
group.add_argument("--out-seq-length", type=int, default=1024,
|
| 52 |
+
help='Size of the output generated text.')
|
| 53 |
+
return parser
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
if __name__ == "__main__":
|
| 57 |
+
initialize_megatron(extra_args_provider=add_text_generate_args,
|
| 58 |
+
args_defaults={'tokenizer_type': 'GPT2BPETokenizer',
|
| 59 |
+
'no_load_rng': True,
|
| 60 |
+
'no_load_optim': True})
|
| 61 |
+
|
| 62 |
+
args = get_args()
|
| 63 |
+
if args.num_layers_per_virtual_pipeline_stage is not None:
|
| 64 |
+
print("Interleaved pipeline schedule is not yet supported for text generation.")
|
| 65 |
+
exit()
|
| 66 |
+
# Set up model and load checkpoint
|
| 67 |
+
model = get_model(model_provider, wrap_with_ddp=False)
|
| 68 |
+
|
| 69 |
+
if args.load is not None:
|
| 70 |
+
_ = load_checkpoint(model, None, None)
|
| 71 |
+
|
| 72 |
+
assert len(model) == 1, "Above condition should have caught this"
|
| 73 |
+
model = model[0]
|
| 74 |
+
if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
|
| 75 |
+
server = MegatronServer(model)
|
| 76 |
+
server.run("0.0.0.0")
|
| 77 |
+
|
| 78 |
+
while True:
|
| 79 |
+
choice = torch.cuda.LongTensor(1)
|
| 80 |
+
torch.distributed.broadcast(choice, 0)
|
| 81 |
+
if choice[0].item() == 0:
|
| 82 |
+
try:
|
| 83 |
+
generate_and_post_process(model)
|
| 84 |
+
except ValueError as ve:
|
| 85 |
+
pass
|
| 86 |
+
elif choice[0].item() == 1:
|
| 87 |
+
try:
|
| 88 |
+
beam_search_and_post_process(model)
|
| 89 |
+
except ValueError as ve:
|
| 90 |
+
pass
|
tools/text_generation_cli.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
import json
|
| 16 |
+
import sys
|
| 17 |
+
import urllib2
|
| 18 |
+
class PutRequest(urllib2.Request):
|
| 19 |
+
'''class to handling putting with urllib2'''
|
| 20 |
+
|
| 21 |
+
def get_method(self, *args, **kwargs):
|
| 22 |
+
return 'PUT'
|
| 23 |
+
|
| 24 |
+
if __name__ == "__main__":
|
| 25 |
+
url = sys.argv[1]
|
| 26 |
+
while True:
|
| 27 |
+
sentence = raw_input("Enter prompt: ")
|
| 28 |
+
tokens_to_generate = int(input("Enter number of tokens to generate: "))
|
| 29 |
+
data = json.dumps({"prompts": [sentence], "tokens_to_generate":tokens_to_generate})
|
| 30 |
+
req = PutRequest(url, data, {'Content-Type': 'application/json'})
|
| 31 |
+
response = urllib2.urlopen(req)
|
| 32 |
+
resp_sentences = json.load(response)
|
| 33 |
+
print("Megatron Response: ")
|
| 34 |
+
print(resp_sentences["text"][0])
|