Spaces:
Runtime error
Runtime error
| from typing import Dict | |
| import numpy as np | |
| from omegaconf import DictConfig, ListConfig | |
| import torch | |
| from torch.utils.data import Dataset | |
| from pathlib import Path | |
| import json | |
| from PIL import Image | |
| from torchvision import transforms | |
| from einops import rearrange | |
| from ldm.util import instantiate_from_config | |
| # from datasets import load_dataset | |
| import os | |
| from collections import defaultdict | |
| from glob import glob | |
| import re | |
| from bisect import bisect_left, bisect_right | |
| import albumentations, cv2 | |
| import time | |
| class SynWhiteBoardDataset(Dataset): | |
| def __init__(self, | |
| img_folder, | |
| caption_folder, | |
| tsv_info_file, | |
| corpus_type = "all_4gram", | |
| image_transforms=[], | |
| first_stage_key = "jpg", | |
| cond_stage_key = "txt", | |
| postprocess=None, | |
| ext = "png", | |
| img_class = "whiteboard", | |
| caption_type = "regular", # "simple" or "regular" or "full" | |
| lower_case = False, | |
| max_num = None, | |
| image_size = 512, | |
| do_padding = True, | |
| explict_arrangement = False, | |
| ) -> None: | |
| self.root_dir = os.path.join(Path(img_folder), corpus_type) | |
| self.caption_folder = caption_folder | |
| assert os.path.exists(self.caption_folder) and os.path.exists(tsv_info_file) | |
| with open(tsv_info_file, "r") as f: | |
| tsv_info_dict = json.loads(f.read()) | |
| total_num = 0 | |
| rank_list = [] | |
| for _, value in tsv_info_dict.items(): | |
| total_num += len(value) | |
| rank_list.append(total_num) | |
| self.rank_list = rank_list | |
| self.total_num = total_num if max_num is None else max_num | |
| self.tsv_info_dict = tsv_info_dict | |
| self.corpus_type = corpus_type | |
| self.first_stage_key = first_stage_key | |
| self.cond_stage_key = cond_stage_key | |
| # postprocess | |
| if isinstance(postprocess, DictConfig): | |
| postprocess = instantiate_from_config(postprocess) | |
| self.postprocess = postprocess | |
| # image transform | |
| if isinstance(image_transforms, ListConfig): | |
| image_transforms = [instantiate_from_config(tt) for tt in image_transforms] | |
| image_transforms.extend([transforms.ToTensor(), # to be checked | |
| transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) | |
| image_transforms = transforms.Compose(image_transforms) | |
| self.tform = image_transforms | |
| self.ext = ext | |
| self.num_rank = eval((list(tsv_info_dict.keys())[0]).split("_")[-1].split(".")[0]) | |
| self.img_class = img_class | |
| self.caption_type = caption_type | |
| self.lower_case = lower_case | |
| self.do_padding = do_padding | |
| self.image_rescaler = albumentations.LongestMaxSize(max_size=image_size, interpolation=cv2.INTER_AREA) | |
| self.image_size = image_size | |
| self.pad = albumentations.PadIfNeeded(min_height= self.image_size, min_width=self.image_size, | |
| border_mode=cv2.BORDER_CONSTANT, value= (255, 255, 255), | |
| ) | |
| self.explict_arrangement = explict_arrangement | |
| def __len__(self): | |
| return self.total_num | |
| def __getitem__(self, index): | |
| pre = time.time() | |
| data = {} | |
| rank = bisect_right(self.rank_list, index) | |
| index_in_tsv = index - ( self.rank_list[rank-1] if rank > 0 else 0 ) | |
| # rank = index % self.num_rank | |
| # index_in_tsv = index // self.num_rank | |
| tsv_name = "{}_{}_{}.tsv".format( | |
| self.corpus_type, rank, self.num_rank | |
| ) | |
| with open(os.path.join(self.caption_folder, tsv_name), "r") as f: | |
| f.seek( | |
| self.tsv_info_dict[tsv_name][index_in_tsv] | |
| ) | |
| caption_info = f.readline().strip() | |
| # print("open caption file", time.time() - pre) | |
| info_list = caption_info.split("\t") | |
| assert len(info_list) == 5 | |
| txt_content, font_file, arrange_, align, imagename= info_list | |
| # imagename= str(index) + ".{}".format(self.ext) | |
| filename = os.path.join(self.root_dir, imagename) | |
| img_pret = time.time() | |
| try: | |
| im = Image.open(filename) | |
| # print("open image time", time.time() - img_pret) | |
| except: | |
| return self.__getitem__(np.random.choice(self.__len__())) | |
| im = self.process_im(im) | |
| data[self.first_stage_key] = im | |
| # print("img process time", time.time() - img_pret) | |
| if self.caption_type == "simple": | |
| caption = 'A {} that says {}'.format( | |
| self.img_class, txt_content, | |
| ) | |
| else: | |
| # elif self.caption_type == "regular": | |
| font_weight = "" | |
| font_style = "" | |
| font_width = "" | |
| font_file = re.sub(u'\\[.*?\\]',"", font_file) # remove [] | |
| font_list = font_file[:-4].split("-") | |
| if len(font_list) > 2: | |
| print("font file name outlier: {}".format(font_file)) | |
| font_list = [ | |
| "-".join(font_list[:-1]), | |
| font_list[-1] | |
| ] | |
| if len(font_list) == 2: | |
| font_name, font_type = font_list | |
| if font_type == "VF": | |
| font_style = "VF" | |
| else: | |
| # font_type = re.sub(u'\\[.*?\\]',"", font_type) # remove [] | |
| font_tlist = re.findall("[A-Z][a-z]*", font_type) | |
| if "Regular" in font_tlist: | |
| font_weight = "Regular" | |
| font_style = "Regular" | |
| else: | |
| # style | |
| if "Italic" in font_tlist: | |
| font_style = "Italic" | |
| font_tlist.remove("Italic") | |
| elif "Oblique" in font_tlist: | |
| font_style = "Oblique" | |
| font_tlist.remove("Oblique") | |
| elif "Cursive" in font_tlist: | |
| font_style = "Cursive" | |
| font_tlist.remove("Cursive") | |
| elif "Book" in font_tlist: | |
| font_style = "Book" | |
| font_tlist.remove("Book") | |
| # width | |
| if "Condensed" in font_tlist: | |
| font_width = "Condensed" | |
| font_tlist.remove("Condensed") | |
| # weight | |
| if len(font_tlist): | |
| font_weight = " ".join(font_tlist) | |
| elif len(font_list) == 1: | |
| font_name = font_list[0] | |
| # font_name = re.sub(u'\\[.*?\\]',"", font_name) # remove [] | |
| if "Italic" in font_name: | |
| font_name = font_name.replace("Italic","") | |
| font_style = "Italic" | |
| if "Bold" in font_name: | |
| font_name = font_name.replace("Bold", "") | |
| font_weight = "Bold" | |
| else: | |
| print("Invalid font file name: {}".format(font_file)) | |
| return self.__getitem__(np.random.choice(self.__len__())) | |
| # Width | |
| if "Condensed" in font_name: | |
| if "Extra" in font_name or "Semi" in font_name or "Ultra" in font_name: | |
| font_name_list = re.findall("[A-Z][a-z]*", font_name) | |
| font_width = " ".join(font_name_list[-2:]) | |
| font_name = "".join(font_name_list[:-2]) | |
| else: | |
| font_name = font_name.rstrip("Condensed") | |
| font_width = "Condensed" | |
| # if "ExtraCondensed" in font_name: | |
| # font_width = "Extra Condensed" | |
| # elif "SemiCondensed" in font_name: | |
| # font_width = "Semi Condensed" | |
| # elif "UltraCondensed" in font_name: | |
| # font_width = "Ultra Condensed" | |
| # else: | |
| # font_width = "Condensed" | |
| caption = 'A {} that says {} written in the font of {}'.format( | |
| self.img_class, txt_content, font_name | |
| ) | |
| addition_cond = 0 | |
| if font_weight != "": | |
| font_weight = font_weight.lower() if self.lower_case else font_weight | |
| caption += " {} {} stroke weight".format( | |
| "with" if addition_cond == 0 else "and", font_weight | |
| ) | |
| addition_cond += 1 | |
| if font_width != "": | |
| font_width = font_width.lower() if self.lower_case else font_width | |
| caption += " {} {} font width".format( | |
| "with" if addition_cond == 0 else "and", font_width | |
| ) | |
| addition_cond += 1 | |
| if font_style != "": | |
| font_style = font_style.lower() if self.lower_case else font_style | |
| caption += " {} {} font style".format( | |
| "with" if addition_cond == 0 else "and", font_style | |
| ) | |
| addition_cond += 1 | |
| if self.caption_type == "full": | |
| words = txt_content.strip('"').split(" ") | |
| assert len(words) == 4 | |
| frn, srn = arrange_.split("_") | |
| frn, srn = eval(frn), eval(srn) | |
| assert (frn + srn == 4 ) | |
| if frn == 0 or srn == 0: | |
| caption += '. All the words are written in the same row.' | |
| else: | |
| if self.explict_arrangement: | |
| caption += '. "{}" is written in the first row while "{}" is in the second row.'.format( | |
| ' '.join(words[:frn]), | |
| ' '.join(words[frn:]) | |
| ) | |
| else: | |
| caption += '. The first {} written in the first row while the {} in the second row.'.format( | |
| "{} words are".format(frn) if frn >1 else "word is", | |
| "other {} words are".format(srn) if srn >1 else "last word is", | |
| ) | |
| # print(caption) | |
| # print(caption) | |
| data[self.cond_stage_key] = caption | |
| # if self.captions is not None: | |
| # data[self.cond_stage_key] = caption | |
| # else: | |
| # data[self.cond_stage_key] = self.default_caption | |
| if self.postprocess is not None: | |
| data = self.postprocess(data) | |
| # print("total time", time.time() - pre) | |
| return data | |
| def process_im(self, im): | |
| im = im.convert("RGB") | |
| if self.do_padding: | |
| # pre = time.time() | |
| im = self.padding_image(im) | |
| # print("padding time", time.time() - pre) | |
| return self.tform(im) | |
| def padding_image(self, im): | |
| # resize | |
| im = np.array(im).astype(np.uint8) | |
| im_rescaled = self.image_rescaler(image=im)["image"] | |
| # padding | |
| im_padded = self.pad(image=im_rescaled)["image"] | |
| return im_padded | |
| # im_out = Image.fromarray(im_padded) | |
| # return im_out |