Spaces:
Runtime error
Runtime error
| from cldm.model import load_state_dict | |
| from cldm.ddim_hacked import DDIMSampler | |
| from ldm.util import instantiate_from_config | |
| import os | |
| from omegaconf import OmegaConf | |
| import argparse, os | |
| from torchvision.transforms import ToTensor | |
| from torch import autocast | |
| from contextlib import nullcontext | |
| from scripts.rendertext_tool import Render_Text, load_model_from_config | |
| # def load_model_from_config(cfg, ckpt, verbose=False, not_use_ckpt=False): | |
| # sd = load_state_dict(ckpt, location='cpu') | |
| # if "model_ema.input_blocks10in_layers0weight" not in sd: | |
| # cfg.model.params.use_ema = False | |
| # model = instantiate_from_config(cfg.model) | |
| # if not not_use_ckpt: | |
| # m, u = model.load_state_dict(sd, strict=False) | |
| # if len(m) > 0 and verbose: | |
| # print("missing keys: {}".format(len(m))) | |
| # print(m) | |
| # if len(u) > 0 and verbose: | |
| # print("unexpected keys: {}".format(len(u))) | |
| # print(u) | |
| # model.cuda() | |
| # model.eval() | |
| # return model | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--cfg", | |
| type=str, | |
| default="configs/stable-diffusion/textcaps_cldm_v20.yaml", | |
| help="path to config which constructs model", | |
| ) | |
| parser.add_argument( | |
| "--ckpt", | |
| type=str, | |
| help="path to checkpoint of model", | |
| ) | |
| parser.add_argument( | |
| "--hint_range_m11", | |
| action="store_true", | |
| help="the range of the hint image ([-1, 1])", | |
| ) | |
| parser.add_argument( | |
| "--precision", | |
| type=str, | |
| help="evaluate at this precision", | |
| choices=["full", "autocast"], | |
| default="full" #"autocast" | |
| ) | |
| parser.add_argument( | |
| "--not_use_ckpt", | |
| action="store_true", | |
| help="not to use the ckpt", | |
| ) | |
| parser.add_argument( | |
| "--build_demo", | |
| action="store_true", | |
| help="whether to build the demo", | |
| ) | |
| parser.add_argument( | |
| "--sep_prompt", | |
| action="store_true", | |
| help="whether to sep the prompt", | |
| ) | |
| parser.add_argument( | |
| "--spell_prompt_type", | |
| type=int, | |
| default=1, | |
| help="1: A sign with the word 'xxx' written on it; 2: A sign that says 'xxx'", | |
| ) | |
| parser.add_argument( | |
| "--max_num_prompts", | |
| type=int, | |
| default=None, | |
| help="max num of the used prompts", | |
| ) | |
| parser.add_argument( | |
| "--grams", | |
| type=int, | |
| default=1, | |
| help="How many grams (words or symbols) to form the to-be-rendered text (used for DrawSpelling Benchmark)", | |
| ) | |
| parser.add_argument( | |
| "--num_samples", | |
| type=int, | |
| default=1, | |
| help="how many samples to produce for each given prompt. A.k.a batch size", | |
| ) | |
| parser.add_argument( | |
| "--from-file", | |
| type=str, | |
| help="if specified, load prompts from this file, separated by newlines", | |
| ) | |
| parser.add_argument( | |
| "--prompt", | |
| type=str, | |
| nargs="?", | |
| default="a sign that says 'Stable Diffusion'", | |
| help="the prompt" | |
| ) | |
| parser.add_argument( | |
| "--rendered_txt", | |
| type=str, | |
| nargs="?", | |
| default="Stable Diffusion", | |
| help="the text to render" | |
| ) | |
| parser.add_argument( | |
| "--uncond_glycon_img", | |
| action="store_true", | |
| help="whether to set glyph embedding as None while using unconditional conditioning", | |
| ) | |
| parser.add_argument( | |
| "--deepspeed_ckpt", | |
| action="store_true", | |
| help="whether to use deepspeed while training", | |
| ) | |
| parser.add_argument( | |
| "--glyph_img_size", | |
| type=int, | |
| default=256, | |
| help="the size of input images of the glyph image encoder", | |
| ) | |
| parser.add_argument( | |
| "--uncond_glyph_image_type", | |
| type=str, | |
| default="white", | |
| help="the type of rendered glyph images as unconditional conditions while using classifier-free guidance" | |
| ) | |
| parser.add_argument( | |
| "--remove_txt_in_prompt", | |
| action="store_true", | |
| help="whether to remove text in the prompt", | |
| ) | |
| parser.add_argument( | |
| "--replace_token", | |
| type=str, | |
| default="", | |
| help="the token used to replace" | |
| ) | |
| return parser | |
| if not os.path.basename(os.getcwd()) == "stablediffusion": | |
| os.chdir(os.path.join(os.getcwd(), "stablediffusion")) | |
| print(os.getcwd()) | |
| parser = parse_args() | |
| opt = parser.parse_args() | |
| if opt.deepspeed_ckpt: | |
| assert os.path.isdir(opt.ckpt) | |
| opt.ckpt = os.path.join(opt.ckpt, "checkpoint", "mp_rank_00_model_states.pt") | |
| assert os.path.exists(opt.ckpt) | |
| cfg = OmegaConf.load(f"{opt.cfg}") | |
| model = load_model_from_config(cfg, f"{opt.ckpt}", verbose=True, not_use_ckpt=opt.not_use_ckpt) | |
| hint_range_m11 = opt.hint_range_m11 | |
| sep_prompt = opt.sep_prompt | |
| ddim_sampler = DDIMSampler(model) | |
| precision_scope = autocast if opt.precision == "autocast" else nullcontext | |
| trans = ToTensor() | |
| render_tool = Render_Text( | |
| model, precision_scope, | |
| trans, | |
| hint_range_m11, | |
| sep_prompt, | |
| uncond_glycon_img= cfg.uncond_glycon_img if hasattr(cfg, "uncond_glycon_img") else opt.uncond_glycon_img, | |
| glyph_control_proc_config= cfg.glyph_control_proc_config if hasattr(cfg, "glyph_control_proc_config") else None, | |
| glyph_img_size = opt.glyph_img_size, | |
| uncond_glyph_image_type = cfg.uncond_glyph_image_type if hasattr(cfg, "uncond_glyph_image_type") else opt.uncond_glyph_image_type, | |
| remove_txt_in_prompt = cfg.remove_txt_in_prompt if hasattr(cfg, "remove_txt_in_prompt") else opt.remove_txt_in_prompt, | |
| replace_token = cfg.replace_token if hasattr(cfg, "replace_token") else opt.replace_token, | |
| ) | |
| if opt.build_demo: | |
| import gradio as gr | |
| block = gr.Blocks().queue() | |
| with block: | |
| with gr.Row(): | |
| gr.Markdown("## Control Stable Diffusion with Glyph Images") | |
| with gr.Row(): | |
| with gr.Column(): | |
| # input_image = gr.Image(source='upload', type="numpy") | |
| rendered_txt = gr.Textbox(label="rendered_txt") | |
| prompt = gr.Textbox(label="Prompt") | |
| if sep_prompt: | |
| prompt_2 = gr.Textbox(label="Prompt_ControlNet") | |
| else: | |
| prompt_2 = gr.Number(value = 0, visible = False) #None #"" | |
| run_button = gr.Button(label="Run") | |
| with gr.Accordion("Advanced options", open=False): | |
| width = gr.Slider(label="bbox_width", minimum=0., maximum=1, value=0.3, step=0.01) | |
| # height = gr.Slider(label="bbox_height", minimum=0., maximum=1, value=0.2, step=0.01) | |
| ratio = gr.Slider(label="bbox_width_height_ratio", minimum=0., maximum=5, value=0., step=0.02) | |
| top_left_x = gr.Slider(label="bbox_top_left_x", minimum=0., maximum=1, value=0.5, step=0.01) | |
| top_left_y = gr.Slider(label="bbox_top_left_y", minimum=0., maximum=1, value=0.5, step=0.01) | |
| yaw = gr.Slider(label="bbox_yaw", minimum=-180, maximum=180, value=0, step=5) | |
| num_rows = gr.Slider(label="num_rows", minimum=1, maximum=4, value=1, step=1) | |
| num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) | |
| image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64) | |
| strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01) | |
| guess_mode = gr.Checkbox(label='Guess Mode', value=False) | |
| # low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1) | |
| # high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1) | |
| ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) | |
| scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1) | |
| seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) | |
| eta = gr.Number(label="eta (DDIM)", value=0.0) | |
| a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed') | |
| n_prompt = gr.Textbox(label="Negative Prompt", | |
| value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality') | |
| with gr.Column(): | |
| result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') | |
| ips = [ | |
| rendered_txt, prompt, | |
| width, ratio, # height, | |
| top_left_x, top_left_y, yaw, num_rows, | |
| a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, | |
| prompt_2 | |
| ] | |
| run_button.click(fn=render_tool.process, inputs=ips, outputs=[result_gallery]) | |
| # run_button.click(fn=process, inputs=ips, outputs=[result_gallery]) | |
| block.launch(server_name='0.0.0.0', share=True) | |
| else: | |
| import easyocr | |
| reader = easyocr.Reader(['en']) | |
| # num_samples = 1 | |
| # rendered_txt = "happy" | |
| # prompt = "A sign that says 'happy'" | |
| num_samples = opt.num_samples | |
| print("the num of samples is {}".format(num_samples)) | |
| if not opt.from_file: | |
| prompts = [opt.prompt] | |
| data = [opt.rendered_txt] | |
| print("the prompt is {}".format(prompts)) | |
| print("the rendered_txt is {}".format(data)) | |
| assert prompts is not None | |
| else: | |
| print(f"reading prompts from {opt.from_file}") | |
| with open(opt.from_file, "r") as f: | |
| data = f.read().splitlines() | |
| if "gram" in os.path.basename(opt.from_file): | |
| data = [item.split("\t")[0] for item in data] | |
| if opt.grams > 1: | |
| data = [" ".join(data[i:i + opt.grams]) for i in range(0, len(data), opt.grams)] | |
| if "DrawText_Spelling" in os.path.basename(opt.from_file) or "gram" in os.path.basename(opt.from_file): | |
| if opt.spell_prompt_type == 1: | |
| prompts = ['A sign with the word "{}" written on it'.format(line.strip()) for line in data] | |
| elif opt.spell_prompt_type == 2: | |
| prompts = ["A sign that says '{}'".format(line.strip()) for line in data] | |
| elif opt.spell_prompt_type == 20: | |
| prompts = ['A sign that says "{}"'.format(line.strip()) for line in data] | |
| elif opt.spell_prompt_type == 3: | |
| prompts = ["A whiteboard that says '{}'".format(line.strip()) for line in data] | |
| elif opt.spell_prompt_type == 30: | |
| prompts = ['A whiteboard that says "{}"'.format(line.strip()) for line in data] | |
| else: | |
| print("Only five types of prompt templates are supported currently") | |
| raise ValueError | |
| # if opt.verbose_all_prompts: | |
| # show_num = opt.max_num_prompts if (opt.max_num_prompts is not None and opt.max_num_prompts >0) else 10 | |
| # for i in range(show_num): | |
| # print("embed the word into the prompt template for {} Benchmark: {}".format( | |
| # os.path.basename(opt.from_file), data[i]) | |
| # ) | |
| # else: | |
| # print("embed the word into the prompt template for {} Benchmark: e.g., {}".format( | |
| # os.path.basename(opt.from_file), data[0]) | |
| # ) | |
| if opt.max_num_prompts is not None and opt.max_num_prompts >0: | |
| print("only use {} prompts to test the model".format(opt.max_num_prompts)) | |
| data = data[:opt.max_num_prompts] | |
| prompts = prompts[:opt.max_num_prompts] | |
| width, ratio, top_left_x, top_left_y, yaw, num_rows = 0.3, 0, 0.5, 0.5, 0, 1 | |
| image_resolution = 512 | |
| strength = 1 | |
| guess_mode = False | |
| ddim_steps = 20 | |
| scale = 9.0 | |
| seed = 1945923867 | |
| eta = 0 | |
| a_prompt = 'best quality, extremely detailed' | |
| n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' | |
| all_results_list = [] | |
| for i in range(len(data)): | |
| ips = ( | |
| data[i], prompts[i], | |
| width, ratio, top_left_x, top_left_y, yaw, num_rows, | |
| a_prompt, n_prompt, | |
| num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta | |
| ) | |
| all_results = render_tool.process(*ips) #process(*ips) | |
| all_results_list.extend(all_results[1:] if data[i] != "" else all_results) | |
| all_ocr_info = [] | |
| for image_array in all_results_list: | |
| ocr_result = reader.readtext(image_array) | |
| all_ocr_info.append(ocr_result) | |