PyTorch
ssl-aasist
custom_code
ash56's picture
Add files using upload-large-folder tool
878264b verified
raw
history blame
3.41 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import pickle
import os
import argparse
import numpy as np
from torch.utils.data import Dataset, DataLoader
from mmpt.processors import PKLJSONStrTextProcessor
from mmpt.utils import ShardedTensor, recursive_config
class TokenizerDataset(Dataset):
def __init__(self, config):
self.text_processor = PKLJSONStrTextProcessor(config)
self.video_ids = list(self.text_processor.data.keys())
def __getitem__(self, idx):
video_id = self.video_ids[idx]
return video_id, self.text_processor(video_id)
def __len__(self):
return len(self.video_ids)
def numpify(shard_idx, video_ids, captions, target_dir, split, prefix, max_cap_len=32):
startends = []
caps_ids = []
for video_id in video_ids:
caption = captions[video_id]
startend = []
cap_ids = []
for start, end, cap in zip(
caption["start"], caption["end"], caption["cap"]):
startend.append(np.array([start, end]).astype("float32"))
cap_id = np.full((max_cap_len,), -1, dtype=np.int32)
cap = cap[:max_cap_len]
cap_id[:len(cap)] = cap
cap_ids.append(cap_id)
startends.append(np.stack(startend))
caps_ids.append(np.stack(cap_ids))
startends = ShardedTensor.from_list(startends)
target_path = os.path.join(
target_dir,
prefix + split + "_" + str(shard_idx)
)
print("save to", target_path)
startends.save(target_path + ".startends")
caps_ids = ShardedTensor.from_list(caps_ids)
caps_ids.save(target_path + ".caps_ids")
def sharding(config, out_file):
with open(out_file, "rb") as fr:
captions = pickle.load(fr)
target_dir = config.target_dir
prefix = os.path.basename(
os.path.splitext(config.caption_pkl_path)[0]
) + "." + config.bert_name + "."
for split in ["train", "val"]:
target_path = os.path.join(target_dir, split + "_meta")
with open(target_path + ".pkl", "rb") as fr:
meta = pickle.load(fr)
print("load meta", target_path, len(meta))
for shard_id in meta:
numpify(
shard_id, meta[shard_id], captions,
target_dir, split, prefix
)
def tokenize(config, out_file):
def collator(samples):
return samples
dataset = TokenizerDataset(config)
data = {}
for idx, batch in enumerate(
DataLoader(dataset, collate_fn=collator, num_workers=16)):
for video_id, caption in batch:
data[video_id] = caption
if idx % 5000 == 0:
print(idx)
with open(out_file, "wb") as fw:
pickle.dump(data, fw, pickle.HIGHEST_PROTOCOL)
def main(args):
config = recursive_config(args.config).dataset
out_file = os.path.splitext(config.caption_pkl_path)[0] \
+ "." + config.bert_name + ".pkl"
if not os.path.isfile(out_file):
tokenize(config, out_file)
sharding(config, out_file)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="pretokenize (raw_)caption.json into pkl.")
parser.add_argument('config', type=str)
args = parser.parse_args()
main(args)