| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | import torch |
| | import math |
| | import pickle |
| | import random |
| | import os |
| | import numpy as np |
| |
|
| | from collections import deque |
| | from typing import Optional, Tuple, List |
| | from .processor import ( |
| | Processor, |
| | MetaProcessor, |
| | TextProcessor, |
| | Aligner, |
| | MMAttentionMask2DProcessor |
| | ) |
| |
|
| | from ..utils import ShardedTensor |
| |
|
| |
|
| | class How2MetaProcessor(MetaProcessor): |
| | def __init__(self, config): |
| | super().__init__(config) |
| | path = self._get_split_path(config) |
| | with open(path) as fd: |
| | self.data = [line.strip() for line in fd] |
| |
|
| | def __getitem__(self, idx): |
| | video_id = self.data[idx] |
| | return video_id, video_id |
| |
|
| |
|
| | class ShardedHow2MetaProcessor(How2MetaProcessor): |
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.split = str(config.split) |
| | self.vfeat_dir = config.vfeat_dir |
| | self._init_shard() |
| |
|
| | def _init_shard(self): |
| | if self.split == "train": |
| | meta_fn = os.path.join(self.vfeat_dir, "train" + "_meta.pkl") |
| | with open(meta_fn, "rb") as fr: |
| | meta = pickle.load(fr) |
| | elif self.split == "valid": |
| | meta_fn = os.path.join(self.vfeat_dir, "val" + "_meta.pkl") |
| | with open(meta_fn, "rb") as fr: |
| | meta = pickle.load(fr) |
| | elif self.split == "test": |
| | print("use how2 val as test.") |
| | meta_fn = os.path.join(self.vfeat_dir, "val" + "_meta.pkl") |
| | with open(meta_fn, "rb") as fr: |
| | meta = pickle.load(fr) |
| | else: |
| | raise ValueError("unsupported for MetaProcessor:", self.split) |
| | video_id_to_shard = {} |
| | for shard_id in meta: |
| | for video_idx, video_id in enumerate(meta[shard_id]): |
| | video_id_to_shard[video_id] = (shard_id, video_idx) |
| | self.video_id_to_shard = video_id_to_shard |
| |
|
| | def __getitem__(self, idx): |
| | video_id, video_id = super().__getitem__(idx) |
| | shard_id, shard_idx = self.video_id_to_shard[video_id] |
| | meta = (video_id, idx, shard_id, shard_idx) |
| | return meta, meta |
| |
|
| |
|
| | class ShardedVideoProcessor(Processor): |
| | """ |
| | mmaped shards of numpy video features. |
| | """ |
| |
|
| | def __init__(self, config): |
| | self.split = str(config.split) |
| | self.vfeat_dir = config.vfeat_dir |
| |
|
| | def __call__(self, video_id): |
| | _, _, shard_id, video_idx = video_id |
| | if self.split == "train": |
| | shard = ShardedTensor.load( |
| | os.path.join(self.vfeat_dir, "train" + "_" + str(shard_id)), |
| | "r" |
| | ) |
| | elif self.split == "valid": |
| | shard = ShardedTensor.load( |
| | os.path.join(self.vfeat_dir, "val" + "_" + str(shard_id)), |
| | "r" |
| | ) |
| | elif self.split == "test": |
| | shard = ShardedTensor.load( |
| | os.path.join(self.vfeat_dir, "val" + "_" + str(shard_id)), |
| | "r" |
| | ) |
| | else: |
| | raise ValueError("unknown split", self.split) |
| | feat = shard[video_idx] |
| | return feat |
| |
|
| |
|
| | class ShardedTextProcessor(Processor): |
| | def __init__(self, config): |
| | self.tfeat_dir = str(config.tfeat_dir) |
| | self.split = str(config.split) |
| |
|
| | def __call__(self, video_id): |
| | _, _, shard_id, shard_idx = video_id |
| | if self.split == "train": |
| | target_path = self.tfeat_dir + "train" + "_" + str(shard_id) |
| | elif self.split == "valid": |
| | target_path = self.tfeat_dir + "val" + "_" + str(shard_id) |
| | elif self.split == "test": |
| | target_path = self.tfeat_dir + "val" + "_" + str(shard_id) |
| | else: |
| | raise ValueError("unknown split", self.split) |
| |
|
| | startend = ShardedTensor.load( |
| | target_path + ".startends", "r")[shard_idx] |
| | cap_ids = ShardedTensor.load( |
| | target_path + ".caps_ids", "r")[shard_idx] |
| | cap = [] |
| | for clip_idx in range(len(cap_ids)): |
| | clip = cap_ids[clip_idx] |
| | cap.append(clip[clip != -1].tolist()) |
| | start, end = startend[:, 0].tolist(), startend[:, 1].tolist() |
| | return {"start": start, "end": end, "cap": cap} |
| |
|
| |
|
| | class FixedLenAligner(Aligner): |
| | """ |
| | In the model we assume text is on the left (closer to BERT formulation) |
| | and video is on the right. |
| | We fix the total length of text + video. |
| | max_video_len is in number of secs. |
| | max_text_len is in number of tokens. |
| | |
| | special tokens formats: |
| | we use the format [CLS] [SEP] text tokens [SEP] [PAD] ... |
| | [CLS] will be splitted out into: |
| | [CLS] video tokens [SEP] text tokens [SEP] [PAD] ... |
| | token_type_ids will be generated by the model (for now). |
| | 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 |
| | | first sequence | second sequence | |
| | so each sequence owns a [SEP] token for no-ops. |
| | """ |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.text_clip_sampler = TextClipSamplingProcessor( |
| | self.max_len - self.max_video_len - 3 |
| | ) |
| | """ |
| | decide subsampling: |
| | `config.subsampling` will change batch_size in trainer. |
| | `config.clip_per_video` (used by RetriTask) doesn't |
| | change batch_size in trainer. |
| | """ |
| | subsampling = config.subsampling \ |
| | if config.subsampling is not None else None |
| | if config.clip_per_video is not None: |
| | subsampling = config.clip_per_video |
| | self.subsampling = subsampling |
| |
|
| | def _get_text_maxlen(self): |
| | |
| | return self.text_clip_sampler.max_text_len |
| |
|
| | def __call__(self, video_id, video_feature, text_feature): |
| | from transformers import default_data_collator |
| | video_idx = video_id[1] |
| | if self.subsampling is not None and self.subsampling >= 1: |
| | batch = [] |
| | for _ in range(self.subsampling): |
| | centerclip_idx = random.randint( |
| | 0, len(text_feature["start"]) - 1) |
| | batch.append( |
| | self.sampling( |
| | video_idx, |
| | video_feature, |
| | text_feature, |
| | centerclip_idx, |
| | self._get_text_maxlen() |
| | )) |
| | batch = self.batch_post_processing(batch, video_feature) |
| | batch = default_data_collator(batch) |
| | else: |
| | raise ValueError( |
| | "dataset.subsampling must be >= 1 for efficient video loading.") |
| | batch = self.sampling(video_idx, video_feature, text_feature) |
| | batch = self.batch_post_processing(batch, video_feature) |
| |
|
| | batch["video_id"] = video_id if isinstance(video_id, str) \ |
| | else video_id[0] |
| | |
| | assert torch.is_tensor(batch["vfeats"]) |
| | return batch |
| |
|
| | def sampling( |
| | self, |
| | video_idx, |
| | video_feature, |
| | text_feature, |
| | centerclip_idx=None, |
| | sampled_max_text_len=None, |
| | ): |
| | text_clip_indexs = self.text_clip_sampler( |
| | text_feature, centerclip_idx, |
| | sampled_max_text_len |
| | ) |
| | if isinstance(video_feature, np.ndarray): |
| | video_len = len(video_feature) |
| | else: |
| | video_len = math.ceil(text_feature["end"][-1]) |
| |
|
| | video_end = min( |
| | math.ceil(text_feature["end"][text_clip_indexs[-1]]), |
| | video_len |
| | ) |
| | video_start = max( |
| | min( |
| | math.floor(text_feature["start"][text_clip_indexs[0]]), |
| | video_end), |
| | 0 |
| | ) |
| |
|
| | video_clips = {"start": [video_start], "end": [video_end]} |
| |
|
| | |
| | vfeats, vmasks = self._build_video_seq( |
| | video_feature, video_clips |
| | ) |
| | caps, cmasks = self._build_text_seq( |
| | text_feature, text_clip_indexs |
| | ) |
| |
|
| | text_start = text_clip_indexs[0] |
| | text_end = text_clip_indexs[-1] + 1 |
| |
|
| | return { |
| | "caps": caps, |
| | "cmasks": cmasks, |
| | "vfeats": vfeats, |
| | "vmasks": vmasks, |
| | "video_start": video_start, |
| | "video_end": video_end, |
| | "text_start": text_start, |
| | "text_end": text_end, |
| | } |
| |
|
| |
|
| | class VariedLenAligner(FixedLenAligner): |
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.sampled_min_len = config.sampled_min_len |
| | self.sampled_max_len = config.sampled_max_len |
| |
|
| | def _get_text_maxlen(self): |
| | return random.randint(self.sampled_min_len, self.sampled_max_len) |
| |
|
| |
|
| | class StartClipAligner(VariedLenAligner): |
| | def sampling( |
| | self, |
| | video_idx, |
| | video_feature, |
| | text_feature, |
| | centerclip_idx=None, |
| | sampled_max_text_len=None, |
| | ): |
| | return super().sampling( |
| | video_idx, video_feature, text_feature, 0) |
| |
|
| |
|
| | class OverlappedAligner(VariedLenAligner): |
| | """video clip and text clip has overlappings |
| | but may not be the same start/end.""" |
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.sampled_video_min_len = config.sampled_video_min_len |
| | self.sampled_video_max_len = config.sampled_video_max_len |
| |
|
| | self.video_clip_sampler = VideoClipSamplingProcessor() |
| |
|
| | def _get_video_maxlen(self): |
| | return random.randint( |
| | self.sampled_video_min_len, self.sampled_video_max_len) |
| |
|
| | def sampling( |
| | self, |
| | video_idx, |
| | video_feature, |
| | text_feature, |
| | centerclip_idx=None, |
| | sampled_max_text_len=None, |
| | ): |
| | text_clip_indexs = self.text_clip_sampler( |
| | text_feature, centerclip_idx, |
| | sampled_max_text_len |
| | ) |
| | if isinstance(video_feature, np.ndarray): |
| | video_len = len(video_feature) |
| | else: |
| | video_len = math.ceil(text_feature["end"][-1]) |
| | low = math.floor(text_feature["start"][text_clip_indexs[0]]) |
| | high = math.ceil(text_feature["end"][text_clip_indexs[-1]]) |
| | if low < high: |
| | center = random.randint(low, high) |
| | else: |
| | center = int((low + high) // 2) |
| | center = max(0, min(video_feature.shape[0] - 1, center)) |
| |
|
| | assert 0 <= center < video_feature.shape[0] |
| |
|
| | video_clips = self.video_clip_sampler( |
| | video_len, self._get_video_maxlen(), center |
| | ) |
| | video_start = video_clips["start"][0] |
| | video_end = video_clips["end"][0] |
| |
|
| | |
| | vfeats, vmasks = self._build_video_seq( |
| | video_feature, video_clips |
| | ) |
| | caps, cmasks = self._build_text_seq( |
| | text_feature, text_clip_indexs |
| | ) |
| |
|
| | text_start = text_clip_indexs[0] |
| | text_end = text_clip_indexs[-1] + 1 |
| |
|
| | return { |
| | "caps": caps, |
| | "cmasks": cmasks, |
| | "vfeats": vfeats, |
| | "vmasks": vmasks, |
| | "video_start": video_start, |
| | "video_end": video_end, |
| | "text_start": text_start, |
| | "text_end": text_end, |
| | } |
| |
|
| |
|
| | class MFMMLMAligner(FixedLenAligner): |
| | """ |
| | `FixedLenAligner` with Masked Language Model and Masked Frame Model. |
| | """ |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | keep_prob = config.keep_prob if config.keep_prob is not None else 1.0 |
| | self.text_clip_sampler = TextClipSamplingProcessor( |
| | self.max_len - self.max_video_len - 3, keep_prob |
| | ) |
| | self.sampled_min_len = config.sampled_min_len |
| | self.sampled_max_len = config.sampled_max_len |
| | self.masked_token_sampler = TextMaskingProcessor(config) |
| | self.mm_type = config.mm_type \ |
| | if config.mm_type is not None else "full" |
| | self.attnmasker = MMAttentionMask2DProcessor() \ |
| | if self.mm_type == "textgen" else None |
| | self.masked_frame_sampler = FrameMaskingProcessor(config) |
| | self.lazy_vfeat_mask = ( |
| | False if config.lazy_vfeat_mask is None else config.lazy_vfeat_mask |
| | ) |
| | self.mm_prob = config.mm_prob if config.mm_prob is not None else 0. |
| |
|
| | def __call__(self, video_id, video_feature, text_feature): |
| | from transformers import default_data_collator |
| | if self.subsampling is not None and self.subsampling > 1: |
| | batch = [] |
| | for _ in range(self.subsampling): |
| | centerclip_idx = random.randint( |
| | 0, len(text_feature["start"]) - 1) |
| | sampled_max_text_len = random.randint( |
| | self.sampled_min_len, self.sampled_max_len |
| | ) |
| | batch.append( |
| | self.sampling( |
| | video_id, |
| | video_feature, |
| | text_feature, |
| | centerclip_idx, |
| | sampled_max_text_len, |
| | ) |
| | ) |
| | batch = self.batch_post_processing(batch, video_feature) |
| | batch = default_data_collator(batch) |
| | else: |
| | batch = self.sampling(video_id, video_feature, text_feature) |
| | batch = self.batch_post_processing(batch, video_feature) |
| | batch["video_id"] = video_id if isinstance(video_id, str) \ |
| | else video_id[0] |
| | return batch |
| |
|
| | def sampling( |
| | self, |
| | video_id, |
| | video_feature, |
| | text_feature, |
| | centerclip_idx=None, |
| | sampled_max_text_len=None, |
| | ): |
| | output = FixedLenAligner.sampling(self, |
| | video_id, video_feature, text_feature, |
| | centerclip_idx, sampled_max_text_len) |
| |
|
| | masking_text, masking_video = None, None |
| | if random.random() < self.mm_prob: |
| | if random.random() > 0.5: |
| | masking_text, masking_video = self.mm_type, "no" |
| | else: |
| | masking_text, masking_video = "no", "full" |
| | video_feats = output["vfeats"] if not self.lazy_vfeat_mask else None |
| | video_label = self.masked_frame_sampler( |
| | output["vmasks"], masking_video, vfeats=video_feats) |
| | caps, text_label = self.masked_token_sampler( |
| | output["caps"], masking_text) |
| |
|
| | output.update({ |
| | "caps": caps, |
| | "video_label": video_label, |
| | "text_label": text_label, |
| | }) |
| |
|
| | if self.attnmasker is not None: |
| | attention_mask = self.attnmasker( |
| | output["vmasks"], output["cmasks"], masking_text) |
| | output.update({ |
| | "attention_mask": attention_mask |
| | }) |
| | return output |
| |
|
| |
|
| | class FrameMaskingProcessor(Processor): |
| | def __init__(self, config): |
| | self.mfm_probability = 0.15 |
| | if config.mfm_probability is not None: |
| | self.mfm_probability = config.mfm_probability |
| |
|
| | def __call__(self, vmasks, modality_masking=None, vfeats=None): |
| | """ |
| | We perform lazy masking to save data transfer time. |
| | It only generates video_labels by default and MFM model |
| | will do actualy masking. |
| | Return: `video_label` is a binary mask. |
| | """ |
| | video_label = vmasks.clone() |
| | if modality_masking is not None: |
| | if modality_masking == "full": |
| | probability_matrix = torch.full(video_label.shape, 1.) |
| | elif modality_masking == "no": |
| | probability_matrix = torch.full(video_label.shape, 0.) |
| | elif modality_masking == "inverse": |
| | probability_matrix = torch.full( |
| | video_label.shape, 1. - self.mfm_probability) |
| | else: |
| | raise ValueError("unknown modality masking.", modality_masking) |
| | else: |
| | probability_matrix = torch.full( |
| | video_label.shape, self.mfm_probability) |
| | masked_indices = torch.bernoulli(probability_matrix).bool() |
| | |
| | video_label[~masked_indices] = 0 |
| | if vfeats is not None: |
| | vfeats[video_label, :] = 0.0 |
| | return video_label |
| |
|
| |
|
| | class TextGenerationProcessor(Processor): |
| | def __init__(self, tokenizer): |
| | self.bos_token_id = tokenizer.bos_token_id |
| | self.pad_token_id = tokenizer.pad_token_id |
| |
|
| | def __call__(self, inputs): |
| | labels = inputs.clone() |
| | |
| | labels[:2] = -100 |
| | |
| | pad_mask = labels == self.pad_token_id |
| | labels[pad_mask] = -100 |
| | inputs[2:] = torch.cat([ |
| | torch.LongTensor([self.bos_token_id]), |
| | inputs[2:-1]]) |
| | inputs[pad_mask] = self.pad_token_id |
| | assert len(inputs) == len(labels) |
| | return inputs, labels |
| |
|
| |
|
| | class TextMaskingProcessor(Processor): |
| | def __init__(self, config): |
| | """this function is borrowed from |
| | `transformers/data/data_collator.DataCollatorForLanguageModeling`""" |
| | self.mlm_probability = 0.15 |
| | if config.mlm_probability is not None: |
| | self.mlm_probability = config.mlm_probability |
| | self.bert_name = config.bert_name |
| | |
| | |
| | from transformers import AutoTokenizer |
| | self.tokenizer = AutoTokenizer.from_pretrained( |
| | self.bert_name, bos_token="[CLS]", eos_token="[SEP]") |
| | self.textgen = TextGenerationProcessor(self.tokenizer) |
| |
|
| | def __call__( |
| | self, inputs: torch.Tensor, |
| | modality_masking=None, |
| | special_tokens_mask: Optional[torch.Tensor] = None |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | expand modality_masking into |
| | None: traditional bert masking. |
| | "no": no masking. |
| | "full": all [MASK] token for generation. |
| | "gen": autoregressive generation. |
| | """ |
| | """ |
| | Prepare masked tokens inputs/labels for masked language modeling: |
| | 80% MASK, 10% random, 10% original. |
| | """ |
| | labels = inputs.clone() |
| | |
| | |
| | if modality_masking is not None: |
| | if modality_masking == "full": |
| | probability_matrix = torch.full(labels.shape, 1.) |
| | elif modality_masking == "no": |
| | probability_matrix = torch.full(labels.shape, 0.) |
| | elif modality_masking.startswith("textgen"): |
| | |
| | inputs, labels = self.textgen(inputs) |
| | if "mask" not in modality_masking: |
| | return inputs, labels |
| | inputs = self.mask_input(inputs, special_tokens_mask) |
| | return inputs, labels |
| | elif modality_masking == "mask": |
| | inputs = self.mask_input(inputs, special_tokens_mask) |
| | labels = torch.full(inputs.shape, -100) |
| | return inputs, labels |
| | elif modality_masking == "inverse": |
| | probability_matrix = torch.full(labels.shape, 1. - self.mlm_probability) |
| | else: |
| | raise ValueError("unknown modality masking.", modality_masking) |
| | else: |
| | probability_matrix = torch.full(labels.shape, self.mlm_probability) |
| |
|
| | if special_tokens_mask is None: |
| | special_tokens_mask = self.get_special_tokens_mask( |
| | labels.tolist(), already_has_special_tokens=True |
| | ) |
| | special_tokens_mask = torch.tensor( |
| | special_tokens_mask, dtype=torch.bool) |
| | else: |
| | special_tokens_mask = special_tokens_mask.bool() |
| |
|
| | probability_matrix.masked_fill_(special_tokens_mask, value=0.0) |
| | masked_indices = torch.bernoulli(probability_matrix).bool() |
| | labels[~masked_indices] = -100 |
| |
|
| | |
| | |
| | indices_replaced = ( |
| | torch.bernoulli( |
| | torch.full(labels.shape, 0.8)).bool() & masked_indices |
| | ) |
| | inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids( |
| | self.tokenizer.mask_token |
| | ) |
| |
|
| | |
| | indices_random = ( |
| | torch.bernoulli(torch.full(labels.shape, 0.5)).bool() |
| | & masked_indices |
| | & ~indices_replaced |
| | ) |
| | random_words = torch.randint( |
| | len(self.tokenizer), labels.shape, dtype=torch.long |
| | ) |
| | inputs[indices_random] = random_words[indices_random] |
| |
|
| | |
| | |
| | return inputs, labels |
| |
|
| | def mask_input(self, inputs, special_tokens_mask=None): |
| | |
| | probability_matrix = torch.full( |
| | inputs.shape, self.mlm_probability) |
| | if special_tokens_mask is None: |
| | special_tokens_mask = self.get_special_tokens_mask( |
| | inputs.tolist(), already_has_special_tokens=True |
| | ) |
| | special_tokens_mask = torch.tensor( |
| | special_tokens_mask, dtype=torch.bool) |
| | else: |
| | special_tokens_mask = special_tokens_mask.bool() |
| | probability_matrix.masked_fill_(special_tokens_mask, value=0.0) |
| | masked_indices = torch.bernoulli(probability_matrix).bool() |
| | indices_replaced = ( |
| | torch.bernoulli( |
| | torch.full(inputs.shape, 0.8)).bool() & masked_indices |
| | ) |
| | inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids( |
| | self.tokenizer.mask_token |
| | ) |
| |
|
| | |
| | indices_random = ( |
| | torch.bernoulli(torch.full(inputs.shape, 0.5)).bool() |
| | & masked_indices |
| | & ~indices_replaced |
| | ) |
| | random_words = torch.randint( |
| | len(self.tokenizer), inputs.shape, dtype=torch.long |
| | ) |
| | inputs[indices_random] = random_words[indices_random] |
| | return inputs |
| |
|
| | def get_special_tokens_mask( |
| | self, token_ids_0: List[int], |
| | token_ids_1: Optional[List[int]] = None, |
| | already_has_special_tokens: bool = False |
| | ) -> List[int]: |
| | """ |
| | Note: the version from transformers do not consider pad |
| | as special tokens. |
| | """ |
| |
|
| | if already_has_special_tokens: |
| | if token_ids_1 is not None: |
| | raise ValueError( |
| | "You should not supply a second sequence if" |
| | "the provided sequence of " |
| | "ids is already formated with special tokens " |
| | "for the model." |
| | ) |
| | return list(map(lambda x: 1 if x in [ |
| | self.tokenizer.sep_token_id, |
| | self.tokenizer.cls_token_id, |
| | self.tokenizer.pad_token_id] else 0, token_ids_0)) |
| |
|
| | if token_ids_1 is not None: |
| | return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] |
| | return [1] + ([0] * len(token_ids_0)) + [1] |
| |
|
| |
|
| | class TextClipSamplingProcessor(Processor): |
| | def __init__(self, max_text_len, keep_prob=1.0): |
| | self.max_text_len = max_text_len |
| | self.max_video_len = 256 |
| | self.keep_prob = keep_prob |
| |
|
| | def __call__( |
| | self, |
| | text_feature, |
| | centerclip_idx=None, |
| | sampled_max_text_len=None, |
| | sampled_max_video_len=None, |
| | ): |
| | |
| | if sampled_max_text_len is not None: |
| | max_text_len = sampled_max_text_len |
| | else: |
| | max_text_len = self.max_text_len |
| | if sampled_max_video_len is not None: |
| | max_video_len = sampled_max_video_len |
| | else: |
| | max_video_len = self.max_video_len |
| |
|
| | t_num_clips = len(text_feature["start"]) |
| |
|
| | if centerclip_idx is None: |
| | centerclip_idx = random.randint(0, t_num_clips - 1) |
| |
|
| | start_idx, end_idx = centerclip_idx, centerclip_idx + 1 |
| | text_clip_indexs = deque() |
| | text_clip_indexs.append(start_idx) |
| | text_len = len(text_feature["cap"][start_idx]) |
| |
|
| | video_len = max( |
| | 0, |
| | text_feature["end"][start_idx] |
| | - text_feature["start"][start_idx], |
| | ) |
| |
|
| | while ( |
| | (start_idx > 0 or end_idx < t_num_clips) |
| | and text_len < max_text_len |
| | and video_len < max_video_len |
| | ): |
| | if random.random() > 0.5 and end_idx < t_num_clips: |
| | |
| | if random.random() > self.keep_prob and (end_idx + 1) < t_num_clips: |
| | end_idx = end_idx + 1 |
| | text_clip_indexs.append(end_idx) |
| | text_len += len(text_feature["cap"][end_idx]) |
| | end_idx += 1 |
| | elif start_idx > 0: |
| | if random.random() > self.keep_prob and (start_idx - 1) > 0: |
| | start_idx = start_idx - 1 |
| | start_idx -= 1 |
| | text_clip_indexs.insert(0, start_idx) |
| | text_len += len(text_feature["cap"][start_idx]) |
| | else: |
| | if end_idx < t_num_clips: |
| | if random.random() > self.keep_prob and (end_idx + 1) < t_num_clips: |
| | end_idx = end_idx + 1 |
| | text_clip_indexs.append(end_idx) |
| | text_len += len(text_feature["cap"][end_idx]) |
| | end_idx += 1 |
| | else: |
| | return text_clip_indexs |
| | video_len = max( |
| | 0, |
| | text_feature["end"][text_clip_indexs[-1]] |
| | - text_feature["start"][text_clip_indexs[0]], |
| | ) |
| | return text_clip_indexs |
| |
|
| |
|
| | class VideoClipSamplingProcessor(Processor): |
| | def __call__(self, video_len, max_video_len, center): |
| | """ |
| | `video_len`: length of the video. |
| | `max_video_len`: maximum video tokens allowd in a sequence. |
| | `center`: initial starting index. |
| | """ |
| | assert center >= 0 and center < video_len |
| | t_clip_len = 0 |
| | start, end = center, center |
| | while (start > 0 or end < video_len) and t_clip_len < max_video_len: |
| | |
| | if start <= 0: |
| | end += 1 |
| | elif end >= video_len: |
| | start -= 1 |
| | elif random.random() > 0.5: |
| | end += 1 |
| | else: |
| | start -= 1 |
| | t_clip_len += 1 |
| | return {"start": [start], "end": [end]} |
| |
|
| |
|
| | class How2MILNCEAligner(FixedLenAligner): |
| | """reference: `antoine77340/MIL-NCE_HowTo100M/video_loader.py`""" |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.num_candidates = 4 |
| | self.min_time = 5.0 |
| | self.num_sec = 3.2 |
| | |
| | |
| |
|
| | def sampling( |
| | self, |
| | video_id, |
| | video_feature, |
| | text_feature, |
| | centerclip_idx=None, |
| | sampled_max_text_len=None |
| | ): |
| | text, start, end = self._get_text(text_feature) |
| | video = self._get_video(video_feature, start, end) |
| |
|
| | vfeats = torch.zeros((self.max_video_len, video_feature.shape[1])) |
| | vmasks = torch.zeros((self.max_video_len,), dtype=torch.bool) |
| | vfeats[: video.shape[0]] = torch.from_numpy(np.array(video)) |
| | vmasks[: video.shape[0]] = 1 |
| |
|
| | caps, cmasks = [], [] |
| | for words in text: |
| | cap, cmask = self._build_text_seq(text_feature, words) |
| | caps.append(cap) |
| | cmasks.append(cmask) |
| | caps = torch.stack(caps) |
| | cmasks = torch.stack(cmasks) |
| | |
| | |
| |
|
| | return { |
| | "caps": caps, |
| | "cmasks": cmasks, |
| | "vfeats": vfeats, |
| | "vmasks": vmasks, |
| | |
| | } |
| |
|
| | def _get_video(self, video_feature, start, end): |
| | start_seek = random.randint(start, int(max(start, end - self.num_sec))) |
| | |
| | return video_feature[start_seek : int(start_seek + self.num_sec)] |
| |
|
| | def _get_text(self, cap): |
| | ind = random.randint(0, len(cap["start"]) - 1) |
| | if self.num_candidates == 1: |
| | words = [ind] |
| | else: |
| | words = [] |
| | cap_start = self._find_nearest_candidates(cap, ind) |
| | for i in range(self.num_candidates): |
| | words.append([max(0, min(len(cap["cap"]) - 1, cap_start + i))]) |
| |
|
| | start, end = cap["start"][ind], cap["end"][ind] |
| | |
| | |
| | if end - start < self.min_time: |
| | diff = self.min_time - end + start |
| | start = max(0, start - diff / 2) |
| | end = start + self.min_time |
| | return words, int(start), int(end) |
| |
|
| | def _find_nearest_candidates(self, caption, ind): |
| | """find the range of the clips.""" |
| | start, end = ind, ind |
| | |
| | n_candidate = 1 |
| | while n_candidate < self.num_candidates: |
| | |
| | if start == 0: |
| | return 0 |
| | |
| | elif end == (len(caption["start"]) - 1): |
| | return start - (self.num_candidates - n_candidate) |
| | elif (caption["end"][end] - caption["start"][start - 1]) < ( |
| | caption["end"][end + 1] - caption["start"][start] |
| | ): |
| | start -= 1 |
| | else: |
| | end += 1 |
| | n_candidate += 1 |
| | return start |
| |
|
| |
|
| | class PKLJSONStrTextProcessor(TextProcessor): |
| | """`caption.json` from howto100m are preprocessed as a |
| | dict `[video_id, json_str]`. |
| | Json parsing tokenization are conducted on-the-fly and cached into dict. |
| | """ |
| |
|
| | def __init__(self, config, max_clip_text_len=96): |
| | print("[Warning] PKLJSONStrTextProcessor is slow for num_workers > 0.") |
| | self.caption_pkl_path = str(config.caption_pkl_path) |
| | with open(self.caption_pkl_path, "rb") as fd: |
| | self.data = pickle.load(fd) |
| | self.max_clip_text_len = max_clip_text_len |
| | from transformers import AutoTokenizer |
| | self.tokenizer = AutoTokenizer.from_pretrained( |
| | str(config.bert_name), use_fast=config.use_fast |
| | ) |
| |
|
| | def __call__(self, video_id): |
| | caption = self.data[video_id] |
| | if isinstance(caption, str): |
| | import json |
| | caption = json.loads(caption) |
| | cap = [] |
| | for clip_idx, text_clip in enumerate(caption["text"]): |
| | clip_ids = [] |
| | if isinstance(text_clip, str): |
| | clip_ids = self.tokenizer( |
| | text_clip[: self.max_clip_text_len], |
| | add_special_tokens=False |
| | )["input_ids"] |
| | cap.append(clip_ids) |
| | caption["cap"] = cap |
| | caption.pop("text") |
| | self.data[video_id] = caption |
| | return caption |
| |
|