File size: 1,590 Bytes
d382778
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from typing import List, Optional, Tuple
import os
import torch
from torch.utils.data import Dataset


def load_data_from_dir(
    data_folder: str, limit: int = 200
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[Optional[torch.Tensor]], List[Optional[torch.Tensor]]]:
    latents, targets, conditions, unconditions = [], [], [], []
    pt_files = [f for f in os.listdir(data_folder) if f.endswith('pt')]
    for file_name in sorted(pt_files)[:limit]:
        file_path = os.path.join(data_folder, file_name)
        data = torch.load(file_path)
        latents.append(data["latent"])
        targets.append(data["img"])
        conditions.append(data.get("c", None))
        unconditions.append(data.get("uc", None))
    return latents, targets, conditions, unconditions


class LD3Dataset(Dataset):
    def __init__(
        self,
        ori_latent: List[torch.Tensor],
        latent: List[torch.Tensor],
        target: List[torch.Tensor],
        condition:  List[Optional[torch.Tensor]],
        uncondition:  List[Optional[torch.Tensor]],
    ):
        self.ori_latent = ori_latent
        self.latent = latent
        self.target = target
        self.condition = condition
        self.uncondition = uncondition

    def __len__(self) -> int:
        return len(self.ori_latent)

    def __getitem__(self, idx: int):
        img = self.target[idx]
        latent = self.latent[idx]
        ori_latent = self.ori_latent[idx]
        condition = self.condition[idx]
        uncondition = self.uncondition[idx]
        return img, latent, ori_latent, condition, uncondition