| | ''' |
| | AnyText: Multilingual Visual Text Generation And Editing |
| | Paper: https://arxiv.org/abs/2311.03054 |
| | Code: https://github.com/tyxsspa/AnyText |
| | Copyright (c) Alibaba, Inc. and its affiliates. |
| | ''' |
| | import os |
| | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' |
| | import torch |
| | import random |
| | import re |
| | import numpy as np |
| | import cv2 |
| | import einops |
| | import time |
| | from PIL import ImageFont |
| | from cldm.model import create_model, load_state_dict |
| | from cldm.ddim_hacked import DDIMSampler |
| | from t3_dataset import draw_glyph, draw_glyph2 |
| | from util import check_channels, resize_image, save_images |
| | from pytorch_lightning import seed_everything |
| | from modelscope.pipelines import pipeline |
| | from modelscope.utils.constant import Tasks |
| | from modelscope.models.base import TorchModel |
| | from modelscope.preprocessors.base import Preprocessor |
| | from modelscope.pipelines.base import Model, Pipeline |
| | from modelscope.utils.config import Config |
| | from modelscope.pipelines.builder import PIPELINES |
| | from modelscope.preprocessors.builder import PREPROCESSORS |
| | from modelscope.models.builder import MODELS |
| | from bert_tokenizer import BasicTokenizer |
| | checker = BasicTokenizer() |
| | BBOX_MAX_NUM = 8 |
| | PLACE_HOLDER = '*' |
| | max_chars = 20 |
| |
|
| |
|
| | @MODELS.register_module('my-anytext-task', module_name='my-custom-model') |
| | class MyCustomModel(TorchModel): |
| |
|
| | def __init__(self, model_dir, *args, **kwargs): |
| | super().__init__(model_dir, *args, **kwargs) |
| | self.init_model(**kwargs) |
| |
|
| | ''' |
| | return: |
| | result: list of images in numpy.ndarray format |
| | rst_code: 0: normal -1: error 1:warning |
| | rst_info: string of error or warning |
| | debug_info: string for debug, only valid if show_debug=True |
| | ''' |
| | def forward(self, input_tensor, **forward_params): |
| | tic = time.time() |
| | str_warning = '' |
| | |
| | seed = input_tensor.get('seed', -1) |
| | if seed == -1: |
| | seed = random.randint(0, 99999999) |
| | seed_everything(seed) |
| | prompt = input_tensor.get('prompt') |
| | draw_pos = input_tensor.get('draw_pos') |
| | ori_image = input_tensor.get('ori_image') |
| |
|
| | mode = forward_params.get('mode') |
| | sort_priority = forward_params.get('sort_priority', '↕') |
| | show_debug = forward_params.get('show_debug', False) |
| | revise_pos = forward_params.get('revise_pos', False) |
| | img_count = forward_params.get('image_count', 4) |
| | ddim_steps = forward_params.get('ddim_steps', 20) |
| | w = forward_params.get('image_width', 512) |
| | h = forward_params.get('image_height', 512) |
| | strength = forward_params.get('strength', 1.0) |
| | cfg_scale = forward_params.get('cfg_scale', 9.0) |
| | eta = forward_params.get('eta', 0.0) |
| | a_prompt = forward_params.get('a_prompt', 'best quality, extremely detailed,4k, HD, supper legible text, clear text edges, clear strokes, neat writing, no watermarks') |
| | n_prompt = forward_params.get('n_prompt', 'low-res, bad anatomy, extra digit, fewer digits, cropped, worst quality, low quality, watermark, unreadable text, messy words, distorted text, disorganized writing, advertising picture') |
| |
|
| | prompt, texts = self.modify_prompt(prompt) |
| | n_lines = len(texts) |
| | if mode in ['text-generation', 'gen']: |
| | edit_image = np.ones((h, w, 3)) * 127.5 |
| | elif mode in ['text-editing', 'edit']: |
| | if draw_pos is None or ori_image is None: |
| | return None, -1, "Reference image and position image are needed for text editing!", "" |
| | if isinstance(ori_image, str): |
| | ori_image = cv2.imread(ori_image)[..., ::-1] |
| | assert ori_image is not None, f"Can't read ori_image image from{ori_image}!" |
| | elif isinstance(ori_image, torch.Tensor): |
| | ori_image = ori_image.cpu().numpy() |
| | else: |
| | assert isinstance(ori_image, np.ndarray), f'Unknown format of ori_image: {type(ori_image)}' |
| | edit_image = ori_image.clip(1, 255) |
| | edit_image = check_channels(edit_image) |
| | edit_image = resize_image(edit_image, max_length=768) |
| | h, w = edit_image.shape[:2] |
| | |
| | if draw_pos is None: |
| | pos_imgs = np.zeros((w, h, 1)) |
| | if isinstance(draw_pos, str): |
| | draw_pos = cv2.imread(draw_pos)[..., ::-1] |
| | assert draw_pos is not None, f"Can't read draw_pos image from{draw_pos}!" |
| | pos_imgs = 255-draw_pos |
| | elif isinstance(draw_pos, torch.Tensor): |
| | pos_imgs = draw_pos.cpu().numpy() |
| | else: |
| | assert isinstance(draw_pos, np.ndarray), f'Unknown format of draw_pos: {type(draw_pos)}' |
| | pos_imgs = pos_imgs[..., 0:1] |
| | pos_imgs = cv2.convertScaleAbs(pos_imgs) |
| | _, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY) |
| | |
| | pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority) |
| | if len(pos_imgs) == 0: |
| | pos_imgs = [np.zeros((h, w, 1))] |
| | if len(pos_imgs) < n_lines: |
| | if n_lines == 1 and texts[0] == ' ': |
| | pass |
| | else: |
| | return None, -1, f'Found {len(pos_imgs)} positions that < needed {n_lines} from prompt, check and try again!', '' |
| | elif len(pos_imgs) > n_lines: |
| | str_warning = f'Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt.' |
| | |
| | pre_pos = [] |
| | poly_list = [] |
| | for input_pos in pos_imgs: |
| | if input_pos.mean() != 0: |
| | input_pos = input_pos[..., np.newaxis] if len(input_pos.shape) == 2 else input_pos |
| | poly, pos_img = self.find_polygon(input_pos) |
| | pre_pos += [pos_img/255.] |
| | poly_list += [poly] |
| | else: |
| | pre_pos += [np.zeros((h, w, 1))] |
| | poly_list += [None] |
| | np_hint = np.sum(pre_pos, axis=0).clip(0, 1) |
| | |
| | info = {} |
| | info['glyphs'] = [] |
| | info['gly_line'] = [] |
| | info['positions'] = [] |
| | info['n_lines'] = [len(texts)]*img_count |
| | gly_pos_imgs = [] |
| | for i in range(len(texts)): |
| | text = texts[i] |
| | if len(text) > max_chars: |
| | str_warning = f'"{text}" length > max_chars: {max_chars}, will be cut off...' |
| | text = text[:max_chars] |
| | gly_scale = 2 |
| | if pre_pos[i].mean() != 0: |
| | gly_line = draw_glyph(self.font, text) |
| | glyphs = draw_glyph2(self.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False) |
| | gly_pos_img = cv2.drawContours(glyphs*255, [poly_list[i]*gly_scale], 0, (255, 255, 255), 1) |
| | if revise_pos: |
| | resize_gly = cv2.resize(glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0])) |
| | new_pos = cv2.morphologyEx((resize_gly*255).astype(np.uint8), cv2.MORPH_CLOSE, kernel=np.ones((resize_gly.shape[0]//10, resize_gly.shape[1]//10), dtype=np.uint8), iterations=1) |
| | new_pos = new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos |
| | contours, _ = cv2.findContours(new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) |
| | if len(contours) != 1: |
| | str_warning = f'Fail to revise position {i} to bounding rect, remain position unchanged...' |
| | else: |
| | rect = cv2.minAreaRect(contours[0]) |
| | poly = np.int0(cv2.boxPoints(rect)) |
| | pre_pos[i] = cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255. |
| | gly_pos_img = cv2.drawContours(glyphs*255, [poly*gly_scale], 0, (255, 255, 255), 1) |
| | gly_pos_imgs += [gly_pos_img] |
| | else: |
| | glyphs = np.zeros((h*gly_scale, w*gly_scale, 1)) |
| | gly_line = np.zeros((80, 512, 1)) |
| | gly_pos_imgs += [np.zeros((h*gly_scale, w*gly_scale, 1))] |
| | pos = pre_pos[i] |
| | info['glyphs'] += [self.arr2tensor(glyphs, img_count)] |
| | info['gly_line'] += [self.arr2tensor(gly_line, img_count)] |
| | info['positions'] += [self.arr2tensor(pos, img_count)] |
| | |
| | masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0)*(1-np_hint) |
| | masked_img = np.transpose(masked_img, (2, 0, 1)) |
| | masked_img = torch.from_numpy(masked_img.copy()).float().cuda() |
| | encoder_posterior = self.model.encode_first_stage(masked_img[None, ...]) |
| | masked_x = self.model.get_first_stage_encoding(encoder_posterior).detach() |
| | info['masked_x'] = torch.cat([masked_x for _ in range(img_count)], dim=0) |
| |
|
| | hint = self.arr2tensor(np_hint, img_count) |
| |
|
| | cond = self.model.get_learned_conditioning(dict(c_concat=[hint], c_crossattn=[[prompt + ' , ' + a_prompt] * img_count], text_info=info)) |
| | un_cond = self.model.get_learned_conditioning(dict(c_concat=[hint], c_crossattn=[[n_prompt] * img_count], text_info=info)) |
| | shape = (4, h // 8, w // 8) |
| | self.model.control_scales = ([strength] * 13) |
| | samples, intermediates = self.ddim_sampler.sample(ddim_steps, img_count, |
| | shape, cond, verbose=False, eta=eta, |
| | unconditional_guidance_scale=cfg_scale, |
| | unconditional_conditioning=un_cond) |
| | x_samples = self.model.decode_first_stage(samples) |
| | x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) |
| | results = [x_samples[i] for i in range(img_count)] |
| | if mode == 'edit' and False: |
| | results = [r*np_hint+edit_image*(1-np_hint) for r in results] |
| | results = [r.clip(0, 255).astype(np.uint8) for r in results] |
| | if len(gly_pos_imgs) > 0 and show_debug: |
| | glyph_bs = np.stack(gly_pos_imgs, axis=2) |
| | glyph_img = np.sum(glyph_bs, axis=2) * 255 |
| | glyph_img = glyph_img.clip(0, 255).astype(np.uint8) |
| | results += [np.repeat(glyph_img, 3, axis=2)] |
| | |
| | if not show_debug: |
| | debug_info = '' |
| | else: |
| | input_prompt = prompt |
| | for t in texts: |
| | input_prompt = input_prompt.replace('*', f'"{t}"', 1) |
| | debug_info = f'<span style="color:black;font-size:18px">Prompt: </span>{input_prompt}<br> \ |
| | <span style="color:black;font-size:18px">Size: </span>{w}x{h}<br> \ |
| | <span style="color:black;font-size:18px">Image Count: </span>{img_count}<br> \ |
| | <span style="color:black;font-size:18px">Seed: </span>{seed}<br> \ |
| | <span style="color:black;font-size:18px">Cost Time: </span>{(time.time()-tic):.2f}s' |
| | rst_code = 1 if str_warning else 0 |
| | return results, rst_code, str_warning, debug_info |
| |
|
| | def init_model(self, **kwargs): |
| | font_path = kwargs.get('font_path', 'font/Arial_Unicode.ttf') |
| | self.font = ImageFont.truetype(font_path, size=60) |
| | cfg_path = kwargs.get('cfg_path', 'models_yaml/anytext_sd15.yaml') |
| | ckpt_path = kwargs.get('model_path', os.path.join(self.model_dir, 'anytext_v1.1.ckpt')) |
| | clip_path = os.path.join(self.model_dir, 'clip-vit-large-patch14') |
| | self.model = create_model(cfg_path, cond_stage_path=clip_path).cuda().eval() |
| | self.model.load_state_dict(load_state_dict(ckpt_path, location='cuda'), strict=False) |
| | self.ddim_sampler = DDIMSampler(self.model) |
| | self.trans_pipe = pipeline(task=Tasks.translation, model=os.path.join(self.model_dir, 'nlp_csanmt_translation_zh2en')) |
| |
|
| | def modify_prompt(self, prompt): |
| | prompt = prompt.replace('“', '"') |
| | prompt = prompt.replace('”', '"') |
| | p = '"(.*?)"' |
| | strs = re.findall(p, prompt) |
| | if len(strs) == 0: |
| | strs = [' '] |
| | else: |
| | for s in strs: |
| | prompt = prompt.replace(f'"{s}"', f' {PLACE_HOLDER} ', 1) |
| | if self.is_chinese(prompt): |
| | old_prompt = prompt |
| | prompt = self.trans_pipe(input=prompt + ' .')['translation'][:-1] |
| | print(f'Translate: {old_prompt} --> {prompt}') |
| | return prompt, strs |
| |
|
| | def is_chinese(self, text): |
| | text = checker._clean_text(text) |
| | for char in text: |
| | cp = ord(char) |
| | if checker._is_chinese_char(cp): |
| | return True |
| | return False |
| |
|
| | def separate_pos_imgs(self, img, sort_priority, gap=102): |
| | num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img) |
| | components = [] |
| | for label in range(1, num_labels): |
| | component = np.zeros_like(img) |
| | component[labels == label] = 255 |
| | components.append((component, centroids[label])) |
| | if sort_priority == '↕': |
| | fir, sec = 1, 0 |
| | elif sort_priority == '↔': |
| | fir, sec = 0, 1 |
| | components.sort(key=lambda c: (c[1][fir]//gap, c[1][sec]//gap)) |
| | sorted_components = [c[0] for c in components] |
| | return sorted_components |
| |
|
| | def find_polygon(self, image, min_rect=False): |
| | contours, hierarchy = cv2.findContours(image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) |
| | max_contour = max(contours, key=cv2.contourArea) |
| | if min_rect: |
| | |
| | rect = cv2.minAreaRect(max_contour) |
| | poly = np.int0(cv2.boxPoints(rect)) |
| | else: |
| | |
| | epsilon = 0.01 * cv2.arcLength(max_contour, True) |
| | poly = cv2.approxPolyDP(max_contour, epsilon, True) |
| | n, _, xy = poly.shape |
| | poly = poly.reshape(n, xy) |
| | cv2.drawContours(image, [poly], -1, 255, -1) |
| | return poly, image |
| |
|
| | def arr2tensor(self, arr, bs): |
| | arr = np.transpose(arr, (2, 0, 1)) |
| | _arr = torch.from_numpy(arr.copy()).float().cuda() |
| | _arr = torch.stack([_arr for _ in range(bs)], dim=0) |
| | return _arr |
| |
|
| |
|
| | @PREPROCESSORS.register_module('my-anytext-task', module_name='my-custom-preprocessor') |
| | class MyCustomPreprocessor(Preprocessor): |
| |
|
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self.trainsforms = self.init_preprocessor(**kwargs) |
| |
|
| | def __call__(self, results): |
| | return self.trainsforms(results) |
| |
|
| | def init_preprocessor(self, **kwarg): |
| | """ Provide default implementation based on preprocess_cfg and user can reimplement it. |
| | if nothing to do, then return lambda x: x |
| | """ |
| | return lambda x: x |
| |
|
| |
|
| | @PIPELINES.register_module('my-anytext-task', module_name='my-custom-pipeline') |
| | class MyCustomPipeline(Pipeline): |
| | """ Give simple introduction to this pipeline. |
| | |
| | Examples: |
| | |
| | >>> from modelscope.pipelines import pipeline |
| | >>> input = "Hello, ModelScope!" |
| | >>> my_pipeline = pipeline('my-task', 'my-model-id') |
| | >>> result = my_pipeline(input) |
| | |
| | """ |
| |
|
| | def __init__(self, model, preprocessor=None, **kwargs): |
| | super().__init__(model=model, auto_collate=False) |
| | assert isinstance(model, str) or isinstance(model, Model), \ |
| | 'model must be a single str or Model' |
| | pipe_model = self.model |
| | pipe_model.eval() |
| | if preprocessor is None: |
| | preprocessor = MyCustomPreprocessor() |
| | super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) |
| |
|
| | def _sanitize_parameters(self, **pipeline_parameters): |
| | return {}, pipeline_parameters, {} |
| |
|
| | def _check_input(self, inputs): |
| | pass |
| |
|
| | def _check_output(self, outputs): |
| | pass |
| |
|
| | def forward(self, inputs, **forward_params): |
| | return super().forward(inputs, **forward_params) |
| |
|
| | def postprocess(self, inputs): |
| | return inputs |
| |
|
| |
|
| | usr_config_path = 'models' |
| | config = Config({ |
| | "framework": 'pytorch', |
| | "task": 'my-anytext-task', |
| | "model": {'type': 'my-custom-model'}, |
| | "pipeline": {"type": "my-custom-pipeline"}, |
| | "allow_remote": True |
| | }) |
| | |
| |
|
| | if __name__ == "__main__": |
| | img_save_folder = "SaveImages" |
| | inference = pipeline('my-anytext-task', model=usr_config_path) |
| | params = { |
| | "show_debug": True, |
| | "image_count": 2, |
| | "ddim_steps": 20, |
| | } |
| |
|
| | |
| | mode = 'text-generation' |
| | input_data = { |
| | "prompt": 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream', |
| | "seed": 66273235, |
| | "draw_pos": 'example_images/gen9.png' |
| | } |
| | results, rtn_code, rtn_warning, debug_info = inference(input_data, mode=mode, **params) |
| | if rtn_code >= 0: |
| | save_images(results, img_save_folder) |
| | |
| | mode = 'text-editing' |
| | input_data = { |
| | "prompt": 'A cake with colorful characters that reads "EVERYDAY"', |
| | "seed": 8943410, |
| | "draw_pos": 'example_images/edit7.png', |
| | "ori_image": 'example_images/ref7.jpg' |
| | } |
| | results, rtn_code, rtn_warning, debug_info = inference(input_data, mode=mode, **params) |
| | if rtn_code >= 0: |
| | save_images(results, img_save_folder) |
| | print(f'Done, result images are saved in: {img_save_folder}') |
| |
|