xingzhikb's picture
init
002bd9b
import json
import os
import datasets
logger = datasets.logging.get_logger(__name__)
_BASE_IMAGE_METADATA_FEATURES = {
"image_id": datasets.Value("int32"),
"width": datasets.Value("int32"),
"height": datasets.Value("int32"),
"file_name": datasets.Value("string"),
"task_type": datasets.Value("string"),
}
_BASE_REGION_FEATURES = {
"region_id": datasets.Value("int32"),
"image_id": datasets.Value("int32"),
"phrases": [datasets.Value("string")],
"x": datasets.Value("int32"),
"y": datasets.Value("int32"),
"width": datasets.Value("int32"),
"height": datasets.Value("int32"),
}
_ANNOTATION_FEATURES = {
"regions": [_BASE_REGION_FEATURES],
}
class CustomDatasetConfig(datasets.BuilderConfig):
def __init__(
self,
name,
splits,
data_dir: str = None,
with_image: bool = True,
task_type: str = "recognition",
**kwargs,
):
super().__init__(name, **kwargs)
self.splits = splits
self.with_image = with_image
self.task_type = task_type
self.data_dir = data_dir
@property
def features(self):
return datasets.Features(
{
**({"image": datasets.Image()} if self.with_image else {}),
**_BASE_IMAGE_METADATA_FEATURES,
**_ANNOTATION_FEATURES,
}
)
class CustomDataset(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.1")
BUILDER_CONFIG_CLASS = CustomDatasetConfig
BUILDER_CONFIGS = [
CustomDatasetConfig(name="custom", splits=["train", "validation"]),
]
DEFAULT_CONFIG_NAME = "custom"
config: CustomDatasetConfig
def _info(self):
return datasets.DatasetInfo(features=self.config.features)
def _split_generators(self, dl_manager):
"""Returns SplitGenerators."""
data_dir = self.config.data_dir
if data_dir is None:
raise ValueError(
"This script is supposed to work with a local dataset. The argument `data_dir` in `load_dataset()` is required."
)
splits = []
for split in self.config.splits:
if split == "train":
dataset = datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"json_path": os.path.join(data_dir, "train.json"),
},
)
elif split in ["val", "valid", "validation", "dev"]:
dataset = datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
gen_kwargs={
"json_path": os.path.join(data_dir, "test.json"), # Using test.json for validation
},
)
else:
continue
splits.append(dataset)
return splits
def _generate_examples(self, json_path):
with open(json_path, 'r') as f:
data = json.load(f)
for idx, image in enumerate(data["images"]):
image_id = image["image_id"]
image_metadata = {
"file_name": image["file_name"],
"height": image["height"],
"width": image["width"],
"image_id": image["image_id"],
}
annotations = [
ann for ann in data["annotations"] if ann["image_id"] == image_id
]
regions = [{
"region_id": ann["region_id"],
"image_id": ann["image_id"],
"phrases": ann["phrases"],
"x": ann["x"],
"y": ann["y"],
"width": ann["width"],
"height": ann["height"],
} for ann in annotations]
image_dict = {"image": os.path.join(self.config.data_dir, image["file_name"])} if self.config.with_image else {}
yield idx, {**image_dict, **image_metadata, "regions": regions, "task_type": self.config.task_type}