| |
| import pandas as pd |
| import datasets |
| import os |
| import pickle |
|
|
| _VERSION = datasets.Version("0.0.3") |
|
|
| _DESCRIPTION = "TODO" |
| _HOMEPAGE = "TODO" |
| _LICENSE = "TODO" |
| _CITATION = "TODO" |
|
|
| _FEATURES = datasets.Features( |
| { |
| "target": datasets.Image(), |
| "source": datasets.Image(), |
| "heatmap": datasets.Image(), |
| "depth": datasets.Image(), |
| "prompt": datasets.Value("string"), |
| }, |
| ) |
|
|
| METADATA_DIR = "/fsx/proj-fmri/ckadirt/MindEyeV2/src/cnd_prov/data.pkl" |
| SOURCE_DIR = "/fsx/proj-fmri/shared/controlNetData/source" |
| TARGET_DIR = "/fsx/proj-fmri/shared/controlNetData/target" |
| HEATMAP_DIR = "/fsx/proj-fmri/shared/controlNetData/seg" |
| DEPTH_DIR = "/fsx/proj-fmri/shared/dinov2_depth" |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| _DEFAULT_CONFIG = datasets.BuilderConfig(name="default", version=_VERSION) |
|
|
|
|
| class CocoTest(datasets.GeneratorBasedBuilder): |
| BUILDER_CONFIGS = [_DEFAULT_CONFIG] |
| DEFAULT_CONFIG_NAME = "default" |
|
|
| def _info(self): |
| return datasets.DatasetInfo( |
| description=_DESCRIPTION, |
| features=_FEATURES, |
| supervised_keys=None, |
| homepage=_HOMEPAGE, |
| license=_LICENSE, |
| citation=_CITATION, |
| ) |
|
|
| def _split_generators(self, dl_manager): |
| metadata_path = METADATA_DIR |
| target_dir = TARGET_DIR |
| source_dir = SOURCE_DIR |
| heatmap_dir = HEATMAP_DIR |
| depth_dir = DEPTH_DIR |
|
|
| return [ |
| datasets.SplitGenerator( |
| name=datasets.Split.TRAIN, |
| |
| gen_kwargs={ |
| "metadata_path": metadata_path, |
| "target_dir": TARGET_DIR, |
| "source_dir": SOURCE_DIR, |
| "heatmap_dir": HEATMAP_DIR, |
| "depth_dir": DEPTH_DIR, |
| "num_examples": 190573, |
| }, |
| ), |
| datasets.SplitGenerator( |
| name=datasets.Split.VALIDATION, |
| |
| gen_kwargs={ |
| "metadata_path": metadata_path, |
| "target_dir": TARGET_DIR, |
| "source_dir": SOURCE_DIR, |
| "heatmap_dir": HEATMAP_DIR, |
| "depth_dir": DEPTH_DIR, |
| "num_examples": 20000, |
| }, |
| ), |
| ] |
|
|
| def _generate_examples(self, metadata_path, target_dir, source_dir, heatmap_dir, depth_dir, num_examples): |
| data = [] |
| with open(metadata_path, 'rb') as f: |
| loaded_data = pickle.load(f) |
| for line in loaded_data[:num_examples]: |
| data.append(line) |
|
|
| for _, item in enumerate(data): |
| source_filename = item['source'] |
| target_filename = item['target'] |
| heatmap_filename = item['h_map'] |
| depth_filename = item['depth'] |
| prompt = item['prompt'] |
| |
| |
|
|
| tgt_img = open(target_filename, "rb").read() |
| src_img = open(source_filename, "rb").read() |
| h_img = open(heatmap_filename, "rb").read() |
| d_img = open(depth_filename, "rb").read() |
|
|
| yield item["target"], { |
| "prompt": prompt, |
| "target": { |
| "path": target_filename, |
| "bytes": tgt_img, |
| }, |
| "source": { |
| "path": source_filename, |
| "bytes": src_img, |
| }, |
| "heatmap": { |
| "path": heatmap_filename, |
| "bytes": h_img, |
| }, |
| "depth": { |
| "path": depth_filename, |
| "bytes": d_img, |
| }, |
| } |
|
|