Spaces:
Sleeping
Sleeping
| # Copyright 2024 The HuggingFace Team and The MeissonFlow Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import torch | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| from PIL.ImageOps import exif_transpose | |
| from PIL import Image | |
| import io | |
| import json | |
| import numpy as np | |
| import pyarrow.parquet as pq | |
| import random | |
| import bisect | |
| import pyarrow.fs as fs | |
| def tokenize_prompt( | |
| tokenizer, | |
| prompt, | |
| text_encoder_architecture='open_clip', | |
| padding='max_length', | |
| max_length=77, | |
| max_length_t5=256, | |
| ): | |
| if text_encoder_architecture == 'CLIP' or text_encoder_architecture == 'open_clip': | |
| input_ids = tokenizer( | |
| prompt, | |
| truncation=True, | |
| padding=padding, | |
| max_length=max_length, | |
| return_tensors="pt", | |
| ).input_ids | |
| return input_ids | |
| elif text_encoder_architecture == 't5_clip': # we have two tokenizers, 1st for CLIP, 2nd for T5 | |
| input_ids = [] | |
| input_ids.append(tokenizer[0]( | |
| prompt, | |
| truncation=True, | |
| padding=padding, | |
| max_length=max_length, | |
| return_tensors="pt", | |
| ).input_ids) | |
| input_ids.append(tokenizer[1]( | |
| prompt, | |
| truncation=True, | |
| padding=padding, | |
| max_length=max_length_t5, | |
| return_tensors="pt", | |
| ).input_ids) | |
| return input_ids | |
| elif text_encoder_architecture == "gemma": | |
| input_ids = [] | |
| input_ids.append(tokenizer[0]( | |
| prompt, | |
| truncation=True, | |
| padding=padding, | |
| max_length=max_length, | |
| return_tensors="pt", | |
| ).input_ids) | |
| input_ids.append(tokenizer[1]( | |
| prompt, | |
| truncation=True, | |
| padding=padding, | |
| max_length=max_length_t5, | |
| return_tensors="pt", | |
| ).input_ids) | |
| return input_ids | |
| else: | |
| raise ValueError(f"Unknown text_encoder_architecture: {text_encoder_architecture}") | |
| def encode_prompt( | |
| text_encoder, | |
| input_ids, | |
| text_encoder_architecture='open_clip' | |
| ): | |
| if text_encoder_architecture == 'CLIP' or text_encoder_architecture == 'open_clip': | |
| outputs = text_encoder(input_ids=input_ids, return_dict=True, output_hidden_states=True) | |
| encoder_hidden_states = outputs.hidden_states[-2] | |
| cond_embeds = outputs[0] | |
| return encoder_hidden_states, cond_embeds | |
| elif text_encoder_architecture == 't5_clip': | |
| outputs_clip = text_encoder[0]( | |
| input_ids=input_ids[0], | |
| return_dict=True, | |
| output_hidden_states=True | |
| ) | |
| outputs_t5 = text_encoder[1]( | |
| input_ids=input_ids[1], | |
| return_dict=True, | |
| output_hidden_states=True | |
| ) | |
| encoder_hidden_states = outputs_t5.last_hidden_state | |
| cond_embeds = outputs_clip.text_embeds | |
| return encoder_hidden_states, cond_embeds | |
| elif text_encoder_architecture == "gemma": | |
| outputs_clip = text_encoder[0]( | |
| input_ids=input_ids[0], | |
| return_dict=True, | |
| output_hidden_states=True | |
| ) | |
| outputs_gemma = text_encoder[1]( | |
| input_ids=input_ids[1], | |
| return_dict=True, | |
| output_hidden_states=True | |
| ) | |
| encoder_hidden_states = outputs_gemma.last_hidden_state | |
| cond_embeds = outputs_clip.text_embeds | |
| return encoder_hidden_states, cond_embeds | |
| else: | |
| raise ValueError(f"Unknown text_encoder_architecture: {text_encoder_architecture}") | |
| def process_image(image, size, Norm=False, hps_score=6.0): | |
| image = exif_transpose(image) | |
| if not image.mode == "RGB": | |
| image = image.convert("RGB") | |
| orig_height = image.height | |
| orig_width = image.width | |
| image = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)(image) | |
| c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(size, size)) | |
| image = transforms.functional.crop(image, c_top, c_left, size, size) | |
| image = transforms.ToTensor()(image) | |
| if Norm: | |
| image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)(image) | |
| micro_conds = torch.tensor( | |
| [orig_width, orig_height, c_top, c_left, hps_score], | |
| ) | |
| return {"image": image, "micro_conds": micro_conds} | |
| class ImageCaptionLargeDataset(Dataset): | |
| def __init__( | |
| self, | |
| root_dir, | |
| tokenizer, | |
| size, | |
| text_encoder_architecture="CLIP", | |
| norm=False | |
| ): | |
| self.root_dir = root_dir | |
| self.tokenizer = tokenizer | |
| self.size = size | |
| self.text_encoder_architecture = text_encoder_architecture | |
| self.norm = norm | |
| self.data_list = [] | |
| for root, dirnames, filenames in os.walk(root_dir): | |
| for filename in filenames: | |
| if filename.endswith(".jpg") or filename.endswith(".png"): | |
| base_name = os.path.splitext(filename)[0] | |
| txt_file = os.path.join(root, base_name + ".txt") | |
| if os.path.exists(txt_file): | |
| self.data_list.append((root, base_name + ".txt", filename)) | |
| def __len__(self): | |
| return len(self.data_list) | |
| def __getitem__(self, idx): | |
| try: | |
| sub_dir, txtfilename, imgfilename = self.data_list[idx] | |
| img_path = os.path.join(sub_dir, imgfilename) | |
| caption_path = os.path.join(sub_dir, txtfilename) | |
| image = Image.open(img_path).convert("RGB") | |
| ret = process_image(image, self.size, self.norm) | |
| with open(caption_path, "r", encoding="utf-8") as f: | |
| caption = f.read().strip() | |
| ret["prompt_input_ids"] = tokenize_prompt( | |
| self.tokenizer, caption, self.text_encoder_architecture | |
| ) | |
| return ret | |
| except Exception as e: | |
| print("===========================================") | |
| print(f"[Warning] Error at index {idx}: {img_path}") | |
| print("===========================================") | |
| if idx + 1 < len(self.data_list): | |
| return self.__getitem__(idx + 1) | |
| else: | |
| return self.__getitem__(len(self.data_list) - 1) | |
| class MultiSourceVLDataset(Dataset): | |
| """ | |
| A unified dataloader for | |
| • LLaVA-Instruct-150K | |
| • MMMU (multiple-choice QA) | |
| • VQAv2 | |
| • Local caption files under `pdd3/` | |
| """ | |
| def __init__( | |
| self, | |
| tokenizer, | |
| size: int, | |
| text_encoder_architecture: str = "CLIP", | |
| norm: bool = False, | |
| # ----- paths ----- | |
| llava_json: str = None, llava_img_root: str = None, | |
| mmmu_json: str = None, mmmu_img_root: str = None, | |
| vqa_ann_json: str = None, vqa_img_root: str = None, | |
| gqa_json: str = None, gqa_img_root: str = None, | |
| coco_json: str = None, coco_img_root: str = None, | |
| coco_qa_json: str = None, | |
| mg_llava_json: str = None, mg_llava_root: str = None, | |
| pdd3_dir: str = None, caption_dir: str = None, | |
| ): | |
| self.tokenizer = tokenizer | |
| self.size = size | |
| self.arch = text_encoder_architecture | |
| self.norm = norm | |
| self.gen_samples = [] # [(img_path, prompt), ...] | |
| self.mmu_samples = [] # [(img_path, question, answer), ...] | |
| if llava_json: | |
| self._load_llava(llava_json, llava_img_root) | |
| if mmmu_json: | |
| self._load_mmmu(mmmu_json, mmmu_img_root) | |
| if vqa_ann_json: | |
| self._load_vqav2(vqa_ann_json, vqa_img_root) | |
| if coco_json: | |
| self._load_coco2014_captions(coco_json, coco_img_root) | |
| if coco_qa_json: | |
| self._load_coco2014_qa(coco_qa_json, coco_img_root) | |
| if gqa_json: | |
| self._load_gqa(gqa_json, gqa_img_root) | |
| if mg_llava_json: | |
| self._load_mg_llava(mg_llava_json, mg_llava_root) | |
| if caption_dir: | |
| self._load_caption(caption_dir) | |
| if pdd3_dir: | |
| self._load_pdd3(pdd3_dir) | |
| self.len_mmu = len(self.mmu_samples) | |
| self.len_gen = len(self.gen_samples) | |
| # ------------------------------------------------------------------ # | |
| # dataset parsers # | |
| # ------------------------------------------------------------------ # | |
| def _load_llava(self, json_path, img_root): | |
| with open(json_path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| for ex in data: | |
| img_file = os.path.join(img_root, ex["image"]) | |
| human_msg = next(m["value"] for m in ex["conversations"] if m["from"] == "human") | |
| gpt_msg = next(m["value"] for m in ex["conversations"] if m["from"] == "gpt") | |
| self.mmu_samples.append((img_file, human_msg.strip(), gpt_msg.strip())) | |
| def _load_mmmu(self, json_path, img_root): | |
| with open(json_path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| for ex in data: | |
| img_file = os.path.join(img_root, ex["image"]) | |
| choices = "\n".join([f"{chr(65+i)}. {c}" for i, c in enumerate(ex["choices"])]) | |
| question = f"{ex['question'].strip()}\n{choices}" | |
| answer = f"{ex['answer']}" | |
| self.mmu_samples.append((img_file, question, answer)) | |
| def _load_coco2014_qa(self, ann_jsonl, img_root): | |
| with open(ann_jsonl, "r", encoding="utf-8") as file: | |
| data = [json.loads(line) for line in file if line.strip()] | |
| for ann in data: | |
| image = ann["image"] | |
| question = ann["question"] | |
| answer = ann["label"] | |
| image_path = os.path.join(img_root, image) | |
| self.mmu_samples.append((image_path, question, answer)) | |
| def _load_coco2014_captions(self, ann_json, img_root): | |
| """ | |
| Load COCO 2014 image-caption pairs from caption annotation file. | |
| Args: | |
| ann_json (str): Path to COCO-style captions JSON (e.g., captions_train2014.json) | |
| img_root (str): Directory containing COCO images (should include 'train2014/' and 'val2014/' subdirs) | |
| """ | |
| with open(ann_json, "r") as f: | |
| data = json.load(f) | |
| is_train = "train" in os.path.basename(ann_json).lower() | |
| img_subdir = "train2014" if is_train else "val2014" | |
| prefix = "COCO_train2014_" if is_train else "COCO_val2014_" | |
| for ann in data["annotations"]: | |
| image_id = ann["image_id"] | |
| caption = ann["caption"] | |
| image_filename = f"{prefix}{image_id:012d}.jpg" | |
| image_path = os.path.join(img_root, img_subdir, image_filename) | |
| question = "Please describe this image concisely." | |
| self.mmu_samples.append((image_path, question, caption)) | |
| def _load_vqav2(self, ann_json, img_root): | |
| with open(ann_json, "r") as file: | |
| annos = json.load(file) | |
| for ann in annos: | |
| q = ann["question"] | |
| answer = ann["answer"] | |
| img_path = ann["image"] | |
| img_file = os.path.join( | |
| img_root, | |
| img_path # if val, modify to val2014 | |
| ) | |
| self.mmu_samples.append((img_file, q, answer)) | |
| def _load_gqa(self, ann_json_root, img_root): | |
| annos = {} | |
| for jsonfile in os.listdir(ann_json_root): | |
| jsonpath = os.path.join(ann_json_root, jsonfile) | |
| with open(jsonpath, "r") as file: | |
| anno = json.load(file) | |
| annos.update(anno) | |
| for ann in annos.values(): | |
| q = ann["question"] | |
| answer = ann["fullAnswer"] | |
| img_name = ann["imageId"] + ".jpg" | |
| img_path = os.path.join( | |
| img_root, | |
| img_name | |
| ) | |
| self.mmu_samples.append((img_path, q, answer)) | |
| def _load_mg_llava(self, json_path, img_root): | |
| with open(json_path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| for ex in data: | |
| image = ex.get("image", None) | |
| if image is not None: | |
| img_file = os.path.join(img_root, ex["image"]) | |
| if os.path.exists(img_file): | |
| human_msg = next(m["value"] for m in ex["conversations"] if m["from"] == "human") | |
| gpt_msg = next(m["value"] for m in ex["conversations"] if m["from"] == "gpt") | |
| self.mmu_samples.append((img_file, human_msg.strip(), gpt_msg.strip())) | |
| def _load_caption(self, root_dir): | |
| for root, _, files in os.walk(root_dir): | |
| for f in files: | |
| if f.lower().endswith((".jpg", ".png")): | |
| base = os.path.splitext(f)[0] | |
| txt_path = os.path.join(root, base + ".txt") | |
| if os.path.exists(txt_path): | |
| with open(txt_path, "r") as file: | |
| caption = file.read().strip() | |
| q = "Please describe this image." | |
| self.mmu_samples.append((os.path.join(root, f), q, caption)) | |
| def _load_pdd3(self, root_dir): | |
| for root, _, files in os.walk(root_dir): | |
| for f in files: | |
| if f.lower().endswith((".jpg", ".png")): | |
| base = os.path.splitext(f)[0] | |
| txt_path = os.path.join(root, base + ".txt") | |
| if os.path.exists(txt_path): | |
| with open(txt_path, "r") as file: | |
| caption = file.read().strip() | |
| self.gen_samples.append((os.path.join(root, f), caption)) | |
| # ------------------------------------------------------------------ # | |
| # PyTorch Dataset API # | |
| # ------------------------------------------------------------------ # | |
| def __len__(self): | |
| return max(self.len_gen, self.len_mmu) | |
| def __getitem__(self, idx): | |
| get_mmu_data = False | |
| get_gen_data = False | |
| while not get_mmu_data: | |
| try: | |
| mmu_img_path, question, answer = self.mmu_samples[idx] | |
| get_mmu_data = True | |
| except: | |
| idx = random.randint(0, self.len_mmu - 1) | |
| while not get_gen_data: | |
| try: | |
| gen_img_path, prompt = self.gen_samples[idx] | |
| get_gen_data = True | |
| except: | |
| idx = random.randint(0, self.len_gen - 1) | |
| try: | |
| # ---- image ---- | |
| mmu_image = Image.open(mmu_img_path).convert("RGB") | |
| mmu_ret = process_image(mmu_image, self.size, self.norm) | |
| gen_image = Image.open(gen_img_path).convert("RGB") | |
| gen_ret = process_image(gen_image, self.size, self.norm) | |
| ret = dict( | |
| gen_image=gen_ret["image"], | |
| gen_micro_conds=gen_ret["micro_conds"], | |
| mmu_image=mmu_ret["image"], | |
| mmu_micro_conds=mmu_ret["micro_conds"] | |
| ) | |
| # ---- text ---- | |
| question = question.replace("<image>", "").replace("\n", "") | |
| question_ids = tokenize_prompt( | |
| self.tokenizer, | |
| question, | |
| self.arch, | |
| padding=False, | |
| ) | |
| question_ids = question_ids[:, :-1] | |
| q_len = len(question_ids[0]) | |
| if answer: | |
| full_prompt = question + " " + answer | |
| else: | |
| full_prompt = question | |
| mmu_input_ids = tokenize_prompt(self.tokenizer, full_prompt, self.arch) | |
| gen_input_ids = tokenize_prompt(self.tokenizer, prompt, self.arch) | |
| ret.update({ | |
| "gen_input_ids": gen_input_ids, | |
| "mmu_input_ids": mmu_input_ids, | |
| "question_len": torch.LongTensor([q_len]) | |
| }) | |
| return ret | |
| except: | |
| print("================================================================") | |
| print(f"There is something wrong with {mmu_img_path} or {gen_img_path}.") | |
| print("================================================================") | |
| if idx < self.len_gen - 1 or idx < self.len_mmu - 1: | |
| return self.__getitem__(idx + 1) | |
| else: | |
| idx = random.randint(0, self.len_gen - 1) | |
| return self.__getitem__(idx) |