| | |
| | |
| | |
| | |
| |
|
| | import json |
| | import os |
| | from abc import abstractmethod |
| | from pathlib import Path |
| |
|
| | import json5 |
| | import torch |
| | import yaml |
| |
|
| |
|
| | |
| | class BaseDataset(torch.utils.data.Dataset): |
| | r"""Base dataset for training and validating.""" |
| |
|
| | def __init__(self, args, cfg, is_valid=False): |
| | pass |
| |
|
| |
|
| | class BaseTestDataset(torch.utils.data.Dataset): |
| | r"""Test dataset for inference.""" |
| |
|
| | def __init__(self, args=None, cfg=None, infer_type="from_dataset"): |
| | assert infer_type in ["from_dataset", "from_file"] |
| |
|
| | self.args = args |
| | self.cfg = cfg |
| | self.infer_type = infer_type |
| |
|
| | @abstractmethod |
| | def __getitem__(self, index): |
| | pass |
| |
|
| | def __len__(self): |
| | return len(self.metadata) |
| |
|
| | def get_metadata(self): |
| | path = Path(self.args.source) |
| | if path.suffix == ".json" or path.suffix == ".jsonc": |
| | metadata = json5.load(open(self.args.source, "r")) |
| | elif path.suffix == ".yaml" or path.suffix == ".yml": |
| | metadata = yaml.full_load(open(self.args.source, "r")) |
| | else: |
| | raise ValueError(f"Unsupported file type: {path.suffix}") |
| |
|
| | return metadata |
| |
|