Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,475 Bytes
f06aba5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
from typing import Any, Callable, Dict, List, Optional, Union
import os
import random
import traceback
import math
import json
import numpy as np
import torch
import torch.distributed as dist
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoTokenizer
from PIL import Image
from tqdm import tqdm
from longcat_image.dataset import MULTI_RESOLUTION_MAP
from longcat_image.utils import encode_prompt
from longcat_image.dataset import MultiResolutionDistributedSampler
Image.MAX_IMAGE_PIXELS = 2000000000
MAX_RETRY_NUMS = 100
class Text2ImageLoraDataSet(torch.utils.data.Dataset):
def __init__(self,
cfg: dict,
txt_root: str,
tokenizer: AutoTokenizer,
resolution: tuple = (1024, 1024),
repeats: int = 1 ):
super(Text2ImageLoraDataSet, self).__init__()
self.resolution = resolution
self.text_tokenizer_max_length = cfg.text_tokenizer_max_length
self.null_text_ratio = cfg.null_text_ratio
self.aspect_ratio_type = cfg.aspect_ratio_type
self.aspect_ratio = MULTI_RESOLUTION_MAP[self.aspect_ratio_type]
self.tokenizer = tokenizer
self.prompt_template_encode_prefix = cfg.prompt_template_encode_prefix
self.prompt_template_encode_suffix = cfg.prompt_template_encode_suffix
self.prompt_template_encode_start_idx = cfg.prompt_template_encode_start_idx
self.prompt_template_encode_end_idx = cfg.prompt_template_encode_end_idx
self.total_datas = []
self.data_resolution_infos = []
with open(txt_root, 'r') as f:
lines = f.readlines()
lines *= cfg.repeats
for line in tqdm(lines):
data = json.loads(line.strip())
try:
height, widht = int(data['height']), int(data['width'])
self.data_resolution_infos.append((height, widht))
self.total_datas.append(data)
except Exception as e:
print(f'get error {e}, data {data}.')
continue
self.data_nums = len(self.total_datas)
print(f'get sampler {len(self.total_datas)}, from {txt_root}!!!')
def transform_img(self, image, original_size, target_size):
img_h, img_w = original_size
target_height, target_width = target_size
original_aspect = img_h / img_w # height/width
crop_aspect = target_height / target_width
if original_aspect >= crop_aspect:
resize_width = target_width
resize_height = math.ceil(img_h * (target_width/img_w))
else:
resize_width = math.ceil(img_w * (target_height/img_h))
resize_height = target_height
image = T.Compose([
T.Resize((resize_height, resize_width),interpolation=InterpolationMode.BICUBIC), # Image.LANCZOS
T.CenterCrop((target_height, target_width)),
T.ToTensor(),
T.Normalize([.5], [.5]),
])(image)
return image
def __getitem__(self, index_tuple):
index, target_size = index_tuple
for _ in range(MAX_RETRY_NUMS):
try:
item = self.total_datas[index]
img_path = item["img_path"]
prompt = item['prompt']
if random.random() < self.null_text_ratio:
prompt = ''
raw_image = Image.open(img_path).convert('RGB')
assert raw_image is not None
img_w, img_h = raw_image.size
raw_image = self.transform_img(raw_image, original_size=(img_h, img_w), target_size= target_size )
input_ids,attention_mask = encode_prompt(prompt, self.tokenizer, self.text_tokenizer_max_length, self.prompt_template_encode_prefix, self.prompt_template_encode_suffix )
return {"image": raw_image, "prompt": prompt, 'input_ids': input_ids, 'attention_mask': attention_mask}
except Exception as e:
traceback.print_exc()
print(f"failed read data {e}!!!")
index = random.randint(0, self.data_nums-1)
def __len__(self):
return self.data_nums
def collate_fn(self, batchs):
images = torch.stack([example["image"] for example in batchs])
input_ids = torch.stack([example["input_ids"] for example in batchs])
attention_mask = torch.stack([example["attention_mask"] for example in batchs])
prompts = [example['prompt'] for example in batchs]
batch_dict = {
"images": images,
"input_ids": input_ids,
"attention_mask": attention_mask,
"prompts": prompts,
}
return batch_dict
def build_dataloader(cfg: dict,
csv_root: str,
tokenizer: AutoTokenizer,
resolution: tuple = (1024, 1024)):
dataset = Text2ImageLoraDataSet(cfg, csv_root, tokenizer, resolution)
sampler = MultiResolutionDistributedSampler(batch_size=cfg.train_batch_size, dataset=dataset,
data_resolution_infos=dataset.data_resolution_infos,
bucket_info=dataset.aspect_ratio,
epoch=0,
num_replicas=None,
rank=None
)
train_loader = torch.utils.data.DataLoader(
dataset,
collate_fn=dataset.collate_fn,
batch_size=cfg.train_batch_size,
num_workers=cfg.dataloader_num_workers,
sampler=sampler,
shuffle=None,
)
return train_loader
if __name__ == '__main__':
import sys
import argparse
from torchvision.transforms.functional import to_pil_image
txt_root = 'xxx'
cfg = argparse.Namespace(
txt_root=txt_root,
text_tokenizer_max_length=512,
resolution=1024,
text_encoder_path="xxx",
center_crop=True,
dataloader_num_workers=0,
null_text_ratio=0.1,
train_batch_size=16,
seed=0,
aspect_ratio_type='mar_1024',
revision=None)
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(cfg.text_encoder_path, trust_remote_code=True)
data_loader = build_dataloader(cfg, cfg.csv_root, tokenizer, cfg.resolution)
_oroot = f'./debug_data_example_show'
os.makedirs(_oroot, exist_ok=True)
cnt = 0
for epoch in range(1):
print(f"Start, epoch {epoch}!!!")
for i_batch, batch in enumerate(data_loader):
print(batch['attention_mask'].shape)
print(batch['images'].shape)
batch_prompts = batch['prompts']
for idx, per_img in enumerate(batch['images']):
re_transforms = T.Compose([
T.Normalize(mean=[-0.5/0.5], std=[1.0/0.5])
])
prompt = batch_prompts[idx]
img = to_pil_image(re_transforms(per_img))
prompt = prompt[:min(30, len(prompt))]
oname = _oroot + f'/{str(i_batch)}_{str(idx)}_{prompt}.png'
img.save(oname)
if cnt > 100:
break
cnt += 1
|