| | |
| |
|
| | import torch |
| | import os |
| | import numpy as np |
| | import pickle |
| |
|
| | from . import retri |
| | from ..utils import get_local_rank |
| |
|
| |
|
| | class VectorPool(object): |
| | """ |
| | Base class of retrieval space. |
| | """ |
| |
|
| | def __init__(self, config): |
| | from transformers import AutoConfig |
| | self.hidden_size = AutoConfig.from_pretrained( |
| | config.dataset.bert_name).hidden_size |
| | self.retriever_cls = getattr(retri, config.retriever_cls) |
| |
|
| | def __call__(self, sample, **kwargs): |
| | raise NotImplementedError |
| |
|
| | def build_retriver( |
| | self, |
| | retriever_cls=None, |
| | hidden_size=None, |
| | centroids=512, |
| | db_type="flatl2", |
| | examples_per_cent_to_train=48 |
| | ): |
| |
|
| | """merge results from multiple gpus and return a retriver..""" |
| | self.retriver = retriever_cls( |
| | hidden_size, centroids, db_type, examples_per_cent_to_train) |
| | return self.retriver |
| |
|
| | def __repr__(self): |
| | if hasattr(self, "retriver"): |
| | retriver_name = str(len(self.retriver)) |
| | else: |
| | retriver_name = "no retriver field yet" |
| | return self.__class__.__name__ \ |
| | + "(" + retriver_name + ")" |
| |
|
| |
|
| | class VideoVectorPool(VectorPool): |
| | """ |
| | average clips of a video as video representation. |
| | """ |
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.build_retriver(self.retriever_cls, self.hidden_size) |
| |
|
| | def __call__(self, sample, subsampling, **kwargs): |
| | hidden_states = ( |
| | sample["pooled_video"] + sample["pooled_text"]) / 2. |
| | hidden_states = hidden_states.view( |
| | -1, subsampling, |
| | hidden_states.size(-1)) |
| | hidden_states = torch.mean(hidden_states, dim=1) |
| | hidden_states = hidden_states.cpu().detach().numpy() |
| | video_ids = [] |
| | for offset_idx, video_id in enumerate(sample["video_id"]): |
| | if isinstance(video_id, tuple) and len(video_id) == 3: |
| | |
| | video_id = video_id[0] |
| | video_ids.append(video_id) |
| | assert len(video_ids) == len(hidden_states) |
| | self.retriver.add( |
| | hidden_states.astype("float32"), |
| | video_ids |
| | ) |
| |
|
| |
|
| | class DistributedVectorPool(VectorPool): |
| | """ |
| | support sync of multiple gpus/nodes. |
| | """ |
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.out_dir = os.path.join( |
| | config.fairseq.checkpoint.save_dir, |
| | "retri") |
| | os.makedirs(self.out_dir, exist_ok=True) |
| | self.hidden_states = [] |
| | self.video_ids = [] |
| |
|
| | def build_retriver( |
| | self, |
| | retriever_cls=None, |
| | hidden_size=None, |
| | centroids=4096, |
| | db_type="flatl2", |
| | examples_per_cent_to_train=48 |
| | ): |
| | if retriever_cls is None: |
| | retriever_cls = self.retriever_cls |
| | if hidden_size is None: |
| | hidden_size = self.hidden_size |
| | """merge results from multiple gpus and return a retriver..""" |
| | if torch.distributed.is_initialized(): |
| | self.save() |
| | |
| | torch.distributed.barrier() |
| | world_size = torch.distributed.get_world_size() |
| | else: |
| | world_size = 1 |
| | self.retriver = retriever_cls( |
| | hidden_size, centroids, db_type, examples_per_cent_to_train) |
| | |
| | for local_rank in range(world_size): |
| | if get_local_rank() == 0: |
| | print("load local_rank", local_rank) |
| | hidden_states, video_ids = self.load(local_rank) |
| | hidden_states = hidden_states.astype("float32") |
| | self.retriver.add(hidden_states, video_ids) |
| | return self.retriver |
| |
|
| | def load(self, local_rank): |
| | hidden_states = np.load( |
| | os.path.join( |
| | self.out_dir, |
| | "hidden_state" + str(local_rank) + ".npy" |
| | ) |
| | ) |
| |
|
| | with open( |
| | os.path.join( |
| | self.out_dir, "video_id" + str(local_rank) + ".pkl"), |
| | "rb") as fr: |
| | video_ids = pickle.load(fr) |
| | return hidden_states, video_ids |
| |
|
| | def save(self): |
| | hidden_states = np.vstack(self.hidden_states) |
| | assert len(hidden_states) == len(self.video_ids), "{}, {}".format( |
| | len(hidden_states), |
| | len(self.video_ids) |
| | ) |
| | local_rank = torch.distributed.get_rank() \ |
| | if torch.distributed.is_initialized() else 0 |
| |
|
| | np.save( |
| | os.path.join( |
| | self.out_dir, |
| | "hidden_state" + str(local_rank) + ".npy"), |
| | hidden_states) |
| |
|
| | with open( |
| | os.path.join( |
| | self.out_dir, |
| | "video_id" + str(local_rank) + ".pkl"), |
| | "wb") as fw: |
| | pickle.dump( |
| | self.video_ids, |
| | fw, |
| | protocol=pickle.HIGHEST_PROTOCOL |
| | ) |
| |
|
| |
|
| | class DistributedVideoVectorPool(DistributedVectorPool): |
| | """ |
| | average clips of a video as video representation. |
| | """ |
| | def __call__(self, sample, subsampling, **kwargs): |
| | hidden_states = ( |
| | sample["pooled_video"] + sample["pooled_text"]) / 2. |
| | hidden_states = hidden_states.view( |
| | -1, subsampling, |
| | hidden_states.size(-1)) |
| | hidden_states = torch.mean(hidden_states, dim=1) |
| | hidden_states = hidden_states.cpu().detach().numpy() |
| | video_ids = [] |
| | for offset_idx, video_id in enumerate(sample["video_id"]): |
| | if isinstance(video_id, tuple) and len(video_id) == 3: |
| | |
| | video_id = video_id[0] |
| | video_ids.append(video_id) |
| | assert len(video_ids) == len(hidden_states) |
| | self.hidden_states.append(hidden_states) |
| | self.video_ids.extend(video_ids) |
| |
|
| |
|
| | |
| |
|
| | class TextClipVectorPool(VectorPool): |
| | def __init__(self, config): |
| | from transformers import AutoConfig |
| | hidden_size = AutoConfig.from_pretrained( |
| | config.dataset.bert_name).hidden_size |
| | retriever_cls = getattr(retri, config.retriever_cls) |
| | self.build_retriver(retriever_cls, hidden_size) |
| |
|
| | def __call__(self, sample, **kwargs): |
| | clip_meta = sample["clip_meta"].cpu() |
| | assert torch.all(torch.le(clip_meta[:, 4], clip_meta[:, 5])) |
| | text_meta = [tuple(item.tolist()) for item in clip_meta[:, 3:]] |
| |
|
| | if hasattr(self, "retriver"): |
| | |
| | self.retriver.add( |
| | sample["pooled_text"].cpu().numpy().astype("float32"), |
| | text_meta |
| | ) |
| | else: |
| | raise NotImplementedError |
| |
|
| |
|
| | class MMClipVectorPool(VectorPool): |
| | """ |
| | Multimodal Clip-level vector pool. |
| | """ |
| | def __init__(self, out_dir): |
| | """use hidden_states to store `(video, text)`.""" |
| | """use video_ids to store `(video_id, start, end)`.""" |
| | super().__init__(out_dir) |
| |
|
| | def __call__(self, sample, **kwargs): |
| | pooled_video = sample["pooled_video"].cpu().unsqueeze(1).numpy() |
| | pooled_text = sample["pooled_text"].cpu().unsqueeze(1).numpy() |
| |
|
| | self.hidden_states.append( |
| | np.concatenate([pooled_video, pooled_text], axis=1) |
| | ) |
| |
|
| | video_starts = sample["video_start"].cpu() |
| | video_ends = sample["video_end"].cpu() |
| | assert torch.all(torch.le(video_starts, video_ends)) |
| |
|
| | text_starts = sample["text_start"].cpu() |
| | text_ends = sample["text_end"].cpu() |
| | assert torch.all(torch.le(text_starts, text_ends)) |
| | subsample_size = sample["pooled_video"].size(0) // len(sample["video_id"]) |
| | video_ids = [video_id for video_id in sample["video_id"] |
| | for _ in range(subsample_size) |
| | ] |
| | for video_id, video_start, video_end, text_start, text_end in zip( |
| | video_ids, video_starts, video_ends, text_starts, text_ends): |
| | self.video_ids.append(( |
| | video_id, |
| | (int(video_start), int(video_end)), |
| | (int(text_start), int(text_end)) |
| | )) |
| |
|