import os import errno import numpy as np from torch.nn import init import torch import torch.nn as nn import arabic_reshaper from bidi.algorithm import get_display from PIL import Image, ImageDraw, ImageFont from copy import deepcopy import skimage.transform from miscc.config import cfg # For visualization ################################################ COLOR_DIC = {0:[128,64,128], 1:[244, 35,232], 2:[70, 70, 70], 3:[102,102,156], 4:[190,153,153], 5:[153,153,153], 6:[250,170, 30], 7:[220, 220, 0], 8:[107,142, 35], 9:[152,251,152], 10:[70,130,180], 11:[220,20, 60], 12:[255, 0, 0], 13:[0, 0, 142], 14:[119,11, 32], 15:[0, 60,100], 16:[0, 80, 100], 17:[0, 0, 230], 18:[0, 0, 70], 19:[0, 0, 0]} FONT_MAX = 50 def drawCaption(convas, captions, ixtoword, vis_size, off1=2, off2=2): num = captions.size(0) img_txt = Image.fromarray(convas) # get a font # fnt = None # ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50) # fnt = ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50) fnt = ImageFont.truetype('/content/drive/MyDrive/Graduation Project/GANS/AD-GAN/code/miscc/sahel-font/dist/Sahel-Light.ttf', 18) # get a drawing context d = ImageDraw.Draw(img_txt) sentence_list = [] for i in range(num): cap = captions[i].data.cpu().numpy() sentence = [] for j in range(len(cap)): if cap[j] == 0: break word = ixtoword[cap[j]].encode('cp1256', 'ignore').decode('cp1256') drawed_word = arabic_reshaper.reshape(word) # correct its shape drawed_word = get_display(drawed_word) # correct its direction d.text(((j + off1) * (vis_size + off2), i * FONT_MAX), '%d:%s' % (j, drawed_word[:6]), font=fnt, fill=(255, 255, 255, 255)) sentence.append(word) sentence_list.append(sentence) return img_txt, sentence_list def build_super_images(real_imgs, captions, ixtoword, attn_maps, att_sze, lr_imgs=None, batch_size=cfg.TRAIN.BATCH_SIZE, max_word_num=cfg.TEXT.WORDS_NUM): nvis = 8 real_imgs = real_imgs[:nvis] if lr_imgs is not None: lr_imgs = lr_imgs[:nvis] if att_sze == 17: vis_size = att_sze * 16 else: vis_size = real_imgs.size(2) text_convas = \ np.ones([batch_size * FONT_MAX, (max_word_num + 2) * (vis_size + 2), 3], dtype=np.uint8) for i in range(max_word_num): istart = (i + 2) * (vis_size + 2) iend = (i + 3) * (vis_size + 2) text_convas[:, istart:iend, :] = COLOR_DIC[i] real_imgs = \ nn.functional.interpolate(real_imgs,size=(vis_size, vis_size), mode='bilinear', align_corners=False) # [-1, 1] --> [0, 1] real_imgs.add_(1).div_(2).mul_(255) real_imgs = real_imgs.data.numpy() # b x c x h x w --> b x h x w x c real_imgs = np.transpose(real_imgs, (0, 2, 3, 1)) pad_sze = real_imgs.shape middle_pad = np.zeros([pad_sze[2], 2, 3]) post_pad = np.zeros([pad_sze[1], pad_sze[2], 3]) if lr_imgs is not None: lr_imgs = \ nn.functional.interpolate(lr_imgs,size=(vis_size, vis_size), mode='bilinear', align_corners=False) # [-1, 1] --> [0, 1] lr_imgs.add_(1).div_(2).mul_(255) lr_imgs = lr_imgs.data.numpy() # b x c x h x w --> b x h x w x c lr_imgs = np.transpose(lr_imgs, (0, 2, 3, 1)) # batch x seq_len x 17 x 17 --> batch x 1 x 17 x 17 seq_len = max_word_num img_set = [] num = nvis # len(attn_maps) text_map, sentences = \ drawCaption(text_convas, captions, ixtoword, vis_size) text_map = np.asarray(text_map).astype(np.uint8) bUpdate = 1 for i in range(num): attn = attn_maps[i].cpu().view(1, -1, att_sze, att_sze) # --> 1 x 1 x 17 x 17 attn_max = attn.max(dim=1, keepdim=True) attn = torch.cat([attn_max[0], attn], 1) # attn = attn.view(-1, 1, att_sze, att_sze) attn = attn.repeat(1, 3, 1, 1).data.numpy() # n x c x h x w --> n x h x w x c attn = np.transpose(attn, (0, 2, 3, 1)) num_attn = attn.shape[0] # img = real_imgs[i] if lr_imgs is None: lrI = img else: lrI = lr_imgs[i] row = [lrI, middle_pad] row_merge = [img, middle_pad] row_beforeNorm = [] minVglobal, maxVglobal = 1, 0 for j in range(num_attn): one_map = attn[j] if (vis_size // att_sze) > 1: one_map = \ skimage.transform.pyramid_expand(one_map, sigma=20, upscale=vis_size // att_sze, multichannel=True) row_beforeNorm.append(one_map) minV = one_map.min() maxV = one_map.max() if minVglobal > minV: minVglobal = minV if maxVglobal < maxV: maxVglobal = maxV for j in range(seq_len + 1): if j < num_attn: one_map = row_beforeNorm[j] one_map = (one_map - minVglobal) / (maxVglobal - minVglobal) one_map *= 255 # PIL_im = Image.fromarray(np.uint8(img)) PIL_att = Image.fromarray(np.uint8(one_map)) merged = \ Image.new('RGBA', (vis_size, vis_size), (0, 0, 0, 0)) mask = Image.new('L', (vis_size, vis_size), (210)) merged.paste(PIL_im, (0, 0)) merged.paste(PIL_att, (0, 0), mask) merged = np.array(merged)[:, :, :3] else: one_map = post_pad merged = post_pad row.append(one_map) row.append(middle_pad) # row_merge.append(merged) row_merge.append(middle_pad) row = np.concatenate(row, 1) row_merge = np.concatenate(row_merge, 1) txt = text_map[i * FONT_MAX: (i + 1) * FONT_MAX] if txt.shape[1] != row.shape[1]: print('txt', txt.shape, 'row', row.shape) bUpdate = 0 break row = np.concatenate([txt, row, row_merge], 0) img_set.append(row) if bUpdate: img_set = np.concatenate(img_set, 0) img_set = img_set.astype(np.uint8) return img_set, sentences else: return None def build_super_images2(real_imgs, captions, cap_lens, ixtoword, attn_maps, att_sze, vis_size=256, topK=5): batch_size = real_imgs.size(0) max_word_num = np.max(cap_lens) text_convas = np.ones([batch_size * FONT_MAX, max_word_num * (vis_size + 2), 3], dtype=np.uint8) real_imgs = \ nn.functional.interpolate(real_imgs,size=(vis_size, vis_size), mode='bilinear', align_corners=False) # [-1, 1] --> [0, 1] real_imgs.add_(1).div_(2).mul_(255) real_imgs = real_imgs.data.numpy() # b x c x h x w --> b x h x w x c real_imgs = np.transpose(real_imgs, (0, 2, 3, 1)) pad_sze = real_imgs.shape middle_pad = np.zeros([pad_sze[2], 2, 3]) # batch x seq_len x 17 x 17 --> batch x 1 x 17 x 17 img_set = [] num = len(attn_maps) text_map, sentences = \ drawCaption(text_convas, captions, ixtoword, vis_size, off1=0) text_map = np.asarray(text_map).astype(np.uint8) bUpdate = 1 for i in range(num): attn = attn_maps[i].cpu().view(1, -1, att_sze, att_sze) # attn = attn.view(-1, 1, att_sze, att_sze) attn = attn.repeat(1, 3, 1, 1).data.numpy() # n x c x h x w --> n x h x w x c attn = np.transpose(attn, (0, 2, 3, 1)) num_attn = cap_lens[i] thresh = 2./float(num_attn) # img = real_imgs[i] row = [] row_merge = [] row_txt = [] row_beforeNorm = [] conf_score = [] for j in range(num_attn): one_map = attn[j] mask0 = one_map > (2. * thresh) conf_score.append(np.sum(one_map * mask0)) mask = one_map > thresh one_map = one_map * mask if (vis_size // att_sze) > 1: one_map = \ skimage.transform.pyramid_expand(one_map, sigma=20, upscale=vis_size // att_sze, multichannel=True) minV = one_map.min() maxV = one_map.max() one_map = (one_map - minV) / (maxV - minV) row_beforeNorm.append(one_map) sorted_indices = np.argsort(conf_score)[::-1] for j in range(num_attn): one_map = row_beforeNorm[j] one_map *= 255 # PIL_im = Image.fromarray(np.uint8(img)) PIL_att = Image.fromarray(np.uint8(one_map)) merged = \ Image.new('RGBA', (vis_size, vis_size), (0, 0, 0, 0)) mask = Image.new('L', (vis_size, vis_size), (180)) # (210) merged.paste(PIL_im, (0, 0)) merged.paste(PIL_att, (0, 0), mask) merged = np.array(merged)[:, :, :3] row.append(np.concatenate([one_map, middle_pad], 1)) # row_merge.append(np.concatenate([merged, middle_pad], 1)) # txt = text_map[i * FONT_MAX:(i + 1) * FONT_MAX, j * (vis_size + 2):(j + 1) * (vis_size + 2), :] row_txt.append(txt) # reorder row_new = [] row_merge_new = [] txt_new = [] for j in range(num_attn): idx = sorted_indices[j] row_new.append(row[idx]) row_merge_new.append(row_merge[idx]) txt_new.append(row_txt[idx]) row = np.concatenate(row_new[:topK], 1) row_merge = np.concatenate(row_merge_new[:topK], 1) txt = np.concatenate(txt_new[:topK], 1) if txt.shape[1] != row.shape[1]: print('Warnings: txt', txt.shape, 'row', row.shape, 'row_merge_new', row_merge_new.shape) bUpdate = 0 break row = np.concatenate([txt, row_merge], 0) img_set.append(row) if bUpdate: img_set = np.concatenate(img_set, 0) img_set = img_set.astype(np.uint8) return img_set, sentences else: return None #################################################################### def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: nn.init.orthogonal_(m.weight.data, 1.0) elif classname.find('BatchNorm') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) elif classname.find('Linear') != -1: nn.init.orthogonal_(m.weight.data, 1.0) if m.bias is not None: m.bias.data.fill_(0.0) def load_params(model, new_param): for p, new_p in zip(model.parameters(), new_param): p.data.copy_(new_p) def copy_G_params(model): flatten = deepcopy(list(p.data for p in model.parameters())) return flatten def mkdir_p(path): try: os.makedirs(path) except OSError as exc: # Python >2.5 if exc.errno == errno.EEXIST and os.path.isdir(path): pass else: raise