Spaces:
Sleeping
Sleeping
| # coding=utf-8 | |
| # Copyright 2025 MMaDA Team | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import bisect | |
| import csv | |
| import logging | |
| import itertools | |
| import json | |
| import math | |
| import os | |
| import hashlib | |
| import contextlib | |
| from pathlib import Path | |
| from accelerate import Accelerator | |
| from itertools import chain | |
| # Video real-time? | |
| import os.path as osp | |
| import time | |
| import requests | |
| import random | |
| import re | |
| import datasets | |
| import pandas as pd | |
| from functools import partial | |
| from typing import List, Optional, Union, Dict, Any, Sequence | |
| from glob import glob | |
| from tqdm import tqdm | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import torch | |
| from dataclasses import dataclass | |
| from datasets import Dataset as HFDataset | |
| from datasets import load_dataset, get_dataset_config_names | |
| from io import BytesIO | |
| Image.warnings.simplefilter('error', Image.DecompressionBombWarning) | |
| import webdataset as wds | |
| import yaml | |
| from braceexpand import braceexpand | |
| from torch.utils.data import default_collate, Dataset | |
| from torchvision import transforms | |
| from transformers import PreTrainedTokenizer | |
| from datasets import ( | |
| load_dataset, | |
| load_from_disk, | |
| DatasetDict, | |
| DownloadConfig, | |
| get_dataset_config_names, | |
| concatenate_datasets, | |
| ) | |
| import warnings | |
| from training.utils import image_transform as utils_image_transform, image_transform_squash as utils_image_transform_squash | |
| from webdataset.tariterators import ( | |
| base_plus_ext, | |
| tar_file_expander, | |
| url_opener, | |
| valid_sample, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| S2T_INSTRUCTION = ["Transcribe the given audio.", | |
| "Write down what you hear in the audio.", | |
| "Provide a transcript for the given speech.", | |
| "What does the speaker in the audio say?", | |
| "Convert the speech in the audio to text.", | |
| "Listen to the audio and write out the text."] | |
| T2S_INSTRUCTION = ["Generate speech for the given text.", | |
| "Read the given sentence aloud.", | |
| "Say the given words.", | |
| "Convert the given text into spoken audio.", | |
| "Speak the given text.", | |
| "Synthesize the text into speech."] | |
| V2T_INSTRUCTION = ["Describe the video in detail.", | |
| "Please provide a detailed description of the video.", | |
| "What is happening in the video?", | |
| "Describe the content of the video in detail.",] | |
| V2S_INSTRUCTION = [ | |
| "Generate speech that describes the given video.", | |
| "Narrate the events happening in the video.", | |
| "Produce spoken audio describing the video content.", | |
| "Convert the video into a detailed spoken narration.", | |
| "Speak a description of what is shown in the video.", | |
| "Synthesize speech that explains the content of the video.", | |
| ] | |
| person_token = ["a person", "someone", "somebody"] | |
| def replace_person_token(t): | |
| "Used for CC12M - handles all case variations of <person> tag" | |
| t = re.sub(r"<person>([,\s]*(and)*[,\s]*<person>)+", " people ", t, flags=re.IGNORECASE) | |
| person_pattern = re.compile(r"<person>", re.IGNORECASE) | |
| while person_pattern.search(t): | |
| match = person_pattern.search(t) | |
| t = t[:match.start()] + f" {random.choice(person_token)} " + t[match.end():] | |
| return t | |
| def filter_keys(key_set): | |
| def _f(dictionary): | |
| return {k: v for k, v in dictionary.items() if k in key_set} | |
| return _f | |
| def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None, src=None): | |
| """Return function over iterator that groups key, value pairs into samples. | |
| :param keys: function that splits the key into key and extension (base_plus_ext) | |
| :param lcase: convert suffixes to lower case (Default value = True) | |
| """ | |
| current_sample = None | |
| for filesample in data: | |
| assert isinstance(filesample, dict) | |
| if "fname" not in filesample.keys(): | |
| print(f"fname not in filesample.keys(): {filesample}") | |
| print(f"src: {src}") | |
| continue | |
| fname, value = filesample["fname"], filesample["data"] | |
| prefix, suffix = keys(fname) | |
| if prefix is None: | |
| continue | |
| if lcase: | |
| suffix = suffix.lower() | |
| if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample: | |
| if valid_sample(current_sample): | |
| yield current_sample | |
| current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) | |
| if suffixes is None or suffix in suffixes: | |
| current_sample[suffix] = value | |
| if valid_sample(current_sample): | |
| yield current_sample | |
| def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue): | |
| # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw | |
| streams = url_opener(src, handler=handler) | |
| files = tar_file_expander(streams, handler=handler) # [{fname,data,__url__}, ...] __url__ 字段标识当前读取的文件来自哪个 tar 包 | |
| samples = group_by_keys_nothrow(files, handler=handler, src=src) | |
| return samples | |
| def image_transform(sample, resolution=256): | |
| image = sample["images"] | |
| image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BICUBIC)(image) | |
| image = transforms.CenterCrop((resolution, resolution))(image) | |
| image = transforms.ToTensor()(image) | |
| image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)(image) | |
| sample["images"] = image | |
| return sample | |
| def image_transform_squash(sample, resolution=256): | |
| image = sample["images"] | |
| image = transforms.Resize((resolution, resolution), interpolation=transforms.InterpolationMode.BICUBIC)(image) | |
| image = transforms.ToTensor()(image) | |
| image = transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])(image) | |
| sample["images"] = image | |
| return sample | |
| def conditional_image_transform(sample, resolution=256): | |
| url = sample.get("__url__", "") | |
| special_datasets = ['ai2d', 'clevr', 'docvqa', 'geo'] | |
| use_squash = False | |
| for keyword in special_datasets: | |
| if keyword in url: | |
| use_squash = True | |
| break | |
| if use_squash: | |
| return image_transform_squash(sample, resolution) | |
| else: | |
| return image_transform(sample, resolution) | |
| def remove_prefix(caption): | |
| caption = caption.replace('The image features ', '').replace('The image presents ', '').replace( | |
| "The image you've sent is, ", '').replace("In the center of the image, ", '').replace( | |
| "The image showcases ", '').replace("The image is ", '').replace( | |
| "The image captures ", '').replace("In the given image ", '').replace( | |
| "The image portrays ", '').replace("In the image, ", '').replace("In this image, we see ", '').replace( | |
| "The image depicts ", '').replace("This is ", '').replace("In this image, ", '').replace( | |
| "This image captures ", '') | |
| return caption | |
| def filter_long_samples(sample): | |
| return sample.get('input_ids') is not None | |
| class Text2ImageDataset: | |
| def __init__( | |
| self, | |
| train_shards_path_or_url: Union[str, List[str]], | |
| tokenizer: PreTrainedTokenizer, | |
| max_seq_length: int, | |
| num_train_examples: int, | |
| per_gpu_batch_size: int, | |
| global_batch_size: int, | |
| num_workers: int, | |
| resolution: int = 256, | |
| shuffle_buffer_size: int = 1000, | |
| pin_memory: bool = False, | |
| persistent_workers: bool = False, | |
| external_caption_path: Optional[str] = '', | |
| external_journeydb_caption_path: Optional[str] = '', | |
| external_laion12m_caption_path: Optional[str] = '', | |
| external_cc12m_caption_path: Optional[str] = '', | |
| external_text_to_image_2M_512_caption_path: Optional[str] = '', | |
| external_ai2d_caption_path: Optional[str] = '', | |
| external_clevr_caption_path: Optional[str] = '', | |
| external_docvqa_caption_path: Optional[str] = '', | |
| external_geo_caption_path: Optional[str] = '', | |
| is_captioning: bool = False, | |
| add_caption_prompt: bool = False, | |
| long_caption: bool = True, | |
| shuffle: bool = True, | |
| ): | |
| if f"{train_shards_path_or_url}.yaml" in os.listdir('./configs'): | |
| with open(f"./configs/{train_shards_path_or_url}.yaml") as f: | |
| train_shards_path_or_url = yaml.safe_load(f) | |
| self.long_caption = long_caption | |
| self.external_caption_path = external_caption_path | |
| self.external_journeydb_caption_path = external_journeydb_caption_path | |
| self.external_laion12m_caption_path = external_laion12m_caption_path | |
| self.external_cc12m_caption_path = external_cc12m_caption_path | |
| self.external_text_to_image_2M_512_caption_path = external_text_to_image_2M_512_caption_path | |
| self.is_captioning = is_captioning | |
| self.add_caption_prompt = add_caption_prompt | |
| if self.add_caption_prompt: | |
| with open("./training/questions.json") as f: | |
| self.caption_prompt = json.load(f) | |
| # self.caption_prompt = ['USER: \n' + prompt + ' ASSISTANT:' for prompt in self.caption_prompt] | |
| self.caption_prompt = ['<|start_header_id|>user<|end_header_id|>\n' + prompt + '<eot_id><|start_header_id|>assistant<|end_header_id|>\n' for prompt in self.caption_prompt] | |
| else: | |
| self.caption_prompt = None | |
| if external_journeydb_caption_path != '': | |
| with open(external_journeydb_caption_path) as file: | |
| self.journeydb_caption = json.load(file) | |
| else: | |
| self.journeydb_caption = None | |
| if external_ai2d_caption_path!= '': | |
| self.ai2d_caption = pd.read_csv(external_ai2d_caption_path) | |
| if external_clevr_caption_path!= '': | |
| self.clevr_caption = pd.read_csv(external_clevr_caption_path) | |
| if external_docvqa_caption_path!= '': | |
| self.docvqa_caption = pd.read_csv(external_docvqa_caption_path) | |
| if external_geo_caption_path!= '': | |
| self.geo_caption = pd.read_csv(external_geo_caption_path) | |
| def tokenize(text): | |
| if tokenizer is not None: | |
| text = replace_person_token(text) | |
| encoding = tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=2 * max_seq_length, | |
| padding=False, | |
| return_tensors="pt" | |
| ) | |
| full_input_ids = encoding.input_ids[0] | |
| if len(full_input_ids) > max_seq_length: | |
| return None | |
| else: | |
| return text | |
| else: | |
| return text | |
| if not isinstance(train_shards_path_or_url, str): | |
| train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url] | |
| # flatten list using itertools | |
| train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url)) | |
| if external_caption_path != '': | |
| processing_pipeline = [ | |
| wds.decode("pil", handler=wds.ignore_and_continue), | |
| wds.map(self.load_external_caption, handler=wds.ignore_and_continue), | |
| wds.rename( | |
| images="jpg;png;jpeg;webp", | |
| input_ids="text;txt;caption", | |
| handler=wds.warn_and_continue, | |
| ), | |
| wds.map(partial(conditional_image_transform, resolution=resolution), handler=wds.warn_and_continue), | |
| wds.map(filter_keys(set(["images", "input_ids"]))), | |
| wds.map_dict( | |
| input_ids=tokenize, | |
| handler=wds.warn_and_continue, | |
| ), | |
| wds.select(filter_long_samples), | |
| ] | |
| else: | |
| processing_pipeline = [ | |
| wds.decode("pil", handler=wds.ignore_and_continue), | |
| wds.rename( | |
| images="jpg;png;jpeg;webp", | |
| input_ids="text;txt;caption", | |
| handler=wds.warn_and_continue, | |
| ), | |
| wds.map(partial(conditional_image_transform, resolution=resolution), handler=wds.warn_and_continue), | |
| wds.map(filter_keys(set(["images", "input_ids"]))), | |
| wds.map_dict( | |
| input_ids=tokenize, | |
| handler=wds.warn_and_continue, | |
| ), | |
| wds.select(filter_long_samples), | |
| ] | |
| pipeline = [ | |
| wds.ResampledShards(train_shards_path_or_url), | |
| tarfile_to_samples_nothrow, | |
| wds.shuffle(shuffle_buffer_size), | |
| *processing_pipeline, | |
| wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate), | |
| ] | |
| num_batches = math.ceil(num_train_examples / global_batch_size) | |
| num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker | |
| num_batches = num_worker_batches * num_workers | |
| num_samples = num_batches * global_batch_size | |
| self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches) | |
| self._train_dataloader = wds.WebLoader( | |
| self._train_dataset, | |
| batch_size=None, | |
| shuffle=False, | |
| num_workers=num_workers, | |
| pin_memory=pin_memory, | |
| persistent_workers=persistent_workers, | |
| ) | |
| # add meta-data to dataloader instance for convenience | |
| self._train_dataloader.num_batches = num_batches | |
| self._train_dataloader.num_samples = num_samples | |
| def load_external_caption(self, sample): | |
| if 'SA1B' in sample['__key__'] or 'sa' in sample['__key__']: | |
| captionf = f"{self.external_caption_path}/{sample['__key__'].split('/')[-1]}.txt" | |
| if os.path.exists(captionf): | |
| with open(captionf, "r") as reader: | |
| captions = reader.readlines()[0].replace('\n', '') | |
| else: | |
| captions = "" | |
| # for captioning | |
| if self.is_captioning: | |
| if self.add_caption_prompt is not None: | |
| prompt = random.sample(self.caption_prompt, 1)[0] | |
| sample['txt'] = prompt + captions | |
| else: | |
| sample['txt'] = captions | |
| # for generation | |
| else: | |
| # randomly choose short and long captions | |
| if random.random() < 0.5: | |
| sample['txt'] = captions.split('.')[0] | |
| else: | |
| sample['txt'] = captions | |
| sample['txt'] = remove_prefix(sample['txt']) | |
| return sample | |
| elif 'laion' in sample['__url__']: | |
| url_part = sample['__url__'].split('/')[-1].split('.')[0] | |
| key = sample['__key__'].split('/')[-1] | |
| captionf = os.path.join(self.external_laion12m_caption_path, url_part, f"{key}.caption") | |
| if os.path.exists(captionf): | |
| with open(captionf, "r") as reader: | |
| captions = reader.read().strip() | |
| else: | |
| captions = "" | |
| # for captioning | |
| if self.is_captioning: | |
| if self.add_caption_prompt is not None: | |
| prompt = random.sample(self.caption_prompt, 1)[0] | |
| sample['txt'] = prompt + captions | |
| else: | |
| sample['txt'] = captions | |
| # for generation | |
| else: | |
| # randomly choose short and long captions | |
| if random.random() < 0.5: | |
| sample['txt'] = captions.split('.')[0] | |
| else: | |
| sample['txt'] = captions | |
| sample['txt'] = remove_prefix(sample['txt']) | |
| return sample | |
| elif 'cc12m' in sample['__url__']: | |
| url_part = sample['__url__'].split('/')[-1].split('.')[0] | |
| key = sample['__key__'].split('/')[-1] | |
| captionf = os.path.join(self.external_cc12m_caption_path, url_part, f"{key}.caption") | |
| if os.path.exists(captionf): | |
| with open(captionf, "r") as reader: | |
| captions = reader.read().strip() | |
| else: | |
| captions = "" | |
| # for captioning | |
| if self.is_captioning: | |
| if self.add_caption_prompt is not None: | |
| prompt = random.sample(self.caption_prompt, 1)[0] | |
| sample['txt'] = prompt + captions | |
| else: | |
| sample['txt'] = captions | |
| # for generation | |
| else: | |
| # randomly choose short and long captions | |
| if random.random() < 0.5: | |
| sample['txt'] = captions.split('.')[0] | |
| else: | |
| sample['txt'] = captions | |
| sample['txt'] = remove_prefix(sample['txt']) | |
| return sample | |
| elif "text-to-image-2M" in sample['__url__']: | |
| if "json" in sample and "prompt" in sample["json"]: | |
| captions = sample["json"]["prompt"] | |
| else: | |
| print(f"sample has no json or prompt: {sample}") | |
| captions = "" | |
| sample['txt'] = captions | |
| return sample | |
| elif 'ai2d' in sample['__url__']: | |
| key = sample['__key__'].split('/')[-1] | |
| df_row = self.ai2d_caption[self.ai2d_caption['image'].astype(str) == key + '.png'] | |
| if len(df_row) == 0: | |
| print(f"No captions available for key {sample['__key__']}") | |
| return sample | |
| elif len(df_row) > 1: | |
| # print(f"Multiple captions available for key {sample['__key__']}") | |
| df_row = df_row.sample(1) | |
| question = df_row['question'].values[0] | |
| solution = df_row['solution'].values[0] | |
| caption = ( | |
| '<|start_header_id|>user<|end_header_id|>\n' | |
| "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here\n" | |
| f"{question}\n" | |
| '<eot_id><|start_header_id|>assistant<|end_header_id|>\n' | |
| f"{solution}" | |
| ) | |
| sample['txt'] = caption | |
| return sample | |
| elif 'clevr' in sample['__url__']: | |
| key = sample['__key__'].split('/')[-1] | |
| df_row = self.clevr_caption[self.clevr_caption['image'].astype(str) == key + ".jpg"] | |
| if len(df_row) == 0: | |
| print(f"No captions available for key {sample['__key__']}") | |
| return sample | |
| elif len(df_row) > 1: | |
| # print(f"Multiple captions available for key {sample['__key__']}") | |
| df_row = df_row.sample(1) | |
| question = df_row['question'].values[0] | |
| solution = df_row['solution'].values[0] | |
| caption = ( | |
| '<|start_header_id|>user<|end_header_id|>\n' | |
| "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here\n" | |
| f"{question}\n" | |
| '<eot_id><|start_header_id|>assistant<|end_header_id|>\n' | |
| f"{solution}" | |
| ) | |
| sample['txt'] = caption | |
| return sample | |
| elif 'docvqa' in sample['__url__']: | |
| key = sample['__key__'].split('/')[-1] | |
| df_row = self.docvqa_caption[self.docvqa_caption['image'].astype(str) == key + ".png"] | |
| if len(df_row) == 0: | |
| print(f"No captions available for key {sample['__key__']}") | |
| return sample | |
| elif len(df_row) > 1: | |
| # print(f"Multiple captions available for key {sample['__key__']}") | |
| df_row = df_row.sample(1) | |
| question = df_row['question'].values[0] | |
| solution = df_row['solution'].values[0] | |
| caption = ( | |
| '<|start_header_id|>user<|end_header_id|>\n' | |
| "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here\n" | |
| f"{question}\n" | |
| '<eot_id><|start_header_id|>assistant<|end_header_id|>\n' | |
| f"{solution}" | |
| ) | |
| sample['txt'] = caption | |
| return sample | |
| elif 'geo' in sample['__url__']: | |
| key = sample['__key__'].split('/')[-1] | |
| df_row = self.geo_caption[self.geo_caption['image'].astype(str) == key + ".jpg"] | |
| if len(df_row) == 0: | |
| print(f"No captions available for key {sample['__key__']}") | |
| return sample | |
| elif len(df_row) > 1: | |
| # print(f"Multiple captions available for key {sample['__key__']}") | |
| df_row = df_row.sample(1) | |
| question = df_row['question'].values[0] | |
| solution = df_row['solution'].values[0] | |
| caption = ( | |
| '<|start_header_id|>user<|end_header_id|>\n' | |
| "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here\n" | |
| f"{question}\n" | |
| '<eot_id><|start_header_id|>assistant<|end_header_id|>\n' | |
| f"{solution}" | |
| ) | |
| sample['txt'] = caption | |
| return sample | |
| elif self.journeydb_caption is not None and sample['__key__'] in self.journeydb_caption: | |
| captions_list = self.journeydb_caption[sample['__key__']] | |
| if len(captions_list) == 0: | |
| print(f"No captions available for key {sample['__key__']}") | |
| return sample | |
| sample['txt'] = random.sample(captions_list, 1)[0] | |
| return sample | |
| else: | |
| print(f"none exist sample: {sample}") | |
| return sample | |
| def train_dataset(self): | |
| return self._train_dataset | |
| def train_dataloader(self): | |
| return self._train_dataloader | |
| # +++++ S2T/T2S Dataset Definition +++++ | |
| class SpeechTextDataset(Dataset): | |
| def __init__(self, dataset : str, subset : str, split : Optional[str] = None): | |
| self.dataset_name = dataset | |
| if self.dataset_name == "gigaspeech": # subset is either "xs" or "xl" | |
| self.hgf_dataset : datasets.Dataset = load_dataset("speechcolab/gigaspeech", subset, split=split) | |
| elif self.dataset_name == "librispeech": | |
| root_path = "/home/work/AIDAS/data/audio/LibriSpeech" | |
| self.dataset_path = root_path + "/" + subset # subset is like "train-clean-100", etc | |
| if split is not None: | |
| warnings.warn(f"Split parameter '{split}' is provided but will not be used for LibriSpeech dataset.") | |
| # librispeech path processing | |
| self.subdirs_path = sorted(list(glob(self.dataset_path + "/*/*"))) | |
| self.subdirs_len = [len(glob(subdir + "/*.flac")) for subdir in self.subdirs_path] | |
| self.subdirs_len_accum = list(itertools.accumulate(self.subdirs_len)) | |
| # handle wrong subset name | |
| if len(self.subdirs_path) == 0: | |
| raise ValueError(f"Invalid subset name '{subset}' for LibriSpeech dataset. Available subsets are: train-clean-100, train-clean-360") | |
| elif self.dataset_name == "commonvoice": | |
| self.commonvoice_path = "/home/work/AIDAS/data/audio/commonvoice/cv-corpus-22.0-2025-06-20/en" | |
| if split is not None: | |
| warnings.warn(f"Split parameter '{split}' is provided but will not be used for commonvoice dataset.") | |
| self.tsv = pd.read_csv(self.commonvoice_path + f"/{subset}.tsv", sep="\t", usecols=["path", "sentence"]) | |
| else: | |
| raise ValueError(f"Unsupported dataset: {dataset}. Supported datasets are: gigaspeech, librispeech, commonvoice.") | |
| def __len__(self): | |
| if self.dataset_name == "gigaspeech": | |
| return len(self.hgf_dataset) | |
| elif self.dataset_name == "librispeech": | |
| return self.subdirs_len_accum[-1] | |
| else: # commonvoice | |
| return len(self.tsv) | |
| def __getitem__(self, idx): | |
| audio_path : str; text : str | |
| if self.dataset_name == "gigaspeech": | |
| sample = self.hgf_dataset[idx] | |
| audio_path = sample["audio"]["path"] | |
| text = sample["text"] | |
| elif self.dataset_name == "librispeech": | |
| # idx overflow | |
| if idx >= self.subdirs_len_accum[-1]: | |
| raise IndexError(f"Index {idx} is out of bounds for the dataset with length {len(self)}.") | |
| # audio_path (flac) | |
| subdir_idx = bisect.bisect_right(self.subdirs_len_accum, idx) | |
| flac_idx = idx - self.subdirs_len_accum[subdir_idx - 1] if subdir_idx > 0 else idx | |
| audio_path = sorted(list(glob(self.subdirs_path[subdir_idx]+"/*.flac")))[flac_idx] | |
| # text | |
| txt_path = glob(self.subdirs_path[subdir_idx]+"/*.txt") | |
| assert len(txt_path) == 1, f"Expected one txt file in {self.subdirs_path[subdir_idx]}, found {len(txt_path)}" | |
| with open(txt_path[0], "r") as f: | |
| txt = f.readlines() | |
| text = " ".join(txt[flac_idx].split(" ")[1:]) # rip off the header, e.g., "103-1240-0007 [TEXT]" | |
| else: # commonvoice | |
| audio_path = self.commonvoice_path + "/clips/" + self.tsv.iloc[idx]["path"] | |
| text = self.tsv.iloc[idx]["sentence"] | |
| return {"audio_path": audio_path, "text": text} | |
| class MixedSpeechTextDataset(Dataset): | |
| def __init__(self, dataset_configs: list): | |
| """ | |
| Initializes and combines multiple speech datasets. | |
| Args: | |
| dataset_configs (list): A list of configuration dictionaries, | |
| where each dict defines a dataset to load. | |
| """ | |
| self.dataset_metadata = [] | |
| self.dataset_lengths = [] | |
| self._sha1 = hashlib.sha1 | |
| # Iterate through the list of dataset configurations from the YAML file | |
| for config in dataset_configs: | |
| name = config['name'] | |
| subset = config.get('subset') | |
| split = config.get('split') | |
| use_tokens = bool(config.get("use_precomputed_tokens", False)) | |
| token_root = config.get("precomputed_tokens_root") | |
| token_root_path = Path(token_root).expanduser() if token_root else None | |
| print(f"Initializing dataset: {name} (Subset: {subset}, Split: {split})") | |
| # --- Gigaspeech --- | |
| if name == "gigaspeech": | |
| hgf_dataset = datasets.load_dataset("speechcolab/gigaspeech", subset, split=split) | |
| self.dataset_metadata.append({ | |
| "name": name, | |
| "data": hgf_dataset, | |
| "use_precomputed_tokens": use_tokens and token_root_path is not None, | |
| "precomputed_tokens_root": token_root_path, | |
| }) | |
| self.dataset_lengths.append(len(hgf_dataset)) | |
| # --- LibriSpeech --- | |
| elif name == "librispeech": | |
| root_path = "/home/work/AIDAS/data/audio/LibriSpeech" | |
| dataset_path = f"{root_path}/{subset}" | |
| if split is not None: | |
| warnings.warn(f"Split parameter '{split}' is provided but will not be used for LibriSpeech.") | |
| subdirs_path = sorted(glob(f"{dataset_path}/*/*")) | |
| if not subdirs_path: | |
| raise ValueError(f"Invalid subset for LibriSpeech or path not found: {dataset_path}") | |
| subdirs_len = [len(glob(f"{subdir}/*.flac")) for subdir in subdirs_path] | |
| subdirs_len_accum = list(itertools.accumulate(subdirs_len)) | |
| metadata = { | |
| "name": name, | |
| "subdirs_path": subdirs_path, | |
| "subdirs_len_accum": subdirs_len_accum, | |
| "use_precomputed_tokens": use_tokens and token_root_path is not None, | |
| "precomputed_tokens_root": token_root_path, | |
| } | |
| self.dataset_metadata.append(metadata) | |
| self.dataset_lengths.append(subdirs_len_accum[-1]) | |
| # --- Common Voice --- | |
| elif name == "commonvoice": | |
| commonvoice_path = "/home/work/AIDAS/data/audio/commonvoice/cv-corpus-22.0-2025-06-20/en" | |
| if split is not None: | |
| warnings.warn(f"Split parameter '{split}' is provided but will not be used for Common Voice.") | |
| tsv_path = f"{commonvoice_path}/{subset}.tsv" | |
| tsv = pd.read_csv(tsv_path, sep="\t", usecols=["path", "sentence"]) | |
| metadata = { | |
| "name": name, | |
| "data_root": f"{commonvoice_path}/clips/", | |
| "tsv": tsv, | |
| "use_precomputed_tokens": use_tokens and token_root_path is not None, | |
| "precomputed_tokens_root": token_root_path, | |
| } | |
| self.dataset_metadata.append(metadata) | |
| self.dataset_lengths.append(len(tsv)) | |
| else: | |
| raise ValueError(f"Unsupported dataset: {name}.") | |
| # Calculate cumulative lengths to map a global index to a specific dataset | |
| self.cumulative_lengths = list(itertools.accumulate(self.dataset_lengths)) | |
| # print(f"✅ All datasets loaded for the SPEECH!. Total length: {self.__len__()} samples.") | |
| def __len__(self): | |
| """Returns the total number of samples across all datasets.""" | |
| return self.cumulative_lengths[-1] if self.cumulative_lengths else 0 | |
| def __getitem__(self, idx): | |
| """ | |
| Fetches a sample from the combined dataset. | |
| It first determines which dataset the global index `idx` belongs to, | |
| calculates the local index within that dataset, and then retrieves the item. | |
| """ | |
| if idx >= self.__len__(): | |
| raise IndexError(f"Index {idx} is out of bounds for the combined dataset with length {self.__len__()}.") | |
| # Find which dataset the index belongs to | |
| dataset_idx = bisect.bisect_right(self.cumulative_lengths, idx) | |
| # Calculate the local index within that dataset | |
| local_idx = idx - self.cumulative_lengths[dataset_idx - 1] if dataset_idx > 0 else idx | |
| metadata = self.dataset_metadata[dataset_idx] | |
| dataset_name = metadata["name"] | |
| dataset_length = self.dataset_lengths[dataset_idx] | |
| audio_path: str | |
| text: str | |
| audio_tokens: Optional[torch.Tensor] | |
| max_retry = 5 | |
| retry = 0 | |
| while retry < max_retry: | |
| try: | |
| audio_tokens = None | |
| if dataset_name == "gigaspeech": | |
| sample = metadata["data"][local_idx] | |
| audio_path = sample["audio"]["path"] | |
| text = sample["text"] | |
| # Preprocess special words to punctuation | |
| text = ( | |
| text.replace(" <COMMA>", ",") | |
| .replace(" <PERIOD>", ".") | |
| .replace(" <QUESTIONMARK>", "?") | |
| .replace(" <EXCLAMATIONMARK>", "!") | |
| ) | |
| elif dataset_name == "librispeech": | |
| # Find the specific subdirectory and file using the local index | |
| subdir_idx = bisect.bisect_right(metadata["subdirs_len_accum"], local_idx) | |
| flac_idx = local_idx - metadata["subdirs_len_accum"][subdir_idx - 1] if subdir_idx > 0 else local_idx | |
| subdir_path = metadata["subdirs_path"][subdir_idx] | |
| audio_path = sorted(glob(f"{subdir_path}/*.flac"))[flac_idx] | |
| # Read the corresponding transcript | |
| txt_path = glob(f"{subdir_path}/*.txt")[0] | |
| with open(txt_path, "r") as f: | |
| line = f.readlines()[flac_idx] | |
| text = " ".join(line.strip().split(" ")[1:]) | |
| else: # commonvoice | |
| row = metadata["tsv"].iloc[local_idx] | |
| audio_path = metadata["data_root"] + row["path"] | |
| text = row["sentence"] | |
| # Preprocess lower case to upper case | |
| text = text.upper() | |
| audio_tokens = self._maybe_load_precomputed_tokens(audio_path, metadata) | |
| return { | |
| "audio_path": audio_path, | |
| "text": text, | |
| "audio_tokens": audio_tokens, | |
| } | |
| except Exception as exc: | |
| print(f"[MixedSpeechTextDataset] Failed to load sample from '{dataset_name}' at local index {local_idx}: {exc!r}") | |
| retry += 1 | |
| if retry >= max_retry: | |
| break | |
| local_idx = random.randint(0, dataset_length - 1) | |
| continue | |
| raise RuntimeError(f"Unable to fetch a valid sample from dataset '{dataset_name}' after {max_retry} retries.") | |
| def _maybe_load_precomputed_tokens(self, audio_path: str, metadata: dict) -> Optional[torch.Tensor]: | |
| if not metadata.get("use_precomputed_tokens"): | |
| return None | |
| root: Optional[Path] = metadata.get("precomputed_tokens_root") | |
| if root is None: | |
| return None | |
| if not root.exists(): | |
| logger.warning("Precomputed token root missing: %s", root) | |
| return None | |
| key = os.path.abspath(audio_path) | |
| digest = self._sha1(key.encode("utf-8")).hexdigest() | |
| token_path = root / digest[:2] / digest[2:4] / f"{digest}.pt" | |
| if not token_path.exists(): | |
| logger.warning("Precomputed audio tokens not found: %s", token_path) | |
| return None | |
| try: | |
| tokens = torch.load(token_path, map_location="cpu") | |
| if isinstance(tokens, torch.Tensor): | |
| return tokens.clone() | |
| if isinstance(tokens, (list, tuple)): | |
| return torch.tensor(tokens, dtype=torch.long) | |
| logger.warning("Unexpected token format in %s (type=%s)", token_path, type(tokens)) | |
| except Exception as exc: | |
| logger.warning("Failed to load precomputed tokens %s: %s", token_path, exc) | |
| return None | |
| class Speech2SpeechDataset(Dataset): | |
| """ | |
| Mixed dataset of emova-sft and InstructS2S-200K. | |
| Return value of __getitem__ indicates a pair of (user, assistant) message (single-turn). | |
| Critically, the return type for emova_sft and instructs2s are different: | |
| emova_sft: tuple[list[int], list[int], Any] | |
| instructs2s: tuple[str, str, Optional[Any]] | |
| So in the main training code, we need to handle both types of return values. | |
| Notes: | |
| - Use `s2s_collate_fn` within DataLoader. | |
| - For emova_sft, Tensor cannot be returned because padding is done later in the main training code. | |
| For reference: | |
| Total samples: 496514 | |
| emova-sft (speech-text): 73.6k | |
| emova-sft (speech-image): 71.5k | |
| InstructS2S-200K: 422856 | |
| """ | |
| def __init__(self, dataset_configs: list): | |
| self.dataset_configs = dataset_configs # currently this arg is not used | |
| ## emova-sft (text + image splits) | |
| emova_sft_text = load_dataset("Emova-ollm/emova-sft-4m", "emova-speech-text-en", split='train') | |
| emova_sft_image = load_dataset("Emova-ollm/emova-sft-4m", "emova-speech-image-en", split='train') | |
| def _maybe_cast_image_columns(ds): | |
| for column in ("image", "images"): | |
| if column in ds.column_names: | |
| try: | |
| ds = ds.cast_column(column, datasets.Image(decode=False)) | |
| except Exception: | |
| # Column may already be raw bytes/str; ignore and keep as-is | |
| pass | |
| return ds | |
| emova_sft_text = _maybe_cast_image_columns(emova_sft_text) | |
| emova_sft_image = _maybe_cast_image_columns(emova_sft_image) | |
| def emova_sft_preprocess_batch(batch): | |
| # Extract conversations from the batch | |
| conversations_list = batch['conversations'] | |
| usr_ids_list = [] | |
| asst_ids_list = [] | |
| images_list = [] | |
| def normalize_emova_ids(ids: str) -> list[int]: | |
| unit_numbers = ids.replace('<|speech_', '').replace('|>', ' ').strip() | |
| unit_ids = [int(unit) for unit in unit_numbers.split(" ")] | |
| return unit_ids | |
| # Process each conversation in the batch | |
| for conversations in conversations_list: | |
| usr_raw = conversations[0]['value'] | |
| asst_raw = conversations[1]['value'] | |
| usr_ids: str = usr_raw.split("\n\nuser question speech:")[-1].strip() | |
| asst_ids: str = json.loads(asst_raw)['assistant response speech'].strip() | |
| usr_ids_list.append(normalize_emova_ids(usr_ids)) | |
| asst_ids_list.append(normalize_emova_ids(asst_ids)) | |
| raw_images = ( | |
| batch.get('image') | |
| or batch.get('images') | |
| or batch.get('image_base64') | |
| or [None] * len(conversations_list) | |
| ) | |
| if not isinstance(raw_images, (list, tuple)): | |
| raw_images = [raw_images] * len(conversations_list) | |
| else: | |
| raw_images = list(raw_images) | |
| if len(raw_images) != len(conversations_list): | |
| # Align lengths by padding/truncating with None without decoding payloads | |
| adjusted = raw_images[:len(conversations_list)] | |
| if len(adjusted) < len(conversations_list): | |
| adjusted.extend([None] * (len(conversations_list) - len(adjusted))) | |
| raw_images = adjusted | |
| images_list.extend(raw_images) | |
| # Return a dictionary with lists of processed data | |
| return { | |
| "usr_ids": usr_ids_list, | |
| "asst_ids": asst_ids_list, | |
| "image": images_list, | |
| } | |
| self.emova_sft_text = emova_sft_text.map( | |
| emova_sft_preprocess_batch, | |
| batched=True, | |
| batch_size=1024, | |
| remove_columns=['conversations'], | |
| desc="Processing emova-sft (text)", | |
| num_proc=16 | |
| ) | |
| self.emova_sft_image = emova_sft_image.map( | |
| emova_sft_preprocess_batch, | |
| batched=True, | |
| batch_size=1024, | |
| remove_columns=['conversations'], | |
| desc="Processing emova-sft (image)", | |
| num_proc=16 | |
| ) | |
| self._emova_text_len = len(self.emova_sft_text) | |
| self._emova_image_len = len(self.emova_sft_image) | |
| self._emova_total_len = self._emova_text_len + self._emova_image_len | |
| ## InstructS2S-200K (with caching) | |
| instructs2s_rootdir = "/home/work/AIDAS/data/InstructS2S-200K/en/wav" | |
| self.instructs2s_wav_pair_paths = [] | |
| pairs_txt = os.path.join(instructs2s_rootdir, "pairs.txt") | |
| if os.path.isfile(pairs_txt): | |
| with open(pairs_txt, "r") as f: | |
| for line in tqdm(f, desc="Loading InstructS2S-200K paths from cached file"): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| parts = line.split() | |
| if len(parts) >= 2: | |
| self.instructs2s_wav_pair_paths.append((parts[0], parts[1])) | |
| else: | |
| instructs2s_wav_dirs = [p for p in glob(os.path.join(instructs2s_rootdir, "*")) if os.path.isdir(p)] | |
| # Walk each directory and collect (user, assistant) wav pairs | |
| for dir_path in tqdm(instructs2s_wav_dirs, desc="Processing instructs2s-200k"): | |
| dir_name = os.path.basename(dir_path) | |
| k = 1 | |
| while True: | |
| user_wav = os.path.join(dir_path, f"{dir_name}-{k}-user.wav") | |
| assistant_wav = os.path.join(dir_path, f"{dir_name}-{k}-assistant.wav") | |
| if os.path.isfile(user_wav) and os.path.isfile(assistant_wav): | |
| self.instructs2s_wav_pair_paths.append((user_wav, assistant_wav)) | |
| k += 1 | |
| continue | |
| break | |
| with open(pairs_txt, "w") as f: | |
| for u, a in self.instructs2s_wav_pair_paths: | |
| f.write(f"{u} {a}\n") | |
| ## Mixed dataset (ordered) | |
| self.mixed_dataset = [self.emova_sft_text, self.emova_sft_image, self.instructs2s_wav_pair_paths] | |
| def __len__(self): | |
| return sum([len(dataset) for dataset in self.mixed_dataset]) | |
| def __getitem__(self, idx) -> Union[tuple[list[int], list[int], Any], tuple[str, str, Optional[Any]]]: | |
| if idx < self._emova_text_len: # emova_sft text split | |
| sample = self.emova_sft_text[idx] | |
| elif idx < self._emova_total_len: # emova_sft image split | |
| sample = self.emova_sft_image[idx - self._emova_text_len] | |
| else: # instructs2s | |
| local_idx = idx - self._emova_total_len | |
| usr_wav, asst_wav = self.instructs2s_wav_pair_paths[local_idx] | |
| return usr_wav, asst_wav, None # tuple[str, str, Optional[image]]; wav file paths | |
| usr_ids = sample['usr_ids'] | |
| asst_ids = sample['asst_ids'] | |
| image = sample.get('image') | |
| return usr_ids, asst_ids, image # tuple[list[int], list[int], image] | |
| def s2s_collate_fn(batch): | |
| """ | |
| Collate function for Speech2SpeechDataset. | |
| """ | |
| emova_data = [] | |
| instructs2s_data = [] | |
| for item in batch: | |
| if isinstance(item[0], list): # emova_sft: tuple[list[int], list[int]] | |
| emova_data.append(item) | |
| else: # instructs2s: tuple[str, str] | |
| instructs2s_data.append(item) | |
| return { | |
| 'emova_sft': emova_data, | |
| 'instructs2s': instructs2s_data, | |
| } | |
| class VideoCaptionDataset(Dataset): | |
| def __init__( | |
| self, | |
| transform, | |
| tokenizer, | |
| max_seq_length: int, | |
| resolution: int = 256, | |
| panda70m_path = "/home/work/AIDAS/data/video/panda70m/panda70m_training_2m", | |
| openvid1m_path = "/home/work/AIDAS/data/video/openvid1m/video", | |
| webvid10m_path = "/home/work/AIDAS/data/video/webvid10m", | |
| llavavid_path = "/home/work/AIDAS/data/video/LLaVA-Video-178K", | |
| dataset_name = "openvid1m", | |
| llavavid_local_files_only: bool = False, | |
| llavavid_skip_configs: Optional[Sequence[str]] = None, | |
| llavavid_skip_video_patterns: Optional[Sequence[str]] = None, | |
| sample_method='uniform', | |
| num_frames: int = 8, | |
| vq_model=None, | |
| ): | |
| available_datasets = ['panda70m', 'openvid1m', 'webvid10m', 'llavavid'] | |
| if dataset_name not in available_datasets: | |
| raise ValueError(f"Invalid dataset name: {dataset_name}. Available datasets: {available_datasets}") | |
| self.max_seq_length = max_seq_length | |
| self.transform = transform | |
| self.vq_model = vq_model | |
| self.tokenizer = tokenizer | |
| self.resolution = resolution | |
| self.sample_method = sample_method | |
| self.dataset_name = dataset_name | |
| self.num_frames = num_frames | |
| self.llavavid_local_files_only = llavavid_local_files_only | |
| self.llavavid_skip_configs = set(llavavid_skip_configs or []) | |
| self.llavavid_skip_video_patterns = tuple(llavavid_skip_video_patterns or []) | |
| self.caption_prompt = V2T_INSTRUCTION | |
| self.caption_prompt = ['<|start_header_id|>user<|end_header_id|>\n' + prompt + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' for prompt in self.caption_prompt] | |
| self.webvid10m_path = webvid10m_path | |
| if dataset_name == 'panda70m': | |
| self.vid_data = self._collect_panda70m(panda70m_path) | |
| self.dataset_root = panda70m_path | |
| elif dataset_name == 'webvid10m': | |
| self.vid_data = self._collect_webvid10m(webvid10m_path) | |
| self.dataset_root = webvid10m_path | |
| elif dataset_name == 'openvid1m': | |
| self.vid_data = self._collect_openvid1m(openvid1m_path) | |
| self.dataset_root = openvid1m_path | |
| elif dataset_name == 'llavavid': | |
| self.vid_data = self._collect_llavavid(llavavid_path) | |
| self.dataset_root = Path(llavavid_path) | |
| self.llavavid_video_root = Path(llavavid_path) | |
| else: | |
| raise ValueError(f"Invalid dataset name: {dataset_name}. Available datasets: panda70m, webvid10m") | |
| def _get_caption_prompt(self): | |
| """ | |
| Get a random caption prompt from the list of caption prompts. | |
| """ | |
| return np.random.choice(self.caption_prompt) | |
| def _tokenize(self, text): | |
| if self.tokenizer is not None: | |
| input_ids = self.tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=2 * self.max_seq_length, | |
| padding=False, | |
| return_tensors="pt" | |
| )[0] | |
| if len(input_ids) > self.max_seq_length: | |
| return None | |
| else: | |
| return input_ids | |
| else: | |
| raise ValueError("Tokenizer is not provided.") | |
| def _collect_webvid10m(self, root_path): | |
| print("Loading videos from WebVid10m dataset...") | |
| csv_path = osp.join(root_path, "webvid-10M-train.csv") | |
| webvid_pd = pd.read_csv(csv_path) | |
| self.dataset_length = len(webvid_pd) | |
| print(f"{len(webvid_pd)} videos has been loaded.") | |
| return webvid_pd | |
| def _collect_panda70m(self, root_path): | |
| video_caption_pairs = [] | |
| subdirs = sorted(os.listdir(root_path)) | |
| print("Loading videos from panda70m dataset...") | |
| for subdir in subdirs: | |
| full_subdir = os.path.join(root_path, subdir) | |
| if not os.path.isdir(full_subdir): | |
| continue | |
| video_paths = glob(os.path.join(full_subdir, "*.mp4")) | |
| for video_path in video_paths: | |
| caption_path = video_path.replace(".mp4", ".txt") | |
| if os.path.exists(caption_path): | |
| with open(caption_path, 'r') as f: | |
| caption = f.read().strip() | |
| prompt = self._get_caption_prompt() | |
| video_caption_pairs.append({ | |
| "video": video_path, | |
| "caption": prompt + caption | |
| }) | |
| print(f"{len(video_caption_pairs)} videos has been loaded.") | |
| return video_caption_pairs | |
| def _collect_openvid1m(self, root_path): | |
| csv_path = osp.join(root_path, "OpenVid-1M.csv") | |
| openvid_pd = pd.read_csv(csv_path) | |
| self.dataset_length = len(openvid_pd) | |
| print(f"{len(openvid_pd)} videos has been loaded.") | |
| return openvid_pd | |
| def _collect_llavavid( | |
| self, | |
| root_path="lmms-lab/LLaVA-Video-178K", | |
| cache_dir="/home/work/AIDAS/huggingface/datasets" | |
| ): | |
| """ | |
| Collect all available (and locally cached) subsets of the LLaVA-Video-178K dataset. | |
| Handles both on-disk exports (each config stored as subfolders of splits) and remote configs. | |
| Returns a single flattened HuggingFace Dataset that concatenates every successfully loaded config. | |
| """ | |
| DATASET_NAME = root_path | |
| local_root = Path(DATASET_NAME) | |
| configs: list[str] | |
| using_local_dirs = local_root.exists() | |
| configs = [] | |
| if using_local_dirs: | |
| for p in sorted(local_root.iterdir()): | |
| if not p.is_dir(): | |
| continue | |
| if p.name.startswith("."): | |
| continue | |
| split_exists = any((p / split_name).exists() for split_name in ("open_ended", "caption", "multi_choice")) | |
| if not split_exists: | |
| continue | |
| configs.append(p.name) | |
| if not configs: | |
| using_local_dirs = False | |
| if not configs: | |
| try: | |
| configs = get_dataset_config_names(DATASET_NAME) | |
| using_local_dirs = False | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to fetch configs for {DATASET_NAME}: {e}") | |
| skip_configs = getattr(self, "llavavid_skip_configs", set()) | |
| if skip_configs: | |
| existing = [cfg for cfg in configs if cfg in skip_configs] | |
| if existing: | |
| print(f"LLaVA-Vid: skipping configs {existing}") | |
| configs = [cfg for cfg in configs if cfg not in skip_configs] | |
| if not configs: | |
| raise RuntimeError("All LLaVA-Video configs were skipped; nothing left to load.") | |
| def _add_config_column(dataset: HFDataset, cfg_name: str, row_count: int): | |
| """Attach the originating config name so downstream can locate videos.""" | |
| if dataset is None or not cfg_name: | |
| return dataset | |
| if "llavavid_config" in dataset.column_names: | |
| return dataset | |
| return dataset.add_column("llavavid_config", [cfg_name] * row_count) | |
| def _flatten_dataset(ds_obj, label: str, cfg_name: str): | |
| """Convert DatasetDicts into a single Dataset and report the row count.""" | |
| if ds_obj is None: | |
| return None, 0 | |
| if isinstance(ds_obj, DatasetDict): | |
| splits = [split for split in ds_obj.values()] | |
| if not splits: | |
| print(f"Skipping {label}: dataset dict has no splits.") | |
| return None, 0 | |
| total_rows = sum(len(split) for split in splits) | |
| if len(splits) == 1: | |
| return splits[0], total_rows | |
| try: | |
| merged = concatenate_datasets(splits) | |
| except Exception as merge_err: | |
| print(f"Skipping {label}: failed to concatenate splits: {merge_err}") | |
| return None, 0 | |
| dataset = merged | |
| else: | |
| dataset = ds_obj | |
| try: | |
| total_rows = len(dataset) | |
| except Exception as len_err: | |
| print(f"Skipping {label}: unable to compute dataset length ({len_err}).") | |
| return None, 0 | |
| dataset = _add_config_column(dataset, cfg_name, total_rows) | |
| return dataset, total_rows | |
| def _load_local_config(cfg_name: str): | |
| """Attempt to read a single config from disk, handling split sub-directories if needed.""" | |
| cfg_root = local_root / cfg_name | |
| if not cfg_root.exists(): | |
| return None, 0 | |
| # First try loading the directory directly (Dataset or DatasetDict exports). | |
| try: | |
| ds_direct = load_from_disk(str(cfg_root)) | |
| except Exception as direct_err: | |
| print(f"Failed to load config {cfg_name} via load_from_disk: {direct_err}.") | |
| else: | |
| ds_flat, ds_count = _flatten_dataset(ds_direct, cfg_name, cfg_name) | |
| if ds_flat is not None and ds_count > 0: | |
| return ds_flat, ds_count | |
| # Fallback: iterate over split sub-directories (caption/open_ended/multi_choice, etc.). | |
| split_dirs = [p for p in sorted(cfg_root.iterdir()) if p.is_dir()] | |
| if not split_dirs: | |
| return None, 0 | |
| split_datasets = [] | |
| for split_dir in split_dirs: | |
| try: | |
| split_ds = load_from_disk(str(split_dir)) | |
| except Exception as split_err: | |
| print(f"Skipping {cfg_name}/{split_dir.name}: {split_err}") | |
| continue | |
| split_datasets.append(split_ds) | |
| if not split_datasets: | |
| return None, 0 | |
| split_total = sum(len(split_ds) for split_ds in split_datasets) | |
| if len(split_datasets) == 1: | |
| dataset = split_datasets[0] | |
| else: | |
| try: | |
| dataset = concatenate_datasets(split_datasets) | |
| except Exception as merge_err: | |
| print(f"Skipping {cfg_name}: failed to concatenate split datasets: {merge_err}") | |
| return None, 0 | |
| dataset = _add_config_column(dataset, cfg_name, split_total) | |
| return dataset, split_total | |
| datasets_loaded = [] | |
| total_count = 0 | |
| for cfg in configs: | |
| ds = None | |
| cfg_count = 0 | |
| if using_local_dirs: | |
| ds, cfg_count = _load_local_config(cfg) | |
| if ds is None or cfg_count == 0: | |
| download_cfg = None | |
| if self.llavavid_local_files_only: | |
| download_cfg = DownloadConfig(local_files_only=True) | |
| try: | |
| remote_ds = load_dataset( | |
| DATASET_NAME, | |
| name=cfg, | |
| cache_dir=cache_dir, | |
| verification_mode="no_checks", | |
| download_config=download_cfg, | |
| ) | |
| except Exception as remote_err: | |
| print(f"Skipping {cfg}: {remote_err}") | |
| continue | |
| ds, cfg_count = _flatten_dataset(remote_ds, cfg, cfg) | |
| if ds is None or cfg_count == 0: | |
| print(f"Skipping {cfg}: dataset empty after flattening.") | |
| continue | |
| datasets_loaded.append(ds) | |
| total_count += cfg_count | |
| if not datasets_loaded: | |
| raise RuntimeError("No valid configs could be loaded!") | |
| if len(datasets_loaded) == 1: | |
| global_dataset = datasets_loaded[0] | |
| else: | |
| try: | |
| global_dataset = concatenate_datasets(datasets_loaded) | |
| except Exception as merge_err: | |
| print(f"Failed to concatenate configs in one step: {merge_err}. Trying pairwise concatenation.") | |
| try: | |
| combined = datasets_loaded[0] | |
| for ds_next in datasets_loaded[1:]: | |
| combined = concatenate_datasets([combined, ds_next]) | |
| global_dataset = combined | |
| except Exception as pair_err: | |
| raise RuntimeError(f"Unable to merge LLaVA-Video configs: {pair_err}") from pair_err | |
| # Filter out samples whose video path matches known-bad patterns (e.g., missing shareVideoGPTV frames) | |
| skip_patterns = getattr(self, "llavavid_skip_video_patterns", tuple()) | |
| if skip_patterns: | |
| def _matches_skip(entry: dict[str, Any]) -> bool: | |
| video_entry = entry.get("video") | |
| if not isinstance(video_entry, str): | |
| return False | |
| return any(pattern in video_entry for pattern in skip_patterns) | |
| def _filter_dataset(ds_obj): | |
| if isinstance(ds_obj, list): | |
| filtered_list = [] | |
| removed_total = 0 | |
| for item in ds_obj: | |
| filtered_item, removed_item = _filter_dataset(item) | |
| removed_total += removed_item | |
| if filtered_item is None: | |
| continue | |
| filtered_list.append(filtered_item) | |
| return filtered_list, removed_total | |
| elif isinstance(ds_obj, HFDataset): | |
| before = len(ds_obj) | |
| filtered = ds_obj.filter(lambda ex: not _matches_skip(ex)) | |
| removed = before - len(filtered) | |
| return filtered, removed | |
| elif isinstance(ds_obj, dict): | |
| return (None, 1) if _matches_skip(ds_obj) else (ds_obj, 0) | |
| else: | |
| return ds_obj, 0 | |
| global_dataset, removed_samples = _filter_dataset(global_dataset) | |
| if removed_samples > 0: | |
| total_count -= removed_samples | |
| print(f"LLaVA-Vid: skipped {removed_samples} samples matching patterns {skip_patterns}.") | |
| print(f"LLaVA-Vid: {len(datasets_loaded)} configs loaded.") | |
| print(f"LLaVA-Vid: {total_count:,} total samples loaded.") | |
| self.dataset_length = total_count | |
| return global_dataset | |
| def __len__(self): | |
| return len(self.vid_data) | |
| def __getitem__(self, idx): | |
| max_try_count = 50 | |
| for try_count in range(max_try_count): | |
| try: | |
| data = self._sample_data(idx) | |
| except Exception as exc: | |
| logger.warning( | |
| "VideoCaptionDataset failed to fetch index %s on attempt %s/%s: %s", | |
| idx, | |
| try_count + 1, | |
| max_try_count, | |
| exc, | |
| ) | |
| idx = random.randint(0, self.dataset_length - 1) | |
| continue | |
| if data is not None: | |
| return { | |
| "video": data["video"], | |
| "caption": data["caption"], | |
| } | |
| idx = random.randint(0, self.dataset_length - 1) | |
| logger.warning( | |
| "VideoCaptionDataset exhausted %s attempts without a valid sample; returning None.", | |
| max_try_count, | |
| ) | |
| return None | |
| def _sample_data_webvid10m(self): | |
| store_path = osp.join(self.webvid10m_path, "video_store") | |
| row = self.video_caption_pairs['webvid10m'].sample(1).iloc[0] | |
| video_id = str(row["videoid"]) | |
| url = row["contentUrl"] | |
| caption = row["name"] | |
| video_path = osp.join(store_path, f"{video_id}.mp4") | |
| if not osp.exists(video_path): # not downloaded yet | |
| download_video_url(url, video_path) | |
| # print(video_id) | |
| # print(_whoami_str()) | |
| return video_path, caption | |
| def _sample_data(self, idx): | |
| if self.dataset_name == 'webvid10m': | |
| # currently randomly sample from the dataset | |
| video_path, caption = self._sample_data_webvid10m() | |
| elif self.dataset_name == 'panda70m': | |
| raise NotImplementedError("Panda70m is not implemented yet.") | |
| # video_path, caption = self._sample_data_panda70m() | |
| elif self.dataset_name == 'openvid1m': | |
| data_row = self.vid_data.iloc[idx] | |
| video_path = osp.join(self.dataset_root, "video", data_row["video"]) | |
| caption = data_row["caption"] | |
| elif self.dataset_name == 'llavavid': | |
| data_row = self.vid_data[idx] | |
| video_entry = data_row['video'] | |
| cfg_name = data_row.get('llavavid_config') if isinstance(data_row, dict) else None | |
| caption = data_row['conversations'] # this is a list of turns in llavavid | |
| resolved_video_path = None | |
| if isinstance(video_entry, str): | |
| candidate_paths = [] | |
| video_path_obj = Path(video_entry) | |
| if video_path_obj.is_absolute() and video_path_obj.exists(): | |
| resolved_video_path = video_path_obj | |
| else: | |
| if hasattr(self, "llavavid_video_root"): | |
| base_root = Path(self.llavavid_video_root) | |
| if cfg_name: | |
| candidate_paths.append(base_root / cfg_name / video_entry) | |
| candidate_paths.append(base_root / video_entry) | |
| # Also allow treating the stored value as relative to current dir. | |
| candidate_paths.append(Path(video_entry)) | |
| for candidate in candidate_paths: | |
| if candidate.exists(): | |
| resolved_video_path = candidate | |
| break | |
| if resolved_video_path is None: | |
| logger.warning( | |
| "LLaVA-Video sample missing video file: %s (config=%s)", | |
| video_entry, | |
| cfg_name, | |
| ) | |
| return None | |
| if resolved_video_path.suffix.lower() == ".mkv": | |
| logger.warning( | |
| "LLaVA-Video skipping MKV file: %s (config=%s)", | |
| resolved_video_path, | |
| cfg_name, | |
| ) | |
| return None | |
| video_path = str(resolved_video_path) | |
| else: | |
| raise ValueError(f"Invalid dataset name: {self.dataset_name}. Available datasets: panda70m, webvid10m, openvid1m") | |
| try: | |
| frames = load_video_mp4( | |
| video_path=video_path, | |
| sample_method=self.sample_method, | |
| num_frames=self.num_frames, | |
| resolution=self.resolution, | |
| transform=self.transform, | |
| strict=False, | |
| ) | |
| except Exception as exc: | |
| logger.warning( | |
| "LLaVA-Video sample failed to load (%s): %s", | |
| video_path, | |
| exc, | |
| ) | |
| return None | |
| if frames is None: | |
| logger.warning( | |
| "LLaVA-Video sample timed out while reading frames (%s); skipping sample.", | |
| video_path, | |
| ) | |
| return None | |
| return { | |
| "video": frames, # torch tensor (T, C, H, W) | |
| "caption": caption # input_ids (seq_len); str | |
| } | |
| def download_video_url(url: str, save_path, timeout=10, max_retries=3) -> bool: | |
| for attempt in range(1, max_retries + 1): | |
| try: | |
| with requests.get(url, stream=True, timeout=timeout) as r: | |
| r.raise_for_status() | |
| with open(save_path, 'wb') as f: | |
| for chunk in r.iter_content(chunk_size=8192): | |
| if chunk: | |
| f.write(chunk) | |
| return True # Success | |
| except Exception as e: | |
| print(f"[Attempt {attempt}/{max_retries}] Download failed: {e}") | |
| if attempt < max_retries: | |
| sleep_time = 2 ** (attempt - 1) # exponential backoff: 1,2,4,8,... | |
| time.sleep(sleep_time) | |
| else: | |
| return False # all attempts failed | |
| return False | |
| def load_video_mp4( | |
| video_path, | |
| sample_method: str = 'uniform', | |
| num_frames: int = 8, | |
| resolution: int = 256, | |
| transform=None, | |
| *, | |
| per_frame_timeout: float = 1.5, | |
| read_retry_interval: float = 0.05, | |
| strict: bool = True, | |
| ): | |
| """ | |
| Load video frames and return them as a list of PIL images. | |
| Args: | |
| video_path: Path to the video file. | |
| sample_method: Sampling method, 'uniform' or 'random'. | |
| num_frames: Number of frames to sample from the video. | |
| per_frame_timeout: Max seconds to block while seeking/reading a frame. | |
| read_retry_interval: Delay between read retries while waiting for a frame. | |
| strict: When False, return None on timeout/seek failure instead of raising. | |
| Returns: | |
| List[Image.Image] | None (if strict=False and a timeout/seek failure occurs) | |
| """ | |
| with open(os.devnull, "w") as devnull, contextlib.redirect_stderr(devnull): | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| raise IOError(f"Could not open video file {video_path}") | |
| if per_frame_timeout <= 0: | |
| per_frame_timeout = 0.1 | |
| if read_retry_interval <= 0: | |
| read_retry_interval = 0.01 | |
| def _read_frame_with_timeout(frame_index: Optional[int] = None): | |
| deadline = time.monotonic() + per_frame_timeout | |
| attempts = 0 | |
| while True: | |
| if frame_index is not None: | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, int(frame_index)) | |
| ret, frame = cap.read() | |
| if ret and frame is not None: | |
| return frame | |
| attempts += 1 | |
| if time.monotonic() >= deadline: | |
| return None | |
| time.sleep(min(read_retry_interval, max(deadline - time.monotonic(), 0.0))) | |
| try: | |
| frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| except Exception: | |
| frame_count = -1 | |
| if frame_count is None or frame_count <= 0: | |
| # Fallback: attempt to read sequentially but stop early on failure | |
| frames = [] | |
| try: | |
| while len(frames) < num_frames: | |
| frame = _read_frame_with_timeout() | |
| if frame is None: | |
| break | |
| frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))) | |
| finally: | |
| cap.release() | |
| if len(frames) < num_frames: | |
| msg = f"Video {video_path} has insufficient frames ({len(frames)})." | |
| if strict: | |
| raise ValueError(msg) | |
| logger.warning("%s Skipping sample.", msg) | |
| return None | |
| selected = frames | |
| else: | |
| if frame_count < num_frames: | |
| cap.release() | |
| msg = f"Video {video_path} has insufficient frames ({frame_count})." | |
| if strict: | |
| raise ValueError(msg) | |
| logger.warning("%s Skipping sample.", msg) | |
| return None | |
| if sample_method == 'uniform': | |
| indices = np.linspace(0, frame_count - 1, num_frames).astype(int) | |
| elif sample_method == 'random': | |
| indices = np.sort(np.random.choice(frame_count, num_frames, replace=False)) | |
| else: | |
| cap.release() | |
| raise ValueError(f"Sampling method {sample_method} not supported.") | |
| selected = [] | |
| try: | |
| for idx in indices: | |
| frame = _read_frame_with_timeout(idx) | |
| if frame is None: | |
| msg = ( | |
| f"Timed out ({per_frame_timeout:.2f}s) seeking frame {idx} in {video_path}" | |
| ) | |
| if strict: | |
| raise TimeoutError(msg) | |
| logger.warning("%s. Skipping sample.", msg) | |
| return None | |
| selected.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))) | |
| finally: | |
| cap.release() | |
| sampled_frames = [] | |
| for frame in selected: | |
| if transform: | |
| frame = transform(frame, resolution=resolution) | |
| sampled_frames.append(frame) | |
| return sampled_frames | |
| class VideoSpeechDataset(Dataset): | |
| """Loads paired video clips and speech audio paths or pre-tokenized speech.""" | |
| def __init__( | |
| self, | |
| *, | |
| transform=None, | |
| resolution: int = 256, | |
| num_frames: int = 8, | |
| video_root: str = "/home/work/AIDAS/data/video/openvid1m/video/video", | |
| audio_root: str = "/home/work/AIDAS/data/video-speech", | |
| speech_dir_name: str = "openvid-speech-trunc", | |
| index_path: str = "/home/work/AIDAS/data/video-speech/openvid-speech.csv", | |
| sample_method: str = "uniform", | |
| precomputed_tokens_root: Optional[str] = None, | |
| ) -> None: | |
| self.transform = transform | |
| self.resolution = resolution | |
| self.num_frames = num_frames | |
| self.sample_method = sample_method or "uniform" | |
| if self.sample_method not in {"uniform", "random"}: | |
| logger.warning("Unknown sample_method '%s', defaulting to 'uniform'", self.sample_method) | |
| self.sample_method = "uniform" | |
| self.video_root = Path(video_root).expanduser().resolve() | |
| audio_base = Path(audio_root).expanduser() | |
| if speech_dir_name: | |
| audio_base = audio_base / speech_dir_name | |
| self.audio_root = audio_base.resolve() | |
| self.index_path = Path(index_path).expanduser().resolve() | |
| if not self.index_path.exists(): | |
| raise FileNotFoundError(f"VideoSpeechDataset index not found: {self.index_path}") | |
| self.precomputed_tokens_root = ( | |
| Path(precomputed_tokens_root).expanduser().resolve() | |
| if precomputed_tokens_root | |
| else None | |
| ) | |
| if self.precomputed_tokens_root is not None and not self.precomputed_tokens_root.exists(): | |
| logger.warning( | |
| "Precomputed speech token root %s missing; falling back to raw audio paths.", | |
| self.precomputed_tokens_root, | |
| ) | |
| self.precomputed_tokens_root = None | |
| self._samples: list[tuple[Path, Path]] = [] | |
| self._token_cache: Dict[str, torch.Tensor] = {} | |
| self._token_cache_limit = 4096 | |
| self._load_index() | |
| if not self._samples: | |
| raise RuntimeError(f"VideoSpeechDataset found no valid samples in {self.index_path}") | |
| def _load_index(self) -> None: | |
| missing = 0 | |
| with self.index_path.open("r", newline="") as csvfile: | |
| reader = csv.reader(csvfile) | |
| for row in reader: | |
| if not row: | |
| continue | |
| base = row[0].strip() | |
| if not base: | |
| continue | |
| if base.lower().endswith(".wav"): | |
| base = base[:-4] | |
| video_path = self.video_root / f"{base}.mp4" | |
| audio_path = self.audio_root / f"{base}.wav" | |
| if not video_path.is_file() or not audio_path.is_file(): | |
| missing += 1 | |
| continue | |
| self._samples.append((video_path, audio_path)) | |
| if missing: | |
| logger.info( | |
| "VideoSpeechDataset skipped %d entries missing media (index=%s)", | |
| missing, | |
| self.index_path, | |
| ) | |
| def __len__(self) -> int: | |
| return len(self._samples) | |
| def _transform_frame(self, image: Image.Image, resolution: int) -> torch.Tensor: | |
| if self.transform is None: | |
| return utils_image_transform(image, resolution) | |
| try: | |
| return self.transform(image, resolution=resolution) | |
| except TypeError: | |
| return self.transform(image) | |
| def _resolve_token_path(self, audio_path: Path) -> Optional[Path]: | |
| if self.precomputed_tokens_root is None: | |
| return None | |
| digest = hashlib.sha1(os.path.abspath(str(audio_path)).encode("utf-8")).hexdigest() | |
| return self.precomputed_tokens_root / digest[:2] / digest[2:4] / f"{digest}.pt" | |
| def _get_precomputed_tokens(self, audio_path: Path) -> Optional[torch.Tensor]: | |
| cache_key = os.path.abspath(str(audio_path)) | |
| cached = self._token_cache.get(cache_key) | |
| if cached is not None: | |
| return cached.clone() | |
| token_path = self._resolve_token_path(audio_path) | |
| if token_path is None or not token_path.exists(): | |
| return None | |
| try: | |
| tokens = torch.load(token_path, map_location="cpu") | |
| except Exception as exc: | |
| logger.warning("Failed to load precomputed speech tokens %s: %s", token_path, exc) | |
| return None | |
| if not isinstance(tokens, torch.Tensor): | |
| return None | |
| tokens = tokens.to(dtype=torch.long, copy=False) | |
| if len(self._token_cache) < self._token_cache_limit: | |
| self._token_cache[cache_key] = tokens | |
| return tokens.clone() | |
| def _prepare_speech_entry(self, audio_path: Path): | |
| tokens = self._get_precomputed_tokens(audio_path) | |
| if tokens is not None: | |
| return tokens | |
| return str(audio_path) | |
| def __getitem__(self, idx: int) -> Dict[str, Any]: | |
| video_path, audio_path = self._samples[idx] | |
| frames = load_video_mp4( | |
| str(video_path), | |
| sample_method=self.sample_method, | |
| num_frames=self.num_frames, | |
| resolution=self.resolution, | |
| transform=self._transform_frame, | |
| ) | |
| speech_entry = self._prepare_speech_entry(audio_path) | |
| return { | |
| "video": frames, | |
| "speech": speech_entry, | |
| } | |
| class TextImageInterleavedDataset: | |
| """ | |
| HF-backed dataset that yields rows of: | |
| { | |
| "image_paths": [str, ...], # absolute paths (no decoding) | |
| "user_text": str, | |
| "assistant_text": str, | |
| } | |
| """ | |
| def __init__( | |
| self, | |
| *, | |
| configs: Union[str, Sequence[str], None] = None, # default: all configs | |
| split: str = "train", | |
| data_root: str = "/home/work/AIDAS/data/TIGER-Lab/Mantis-Instruct", | |
| max_images: Optional[int] = None, | |
| filter_empty: bool = True, | |
| resolution: int = 256, | |
| # sampling controls | |
| per_config_fraction: float = 1/7, # ← sample 1/7 PER CONFIG | |
| sample_seed: int = 42, | |
| # kept for compatibility, not used in this 1/7-per-config version | |
| max_samples: Optional[int] = 1_000_000, | |
| local_data_root: Optional[str] = None, | |
| local_data_files: Optional[Dict[str, Any]] = None, | |
| local_files_only: bool = False, | |
| ): | |
| self.data_root = data_root | |
| self.split = self._normalize_split(split) | |
| self.max_images = max_images | |
| self.filter_empty = filter_empty | |
| self.resolution = resolution | |
| self.local_data_root = local_data_root | |
| self.local_data_files = local_data_files or {} | |
| self._download_config = DownloadConfig(local_files_only=True) if local_files_only else None | |
| # cache transforms | |
| self._tfm_crop = transforms.Compose([ | |
| transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BICUBIC), | |
| transforms.CenterCrop((resolution, resolution)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]), | |
| ]) | |
| self._tfm_squash = transforms.Compose([ | |
| transforms.Resize((resolution, resolution), interpolation=transforms.InterpolationMode.BICUBIC), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]), | |
| ]) | |
| # ---- resolve configs ---- | |
| if configs is None or configs == "all": | |
| cfgs = self._resolve_configs_from_local() | |
| if not cfgs: | |
| cfgs = sorted(get_dataset_config_names("TIGER-Lab/Mantis-Instruct")) | |
| elif isinstance(configs, str): | |
| cfgs = [configs] | |
| else: | |
| cfgs = list(configs) | |
| self.configs = cfgs | |
| rng = np.random.default_rng(sample_seed) | |
| per_cfg_ds: List[HFDataset] = [] | |
| acc = Accelerator() | |
| for cfg in cfgs: | |
| base_ds = self._load_base_dataset(cfg, acc) | |
| if base_ds is None: | |
| continue | |
| # --- SAMPLE 1/7 OF BASE ROWS *PER CONFIG* (before any map/expansion) --- | |
| n = len(base_ds) | |
| if n == 0: | |
| continue | |
| k = max(1, int(np.floor(n * per_config_fraction))) | |
| # reproducible uniform sample without replacement | |
| sel_idx = rng.choice(n, size=k, replace=False) | |
| base_ds = base_ds.select(list(sel_idx)) | |
| # locate image dir for this (cfg, split) | |
| img_dir = self._resolve_img_dir(cfg, self.split) | |
| if img_dir is None: | |
| raise FileNotFoundError(f"No image dir for config='{cfg}', split='{self.split}'") | |
| # (1) attach constants | |
| def add_const_cols(batch): | |
| m = len(next(iter(batch.values()))) if batch else 0 | |
| return {"config": [cfg]*m, "img_dir": [img_dir]*m} | |
| ds = base_ds.map(add_const_cols, batched=True) | |
| # (2) normalize image column → absolute string paths | |
| image_key = self._guess_image_key(ds.column_names) | |
| def make_abs_paths(batch): | |
| bases = batch["img_dir"] # list[str] per row | |
| rels = batch[image_key] # per-row: list[dict]|dict|list[str]|str|None | |
| def dict_to_rel(d: Dict[str, Any]) -> Optional[str]: | |
| # typical HF Image: {"path": "...", "bytes": ...} | |
| for k in ("path", "file_name", "filepath", "image_path", "name"): | |
| v = d.get(k) | |
| if isinstance(v, str) and v: | |
| return v | |
| # nested | |
| img = d.get("image") | |
| if isinstance(img, dict): | |
| v = img.get("path") | |
| if isinstance(v, str) and v: | |
| return v | |
| return None | |
| out_paths = [] | |
| for base, r in zip(bases, rels): | |
| # normalize r → list[str] | |
| if r is None: | |
| row = [] | |
| elif isinstance(r, str): | |
| row = [r] | |
| elif isinstance(r, dict): | |
| s = dict_to_rel(r) | |
| row = [s] if s else [] | |
| elif isinstance(r, list): | |
| tmp = [] | |
| for x in r: | |
| if isinstance(x, str): | |
| tmp.append(x) | |
| elif isinstance(x, dict): | |
| s = dict_to_rel(x) | |
| if s: | |
| tmp.append(s) | |
| row = tmp | |
| else: | |
| row = [] | |
| # join to absolute (keep absolute if already) | |
| abs_paths = [p if os.path.isabs(p) else os.path.join(base, p) for p in row if isinstance(p, str)] | |
| # cap if requested | |
| if self.max_images is not None and len(abs_paths) > self.max_images: | |
| abs_paths = abs_paths[: self.max_images] | |
| out_paths.append(abs_paths) | |
| return {"image_paths": out_paths} | |
| ds = ds.map(make_abs_paths, batched=True) | |
| # (3) expand conversation: one row per (user → assistant) turn | |
| conv_key = "conversation" | |
| def expand_turns(batch): | |
| image_paths_list = batch["image_paths"] | |
| conversations = batch.get(conv_key, [[]] * len(image_paths_list)) | |
| out_img_paths, out_user, out_assistant = [], [], [] | |
| for img_paths, conv in zip(image_paths_list, conversations): | |
| conv = conv or [] | |
| # walk adjacent pairs | |
| i = 0 | |
| while i < len(conv) - 1: | |
| a, b = conv[i], conv[i + 1] | |
| if (isinstance(a, dict) and isinstance(b, dict) | |
| and a.get("role") == "user" and b.get("role") == "assistant"): | |
| user_text = (a.get("content") or "").strip() | |
| assistant_text = (b.get("content") or "").strip() | |
| if (not self.filter_empty) or assistant_text: | |
| out_img_paths.append(img_paths) | |
| out_user.append(user_text) | |
| out_assistant.append(assistant_text) | |
| i += 2 | |
| else: | |
| i += 1 | |
| return { | |
| "image_paths": out_img_paths, | |
| "user_text": out_user, | |
| "assistant_text": out_assistant, | |
| } | |
| ds = ds.map(expand_turns, batched=True, remove_columns=ds.column_names) | |
| if self.filter_empty: | |
| ds = ds.filter(lambda e: bool(e["assistant_text"])) | |
| per_cfg_ds.append(ds) | |
| if not per_cfg_ds: | |
| raise ValueError("Empty dataset after per-config sampling and preprocessing.") | |
| self.dataset = concatenate_datasets(per_cfg_ds) if len(per_cfg_ds) > 1 else per_cfg_ds[0] | |
| self.dataset = self.dataset.with_format("python") | |
| print(f"[HF Dataset] per-config 1/7 sampled; configs={self.configs}, split='{self.split}', rows={len(self.dataset)}") | |
| # ---- public API ---- | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| start_idx = idx | |
| attempts = 0 | |
| max_attempts = 10 | |
| while attempts < max_attempts: | |
| ex = self.dataset[idx] | |
| text = ( | |
| "<|start_header_id|>user<|end_header_id|>\n" | |
| f"{ex['user_text']}\n" | |
| "<eot_id><|start_header_id|>assistant<|end_header_id|>\n" | |
| f"{ex['assistant_text']}" | |
| ) | |
| paths = ex["image_paths"] | |
| imgs: list[torch.Tensor] = [] | |
| for path in paths: | |
| img = self._load_and_transform_one(path) | |
| if img is not None: | |
| imgs.append(img) | |
| if imgs: | |
| return { | |
| "images": imgs, | |
| "text": text, | |
| } | |
| attempts += 1 | |
| idx = (idx + 1) % len(self.dataset) | |
| if idx == start_idx: | |
| break | |
| raise RuntimeError("TextImageInterleavedDataset: no valid images found after retries.") | |
| # ---- helpers ---- | |
| def _normalize_split(split: str) -> str: | |
| s = split.lower() | |
| return {"val": "validation", "dev": "validation"}.get(s, s) | |
| def _resolve_img_dir(self, cfg: str, split: str) -> Optional[str]: | |
| # Typical local layout: | |
| # {data_root}/{cfg}/{split}_images | |
| # {data_root}/{cfg}/images | |
| cand1 = os.path.join(self.data_root, cfg, f"{split}_images") | |
| cand2 = os.path.join(self.data_root, cfg, "images") | |
| for c in (cand1, cand2): | |
| if os.path.isdir(c): | |
| return c | |
| return None | |
| def _load_and_transform_one(self, path: str): | |
| try: | |
| with Image.open(path) as im: | |
| im = im.convert("RGB") | |
| except FileNotFoundError: | |
| return None | |
| except Exception: | |
| return None | |
| return self._tfm_crop(im) | |
| def _guess_image_key(cols: List[str]) -> str: | |
| for k in ("images", "image_paths", "imgs", "paths", "image"): | |
| if k in cols: | |
| return k | |
| raise KeyError(f"Cannot find image column among {cols}") | |
| def _resolve_configs_from_local(self) -> List[str]: | |
| cfgs: List[str] = [] | |
| if self.local_data_root: | |
| root = Path(self.local_data_root) | |
| if root.is_dir(): | |
| for entry in sorted(root.iterdir()): | |
| if not entry.is_dir(): | |
| continue | |
| if self._has_split_data(entry): | |
| cfgs.append(entry.name) | |
| if not cfgs and self.local_data_files: | |
| cfgs = [k for k in sorted(self.local_data_files.keys()) if k != "default"] | |
| return cfgs | |
| def _has_split_data(self, cfg_path: Path) -> bool: | |
| split_dir = cfg_path / self.split | |
| if split_dir.is_dir(): | |
| return True | |
| alt_dirs = [ | |
| cfg_path / f"{self.split}.dataset", | |
| cfg_path / f"{self.split}.arrow", | |
| ] | |
| for candidate in alt_dirs: | |
| if candidate.is_dir(): | |
| return True | |
| patterns = [ | |
| cfg_path / self.split / "*.arrow", | |
| cfg_path / self.split / "*.parquet", | |
| cfg_path / f"{self.split}/*.arrow", | |
| cfg_path / f"{self.split}/*.parquet", | |
| cfg_path / f"{self.split}*.arrow", | |
| cfg_path / f"{self.split}*.parquet", | |
| ] | |
| for pattern in patterns: | |
| if glob(str(pattern)): | |
| return True | |
| return False | |
| def _load_base_dataset(self, cfg: str, acc: Accelerator) -> Optional[HFDataset]: | |
| base_ds: Optional[HFDataset] = None | |
| if self.local_data_root is not None: | |
| # print(self.local_data_root) | |
| base_ds = self._load_from_local_root(cfg) | |
| if base_ds is None and self.local_data_files: | |
| base_ds = self._load_from_local_data_files(cfg) | |
| if base_ds is not None: | |
| return base_ds | |
| kwargs = {} | |
| if self._download_config is not None: | |
| kwargs["download_config"] = self._download_config | |
| if acc.num_processes > 1: | |
| acc.wait_for_everyone() | |
| try: | |
| base_ds = load_dataset( | |
| "TIGER-Lab/Mantis-Instruct", | |
| cfg, | |
| split=self.split, | |
| **kwargs, | |
| ) | |
| except Exception as exc: | |
| if self._download_config is not None: | |
| raise RuntimeError( | |
| f"Failed to load local dataset for config='{cfg}'. " | |
| "Ensure that the dataset is cached or provide 'local_data_root'." | |
| ) from exc | |
| raise | |
| finally: | |
| if acc.num_processes > 1: | |
| acc.wait_for_everyone() | |
| return base_ds | |
| def _load_from_local_root(self, cfg: str) -> Optional[HFDataset]: | |
| cfg_root = os.path.join(self.local_data_root, cfg) | |
| if not os.path.exists(cfg_root): | |
| return None | |
| candidates = [ | |
| cfg_root, | |
| os.path.join(cfg_root, self.split), | |
| os.path.join(cfg_root, f"{self.split}.dataset"), | |
| ] | |
| for path in candidates: | |
| if not os.path.isdir(path): | |
| continue | |
| try: | |
| loaded = load_from_disk(path) | |
| if isinstance(loaded, DatasetDict): | |
| if self.split in loaded: | |
| return loaded[self.split] | |
| continue | |
| return loaded | |
| except Exception: | |
| continue | |
| patterns = [ | |
| os.path.join(cfg_root, f"{self.split}.parquet"), | |
| os.path.join(cfg_root, f"{self.split}/*.parquet"), | |
| os.path.join(cfg_root, f"{self.split}_*.parquet"), | |
| os.path.join(cfg_root, f"{self.split}.json"), | |
| os.path.join(cfg_root, f"{self.split}.jsonl"), | |
| os.path.join(cfg_root, f"{self.split}/*.jsonl"), | |
| os.path.join(cfg_root, f"{self.split}.arrow"), | |
| os.path.join(cfg_root, f"{self.split}/*.arrow"), | |
| ] | |
| for pattern in patterns: | |
| files = sorted(glob(pattern)) | |
| if files: | |
| return self._load_from_files(files) | |
| return None | |
| def _load_from_local_data_files(self, cfg: str) -> Optional[HFDataset]: | |
| spec = self.local_data_files.get(cfg) or self.local_data_files.get("default") | |
| if spec is None: | |
| return None | |
| if isinstance(spec, str): | |
| entries = [spec] | |
| loader = None | |
| elif isinstance(spec, dict): | |
| loader = spec.get("type") or spec.get("loader") or spec.get("format") | |
| files = spec.get(self.split) or spec.get("files") | |
| if files is None: | |
| return None | |
| entries = files if isinstance(files, list) else [files] | |
| else: | |
| entries = list(spec) | |
| loader = None | |
| resolved_files: list[str] = [] | |
| for entry in entries: | |
| if not entry: | |
| continue | |
| matched = sorted(glob(entry)) | |
| if matched: | |
| resolved_files.extend(matched) | |
| elif os.path.exists(entry): | |
| resolved_files.append(entry) | |
| if not resolved_files: | |
| return None | |
| return self._load_from_files(resolved_files, loader_hint=loader) | |
| def _load_from_files(self, files: list[str], loader_hint: Optional[str] = None) -> Optional[HFDataset]: | |
| if not files: | |
| return None | |
| ext = Path(files[0]).suffix.lower() | |
| loader = loader_hint | |
| if loader is None: | |
| if ext in (".parquet",): | |
| loader = "parquet" | |
| elif ext in (".json", ".jsonl"): | |
| loader = "json" | |
| elif ext in (".arrow", ".feather"): | |
| loader = "arrow" | |
| if loader == "parquet": | |
| return load_dataset("parquet", data_files={self.split: files}, split=self.split) | |
| if loader in {"json", "jsonl"}: | |
| return load_dataset("json", data_files={self.split: files}, split=self.split) | |
| if loader == "arrow": | |
| datasets = [HFDataset.from_file(path) for path in files] | |
| return concatenate_datasets(datasets) if len(datasets) > 1 else datasets[0] | |
| return None | |
| class HFInstructionTextDataset(Dataset): | |
| """Mixed instruction-following text dataset sourced from multiple HF corpora.""" | |
| HF_SOURCES = ( | |
| { | |
| "name": "openai/gsm8k", | |
| "config": "main", | |
| "split": "train", | |
| "user_key": "question", | |
| "assistant_key": "answer", | |
| }, | |
| { | |
| "name": "qwedsacf/grade-school-math-instructions", | |
| "config": None, | |
| "split": "train", | |
| "user_key": "INSTRUCTION", | |
| "assistant_key": "RESPONSE", | |
| }, | |
| { | |
| "name": "alespalla/chatbot_instruction_prompts", | |
| "config": None, | |
| "split": "train", | |
| "user_key": "prompt", | |
| "assistant_key": "response", | |
| }, | |
| { | |
| "name": "TIGER-Lab/MathInstruct", | |
| "config": None, | |
| "split": "train", | |
| "user_key": "instruction", | |
| "assistant_key": "output", | |
| }, | |
| ) | |
| def __init__( | |
| self, | |
| *, | |
| split: str = "train", | |
| max_samples_per_source: Optional[int] = None, | |
| max_total_samples: Optional[int] = None, | |
| seed: int = 42, | |
| ) -> None: | |
| self.split = split | |
| self.seed = seed | |
| self.samples: List[str] = [] | |
| rng = random.Random(seed) | |
| for source in self.HF_SOURCES: | |
| desired_split = source.get("split", split) | |
| try: | |
| dataset_name = source["name"] | |
| dataset_config = source.get("config") | |
| if dataset_config is not None: | |
| hf_ds = load_dataset(dataset_name, dataset_config, split=desired_split) | |
| else: | |
| hf_ds = load_dataset(dataset_name, split=desired_split) | |
| except Exception as exc: | |
| print(f"[HFInstructionTextDataset] Failed to load {source['name']}: {exc}") | |
| continue | |
| if max_samples_per_source is not None and len(hf_ds) > max_samples_per_source: | |
| hf_ds = hf_ds.shuffle(seed=seed).select(range(max_samples_per_source)) | |
| user_key = source["user_key"] | |
| assistant_key = source["assistant_key"] | |
| for example in hf_ds: | |
| user_raw = str(example.get(user_key, "")).strip() | |
| assistant_raw = str(example.get(assistant_key, "")).strip() | |
| if not user_raw or not assistant_raw: | |
| continue | |
| formatted = self._format_dialogue(user_raw, assistant_raw) | |
| if formatted: | |
| self.samples.append(formatted) | |
| if not self.samples: | |
| raise ValueError("HFInstructionTextDataset loaded zero valid samples.") | |
| rng.shuffle(self.samples) | |
| if max_total_samples is not None: | |
| self.samples = self.samples[: max_total_samples] | |
| def _format_dialogue(user_text: str, assistant_text: str) -> str: | |
| return ( | |
| "<|start_header_id|>user<|end_header_id|>\n" | |
| f"{user_text}\n" | |
| "<|eot_id><|start_header_id|>assistant<|end_header_id|>\n" | |
| f"{assistant_text}" | |
| ) | |
| def __len__(self) -> int: | |
| return len(self.samples) | |
| def __getitem__(self, index: int) -> Dict[str, str]: | |
| return {"input_ids": self.samples[index]} | |
| def collate_fn(batch: List[Dict[str, str]]) -> Dict[str, List[str]]: | |
| return {"input_ids": [example["input_ids"] for example in batch]} | |
| class TextToImage2MDataset(Dataset): | |
| """Loads jackyhate/text-to-image-2M for text-to-image fine-tuning.""" | |
| def __init__( | |
| self, | |
| split: str = "train", | |
| resolution: int = 256, | |
| dataset_name: str = "jackyhate/text-to-image-2M", | |
| cache_dir: str | None = None, | |
| local_files_only: bool = False, | |
| ) -> None: | |
| self.resolution = resolution | |
| self.dataset_name = dataset_name | |
| self.cache_dir = cache_dir | |
| self.local_files_only = local_files_only | |
| download_cfg = None | |
| if local_files_only: | |
| download_cfg = DownloadConfig(local_files_only=True) | |
| self._dataset = load_dataset( | |
| dataset_name, | |
| split=split, | |
| cache_dir=cache_dir, | |
| download_config=download_cfg, | |
| ) | |
| def __len__(self) -> int: | |
| return len(self._dataset) | |
| def __getitem__(self, idx: int) -> Dict[str, Any]: | |
| sample = self._dataset[idx] | |
| prompt = None | |
| json_meta = sample.get("json") | |
| if isinstance(json_meta, dict): | |
| prompt = json_meta.get("prompt") | |
| if prompt is None: | |
| prompt = sample.get("prompt", "") | |
| image_field = sample.get("jpg") or sample.get("image") | |
| if image_field is None: | |
| raise KeyError("Expected image field 'jpg' in text-to-image-2M sample") | |
| if isinstance(image_field, Image.Image): | |
| image = image_field.convert("RGB") | |
| elif isinstance(image_field, bytes): | |
| image = Image.open(BytesIO(image_field)).convert("RGB") | |
| else: | |
| image = Image.fromarray(np.array(image_field)).convert("RGB") | |
| image_tensor = utils_image_transform(image, self.resolution) | |
| return { | |
| "input_prompt": prompt, | |
| "output_prompt": None, | |
| "edit_prompt": None, | |
| "inverse_prompt": None, | |
| "input_image": image_tensor, | |
| "output_image": image_tensor, | |
| } | |
| class HQEditX2IDataset(Dataset): | |
| def __init__( | |
| self, | |
| split: str = "train", | |
| resolution: int = 256, | |
| dataset_name: str = "UCSC-VLAA/HQ-Edit", | |
| cache_dir: str = "/home/work/AIDAS/huggingface/datasets", | |
| ): | |
| self.resolution = resolution | |
| self.cache_dir = cache_dir # retained for API compatibility | |
| self._dataset = load_dataset(dataset_name, split=split) | |
| def __len__(self) -> int: | |
| return len(self._dataset) | |
| def __getitem__(self, idx: int) -> Dict[str, Any]: | |
| sample = self._dataset[idx] | |
| input_tensor = utils_image_transform(sample['input_image'].convert("RGB"), self.resolution) | |
| output_tensor = utils_image_transform(sample['output_image'].convert("RGB"), self.resolution) | |
| return { | |
| "input_prompt": sample["input"], | |
| "output_prompt": sample["output"], | |
| "edit_prompt": sample["edit"], | |
| "inverse_prompt": sample["inverse_edit"], | |
| "input_image": input_tensor, | |
| "output_image": output_tensor, | |
| } | |
| class CombinedX2IDataset(Dataset): | |
| """Round-robin combination of multiple x2i-style datasets.""" | |
| def __init__(self, datasets: Sequence[Dataset]): | |
| if not datasets: | |
| raise ValueError("CombinedX2IDataset requires at least one dataset.") | |
| self.datasets = list(datasets) | |
| self.lengths = [len(ds) for ds in self.datasets] | |
| if any(length == 0 for length in self.lengths): | |
| raise ValueError("Underlying x2i dataset has zero length.") | |
| self.cumulative = list(itertools.accumulate(self.lengths)) | |
| self.total_length = self.cumulative[-1] | |
| def __len__(self) -> int: | |
| return self.total_length | |
| def __getitem__(self, idx: int) -> Dict[str, Any]: | |
| if idx < 0 or idx >= self.total_length: | |
| raise IndexError(f"Index {idx} out of bounds for CombinedX2IDataset of length {self.total_length}") | |
| dataset_idx = bisect.bisect_right(self.cumulative, idx) | |
| prev = self.cumulative[dataset_idx - 1] if dataset_idx > 0 else 0 | |
| local_idx = idx - prev | |
| return self.datasets[dataset_idx][local_idx] | |
| class OpenImageI2IDataset(Dataset): | |
| """ | |
| Image-to-image dataset built from local Open Images edit JSONL files. | |
| Supports three JSONL schemas: | |
| * SFT-style single turn edits (text + output_image + local_input_image) | |
| * Preference data; only positive edits (output_image) are used by default | |
| * Multi-turn edits which are flattened into single-turn pairs | |
| """ | |
| def __init__( | |
| self, | |
| resolution: int = 256, | |
| image_root: str | None = None, | |
| sft_jsonl: Union[str, Sequence[str], None] = None, | |
| pref_jsonl: Union[str, Sequence[str], None] = None, | |
| multi_turn_jsonl: Union[str, Sequence[str], None] = None, | |
| prefer_summarized_text: bool = True, | |
| pref_positive_only: bool = True, | |
| skip_missing: bool = True, | |
| max_samples_per_source: int | None = None, | |
| max_total_samples: int | None = None, | |
| seed: int | None = None, | |
| ) -> None: | |
| self.resolution = resolution | |
| self.image_root = image_root | |
| self.prefer_summarized_text = prefer_summarized_text | |
| self.pref_positive_only = pref_positive_only | |
| self.skip_missing = skip_missing | |
| self._rng = random.Random(seed if seed is not None else 0) | |
| self._per_source_limit = self._coerce_positive_int(max_samples_per_source) | |
| self._total_limit = self._coerce_positive_int(max_total_samples) | |
| self._samples: list[dict[str, str]] = [] | |
| self._stats: dict[str, int] = { | |
| "sft": 0, | |
| "pref": 0, | |
| "multi_turn": 0, | |
| "missing_paths": 0, | |
| "invalid_records": 0, | |
| } | |
| sft_paths = self._coerce_paths(sft_jsonl) | |
| pref_paths = self._coerce_paths(pref_jsonl) | |
| multi_turn_paths = self._coerce_paths(multi_turn_jsonl) | |
| for path in sft_paths: | |
| self._samples.extend(self._load_single_turn_file(path, source_key="sft")) | |
| for path in pref_paths: | |
| if not self.pref_positive_only: | |
| logger.warning("OpenImageI2IDataset currently only supports positive preference pairs.") | |
| self._samples.extend(self._load_single_turn_file(path, source_key="pref")) | |
| for path in multi_turn_paths: | |
| self._samples.extend(self._load_multi_turn_file(path)) | |
| if self._total_limit is not None and len(self._samples) > self._total_limit: | |
| self._rng.shuffle(self._samples) | |
| self._samples = self._samples[: self._total_limit] | |
| if not self._samples: | |
| raise ValueError("OpenImageI2IDataset could not load any valid examples.") | |
| logger.info( | |
| "Loaded %d OpenImage i2i samples (sft=%d, pref=%d, multi_turn=%d, missing_paths=%d, invalid=%d).", | |
| len(self._samples), | |
| self._stats["sft"], | |
| self._stats["pref"], | |
| self._stats["multi_turn"], | |
| self._stats["missing_paths"], | |
| self._stats["invalid_records"], | |
| ) | |
| def __len__(self) -> int: | |
| return len(self._samples) | |
| def __getitem__(self, idx: int) -> Dict[str, Any]: | |
| record = self._samples[idx] | |
| input_image = Image.open(record["input_path"]).convert("RGB") | |
| target_image = Image.open(record["target_path"]).convert("RGB") | |
| input_tensor = utils_image_transform(input_image, self.resolution) | |
| target_tensor = utils_image_transform(target_image, self.resolution) | |
| prompt = record["prompt"] | |
| return { | |
| "input_prompt": prompt, | |
| "output_prompt": None, | |
| "edit_prompt": prompt, | |
| "inverse_prompt": None, | |
| "input_image": input_tensor, | |
| "output_image": target_tensor, | |
| } | |
| def _load_single_turn_file(self, path: str, *, source_key: str) -> list[dict[str, str]]: | |
| file_path = os.path.abspath(os.path.expanduser(path)) | |
| if not os.path.exists(file_path): | |
| logger.warning("OpenImageI2IDataset: JSONL file not found: %s", file_path) | |
| return [] | |
| base_dir = os.path.dirname(file_path) | |
| samples: list[dict[str, str]] = [] | |
| with open(file_path, "r", encoding="utf-8") as handle: | |
| for line in handle: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| record = json.loads(line) | |
| except json.JSONDecodeError: | |
| self._stats["invalid_records"] += 1 | |
| continue | |
| prompt = self._select_prompt(record) | |
| input_path = self._resolve_path(record.get("local_input_image"), base_dir=base_dir) | |
| output_path = self._resolve_path(record.get("output_image"), base_dir=base_dir) | |
| sample = self._build_sample(prompt, input_path, output_path) | |
| if sample: | |
| samples.append(sample) | |
| if self._per_source_limit is not None and len(samples) > self._per_source_limit: | |
| self._rng.shuffle(samples) | |
| samples = samples[: self._per_source_limit] | |
| self._stats[source_key] += len(samples) | |
| return samples | |
| def _load_multi_turn_file(self, path: str) -> list[dict[str, str]]: | |
| file_path = os.path.abspath(os.path.expanduser(path)) | |
| if not os.path.exists(file_path): | |
| logger.warning("OpenImageI2IDataset: JSONL file not found: %s", file_path) | |
| return [] | |
| base_dir = os.path.dirname(file_path) | |
| samples: list[dict[str, str]] = [] | |
| with open(file_path, "r", encoding="utf-8") as handle: | |
| for line in handle: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| record = json.loads(line) | |
| except json.JSONDecodeError: | |
| self._stats["invalid_records"] += 1 | |
| continue | |
| multi_samples = self._expand_multi_turn(record, base_dir=base_dir) | |
| if multi_samples: | |
| samples.extend(multi_samples) | |
| if self._per_source_limit is not None and len(samples) > self._per_source_limit: | |
| self._rng.shuffle(samples) | |
| samples = samples[: self._per_source_limit] | |
| self._stats["multi_turn"] += len(samples) | |
| return samples | |
| def _expand_multi_turn(self, record: dict, *, base_dir: str) -> list[dict[str, str]]: | |
| prompts = record.get("metadata_edit_turn_prompts") or [] | |
| files = record.get("files") or [] | |
| if not prompts or not files: | |
| self._stats["invalid_records"] += 1 | |
| return [] | |
| outputs: dict[int, str] = {} | |
| final_image: str | None = None | |
| for entry in files: | |
| file_id = entry.get("id") | |
| url = entry.get("url") | |
| if not file_id or not url: | |
| continue | |
| if file_id.startswith("edit_turn"): | |
| try: | |
| idx = int(file_id.replace("edit_turn", "").strip()) | |
| except ValueError: | |
| continue | |
| outputs[idx] = url | |
| elif file_id == "final_image": | |
| final_image = url | |
| current_input = self._resolve_path(record.get("local_input_image"), base_dir=base_dir) | |
| if not current_input: | |
| return [] | |
| samples: list[dict[str, str]] = [] | |
| for turn_idx, prompt in enumerate(prompts, start=1): | |
| target_rel = outputs.get(turn_idx) | |
| if target_rel is None: | |
| if turn_idx == len(prompts): | |
| target_rel = final_image | |
| else: | |
| break | |
| target_path = self._resolve_path(target_rel, base_dir=base_dir) | |
| if not target_path: | |
| break | |
| sample = self._build_sample(prompt, current_input, target_path) | |
| if not sample: | |
| break | |
| samples.append(sample) | |
| current_input = target_path | |
| return samples | |
| def _select_prompt(self, record: dict) -> str | None: | |
| if self.prefer_summarized_text and record.get("summarized_text"): | |
| return record.get("summarized_text") | |
| if record.get("text"): | |
| return record.get("text") | |
| return record.get("metadata_edit_turn_prompt") | |
| def _build_sample(self, prompt: str | None, input_path: str | None, target_path: str | None) -> dict[str, str] | None: | |
| if not prompt: | |
| self._stats["invalid_records"] += 1 | |
| return None | |
| if not input_path or not target_path: | |
| return None | |
| return { | |
| "prompt": str(prompt).strip(), | |
| "input_path": input_path, | |
| "target_path": target_path, | |
| } | |
| def _resolve_path(self, path: str | None, *, base_dir: str | None = None) -> str | None: | |
| if not path or path.startswith("http://") or path.startswith("https://"): | |
| return None | |
| candidates: list[str] = [] | |
| normalized = path.replace("\\", "/") | |
| if os.path.isabs(normalized): | |
| candidates.append(os.path.normpath(normalized)) | |
| else: | |
| if self.image_root: | |
| candidates.append(os.path.normpath(os.path.join(self.image_root, normalized))) | |
| if base_dir: | |
| candidates.append(os.path.normpath(os.path.join(base_dir, normalized))) | |
| for candidate in candidates: | |
| if not self.skip_missing or os.path.exists(candidate): | |
| return candidate | |
| self._stats["missing_paths"] += 1 | |
| return None | |
| def _coerce_paths(self, value: Union[str, Sequence[str], None]) -> list[str]: | |
| if value is None: | |
| return [] | |
| if isinstance(value, str): | |
| values = [value] | |
| else: | |
| values = [item for item in value if item] | |
| return [os.path.abspath(os.path.expanduser(path)) for path in values] | |
| def _coerce_positive_int(value: Any) -> int | None: | |
| if value is None: | |
| return None | |
| try: | |
| int_value = int(value) | |
| except (TypeError, ValueError): | |
| return None | |
| return int_value if int_value > 0 else None | |
| # import os, socket | |
| # from typing import Optional | |
| # def _dist_identity(): | |
| # """Return a dict with rank info from env/torch if available.""" | |
| # info = {} | |
| # # Env fallbacks for different launchers (torchrun/SLURM/MPI) | |
| # def _get(*keys) -> Optional[int]: | |
| # for k in keys: | |
| # v = os.environ.get(k) | |
| # if v is not None: | |
| # try: | |
| # return int(v) | |
| # except ValueError: | |
| # return None | |
| # return None | |
| # info["rank"] = _get("RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK") | |
| # info["local_rank"] = _get("LOCAL_RANK", "SLURM_LOCALID", "MPI_LOCALRANKID") | |
| # info["node_rank"] = _get("NODE_RANK", "SLURM_NODEID") | |
| # info["world_size"] = _get("WORLD_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE") | |
| # info["hostname"] = socket.gethostname() | |
| # info["pid"] = os.getpid() | |
| # # Optional: torch.distributed status | |
| # try: | |
| # import torch.distributed as dist | |
| # info["dist_initialized"] = dist.is_available() and dist.is_initialized() | |
| # if info["dist_initialized"]: | |
| # info["rank"] = dist.get_rank() | |
| # info["world_size"] = dist.get_world_size() | |
| # info["backend"] = dist.get_backend() | |
| # except Exception: | |
| # info["dist_initialized"] = False | |
| # # Optional: DataLoader worker ID | |
| # try: | |
| # from torch.utils.data import get_worker_info | |
| # wi = get_worker_info() | |
| # info["worker_id"] = wi.id if wi is not None else None | |
| # except Exception: | |
| # info["worker_id"] = None | |
| # return info | |
| # def _whoami_str(): | |
| # i = _dist_identity() | |
| # return ( | |
| # f"[PROC] rank={i['rank']} local_rank={i['local_rank']} node_rank={i['node_rank']} " | |
| # f"world={i['world_size']} worker={i['worker_id']} " | |
| # f"host={i['hostname']} pid={i['pid']} " | |
| # f"{'(backend='+i['backend']+')' if i.get('backend') else ''}" | |
| # ) | |
| if __name__ == '__main__': | |
| pass | |