mindeyev2old / src /cnd_prov /cnd_prov.py
ckadirt's picture
Upload folder using huggingface_hub
626cbe5 verified
#new one
import pandas as pd
import datasets
import os
import pickle
_VERSION = datasets.Version("0.0.3")
_DESCRIPTION = "TODO"
_HOMEPAGE = "TODO"
_LICENSE = "TODO"
_CITATION = "TODO"
_FEATURES = datasets.Features(
{
"target": datasets.Image(),
"source": datasets.Image(),
"heatmap": datasets.Image(),
"depth": datasets.Image(),
"prompt": datasets.Value("string"),
},
)
METADATA_DIR = "/fsx/proj-fmri/ckadirt/MindEyeV2/src/cnd_prov/data.pkl"
SOURCE_DIR = "/fsx/proj-fmri/shared/controlNetData/source"
TARGET_DIR = "/fsx/proj-fmri/shared/controlNetData/target"
HEATMAP_DIR = "/fsx/proj-fmri/shared/controlNetData/seg"
DEPTH_DIR = "/fsx/proj-fmri/shared/dinov2_depth"
# METADATA_URL = hf_hub_url(
# "fusing/fill50k",
# filename="train.jsonl",
# repo_type="dataset",
# )
# IMAGES_URL = hf_hub_url(
# "fusing/fill50k",
# filename="images.zip",
# repo_type="dataset",
# )
# CONDITIONING_IMAGES_URL = hf_hub_url(
# "fusing/fill50k",
# filename="conditioning_images.zip",
# repo_type="dataset",
# )
_DEFAULT_CONFIG = datasets.BuilderConfig(name="default", version=_VERSION)
class CocoTest(datasets.GeneratorBasedBuilder):
BUILDER_CONFIGS = [_DEFAULT_CONFIG]
DEFAULT_CONFIG_NAME = "default"
def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=_FEATURES,
supervised_keys=None,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION,
)
def _split_generators(self, dl_manager):
metadata_path = METADATA_DIR
target_dir = TARGET_DIR
source_dir = SOURCE_DIR
heatmap_dir = HEATMAP_DIR
depth_dir = DEPTH_DIR
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
# These kwargs will be passed to _generate_examples
gen_kwargs={
"metadata_path": metadata_path,
"target_dir": TARGET_DIR,
"source_dir": SOURCE_DIR,
"heatmap_dir": HEATMAP_DIR,
"depth_dir": DEPTH_DIR,
"num_examples": 190573,
},
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
# These kwargs will be passed to _generate_examples
gen_kwargs={
"metadata_path": metadata_path,
"target_dir": TARGET_DIR,
"source_dir": SOURCE_DIR,
"heatmap_dir": HEATMAP_DIR,
"depth_dir": DEPTH_DIR,
"num_examples": 20000,
},
),
]
def _generate_examples(self, metadata_path, target_dir, source_dir, heatmap_dir, depth_dir, num_examples):
data = []
with open(metadata_path, 'rb') as f:
loaded_data = pickle.load(f)
for line in loaded_data[:num_examples]:
data.append(line)
for _, item in enumerate(data):
source_filename = item['source']
target_filename = item['target']
heatmap_filename = item['h_map']
depth_filename = item['depth']
prompt = item['prompt']
tgt_img = open(target_filename, "rb").read()
src_img = open(source_filename, "rb").read()
h_img = open(heatmap_filename, "rb").read()
d_img = open(depth_filename, "rb").read()
yield item["target"], {
"prompt": prompt,
"target": {
"path": target_filename,
"bytes": tgt_img,
},
"source": {
"path": source_filename,
"bytes": src_img,
},
"heatmap": {
"path": heatmap_filename,
"bytes": h_img,
},
"depth": {
"path": depth_filename,
"bytes": d_img,
},
}