|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import pickle |
|
|
import os |
|
|
import argparse |
|
|
import numpy as np |
|
|
|
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from mmpt.processors import PKLJSONStrTextProcessor |
|
|
from mmpt.utils import ShardedTensor, recursive_config |
|
|
|
|
|
|
|
|
class TokenizerDataset(Dataset): |
|
|
def __init__(self, config): |
|
|
self.text_processor = PKLJSONStrTextProcessor(config) |
|
|
self.video_ids = list(self.text_processor.data.keys()) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
video_id = self.video_ids[idx] |
|
|
return video_id, self.text_processor(video_id) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.video_ids) |
|
|
|
|
|
|
|
|
def numpify(shard_idx, video_ids, captions, target_dir, split, prefix, max_cap_len=32): |
|
|
startends = [] |
|
|
caps_ids = [] |
|
|
for video_id in video_ids: |
|
|
caption = captions[video_id] |
|
|
startend = [] |
|
|
cap_ids = [] |
|
|
for start, end, cap in zip( |
|
|
caption["start"], caption["end"], caption["cap"]): |
|
|
startend.append(np.array([start, end]).astype("float32")) |
|
|
cap_id = np.full((max_cap_len,), -1, dtype=np.int32) |
|
|
cap = cap[:max_cap_len] |
|
|
cap_id[:len(cap)] = cap |
|
|
cap_ids.append(cap_id) |
|
|
startends.append(np.stack(startend)) |
|
|
caps_ids.append(np.stack(cap_ids)) |
|
|
|
|
|
startends = ShardedTensor.from_list(startends) |
|
|
target_path = os.path.join( |
|
|
target_dir, |
|
|
prefix + split + "_" + str(shard_idx) |
|
|
) |
|
|
print("save to", target_path) |
|
|
startends.save(target_path + ".startends") |
|
|
caps_ids = ShardedTensor.from_list(caps_ids) |
|
|
caps_ids.save(target_path + ".caps_ids") |
|
|
|
|
|
|
|
|
def sharding(config, out_file): |
|
|
with open(out_file, "rb") as fr: |
|
|
captions = pickle.load(fr) |
|
|
target_dir = config.target_dir |
|
|
prefix = os.path.basename( |
|
|
os.path.splitext(config.caption_pkl_path)[0] |
|
|
) + "." + config.bert_name + "." |
|
|
for split in ["train", "val"]: |
|
|
target_path = os.path.join(target_dir, split + "_meta") |
|
|
with open(target_path + ".pkl", "rb") as fr: |
|
|
meta = pickle.load(fr) |
|
|
print("load meta", target_path, len(meta)) |
|
|
for shard_id in meta: |
|
|
numpify( |
|
|
shard_id, meta[shard_id], captions, |
|
|
target_dir, split, prefix |
|
|
) |
|
|
|
|
|
|
|
|
def tokenize(config, out_file): |
|
|
def collator(samples): |
|
|
return samples |
|
|
dataset = TokenizerDataset(config) |
|
|
data = {} |
|
|
for idx, batch in enumerate( |
|
|
DataLoader(dataset, collate_fn=collator, num_workers=16)): |
|
|
for video_id, caption in batch: |
|
|
data[video_id] = caption |
|
|
if idx % 5000 == 0: |
|
|
print(idx) |
|
|
with open(out_file, "wb") as fw: |
|
|
pickle.dump(data, fw, pickle.HIGHEST_PROTOCOL) |
|
|
|
|
|
|
|
|
def main(args): |
|
|
config = recursive_config(args.config).dataset |
|
|
|
|
|
out_file = os.path.splitext(config.caption_pkl_path)[0] \ |
|
|
+ "." + config.bert_name + ".pkl" |
|
|
if not os.path.isfile(out_file): |
|
|
tokenize(config, out_file) |
|
|
sharding(config, out_file) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser( |
|
|
description="pretokenize (raw_)caption.json into pkl.") |
|
|
parser.add_argument('config', type=str) |
|
|
args = parser.parse_args() |
|
|
main(args) |
|
|
|