| import os |
| from dataclasses import dataclass, field |
| from typing import Dict, List |
| import torch |
| from transformers.tokenization_utils_base import PreTrainedTokenizerBase |
| from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union |
| from collections.abc import Mapping |
| from transformers.data.data_collator import pad_without_fast_tokenizer_warning, _torch_collate_batch |
| import numpy as np |
| from PIL import Image |
| from sklearn.cluster import KMeans |
|
|
| class DataCollatorMixin: |
| def __call__(self, features, return_tensors=None): |
| if return_tensors is None: |
| return_tensors = self.return_tensors |
| if return_tensors == "tf": |
| return self.tf_call(features) |
| elif return_tensors == "pt": |
| return self.torch_call(features) |
| elif return_tensors == "np": |
| return self.numpy_call(features) |
| else: |
| raise ValueError(f"Framework '{return_tensors}' not recognized!") |
| @dataclass |
| class MyDataCollatorForQFormerPatchPretrain(DataCollatorMixin): |
| image_processor: Any |
| tokenizer: PreTrainedTokenizerBase |
| mlm: bool = False |
| mlm_probability: float = 0.15 |
| pad_to_multiple_of: Optional[int] = None |
| tf_experimental_compile: bool = False |
| return_tensors: str = "pt" |
|
|
| def __post_init__(self): |
| if self.mlm: |
| raise ValueError( |
| "This tokenizer does not have a mask token which is necessary for masked language modeling. " |
| "You should pass `mlm=False` to train on causal language modeling instead." |
| ) |
|
|
| def _resize_image(self, image, min_size=448, max_size=1024): |
| """ |
| Resize the image such that the shortest side is min_size while maintaining aspect ratio. |
| """ |
| width, height = image.size |
| if width < min_size or height < min_size: |
| if width < height: |
| new_width = min_size |
| new_height = int(height * (min_size / width)) |
| else: |
| new_height = min_size |
| new_width = int(width * (min_size / height)) |
| return image.resize((new_width, new_height), Image.Resampling.LANCZOS) |
| |
| elif width > max_size or height > max_size: |
| if width > height: |
| new_width = max_size |
| new_height = int(height * (max_size / width)) |
| else: |
| new_height = max_size |
| new_width = int(width * (max_size / height)) |
| return image.resize((new_width, new_height), Image.Resampling.LANCZOS) |
|
|
| return image |
|
|
| def _crop_image(self, image, crop_size=448, overlap=0.5): |
| """ |
| Crop the image into patches of crop_size with a specified overlap. |
| """ |
| width, height = image.size |
| step = int(crop_size * (1 - overlap)) |
| patches = [] |
| for top in range(0, height - crop_size + 1, step): |
| for left in range(0, width - crop_size + 1, step): |
| box = (left, top, left + crop_size, top + crop_size) |
| patch = image.crop(box) |
| patches.append(patch) |
| |
| |
| if width % crop_size != 0: |
| for top in range(0, height - crop_size + 1, step): |
| box = (width - crop_size, top, width, top + crop_size) |
| patch = image.crop(box) |
| patches.append(patch) |
| if height % crop_size != 0: |
| for left in range(0, width - crop_size + 1, step): |
| box = (left, height - crop_size, left + crop_size, height) |
| patch = image.crop(box) |
| patches.append(patch) |
| if width % crop_size != 0 and height % crop_size != 0: |
| box = (width - crop_size, height - crop_size, width, height) |
| patch = image.crop(box) |
| patches.append(patch) |
| |
| return patches |
|
|
| def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: |
| |
| patch_list = [] |
| num_list = [] |
| text_list = [] |
|
|
| for d in examples: |
| |
| image = self._resize_image(d["image"]) |
| patches = self._crop_image(image) |
| patches = [self.image_processor(patch) for patch in patches] |
| patch_list += patches |
| num_list.append(len(patches)) |
| del d["image"] |
|
|
| for d in examples: |
| text_list.append(d["text"]) |
| del d["text"] |
|
|
| batch = {"text": text_list} |
| batch["image"] = torch.stack(patch_list) |
| batch["patch_num"] = num_list |
|
|
| return batch |
|
|
| @dataclass |
| class MyDataCollatorForQFormerPatchInstruct(MyDataCollatorForQFormerPatchPretrain): |
| tokenizer: PreTrainedTokenizerBase |
| image_processor: Any |
| mlm: bool = False |
| mlm_probability: float = 0.15 |
| pad_to_multiple_of: Optional[int] = None |
| tf_experimental_compile: bool = False |
| return_tensors: str = "pt" |
| test: bool = False |
| |
| def pad_token_id_list(self, input_id_list, padding_value=0): |
| """ |
| Pad the list of token ID lists to the maximum length of lists in the input. |
| |
| Args: |
| input_id_list (List[List[int]]): List of token ID lists, each list represents a sequence of token IDs. |
| padding_value (int, optional): The value used for padding shorter lists. Defaults to 0. |
| |
| Returns: |
| List[List[int]]: A new list where all inner lists are padded to the maximum length found in the original list. |
| """ |
| |
| max_length = max(len(inner_list) for inner_list in input_id_list) |
| |
| |
| padded_list = [inner_list + [padding_value] * (max_length - len(inner_list)) for inner_list in input_id_list] |
| |
| return padded_list |
| |
| def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: |
| |
| patch_list = [] |
| num_list = [] |
| input_id_list = [] |
| text_list = [] |
| ans_list = [] |
| attention_mask_list = [] |
| text_input_list = [] |
|
|
| for d in examples: |
| image = self._resize_image(d["image"]) |
| patches = self._crop_image(image) |
| patches = [self.image_processor(patch) for patch in patches] |
| patch_list += patches |
| num_list.append(len(patches)) |
| del d["image"] |
|
|
| for d in examples: |
| input_id_list.append(d["input_ids"]) |
| attention_mask_list.append(d["attention_mask"]) |
| text_input_list.append(d["text_input"]) |
| if self.test: |
| text_list.append(d["text"]) |
| ans_list.append(d["answer"]) |
| del d["answer"] |
| del d["text"] |
| del d["text_input"] |
| |
| input_id_list = self.pad_token_id_list(input_id_list, self.tokenizer.pad_token_id) |
| attention_mask_list = self.pad_token_id_list(attention_mask_list, 0) |
|
|
| batch = {"input_ids": torch.tensor(input_id_list)} |
| batch["attention_mask"] = torch.tensor(attention_mask_list) |
| |
| labels = batch["input_ids"].clone() |
| if self.tokenizer.pad_token_id is not None: |
| labels[labels == self.tokenizer.pad_token_id] = -100 |
|
|
| |
| if self.test: |
| batch["text"] = text_list |
| batch["answers"] = ans_list |
| |
| batch["text_input"] = text_input_list |
| batch["labels"] = labels |
| batch["image"] = torch.stack(patch_list) |
| batch["patch_num"] = num_list |
| return batch |
|
|
| @dataclass |
| class MyDataCollatorForPPathVLM(MyDataCollatorForQFormerPatchInstruct): |
| tokenizer: PreTrainedTokenizerBase |
| image_processor: Any |
| mlm: bool = False |
| mlm_probability: float = 0.15 |
| pad_to_multiple_of: Optional[int] = None |
| tf_experimental_compile: bool = False |
| return_tensors: str = "pt" |
| test: bool = False |
|
|
|
|
| def __post_init__(self): |
| if self.mlm and self.tokenizer.mask_token is None: |
| raise ValueError( |
| "This tokenizer does not have a mask token which is necessary for masked language modeling. " |
| "You should pass `mlm=False` to train on causal language modeling instead." |
| ) |
| self.question_token_id = self.tokenizer.convert_tokens_to_ids('<|Question|>') |
| self.answer_token_id = self.tokenizer.convert_tokens_to_ids('<|Answer|>') |
| self.pad_token_id = self.tokenizer.pad_token_id |
|
|
| def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: |
| |
| patch_list = [] |
| num_list = [] |
| input_id_list = [] |
| attention_mask_list = [] |
|
|
| if self.test: |
| ans_list = [] |
|
|
| for d in examples: |
| image = self._resize_image(d["image"]) |
| patches = self._crop_image(image) |
| patches = [self.image_processor(patch) for patch in patches] |
| patch_list += patches |
| num_list.append(len(patches)) |
| del d["image"] |
|
|
| for d in examples: |
| if self.test: |
| ans_list.append(d["answer"]) |
| del d["answer"] |
| input_id_list.append(d["input_ids"]) |
| attention_mask_list.append(d["attention_mask"]) |
| |
| input_id_list = self.pad_token_id_list(input_id_list, self.tokenizer.pad_token_id) |
| attention_mask_list = self.pad_token_id_list(attention_mask_list, 0) |
|
|
| batch = {"input_ids": torch.tensor(input_id_list)} |
| batch["attention_mask"] = torch.tensor(attention_mask_list) |
| |
| labels = batch["input_ids"].clone() |
| |
| labels[labels == 128000] = -100 |
| labels[labels == self.pad_token_id] = -100 |
|
|
| for row in labels: |
| |
| |
| |
| |
|
|
| |
| start_idx = (row == self.question_token_id).nonzero(as_tuple=True)[0] |
| end_idx = (row == self.answer_token_id).nonzero(as_tuple=True)[0] |
|
|
| |
| if len(start_idx) > 0 and len(end_idx) > 0: |
| start_idx = start_idx[0].item() |
| end_idx = end_idx[0].item() |
|
|
| if start_idx <= end_idx: |
| row[start_idx : end_idx + 1] = -100 |
|
|
| if self.test: |
| batch["answers"] = ans_list |
| batch["labels"] = labels |
| batch["image"] = torch.stack(patch_list) |
| batch["patch_num"] = num_list |
| return batch |
|
|
| @dataclass |
| class MyDataCollatorForWPathVLM(DataCollatorMixin): |
| tokenizer: PreTrainedTokenizerBase |
| fea_root: str = None |
| agg_strategy: str = 'abmil' |
| n_heads: List[int] = field(default_factory=lambda: [32, 16, 8]) |
| fea_name_list: List[str] = field(default_factory=lambda: ['f1024', 'f2048', 'f4096']) |
| fea_dim: int = 512 |
| n_level: int = 3 |
| mlm: bool = False |
| mlm_probability: float = 0.15 |
| pad_to_multiple_of: Optional[int] = None |
| tf_experimental_compile: bool = False |
| return_tensors: str = "pt" |
| test: bool = False |
|
|
| def __post_init__(self): |
| if self.mlm and self.tokenizer.mask_token is None: |
| raise ValueError( |
| "This tokenizer does not have a mask token which is necessary for masked language modeling. " |
| "You should pass `mlm=False` to train on causal language modeling instead." |
| ) |
| self.question_token_id = self.tokenizer.convert_tokens_to_ids('<|Question|>') |
| self.answer_token_id = self.tokenizer.convert_tokens_to_ids('<|Answer|>') |
| self.pad_token_id = self.tokenizer.pad_token_id |
|
|
| def __get_nic__(self, features, coords, size): |
| |
| w = coords[:,0] |
| h = coords[:,1] |
| w_min = w.min() |
| w_max = w.max() |
| h_min = h.min() |
| h_max = h.max() |
| image_shape = [(w_max-w_min)//size+1,(h_max-h_min)//size+1] |
| mask = np.ones((image_shape[0], image_shape[1])) |
| features_nic = np.ones((features.shape[-1], image_shape[0], image_shape[1])) * np.nan |
| coords_nic = -np.ones((image_shape[0], image_shape[1], 2)) |
| |
| for patch_feature, x, y in zip(features, w, h): |
| coord = [x,y] |
| x_nic, y_nic = (x-w_min)//size, (y-h_min)//size |
| features_nic[:, x_nic, y_nic] = patch_feature |
| coords_nic[x_nic, y_nic] = coord |
| |
| mask[np.isnan(features_nic)[0]] = 0 |
| features_nic[np.isnan(features_nic)] = 0 |
| return features_nic, mask |
| |
| def __feature_trans__(self, examples: List[Union[List[int], Any, Dict[str, Any]]], key: str, cor: str): |
| |
| fea_list = [] |
| cor_list = [] |
| patch_masks = [] |
| max_dim = 0 |
| |
| for d in examples: |
| current_dim = d[key].shape[0] |
| if current_dim > max_dim: |
| max_dim = current_dim |
|
|
| for d in examples: |
| original_data = d[key] |
| original_cor = d[cor] |
| current_dim = d[key].shape[0] |
|
|
| padded_data = np.zeros([max_dim, self.fea_dim]) |
| cor_data = np.zeros([max_dim, 2], dtype=int) |
| patch_mask = np.zeros(max_dim, dtype=int) |
|
|
| padded_data[:current_dim, :] = original_data |
| patch_mask[:int(current_dim)] = 1 |
| cor_data[:int(current_dim), :] = original_cor |
| |
|
|
| fea_list.append(torch.from_numpy(padded_data)) |
| cor_list.append(torch.from_numpy(cor_data)) |
| patch_masks.append(patch_mask) |
| |
| return fea_list, cor_list, patch_masks |
|
|
| def __load_full_feature__(self, fea_path_ori: str): |
|
|
| fea_path = '/'.join(fea_path_ori.split('/')[-2:]) |
| fea_path = os.path.join(self.fea_root, fea_path) |
| fea = np.load(fea_path, allow_pickle=True) |
| f = fea[()]['feature'] |
| cor = fea[()]['index'] |
| cor = np.array([filename.split('_')[:2] for filename in cor], dtype=int) |
|
|
| return f, cor |
|
|
| def __sample_feature__(self, f, cor, n_head: int): |
| |
| num_samples = cor.shape[0] |
|
|
| |
| if n_head >= num_samples: |
| return f, cor |
|
|
| |
| kmeans = KMeans(n_clusters=n_head, random_state=42) |
| labels = kmeans.fit_predict(cor) |
|
|
| sampled_indices = [] |
|
|
| |
| for i in range(n_head): |
| group_indices = np.where(labels == i)[0] |
| if len(group_indices) > 0: |
| sampled_index = np.random.choice(group_indices, 1)[0] |
| sampled_indices.append(sampled_index) |
|
|
| |
| f_sampled = f[sampled_indices] |
| cor_sampled = cor[sampled_indices] |
|
|
| return f_sampled, cor_sampled |
|
|
| def __load_clusters_feature__(self, fea_path_ori: str): |
|
|
| fea_path = '/'.join(fea_path_ori.split('/')[-2:]) |
| fea_path = fea_path.split('_')[0] |
| fea_path = fea_path.split('.')[0] + '.npy' |
| fea_path = os.path.join(self.fea_root, fea_path) |
| fea = np.load(fea_path, allow_pickle=True) |
| f1 = fea[()]['f1024'] |
| f2 = fea[()]['f2048'] |
| f3 = fea[()]['f4096'] |
| cor1 = np.zeros((f1.shape[0], 2), dtype=int) |
| cor2 = np.zeros((f2.shape[0], 2), dtype=int) |
| cor3 = np.zeros((f3.shape[0], 2), dtype=int) |
|
|
| return [f1, f2, f3], [cor1, cor2, cor3] |
|
|
| def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: |
|
|
| fea_list = [] |
| cor_list = [] |
| patch_mask_list = [] |
| exa_list = [] |
| instructs = [] |
|
|
| if self.test: |
| ans_list = [] |
| qes_list = [] |
| slide_id_list = [] |
| for d in examples: |
| qes_list.append(d["question"]) |
| ans_list.append(d["answer"]) |
| slide_id_list.append(d["slide_id"]) |
| del d["question"],d["answer"],d["slide_id"] |
| |
| |
| for d in examples: |
| exa = {} |
|
|
| for i in range(len(self.fea_name_list)): |
| fea_name = self.fea_name_list[i] |
| fea_path_ori = d[fea_name] |
| del d[fea_name] |
| if self.agg_strategy in ['abmil','longnet','qformer']: |
| f, cor = self.__load_full_feature__(fea_path_ori) |
| elif self.agg_strategy == 'sample': |
| f, cor = self.__load_full_feature__(fea_path_ori) |
| f, cor = self.__sample_feature__(f, cor, self.n_heads[i]) |
| else: |
| continue |
| exa['f{}'.format(i)] = f |
| exa['cor{}'.format(i)] = cor |
| |
| if self.agg_strategy in ['kmeans','gmm']: |
| f, cor = self.__load_clusters_feature__(fea_path_ori) |
| for i in range(len(f)): |
| exa['f{}'.format(i)] = f[i] |
| exa['cor{}'.format(i)] = cor[i] |
|
|
| exa_list.append(exa) |
|
|
| |
| for level in range(self.n_level): |
| fea, cor, patch_mask = self.__feature_trans__(exa_list, "f{}".format(level), "cor{}".format(level)) |
| fea_list.append(fea) |
| cor_list.append(cor) |
| patch_mask_list.append(patch_mask) |
|
|
| |
| if "input_ids_instruct" in examples[0].keys(): |
| for d in examples: |
| instruct = {} |
| instruct["input_ids"] = d["input_ids_instruct"] |
| instruct["attention_mask"] = d["attention_mask_instruct"] |
| instructs.append(instruct) |
| del d["input_ids_instruct"],d["attention_mask_instruct"] |
| |
| instruct_batch = pad_without_fast_tokenizer_warning( |
| self.tokenizer, instructs, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of |
| ) |
|
|
| if isinstance(examples[0], Mapping): |
| batch = pad_without_fast_tokenizer_warning( |
| self.tokenizer, examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of |
| ) |
| else: |
| batch = { |
| "input_ids": _torch_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of) |
| } |
|
|
| |
| labels = batch["input_ids"].clone() |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| labels[labels == 128000] = -100 |
| labels[labels == self.pad_token_id] = -100 |
|
|
| for row in labels: |
| |
| |
| |
| |
|
|
| |
| start_idx = (row == self.question_token_id).nonzero(as_tuple=True)[0] |
| end_idx = (row == self.answer_token_id).nonzero(as_tuple=True)[0] |
|
|
| |
| if len(start_idx) > 0 and len(end_idx) > 0: |
| start_idx = start_idx[0].item() |
| end_idx = end_idx[0].item() |
|
|
| if start_idx <= end_idx: |
| row[start_idx : end_idx + 1] = -100 |
|
|
| |
| |
| |
| batch["labels"] = labels |
| if instructs: |
| batch["input_ids_instruct"] = instruct_batch["input_ids"][:,1:] |
| batch["attention_mask_instruct"] = instruct_batch["attention_mask"][:,1:] |
|
|
| if self.test: |
| batch["answers"] = ans_list |
| batch["questions"] = qes_list |
| batch["slide_ids"] = slide_id_list |
|
|
| for level in range(self.n_level): |
| batch["fea{}".format(level)] = torch.stack(fea_list[level]) |
| batch["mask{}".format(level)] = torch.from_numpy(np.array(patch_mask_list[level], dtype=int)) |
| batch["cor{}".format(level)] = torch.stack(cor_list[level]) |
|
|
| return batch |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
|
|