File size: 3,669 Bytes
0f52c9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# --- START OF FILE data/n2h_dataset.py (Sửa lỗi AttributeError) ---
import os
from data.pix2pix_dataset import Pix2pixDataset
from data.image_folder import make_dataset
from PIL import Image
import random
from data.base_dataset import get_params, get_transform


class N2HDataset(Pix2pixDataset):

    def __init__(self, opt):
        """Initialize this dataset class.

        A_paths and B_paths are defined here, and we call the initialize
        method of the parent class (Pix2pixDataset) to set up the rest.
        """
        # Gọi __init__ của lớp cha gần nhất (Pix2pixDataset)
        # Pix2pixDataset không có __init__, nên nó sẽ gọi BaseDataset.__init__(self, opt)
        # Điều này là đúng với bản sửa lỗi trước của chúng ta.
        super().__init__(opt)

        # Gọi hàm initialize của lớp cha để thiết lập self.label_paths, self.image_paths, và self.dataset_size
        self.initialize(opt)

    @staticmethod
    def modify_commandline_options(parser, is_train):
        parser = Pix2pixDataset.modify_commandline_options(parser, is_train)
        parser.set_defaults(preprocess_mode='resize_and_crop')
        parser.set_defaults(load_size=286)
        parser.set_defaults(crop_size=256)
        parser.set_defaults(display_winsize=256)
        parser.set_defaults(aspect_ratio=1.0)
        opt, _ = parser.parse_known_args()
        if hasattr(opt, 'num_upsampling_layers'):
            parser.set_defaults(num_upsampling_layers='more')
        return parser

    def get_paths(self, opt):
        croot = opt.croot
        sroot = opt.sroot

        # Logic này giả định cấu trúc thư mục là croot/trainA, sroot/trainB
        c_image_dir = os.path.join(croot, opt.phase + 'A')
        s_image_dir = os.path.join(sroot, opt.phase + 'B')

        if not os.path.isdir(c_image_dir):
            raise FileNotFoundError(f"Content directory not found: {c_image_dir}")
        if not os.path.isdir(s_image_dir):
            raise FileNotFoundError(f"Style directory not found: {s_image_dir}")

        c_image_paths = sorted(make_dataset(c_image_dir, recursive=True))
        s_image_paths = sorted(make_dataset(s_image_dir, recursive=True))

        if opt.phase == 'train' and len(c_image_paths) > 0 and len(s_image_paths) > 0:
            if len(c_image_paths) > len(s_image_paths):
                s_image_paths = s_image_paths * (len(c_image_paths) // len(s_image_paths) + 1)
            elif len(s_image_paths) > len(c_image_paths):
                c_image_paths = c_image_paths * (len(s_image_paths) // len(c_image_paths) + 1)

        instance_paths = []

        return c_image_paths, s_image_paths, instance_paths

    def __getitem__(self, index):
        # Lấy ảnh Day (ảnh A - content)
        # self.label_paths được gán bằng c_image_paths trong Pix2pixDataset.initialize()
        day_path = self.label_paths[index % len(self.label_paths)]

        # Lấy ảnh Night (ảnh B - style) ngẫu nhiên
        # self.image_paths được gán bằng s_image_paths trong Pix2pixDataset.initialize()
        night_path = self.image_paths[random.randint(0, len(self.image_paths) - 1)]

        day_img = Image.open(day_path).convert('RGB')
        night_img = Image.open(night_path).convert('RGB')

        params = get_params(self.opt, day_img.size)
        transform = get_transform(self.opt, params)

        day_tensor = transform(day_img)
        night_tensor = transform(night_img)

        return {'day': day_tensor, 'night': night_tensor, 'cpath': day_path, 'spath_night': night_path}

    def paths_match(self, path1, path2):
        return True