Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| import cv2 | |
| import random | |
| import numpy as np | |
| import gradio as gr | |
| import torch | |
| from zhipuai import ZhipuAI | |
| from pytorch_lightning import seed_everything | |
| from pprint import pprint | |
| from PIL import Image, ImageDraw, ImageFont | |
| from diffusers import ( | |
| ControlNetModel, | |
| StableDiffusionControlNetPipeline, | |
| ) | |
| from diffusers import ( | |
| DDIMScheduler, | |
| PNDMScheduler, | |
| EulerAncestralDiscreteScheduler, | |
| DPMSolverMultistepScheduler, | |
| EulerDiscreteScheduler, | |
| LMSDiscreteScheduler, | |
| HeunDiscreteScheduler | |
| ) | |
| from controlnet_aux import ( | |
| PidiNetDetector, | |
| HEDdetector | |
| ) | |
| BBOX_MAX_NUM = 8 | |
| BBOX_INI_NUM = 0 | |
| MAX_LENGTH = 20 | |
| device = 'cuda' | |
| pipeline = None | |
| pre_pipeline = None | |
| model_root = os.getenv('REPO_ROOT') | |
| scheduler_root = f'{model_root}/Scheduler' | |
| model_list =[ | |
| 'JoyType.v1.0', 'RevAnimated-animation-动漫', 'GhostMix-animation-动漫', | |
| 'rpg.v5-fantasy_realism-奇幻写实', 'midjourneyPapercut-origami-折纸版画', | |
| 'dvarchExterior-architecture-建筑', 'awpainting.v13-portrait-人物肖像' | |
| ] | |
| chn_example_dict = { | |
| '漂亮的风景照,很多山峰,清澈的湖水': 'beautiful landscape, many peaks, clear lake', | |
| '画有玫瑰的卡片,明亮的背景': 'a card with roses, bright background', | |
| '一张关于健康教育的卡片,上面有一些文字,有一些食物图标,背景里有一些水果喝饮料的图标,且背景是模糊的': \ | |
| 'a card for health education, with some writings on it, ' | |
| 'food icons on the card, some fruits and drinking in the background, blur background ' | |
| } | |
| match_dict = { | |
| 'JoyType.v1.0': 'JoyType-v1-1M', | |
| 'RevAnimated-animation-动漫': 'rev-animated-v1-2-2', | |
| 'GhostMix-animation-动漫': 'GhostMix_V2.0', | |
| 'rpg.v5-fantasy_realism-奇幻写实': 'rpg_v5', | |
| 'midjourneyPapercut-origami-折纸版画': 'midjourneyPapercut_v1', | |
| 'dvarchExterior-architecture-建筑': 'dvarchExterior', | |
| 'awpainting.v13-portrait-人物肖像': 'awpainting_v13' | |
| } | |
| font_list = [ | |
| 'CHN-华文行楷', | |
| 'CHN-华文新魏', | |
| 'CHN-清松手写体', | |
| 'CHN-巴蜀墨迹', | |
| 'CHN-雷盖体', | |
| 'CHN-演示夏行楷', | |
| 'CHN-鸿雷板书简体', | |
| 'CHN-斑马字类', | |
| 'CHN-青柳隶书', | |
| 'CHN-辰宇落雁体', | |
| 'CHN-宅家麦克笔', | |
| 'ENG-Playwrite', | |
| 'ENG-Okesip', | |
| 'ENG-Shrikhand', | |
| 'ENG-Nextstep', | |
| 'ENG-Filthyrich', | |
| 'ENG-BebasNeue', | |
| 'ENG-Gloock', | |
| 'ENG-Lemon', | |
| 'RUS-Automatons', | |
| 'RUS-MKyrill', | |
| 'RUS-Alice', | |
| 'RUS-Caveat', | |
| 'KOR-ChosunGs', | |
| 'KOR-Dongle', | |
| 'KOR-GodoMaum', | |
| 'KOR-UnDotum', | |
| 'JPN-GlTsukiji', | |
| 'JPN-Aoyagireisyosimo', | |
| 'JPN-KouzanMouhitu', | |
| 'JPN-Otomanopee' | |
| ] | |
| def change_settings(base_model): | |
| if base_model == model_list[0]: | |
| return gr.update(value=20), gr.update(value=7.5), gr.update(value='PNDM') | |
| elif base_model == model_list[1]: | |
| return gr.update(value=30), gr.update(value=8.5), gr.update(value='Euler') | |
| elif base_model == model_list[2]: | |
| return gr.update(value=32), gr.update(value=8.5), gr.update(value='Euler') | |
| elif base_model == model_list[3]: | |
| return gr.update(value=20), gr.update(value=7.5), gr.update(value='DPM') | |
| elif base_model == model_list[4]: | |
| return gr.update(value=25), gr.update(value=6.5), gr.update(value='Euler') | |
| elif base_model == model_list[5]: | |
| return gr.update(value=25), gr.update(value=8.5), gr.update(value='Euler') | |
| elif base_model == model_list[6]: | |
| return gr.update(value=25), gr.update(value=7), gr.update(value='DPM') | |
| else: | |
| pass | |
| def update_box_num(choice): | |
| update_list_1 = [] # checkbox | |
| update_list_2 = [] # font | |
| update_list_3 = [] # text | |
| update_list_4 = [] # bounding box | |
| for i in range(BBOX_MAX_NUM): | |
| if i < choice: | |
| update_list_1.append(gr.update(value=True)) | |
| update_list_2.append(gr.update(visible=True)) | |
| update_list_3.append(gr.update(visible=True)) | |
| update_list_4.extend([gr.update(visible=False) for _ in range(4)]) | |
| else: | |
| update_list_1.append(gr.update(value=False)) | |
| update_list_2.append(gr.update(visible=False, value='CHN-华文行楷')) | |
| update_list_3.append(gr.update(visible=False, value='')) | |
| update_list_4.extend([ | |
| gr.update(visible=False, value=0.4), | |
| gr.update(visible=False, value=0.4), | |
| gr.update(visible=False, value=0.2), | |
| gr.update(visible=False, value=0.2) | |
| ]) | |
| return *update_list_1, *update_list_2, *update_list_3, *update_list_4 | |
| def load_box_list(example_id, choice): | |
| with open(f'templates/{example_id}.json', 'r') as f: | |
| info = json.load(f) | |
| update_list1 = [] | |
| update_list2 = [] | |
| update_list3 = [] | |
| update_list4 = [] | |
| for i in range(BBOX_MAX_NUM): | |
| visible = info['visible'][i] | |
| pos = info['pos'][i * 4: (i + 1) * 4] | |
| update_list1.append(gr.update(value=visible)) | |
| update_list2.append(gr.update(value=info['font'][i], visible=visible)) | |
| update_list3.append(gr.update(value=info['text'][i], visible=visible)) | |
| update_list4.extend([ | |
| gr.update(value=pos[0]), | |
| gr.update(value=pos[1]), | |
| gr.update(value=pos[2]), | |
| gr.update(value=pos[3]) | |
| ]) | |
| return *update_list1, *update_list2, \ | |
| *update_list3, *update_list4, gr.update(value=-1) | |
| def re_edit(): | |
| global BBOX_MAX_NUM | |
| update_list = [] | |
| for i in range(BBOX_MAX_NUM): | |
| update_list.extend([gr.update(value=0.4), gr.update(value=0.4), gr.update(value=0.2), | |
| gr.update(value=0.2)]) | |
| return *update_list, \ | |
| gr.Image( | |
| value=create_canvas(), | |
| label='Rect Position', elem_id='MD-bbox-rect-t2i', | |
| show_label=False, visible=True | |
| ), \ | |
| gr.Slider(value=512), gr.Slider(value=512) | |
| def resize_w(w, img): | |
| return cv2.resize(img, (w, img.shape[0])) | |
| def resize_h(h, img): | |
| return cv2.resize(img, (img.shape[1], h)) | |
| def create_canvas(w=512, h=512, c=3, line=5): | |
| image = np.full((h, w, c), 200, dtype=np.uint8) | |
| for i in range(h): | |
| if i % (w // line) == 0: | |
| image[i, :, :] = 150 | |
| for j in range(w): | |
| if j % (w // line) == 0: | |
| image[:, j, :] = 150 | |
| image[h // 2 - 8:h // 2 + 8, w // 2 - 8:w // 2 + 8, :] = [200, 0, 0] | |
| return image | |
| def canny(img): | |
| low_threshold = 64 | |
| high_threshold = 100 | |
| img = cv2.Canny(img, low_threshold, high_threshold) | |
| img = img[:, :, None] | |
| img = np.concatenate([img, img, img], axis=2) | |
| return Image.fromarray(img) | |
| def judge_overlap(coord_list1, coord_list2): | |
| judge = coord_list1[0] < coord_list2[2] and coord_list1[2] > coord_list2[0] \ | |
| and coord_list1[1] < coord_list2[3] and coord_list1[3] > coord_list2[1] | |
| return judge | |
| def parse_render_list(box_list, shape, box_num): | |
| width = shape[0] | |
| height = shape[1] | |
| polygons = [] | |
| font_names = [] | |
| texts = [] | |
| valid_list = box_list[:box_num] | |
| pos_list = box_list[box_num: 5 * box_num] | |
| font_name_list = box_list[5 * box_num: 6 * box_num] | |
| text_list = box_list[6 * box_num: 7 * box_num] | |
| empty_flag = False | |
| print(font_name_list, text_list) | |
| for i, valid in enumerate(valid_list): | |
| if valid: | |
| pos = pos_list[i * 4: (i + 1) * 4] | |
| top_left_x = int(pos[0] * width) | |
| top_left_y = int(pos[1] * height) | |
| w = int(pos[2] * width) | |
| h = int(pos[3] * height) | |
| font_name = str(font_name_list[i]) | |
| text = str(text_list[i]) | |
| if text == '': | |
| empty_flag = True | |
| text = 'JoyType' | |
| if w <= 0 or h <= 0: | |
| gr.Warning(f'Area of the box{i + 1} cannot be zero!') | |
| return [], False | |
| polygon = [ | |
| top_left_x, | |
| top_left_y, | |
| w, h | |
| ] | |
| try: | |
| assert font_name in font_list | |
| font_name = font_name.split('-')[-1] | |
| except Exception as e: | |
| gr.Warning('Please choose a correct font!') | |
| return [], False | |
| polygons.append(polygon) | |
| font_names.append(font_name.split('-')[-1]) | |
| texts.append(text) | |
| if empty_flag: | |
| gr.Warning('Null strings will be filled automatically!') | |
| for i in range(len(polygons)): | |
| for j in range(i + 1, len(polygons)): | |
| if judge_overlap( | |
| [polygons[i][0], polygons[i][1], polygons[i][0] + polygons[i][2], polygons[i][1] + polygons[i][3]], | |
| [polygons[j][0], polygons[j][1], polygons[j][0] + polygons[j][2], polygons[j][1] + polygons[j][3]] | |
| ): | |
| gr.Warning('Find overlapping boxes!') | |
| return [], False | |
| render_list = [] | |
| for i in range(len(polygons)): | |
| text_dict = {} | |
| text_dict['text'] = texts[i] | |
| text_dict['polygon'] = polygons[i] | |
| text_dict['font_name'] = font_names[i] | |
| render_list.append(text_dict) | |
| return render_list, True | |
| def render_all_text(render_list, shape, threshold=512): | |
| width = shape[0] | |
| height = shape[1] | |
| board = Image.new('RGB', (width, height), 'black') | |
| for text_dict in render_list: | |
| text = text_dict['text'] | |
| polygon = text_dict['polygon'] | |
| font_name = text_dict['font_name'] | |
| if len(text) > MAX_LENGTH: | |
| text = text[:MAX_LENGTH] | |
| gr.Warning(f'{text}... exceeds the maximum length {MAX_LENGTH} and has been cropped.') | |
| w, h = polygon[2:] | |
| vert = True if w < h else False | |
| image4ratio = Image.new('RGB', (1024, 1024), 'black') | |
| draw = ImageDraw.Draw(image4ratio) | |
| try: | |
| font = ImageFont.truetype(f'./font/{font_name}.ttf', encoding='utf-8', size=50) | |
| except FileNotFoundError: | |
| font = ImageFont.truetype(f'./font/{font_name}.otf', encoding='utf-8', size=50) | |
| if not vert: | |
| draw.text(xy=(0, 0), text=text, font=font, fill='white') | |
| _, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font) | |
| _th += 1 | |
| else: | |
| _tw, y_c = 0, 0 | |
| for c in text: | |
| draw.text(xy=(0, y_c), text=c, font=font, fill='white') | |
| _l, _t, _r, _b = font.getbbox(c) | |
| _tw = max(_tw, _r - _l) | |
| y_c += _b | |
| _th = y_c + 1 | |
| ratio = (_th * w) / (_tw * h) | |
| text_img = image4ratio.crop((0, 0, _tw, _th)) | |
| x_offset, y_offset = 0, 0 | |
| if 0.8 <= ratio <= 1.2: | |
| text_img = text_img.resize((w, h)) | |
| elif ratio < 0.75: | |
| resize_h = int(_th * (w / _tw)) | |
| text_img = text_img.resize((w, resize_h)) | |
| y_offset = (h - resize_h) // 2 | |
| else: | |
| resize_w = int(_tw * (h / _th)) | |
| text_img = text_img.resize((resize_w, h)) | |
| x_offset = (w - resize_w) // 2 | |
| board.paste(text_img, (polygon[0] + x_offset, polygon[1] + y_offset)) | |
| return board | |
| def load_pipeline(model_name, scheduler_name): | |
| controlnet_path = os.path.join(model_root, f'{match_dict["JoyType.v1.0"]}') | |
| model_path = os.path.join(model_root, model_name) | |
| scheduler_name = scheduler_name.lower() | |
| if scheduler_name == 'pndm': | |
| scheduler = PNDMScheduler.from_pretrained(scheduler_root, subfolder='pndm') | |
| if scheduler_name == 'lms': | |
| scheduler = LMSDiscreteScheduler.from_pretrained(scheduler_root, subfolder='lms') | |
| if scheduler_name == 'euler': | |
| scheduler = EulerDiscreteScheduler.from_pretrained(scheduler_root, subfolder='euler') | |
| if scheduler_name == 'dpm': | |
| scheduler = DPMSolverMultistepScheduler.from_pretrained(scheduler_root, subfolder='dpm') | |
| if scheduler_name == 'ddim': | |
| scheduler = DDIMScheduler.from_pretrained(scheduler_root, subfolder='ddim') | |
| if scheduler_name == 'heun': | |
| scheduler = HeunDiscreteScheduler.from_pretrained(scheduler_root, subfolder='heun') | |
| if scheduler_name == 'euler-ancestral': | |
| scheduler = EulerAncestralDiscreteScheduler.from_pretrained(scheduler_root, subfolder='euler-ancestral') | |
| controlnet = ControlNetModel.from_pretrained( | |
| controlnet_path, | |
| subfolder='controlnet', | |
| torch_dtype=torch.float32 | |
| ) | |
| pipeline = StableDiffusionControlNetPipeline.from_pretrained( | |
| model_path, | |
| scheduler=scheduler, | |
| controlnet=controlnet, | |
| torch_dtype=torch.float32, | |
| ).to(device) | |
| return pipeline | |
| def preprocess_prompt(prompt): | |
| client = ZhipuAI(api_key=os.getenv('ZHIPU_API_KEY')) | |
| response = client.chat.completions.create( | |
| model="glm-4-0520", | |
| messages=[ | |
| { | |
| 'role': 'system', | |
| 'content': ''' | |
| Stable Diffusion是一款利用深度学习的文生图模型,支持通过使用提示词来产生新的图像,描述要包含或省略的元素。 | |
| 我在这里引入Stable Diffusion算法中的Prompt概念,又被称为提示符。这里的Prompt通常可以用来描述图像, | |
| 他由普通常见的单词构成,最好是可以在数据集来源站点找到的著名标签(比如Ddanbooru)。 | |
| 下面我将说明Prompt的生出步骤,这里的Prompt主要用于描述人物。在Prompt的生成中,你需要通过提示词来描述 人物属性,主题,外表,情绪,衣服,姿势,视角,动作,背景。 | |
| 用英语单词或短语甚至自然语言的标签来描述,并不局限于我给你的单词。然后将你想要的相似的提示词组合在一起,请使用英文半角,做分隔符,每个提示词不要带引号,并将这些按从最重要到最不重要的顺序 排列。 | |
| 另外请您注意,永远在每个 Prompt的前面加上引号里的内容, | |
| “(((best quality))),(((ultra detailed))),(((masterpiece))),illustration,” 这是高质量的标志。 | |
| 人物属性中,1girl表示你生成了一个女孩,2girls表示生成了两个女孩,一次。另外再注意,Prompt中不能带有-和_。 | |
| 可以有空格和自然语言,但不要太多,单词不能重复。只返回Prompt。 | |
| ''' | |
| }, | |
| { | |
| 'role': 'user', | |
| 'content': prompt | |
| } | |
| ], | |
| temperature=0.5, | |
| max_tokens=2048, | |
| top_p=1, | |
| stream=False, | |
| ) | |
| if response: | |
| glm = [] | |
| glm_return_list = response.choices | |
| for item in glm_return_list: | |
| glm.append(item.message.content) | |
| return {'flag': 1, 'data': glm} | |
| else: | |
| return {'flag': 0, 'data': {}} | |
| def process( | |
| num_samples, | |
| a_prompt, | |
| n_prompt, | |
| conditioning_scale, | |
| cfg_scale, | |
| inference_steps, | |
| seed, | |
| usr_prompt, | |
| rect_img, | |
| base_model, | |
| scheduler_name, | |
| box_num, | |
| *box_list | |
| ): | |
| if usr_prompt == '': | |
| gr.Warning('Must input a prompt!') | |
| return None, gr.Markdown('error') | |
| if seed == -1: | |
| seed = random.randint(0, 2147483647) | |
| seed_everything(seed) | |
| # Support Chinese Input | |
| if usr_prompt in chn_example_dict.keys(): | |
| usr_prompt = chn_example_dict[usr_prompt] | |
| else: | |
| for ch in usr_prompt: | |
| if '\u4e00' <= ch <= '\u9fff': | |
| data = preprocess_prompt(usr_prompt) | |
| if data['flag'] == 1: | |
| usr_prompt = data['data'][0][1: -1] | |
| else: | |
| gr.Warning('Something went wrong while translating your prompt, please try again.') | |
| return None, gr.Markdown('error') | |
| break | |
| shape = (rect_img.shape[1], rect_img.shape[0]) | |
| render_list, flag = parse_render_list(box_list, shape, box_num) | |
| if flag: | |
| render_img = render_all_text(render_list, shape) | |
| else: | |
| return None, gr.Markdown('error') | |
| model_name = match_dict[base_model] | |
| render_img = canny(np.array(render_img)) | |
| w, h = render_img.size | |
| global pipeline, pre_pipeline | |
| if pre_pipeline != model_name or pipeline is None: | |
| pre_pipeline = model_name | |
| pipeline = load_pipeline(model_name, scheduler_name) | |
| batch_render_img = [render_img for _ in range(num_samples)] | |
| batch_prompt = [f'{usr_prompt}, {a_prompt}' for _ in range(num_samples)] | |
| batch_n_prompt = [n_prompt for _ in range(num_samples)] | |
| images = pipeline( | |
| batch_prompt, | |
| negative_prompt=batch_n_prompt, | |
| image=batch_render_img, | |
| controlnet_conditioning_scale=float(conditioning_scale), | |
| guidance_scale=float(cfg_scale), | |
| width=w, | |
| height=h, | |
| num_inference_steps=int(inference_steps), | |
| ).images | |
| return images, gr.Markdown(f'{seed}, {usr_prompt}, {box_list}') | |
| def draw_example(box_list, color, id): | |
| board = Image.fromarray(create_canvas()) | |
| w, h = board.size | |
| draw = ImageDraw.Draw(board, mode='RGBA') | |
| visible = box_list[:BBOX_MAX_NUM] | |
| pos = box_list[BBOX_MAX_NUM: 5 * BBOX_MAX_NUM] | |
| font = box_list[5 * BBOX_MAX_NUM: 6 * BBOX_MAX_NUM] | |
| text = box_list[6 * BBOX_MAX_NUM:] | |
| info = { | |
| 'visible': list(visible), | |
| 'pos': list(pos), | |
| 'font': list(font), | |
| 'text': list(text) | |
| } | |
| with open(f'templates/{id}.json', 'w') as f: | |
| json.dump(info, f) | |
| for i in range(BBOX_MAX_NUM): | |
| if visible[i] is True: | |
| polygon = pos[i * 4: (i + 1) * 4] | |
| print(polygon) | |
| left = w * polygon[0] | |
| top = h * polygon[1] | |
| right = left + w * polygon[2] | |
| bottom = top + h * polygon[3] | |
| draw.rectangle([left, top, right, bottom], outline=color[i][0], fill=color[i][1], width=3) | |
| board.save(f'./examples/{id}.png') | |
| if __name__ == '__main__': | |
| pass | |