| import math |
| import os |
| import json |
| import re |
| import cv2 |
| from dataclasses import dataclass, field |
|
|
| import pytorch_lightning as pl |
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
| from step1x3d_geometry import register |
| from step1x3d_geometry.utils.typing import * |
| from step1x3d_geometry.utils.config import parse_structured |
|
|
| from streaming import StreamingDataLoader |
| from .base import BaseDataModuleConfig, BaseDataset |
|
|
|
|
| @dataclass |
| class ObjaverseDataModuleConfig(BaseDataModuleConfig): |
| pass |
|
|
|
|
| class ObjaverseDataset(BaseDataset): |
| pass |
|
|
|
|
| @register("Objaverse-datamodule") |
| class ObjaverseDataModule(pl.LightningDataModule): |
| cfg: ObjaverseDataModuleConfig |
|
|
| def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None: |
| super().__init__() |
| self.cfg = parse_structured(ObjaverseDataModuleConfig, cfg) |
|
|
| def setup(self, stage=None) -> None: |
| if stage in [None, "fit"]: |
| self.train_dataset = ObjaverseDataset(self.cfg, "train") |
| if stage in [None, "fit", "validate"]: |
| self.val_dataset = ObjaverseDataset(self.cfg, "val") |
| if stage in [None, "test", "predict"]: |
| self.test_dataset = ObjaverseDataset(self.cfg, "test") |
|
|
| def prepare_data(self): |
| pass |
|
|
| def general_loader( |
| self, dataset, batch_size, collate_fn=None, num_workers=0 |
| ) -> DataLoader: |
| return DataLoader( |
| dataset, |
| batch_size=batch_size, |
| collate_fn=collate_fn, |
| num_workers=num_workers, |
| ) |
|
|
| def train_dataloader(self) -> DataLoader: |
| return self.general_loader( |
| self.train_dataset, |
| batch_size=self.cfg.batch_size, |
| collate_fn=self.train_dataset.collate, |
| num_workers=self.cfg.num_workers, |
| ) |
|
|
| def val_dataloader(self) -> DataLoader: |
| return self.general_loader(self.val_dataset, batch_size=1) |
|
|
| def test_dataloader(self) -> DataLoader: |
| return self.general_loader(self.test_dataset, batch_size=1) |
|
|
| def predict_dataloader(self) -> DataLoader: |
| return self.general_loader(self.test_dataset, batch_size=1) |
|
|