Spaces:
Running
Running
| import random | |
| from typing import List, Tuple | |
| from itertools import islice | |
| import datasets | |
| from datasets import load_dataset, concatenate_datasets | |
| from torch.utils.data import Dataset | |
| from PIL import Image | |
| import os | |
| from torchvision.transforms import RandAugment | |
| def get_randaugment_transform(n=2, m=9): | |
| return RandAugment(num_ops=n, magnitude=m) | |
| def add_prompt_template(data): | |
| data["qry"] = f"<|image_1|>{data['qry']}" | |
| data["pos_text"] = f"<|image_1|>{data['pos_text']}" | |
| data["hard_neg_text"] = f"<|image_1|>{data['hard_neg_text']}" | |
| return data | |
| Phi_Image_token = "<|image_1|>" | |
| Llava_Image_token = "<image>" | |
| Qwen_Image_token = "<|image_pad|>" | |
| Internvl_Image_token = "<image>" | |
| class TrainDataset(Dataset): | |
| def __init__(self, data_args, model_args): | |
| self.data_args = data_args | |
| self.model_args = model_args | |
| self.transform = None | |
| if self.data_args.randaugment: | |
| self.transform = get_randaugment_transform() | |
| train_data = [] | |
| if data_args.subset_name is not None: | |
| print(f"Loading {len(data_args.subset_name)} datasets: {data_args.subset_name}") | |
| for subset in data_args.subset_name: | |
| dataset_name = os.path.join(self.data_args.dataset_name, subset) | |
| subset_data = load_dataset( | |
| dataset_name, | |
| split=f"{self.data_args.dataset_split}", | |
| ) | |
| train_data.append(subset_data) | |
| self.train_data = concatenate_datasets(train_data) | |
| self.train_data = self.train_data.shuffle(seed=42) | |
| else: | |
| train_data = load_dataset( | |
| self.data_args.dataset_name, | |
| split=f"{self.data_args.dataset_split}", | |
| ) | |
| if "hard_neg" in self.data_args.dataset_name: | |
| # self.train_data = train_data.map(add_prompt_template, num_proc=8) | |
| print(train_data) | |
| else: | |
| self.train_data = train_data | |
| if self.data_args.num_samples: | |
| # self.train_data = self.train_data[:self.data_args.num_samples] | |
| self.train_data = self.train_data.select(range(self.data_args.num_samples)) | |
| print(f"len of train_data: {len(self.train_data)}") | |
| def __len__(self): | |
| return len(self.train_data) | |
| def _process_image(self, image, resolution): | |
| if image is None: | |
| return None | |
| if resolution == "high": | |
| image = image.resize((1344, 1344)) | |
| elif resolution == "low": | |
| image = image.resize((336, 336)) | |
| elif resolution == "clip": | |
| image = image.resize((224, 224)) | |
| return image | |
| def _get_image(self, img_path): | |
| if img_path == "": | |
| return None | |
| if img_path.startswith('/'): | |
| full_img_path = img_path | |
| else: | |
| full_img_path = os.path.join(self.data_args.image_dir, img_path) | |
| image = Image.open(full_img_path) | |
| if self.model_args.model_backbone == "llava_next": | |
| # TODO: make it configurable | |
| return self._process_image(image, "high") | |
| elif self.model_args.model_backbone == "qwen": | |
| return self._process_image(image, "low") | |
| elif self.model_args.model_backbone == "internvl_2_5": | |
| # TODO: make it configurable | |
| return self._process_image(image, "high") | |
| else: | |
| return image | |
| def __getitem__(self, item) -> Tuple[str, List[str]]: | |
| data_item = self.train_data[item] | |
| qry_text, qry_image_path, pos_text, pos_image_path = ( | |
| data_item["qry"], data_item["qry_image_path"], | |
| data_item["pos_text"], data_item["pos_image_path"], | |
| ) | |
| qry_image = self._get_image(qry_image_path) | |
| if self.transform: | |
| qry_image = self.transform(qry_image) | |
| if self.model_args.model_backbone == "llava_next": | |
| # Update image token | |
| qry_text = qry_text.replace(Phi_Image_token, Llava_Image_token) | |
| pos_text = pos_text.replace(Phi_Image_token, Llava_Image_token) | |
| elif self.model_args.model_backbone == "qwen": | |
| qry_text = qry_text.replace(Phi_Image_token, Qwen_Image_token) | |
| pos_text = pos_text.replace(Phi_Image_token, Qwen_Image_token) | |
| elif self.model_args.model_backbone == "internvl_2_5": | |
| qry_text = qry_text.replace(Phi_Image_token, Internvl_Image_token) | |
| pos_text = pos_text.replace(Phi_Image_token, Internvl_Image_token) | |
| if "hard_neg" in self.data_args.dataset_name: | |
| hard_neg_text, hard_neg_image_path = ( | |
| data_item["hard_neg_text"], data_item["hard_neg_image_path"], | |
| ) | |
| if self.model_args.model_backbone == "llava_next": | |
| # Update image token | |
| hard_neg_text = hard_neg_text.replace(Phi_Image_token, Llava_Image_token) | |
| elif self.model_args.model_backbone == "internvl_2_5": | |
| hard_neg_text = hard_neg_text.replace(Phi_Image_token, Internvl_Image_token) | |
| return ( | |
| qry_text, qry_image, | |
| pos_text, self._get_image(pos_image_path), | |
| hard_neg_text, self._get_image(hard_neg_image_path) | |
| ) | |
| return ( | |
| qry_text, qry_image, | |
| pos_text, self._get_image(pos_image_path) | |
| ) | |
| class EvalDataset(Dataset): | |
| def __init__(self, data_args, model_args, subset, text_field, img_path_field): | |
| """ | |
| (text_field, image_field) -> ("qry_text", "qry_img_path") or ("tgt_text", "tgt_img_path") | |
| """ | |
| self.data_args = data_args | |
| self.model_args = model_args | |
| if data_args.subset_name is not None: | |
| self.eval_data = load_dataset( | |
| self.data_args.dataset_name, | |
| subset, | |
| split=self.data_args.dataset_split, | |
| ) | |
| else: | |
| self.eval_data = load_dataset( | |
| self.data_args.dataset_name, | |
| split=self.data_args.dataset_split, | |
| ) | |
| print(f"len of eval_data: {len(self.eval_data)}") | |
| self.paired_data = self.get_paired_data(text_field, img_path_field) | |
| self.paired_dataset = datasets.Dataset.from_dict({ | |
| "text": [pair["text"] for pair in self.paired_data], | |
| "img_path": [pair["img_path"] for pair in self.paired_data] | |
| }) | |
| def __len__(self): | |
| return len(self.paired_dataset) | |
| def __getitem__(self, item): | |
| text, img_path = self.paired_dataset[item]["text"], self.paired_dataset[item]["img_path"] | |
| if self.model_args.model_backbone == "llava_next": | |
| # Update llava image token | |
| text = text.replace(Phi_Image_token, Llava_Image_token) | |
| elif self.model_args.model_backbone == "qwen": | |
| text = text.replace(Phi_Image_token, Qwen_Image_token) | |
| elif self.model_args.model_backbone == "internvl_2_5": | |
| text = text.replace(Phi_Image_token, Internvl_Image_token) | |
| return text, self._get_image(img_path), | |
| def _process_image(self, image, resolution): | |
| if image is None: | |
| return None | |
| if resolution == "high": | |
| image = image.resize((1344, 1344)) | |
| else: | |
| image = image.resize((336, 336)) | |
| return image | |
| def _get_image(self, img_path): | |
| if img_path == "": | |
| return None | |
| if img_path.startswith("/"): | |
| full_img_path = img_path | |
| else: | |
| full_img_path = os.path.join(self.data_args.image_dir, img_path) | |
| image = Image.open(full_img_path) | |
| if self.model_args.model_backbone == "llava_next": | |
| return self._process_image(image, "high") | |
| elif self.model_args.model_backbone == "internvl_2_5": | |
| return self._process_image(image, "high") | |
| else: | |
| return image | |
| return image | |
| def get_paired_data(self, text_field, img_path_field): | |
| """ | |
| (text_field, image_field) -> ("qry_text", "qry_img_path") or ("tgt_text", "tgt_img_path") | |
| """ | |
| unique_pair = set() | |
| for row in self.eval_data: | |
| if isinstance(row[text_field], str): | |
| if row[text_field]: | |
| unique_pair.add((row[text_field], row[img_path_field])) | |
| else: | |
| if isinstance(row[img_path_field], List): | |
| for img_path in row[img_path_field]: | |
| unique_pair.add((row[text_field], img_path)) | |
| else: | |
| unique_pair.add((row[text_field], row[img_path_field])) | |
| elif isinstance(row[text_field], List): | |
| assert isinstance(row[img_path_field], List) and len(row[img_path_field]) == len(row[text_field]) | |
| for text, img_path in zip(row[text_field], row[img_path_field]): | |
| unique_pair.add((text, img_path)) | |
| paired_data = [{"text": text, "img_path": img_path} for text, img_path in unique_pair] | |
| return paired_data | |
| class FlickrDataset(Dataset): | |
| def __init__(self, modality, model_backbone): | |
| self.model_backbone = model_backbone | |
| self.modality = modality | |
| self.raw_data = load_dataset("nlphuji/flickr_1k_test_image_text_retrieval", split="test") | |
| if modality == "image": | |
| self.eval_data, self.image_names = self.get_image_data() | |
| else: | |
| self.eval_data, self.image_names = self.get_text_data() | |
| def __len__(self): | |
| return len(self.eval_data) | |
| def __getitem__(self, idx): | |
| return self.eval_data[idx] | |
| def __getitem__(self, idx): | |
| text, image = self.eval_data[idx] | |
| if self.model_backbone == "llava_next": | |
| # Update llava image token | |
| text = text.replace(Phi_Image_token, Llava_Image_token) | |
| image = self._process_image(image, "high") | |
| return text, image | |
| def _process_image(self, image, resolution): | |
| if image is None: | |
| return None | |
| if resolution == "high": | |
| image = image.resize((1344, 1344)) | |
| else: | |
| image = image.resize((336, 336)) | |
| return image | |
| def _get_image(self, img_path): | |
| if img_path == "": | |
| return None | |
| full_img_path = os.path.join(self.data_args.image_dir, img_path) | |
| image = Image.open(full_img_path) | |
| if self.model_backbone == "llava_next": | |
| return self._process_image(image, "high") | |
| else: | |
| return image | |
| return image | |
| def get_image_data(self): | |
| eval_data, image_names = [], [] | |
| # i2t | |
| inst = "<|image_1|> Find an image caption describing the given image." # llava-1344-step1k4, i2t=94.0, t2i=80.26 | |
| # inst = "<|image_1|> Represent the given image for image caption retrieval." # llava-1344-step1k4, i2t=94.6, t2i=78.98 | |
| # t2i | |
| # inst = "<|image_1|> Represent the given image." # MSCOCO t2i | |
| for row in self.raw_data: | |
| eval_data.append((inst, row["image"])) | |
| image_names.append(row["filename"]) | |
| return eval_data, image_names | |
| def get_text_data(self): | |
| eval_data, image_names = [], [] | |
| # i2t | |
| inst = "" | |
| # t2i | |
| # inst = "Retrieve an image that matches the given caption: " | |
| # inst = "Find me an everyday image that matches the given caption." # MSCOCO t2i | |
| for row in self.raw_data: | |
| for caption in row["caption"]: | |
| # eval_data.append((caption, None)) | |
| eval_data.append((inst + caption, None)) | |
| image_names.append(row["filename"]) | |
| return eval_data, image_names | |