Instructions to use SrinivasMudiraj/Baaz with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use SrinivasMudiraj/Baaz with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("SrinivasMudiraj/Baaz", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import os | |
| import errno | |
| import numpy as np | |
| from torch.nn import init | |
| import torch | |
| import torch.nn as nn | |
| import arabic_reshaper | |
| from bidi.algorithm import get_display | |
| from PIL import Image, ImageDraw, ImageFont | |
| from copy import deepcopy | |
| import skimage.transform | |
| from miscc.config import cfg | |
| # For visualization ################################################ | |
| COLOR_DIC = {0:[128,64,128], 1:[244, 35,232], | |
| 2:[70, 70, 70], 3:[102,102,156], | |
| 4:[190,153,153], 5:[153,153,153], | |
| 6:[250,170, 30], 7:[220, 220, 0], | |
| 8:[107,142, 35], 9:[152,251,152], | |
| 10:[70,130,180], 11:[220,20, 60], | |
| 12:[255, 0, 0], 13:[0, 0, 142], | |
| 14:[119,11, 32], 15:[0, 60,100], | |
| 16:[0, 80, 100], 17:[0, 0, 230], | |
| 18:[0, 0, 70], 19:[0, 0, 0]} | |
| FONT_MAX = 50 | |
| def drawCaption(convas, captions, ixtoword, vis_size, off1=2, off2=2): | |
| num = captions.size(0) | |
| img_txt = Image.fromarray(convas) | |
| # get a font | |
| # fnt = None # ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50) | |
| # fnt = ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50) | |
| fnt = ImageFont.truetype('/content/drive/MyDrive/Graduation Project/GANS/AD-GAN/code/miscc/sahel-font/dist/Sahel-Light.ttf', 18) | |
| # get a drawing context | |
| d = ImageDraw.Draw(img_txt) | |
| sentence_list = [] | |
| for i in range(num): | |
| cap = captions[i].data.cpu().numpy() | |
| sentence = [] | |
| for j in range(len(cap)): | |
| if cap[j] == 0: | |
| break | |
| word = ixtoword[cap[j]].encode('cp1256', 'ignore').decode('cp1256') | |
| drawed_word = arabic_reshaper.reshape(word) # correct its shape | |
| drawed_word = get_display(drawed_word) # correct its direction | |
| d.text(((j + off1) * (vis_size + off2), i * FONT_MAX), '%d:%s' % (j, drawed_word[:6]), | |
| font=fnt, fill=(255, 255, 255, 255)) | |
| sentence.append(word) | |
| sentence_list.append(sentence) | |
| return img_txt, sentence_list | |
| def build_super_images(real_imgs, captions, ixtoword, | |
| attn_maps, att_sze, lr_imgs=None, | |
| batch_size=cfg.TRAIN.BATCH_SIZE, | |
| max_word_num=cfg.TEXT.WORDS_NUM): | |
| nvis = 8 | |
| real_imgs = real_imgs[:nvis] | |
| if lr_imgs is not None: | |
| lr_imgs = lr_imgs[:nvis] | |
| if att_sze == 17: | |
| vis_size = att_sze * 16 | |
| else: | |
| vis_size = real_imgs.size(2) | |
| text_convas = \ | |
| np.ones([batch_size * FONT_MAX, | |
| (max_word_num + 2) * (vis_size + 2), 3], | |
| dtype=np.uint8) | |
| for i in range(max_word_num): | |
| istart = (i + 2) * (vis_size + 2) | |
| iend = (i + 3) * (vis_size + 2) | |
| text_convas[:, istart:iend, :] = COLOR_DIC[i] | |
| real_imgs = \ | |
| nn.functional.interpolate(real_imgs,size=(vis_size, vis_size), | |
| mode='bilinear', align_corners=False) | |
| # [-1, 1] --> [0, 1] | |
| real_imgs.add_(1).div_(2).mul_(255) | |
| real_imgs = real_imgs.data.numpy() | |
| # b x c x h x w --> b x h x w x c | |
| real_imgs = np.transpose(real_imgs, (0, 2, 3, 1)) | |
| pad_sze = real_imgs.shape | |
| middle_pad = np.zeros([pad_sze[2], 2, 3]) | |
| post_pad = np.zeros([pad_sze[1], pad_sze[2], 3]) | |
| if lr_imgs is not None: | |
| lr_imgs = \ | |
| nn.functional.interpolate(lr_imgs,size=(vis_size, vis_size), | |
| mode='bilinear', align_corners=False) | |
| # [-1, 1] --> [0, 1] | |
| lr_imgs.add_(1).div_(2).mul_(255) | |
| lr_imgs = lr_imgs.data.numpy() | |
| # b x c x h x w --> b x h x w x c | |
| lr_imgs = np.transpose(lr_imgs, (0, 2, 3, 1)) | |
| # batch x seq_len x 17 x 17 --> batch x 1 x 17 x 17 | |
| seq_len = max_word_num | |
| img_set = [] | |
| num = nvis # len(attn_maps) | |
| text_map, sentences = \ | |
| drawCaption(text_convas, captions, ixtoword, vis_size) | |
| text_map = np.asarray(text_map).astype(np.uint8) | |
| bUpdate = 1 | |
| for i in range(num): | |
| attn = attn_maps[i].cpu().view(1, -1, att_sze, att_sze) | |
| # --> 1 x 1 x 17 x 17 | |
| attn_max = attn.max(dim=1, keepdim=True) | |
| attn = torch.cat([attn_max[0], attn], 1) | |
| # | |
| attn = attn.view(-1, 1, att_sze, att_sze) | |
| attn = attn.repeat(1, 3, 1, 1).data.numpy() | |
| # n x c x h x w --> n x h x w x c | |
| attn = np.transpose(attn, (0, 2, 3, 1)) | |
| num_attn = attn.shape[0] | |
| # | |
| img = real_imgs[i] | |
| if lr_imgs is None: | |
| lrI = img | |
| else: | |
| lrI = lr_imgs[i] | |
| row = [lrI, middle_pad] | |
| row_merge = [img, middle_pad] | |
| row_beforeNorm = [] | |
| minVglobal, maxVglobal = 1, 0 | |
| for j in range(num_attn): | |
| one_map = attn[j] | |
| if (vis_size // att_sze) > 1: | |
| one_map = \ | |
| skimage.transform.pyramid_expand(one_map, sigma=20, | |
| upscale=vis_size // att_sze, | |
| multichannel=True) | |
| row_beforeNorm.append(one_map) | |
| minV = one_map.min() | |
| maxV = one_map.max() | |
| if minVglobal > minV: | |
| minVglobal = minV | |
| if maxVglobal < maxV: | |
| maxVglobal = maxV | |
| for j in range(seq_len + 1): | |
| if j < num_attn: | |
| one_map = row_beforeNorm[j] | |
| one_map = (one_map - minVglobal) / (maxVglobal - minVglobal) | |
| one_map *= 255 | |
| # | |
| PIL_im = Image.fromarray(np.uint8(img)) | |
| PIL_att = Image.fromarray(np.uint8(one_map)) | |
| merged = \ | |
| Image.new('RGBA', (vis_size, vis_size), (0, 0, 0, 0)) | |
| mask = Image.new('L', (vis_size, vis_size), (210)) | |
| merged.paste(PIL_im, (0, 0)) | |
| merged.paste(PIL_att, (0, 0), mask) | |
| merged = np.array(merged)[:, :, :3] | |
| else: | |
| one_map = post_pad | |
| merged = post_pad | |
| row.append(one_map) | |
| row.append(middle_pad) | |
| # | |
| row_merge.append(merged) | |
| row_merge.append(middle_pad) | |
| row = np.concatenate(row, 1) | |
| row_merge = np.concatenate(row_merge, 1) | |
| txt = text_map[i * FONT_MAX: (i + 1) * FONT_MAX] | |
| if txt.shape[1] != row.shape[1]: | |
| print('txt', txt.shape, 'row', row.shape) | |
| bUpdate = 0 | |
| break | |
| row = np.concatenate([txt, row, row_merge], 0) | |
| img_set.append(row) | |
| if bUpdate: | |
| img_set = np.concatenate(img_set, 0) | |
| img_set = img_set.astype(np.uint8) | |
| return img_set, sentences | |
| else: | |
| return None | |
| def build_super_images2(real_imgs, captions, cap_lens, ixtoword, | |
| attn_maps, att_sze, vis_size=256, topK=5): | |
| batch_size = real_imgs.size(0) | |
| max_word_num = np.max(cap_lens) | |
| text_convas = np.ones([batch_size * FONT_MAX, | |
| max_word_num * (vis_size + 2), 3], | |
| dtype=np.uint8) | |
| real_imgs = \ | |
| nn.functional.interpolate(real_imgs,size=(vis_size, vis_size), | |
| mode='bilinear', align_corners=False) | |
| # [-1, 1] --> [0, 1] | |
| real_imgs.add_(1).div_(2).mul_(255) | |
| real_imgs = real_imgs.data.numpy() | |
| # b x c x h x w --> b x h x w x c | |
| real_imgs = np.transpose(real_imgs, (0, 2, 3, 1)) | |
| pad_sze = real_imgs.shape | |
| middle_pad = np.zeros([pad_sze[2], 2, 3]) | |
| # batch x seq_len x 17 x 17 --> batch x 1 x 17 x 17 | |
| img_set = [] | |
| num = len(attn_maps) | |
| text_map, sentences = \ | |
| drawCaption(text_convas, captions, ixtoword, vis_size, off1=0) | |
| text_map = np.asarray(text_map).astype(np.uint8) | |
| bUpdate = 1 | |
| for i in range(num): | |
| attn = attn_maps[i].cpu().view(1, -1, att_sze, att_sze) | |
| # | |
| attn = attn.view(-1, 1, att_sze, att_sze) | |
| attn = attn.repeat(1, 3, 1, 1).data.numpy() | |
| # n x c x h x w --> n x h x w x c | |
| attn = np.transpose(attn, (0, 2, 3, 1)) | |
| num_attn = cap_lens[i] | |
| thresh = 2./float(num_attn) | |
| # | |
| img = real_imgs[i] | |
| row = [] | |
| row_merge = [] | |
| row_txt = [] | |
| row_beforeNorm = [] | |
| conf_score = [] | |
| for j in range(num_attn): | |
| one_map = attn[j] | |
| mask0 = one_map > (2. * thresh) | |
| conf_score.append(np.sum(one_map * mask0)) | |
| mask = one_map > thresh | |
| one_map = one_map * mask | |
| if (vis_size // att_sze) > 1: | |
| one_map = \ | |
| skimage.transform.pyramid_expand(one_map, sigma=20, | |
| upscale=vis_size // att_sze, | |
| multichannel=True) | |
| minV = one_map.min() | |
| maxV = one_map.max() | |
| one_map = (one_map - minV) / (maxV - minV) | |
| row_beforeNorm.append(one_map) | |
| sorted_indices = np.argsort(conf_score)[::-1] | |
| for j in range(num_attn): | |
| one_map = row_beforeNorm[j] | |
| one_map *= 255 | |
| # | |
| PIL_im = Image.fromarray(np.uint8(img)) | |
| PIL_att = Image.fromarray(np.uint8(one_map)) | |
| merged = \ | |
| Image.new('RGBA', (vis_size, vis_size), (0, 0, 0, 0)) | |
| mask = Image.new('L', (vis_size, vis_size), (180)) # (210) | |
| merged.paste(PIL_im, (0, 0)) | |
| merged.paste(PIL_att, (0, 0), mask) | |
| merged = np.array(merged)[:, :, :3] | |
| row.append(np.concatenate([one_map, middle_pad], 1)) | |
| # | |
| row_merge.append(np.concatenate([merged, middle_pad], 1)) | |
| # | |
| txt = text_map[i * FONT_MAX:(i + 1) * FONT_MAX, | |
| j * (vis_size + 2):(j + 1) * (vis_size + 2), :] | |
| row_txt.append(txt) | |
| # reorder | |
| row_new = [] | |
| row_merge_new = [] | |
| txt_new = [] | |
| for j in range(num_attn): | |
| idx = sorted_indices[j] | |
| row_new.append(row[idx]) | |
| row_merge_new.append(row_merge[idx]) | |
| txt_new.append(row_txt[idx]) | |
| row = np.concatenate(row_new[:topK], 1) | |
| row_merge = np.concatenate(row_merge_new[:topK], 1) | |
| txt = np.concatenate(txt_new[:topK], 1) | |
| if txt.shape[1] != row.shape[1]: | |
| print('Warnings: txt', txt.shape, 'row', row.shape, | |
| 'row_merge_new', row_merge_new.shape) | |
| bUpdate = 0 | |
| break | |
| row = np.concatenate([txt, row_merge], 0) | |
| img_set.append(row) | |
| if bUpdate: | |
| img_set = np.concatenate(img_set, 0) | |
| img_set = img_set.astype(np.uint8) | |
| return img_set, sentences | |
| else: | |
| return None | |
| #################################################################### | |
| def weights_init(m): | |
| classname = m.__class__.__name__ | |
| if classname.find('Conv') != -1: | |
| nn.init.orthogonal_(m.weight.data, 1.0) | |
| elif classname.find('BatchNorm') != -1: | |
| m.weight.data.normal_(1.0, 0.02) | |
| m.bias.data.fill_(0) | |
| elif classname.find('Linear') != -1: | |
| nn.init.orthogonal_(m.weight.data, 1.0) | |
| if m.bias is not None: | |
| m.bias.data.fill_(0.0) | |
| def load_params(model, new_param): | |
| for p, new_p in zip(model.parameters(), new_param): | |
| p.data.copy_(new_p) | |
| def copy_G_params(model): | |
| flatten = deepcopy(list(p.data for p in model.parameters())) | |
| return flatten | |
| def mkdir_p(path): | |
| try: | |
| os.makedirs(path) | |
| except OSError as exc: # Python >2.5 | |
| if exc.errno == errno.EEXIST and os.path.isdir(path): | |
| pass | |
| else: | |
| raise | |