| from typing import List, Optional, Tuple | |
| import os | |
| import torch | |
| from torch.utils.data import Dataset | |
| def load_data_from_dir( | |
| data_folder: str, limit: int = 200 | |
| ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[Optional[torch.Tensor]], List[Optional[torch.Tensor]]]: | |
| latents, targets, conditions, unconditions = [], [], [], [] | |
| pt_files = [f for f in os.listdir(data_folder) if f.endswith('pt')] | |
| for file_name in sorted(pt_files)[:limit]: | |
| file_path = os.path.join(data_folder, file_name) | |
| data = torch.load(file_path) | |
| latents.append(data["latent"]) | |
| targets.append(data["img"]) | |
| conditions.append(data.get("c", None)) | |
| unconditions.append(data.get("uc", None)) | |
| return latents, targets, conditions, unconditions | |
| class LD3Dataset(Dataset): | |
| def __init__( | |
| self, | |
| ori_latent: List[torch.Tensor], | |
| latent: List[torch.Tensor], | |
| target: List[torch.Tensor], | |
| condition: List[Optional[torch.Tensor]], | |
| uncondition: List[Optional[torch.Tensor]], | |
| ): | |
| self.ori_latent = ori_latent | |
| self.latent = latent | |
| self.target = target | |
| self.condition = condition | |
| self.uncondition = uncondition | |
| def __len__(self) -> int: | |
| return len(self.ori_latent) | |
| def __getitem__(self, idx: int): | |
| img = self.target[idx] | |
| latent = self.latent[idx] | |
| ori_latent = self.ori_latent[idx] | |
| condition = self.condition[idx] | |
| uncondition = self.uncondition[idx] | |
| return img, latent, ori_latent, condition, uncondition |