File size: 8,152 Bytes
f460ce6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
from PIL import Image
from datasets import load_dataset
from torchvision import transforms
import random
import torch
import os
from .pipeline_flux_kontext_control import PREFERRED_KONTEXT_RESOLUTIONS
import numpy as np
from src.condition.edge_extraction import (
    CannyDetector, PidiNetDetector, TEDDetector, LineartStandardDetector, HEDdetector,
    AnyLinePreprocessor, LineartDetector, InformativeDetector
)

Image.MAX_IMAGE_PIXELS = None

def multiple_16(num: float):
    return int(round(num / 16) * 16)

def load_image_safely(image_path, size, root="/mnt/robby-b1/common/datasets/"):
    image_path = os.path.join(root, image_path)
    try:
        image = Image.open(image_path).convert("RGB")
        return image
    except Exception as e:
        print("file error: "+image_path)
        with open("failed_images.txt", "a") as f:
            f.write(f"{image_path}\n")
        return Image.new("RGB", (size, size), (255, 255, 255))
    
def choose_kontext_resolution_from_wh(width: int, height: int):
    aspect_ratio = width / max(1, height)
    _, best_w, best_h = min(
        (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
    )
    return best_w, best_h

class EdgeExtractorManager:
    _instance = None
    _initialized = False

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(EdgeExtractorManager, cls).__new__(cls)
        return cls._instance

    def __init__(self):
        if not self._initialized:
            self.edge_extractors = None
            self.device = None
            self._initialized = True

    def set_device(self, device):
        self.device = device

    def get_edge_extractors(self, device=None):
        # 强制在CPU上初始化,避免DataLoader子进程中触发CUDA初始化
        current_device = "cpu"
        if device is not None:
            self.set_device(current_device)

        if self.edge_extractors is None or len(self.edge_extractors) == 0:
            self.edge_extractors = [
                ("canny", CannyDetector()),
                ("pidinet", PidiNetDetector.from_pretrained().to(current_device)),
                ("ted", TEDDetector.from_pretrained().to(current_device)),
                # ("lineart_standard", LineartStandardDetector()),
                ("hed",  HEDdetector.from_pretrained().to(current_device)),
                ("anyline", AnyLinePreprocessor.from_pretrained().to(current_device)),
                # ("lineart", LineartDetector.from_pretrained().to(current_device)),
                ("informative", InformativeDetector.from_pretrained().to(current_device)),
            ]

        return self.edge_extractors

edge_extractor_manager = EdgeExtractorManager()

def collate_fn(examples):
    if examples[0].get("cond_pixel_values") is not None:
        cond_pixel_values = torch.stack([example["cond_pixel_values"] for example in examples])
        cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
    else:
        cond_pixel_values = None
    source_pixel_values = None

    target_pixel_values = torch.stack([example["pixel_values"] for example in examples])
    target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float()
    token_ids_clip = torch.stack([example["token_ids_clip"] for example in examples])
    token_ids_t5 = torch.stack([example["token_ids_t5"] for example in examples])

    return {
        "cond_pixel_values": cond_pixel_values,
        "source_pixel_values": source_pixel_values,
        "pixel_values": target_pixel_values,
        "text_ids_1": token_ids_clip,
        "text_ids_2": token_ids_t5,
    }


def make_train_dataset_inpaint_mask(args, tokenizers, accelerator=None):
    # 加载CSV数据集:三列,第一列为图片相对路径,第三列为caption
    if args.train_data_dir is not None:
        dataset = load_dataset('csv', data_files=args.train_data_dir)

    # 列名兼容处理:使用第 0 列作为图片路径,第 2 列作为caption
    column_names = dataset["train"].column_names
    image_col = column_names[0]
    caption_col = column_names[2] if len(column_names) >= 3 else column_names[-1]

    size = args.cond_size

    # 设备设置(用于分布式时将部分检测器放到对应GPU)
    if accelerator is not None:
        device = accelerator.device
        edge_extractor_manager.set_device(device)
    else:
        device = "cpu"

    # Transforms
    to_tensor_and_norm = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ])

    # 与 jsonl_datasets_edge.py 保持一致:Resize -> CenterCrop -> ToTensor -> Normalize
    cond_train_transforms = transforms.Compose([
        transforms.Resize((size, size), interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.CenterCrop((size, size)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ])

    tokenizer_clip = tokenizers[0]
    tokenizer_t5 = tokenizers[1]

    def tokenize_prompt_clip_t5(examples):
        captions_raw = examples[caption_col]
        captions = []
        for c in captions_raw:
            if isinstance(c, str):
                if random.random() < 0.25:
                    captions.append("")
                else:
                    captions.append(c)
            else:
                captions.append("")

        text_inputs_clip = tokenizer_clip(
            captions,
            padding="max_length",
            max_length=77,
            truncation=True,
            return_length=False,
            return_overflowing_tokens=False,
            return_tensors="pt",
        )
        text_input_ids_1 = text_inputs_clip.input_ids

        text_inputs_t5 = tokenizer_t5(
            captions,
            padding="max_length",
            max_length=128,
            truncation=True,
            return_length=False,
            return_overflowing_tokens=False,
            return_tensors="pt",
        )
        text_input_ids_2 = text_inputs_t5.input_ids
        return text_input_ids_1, text_input_ids_2

    def preprocess_train(examples):
        batch = {}

        img_paths = examples[image_col]

        target_tensors = []
        cond_tensors = []

        for p in img_paths:
            # Load image by joining with root in load_image_safely
            img = load_image_safely(p, size)
            img = img.convert("RGB")

            # Resize to Kontext preferred resolution for target
            w, h = img.size
            best_w, best_h = choose_kontext_resolution_from_wh(w, h)
            img_rs = img.resize((best_w, best_h), resample=Image.BILINEAR)
            target_tensor = to_tensor_and_norm(img_rs)

            # Build edge condition
            extractor_name, extractor = random.choice(edge_extractor_manager.get_edge_extractors())
            img_np = np.array(img)
            if extractor_name == "informative":
                edge = extractor(img_np, style="contour")
            else:
                edge = extractor(img_np)

            if extractor_name == "ted":
                th = 128
            else:
                th = 32

            edge_np = np.array(edge) if isinstance(edge, Image.Image) else edge
            if edge_np.ndim == 3:
                edge_np = edge_np[..., 0]
            edge_bin = (edge_np > th).astype(np.float32)
            edge_pil = Image.fromarray((edge_bin * 255).astype(np.uint8))
            edge_tensor = cond_train_transforms(edge_pil)
            edge_tensor = edge_tensor.repeat(3, 1, 1)

            target_tensors.append(target_tensor)
            cond_tensors.append(edge_tensor)

        batch["pixel_values"] = target_tensors
        batch["cond_pixel_values"] = cond_tensors

        batch["token_ids_clip"], batch["token_ids_t5"] = tokenize_prompt_clip_t5(examples)
        return batch

    if accelerator is not None:
        with accelerator.main_process_first():
            train_dataset = dataset["train"].with_transform(preprocess_train)
    else:
        train_dataset = dataset["train"].with_transform(preprocess_train)

    return train_dataset