Spaces:
Runtime error
Runtime error
| import argparse | |
| import datetime | |
| import json | |
| import os | |
| import re | |
| import time | |
| import gradio as gr | |
| import requests | |
| from llava.conversation import (default_conversation, conv_templates, | |
| SeparatorStyle) | |
| from llava.constants import LOGDIR | |
| from llava.utils import (build_logger, server_error_msg, | |
| violates_moderation, moderation_msg) | |
| import hashlib | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| from copy import deepcopy | |
| logger = build_logger("gradio_web_server", "gradio_web_server.log") | |
| headers = {"User-Agent": "LLaVA Client"} | |
| no_change_btn = gr.Button() | |
| enable_btn = gr.Button(interactive=True) | |
| disable_btn = gr.Button(interactive=False) | |
| priority = { | |
| "vicuna-13b": "aaaaaaa", | |
| "koala-13b": "aaaaaab", | |
| } | |
| prompt_template = ''' | |
| <image> | |
| Hello! Could you please help me to place {N} foreground elements over the background image of resolution {resolution} to craft an aesthetically pleasing, harmonious, balanced, and visually appealing {domain_name}? | |
| Finding semantic-meaningful objects or visual foci on the background image at first might help in designing, and you should avoid any unnecessary blocking of them. | |
| For each layout, there are 3 additional user requirements and you are expected to generate a layout corresponding to them. Here is the user requirements: {cons_data} | |
| Please return the result by completing the following JSON file. Each element's location and size should be represented by a bounding box described as [left, top, right, bottom], and each number is a continuous digit from 0 to 1. | |
| Here is the initial JSON file: {json_data} | |
| ''' | |
| ELEM_CLASSES = { | |
| "QB-Poster": ["title", "decoration", "subtitle", "itemtitle", "itemlogo", "item", "text", "textbackground", "object", "frame"], | |
| "CGL": ["text", "underlay", "embellishment"], | |
| "Ad Banners": ["header", "preheader", "postheader", "body text", "disclaimer / footnote", "button", "callout", "logo"] | |
| } | |
| CLS2COLOR = { | |
| "QB-Poster": { | |
| "title": "red", "subtitle": "green", "itemlogo": "orange", "item": "blue", "itemtitle": "yellow", | |
| "object": "purple", "textbackground": "pink", "decoration": "brown", "frame": "gray", "text": "cyan", | |
| "false": "black" | |
| }, | |
| "CGL": { | |
| "text": "red", "underlay": "green", "embellishment": "blue", "false": "black" | |
| }, | |
| "Ad Banners": { | |
| "header": "red", "preheader": "green", "postheader": "blue", "body text": "orange", "disclaimer / footnote": "purple", | |
| "button": "pink", "callout": "brown", "logo": "gray", "false": "black" | |
| } | |
| } | |
| def get_json_response(response): | |
| for i in range(len(response)): | |
| if i < len(response) - 1 and response[i:i+2] == "[{": | |
| lo = i | |
| elif i > 1 and response[i-1:i+1] == "}]": | |
| hi = i | |
| try: | |
| string = response[lo:hi+1].replace("'", '"') | |
| json_response = json.loads(string) | |
| except: | |
| json_response = None | |
| return json_response | |
| def draw_box(img, elems, elems2, cls2color): | |
| W, H = img.size | |
| drawn_outline = img.copy() | |
| drawn_fill = img.copy() | |
| draw_ol = ImageDraw.ImageDraw(drawn_outline) | |
| draw_f = ImageDraw.ImageDraw(drawn_fill) | |
| for cls, box in elems: | |
| color = cls2color[cls] | |
| left, top, right, bottom = box | |
| _box = int(left * W), int(top * H), int(right * W), int(bottom * H) | |
| draw_ol.rectangle(_box, fill=None, outline=color, width=max(10 * (W + H) // (1242 + 1660), 1)) | |
| draw_f.rectangle(_box, fill=color) | |
| drawn_outline = drawn_outline.convert("RGBA") | |
| drawn_fill = drawn_fill.convert("RGBA") | |
| drawn_fill.putalpha(int(256 * 0.1)) | |
| drawn = Image.alpha_composite(drawn_outline, drawn_fill) | |
| return drawn | |
| def draw_boxmap(json_response, background_image, cls2color): | |
| pic = background_image.convert("RGB") | |
| cls_box = [(elem['label'], elem['box']) for elem in json_response] | |
| print(cls_box) | |
| drawn = draw_box(background_image, cls_box, cls_box, cls2color) | |
| return drawn.convert("RGB") | |
| def get_conv_log_filename(): | |
| t = datetime.datetime.now() | |
| name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") | |
| return name | |
| def get_model_list(): | |
| ret = requests.post(args.controller_url + "/refresh_all_workers") | |
| assert ret.status_code == 200 | |
| ret = requests.post(args.controller_url + "/list_models") | |
| models = ret.json()["models"] | |
| models.sort(key=lambda x: priority.get(x, x)) | |
| logger.info(f"Models: {models}") | |
| return models | |
| get_window_url_params = """ | |
| function() { | |
| const params = new URLSearchParams(window.location.search); | |
| url_params = Object.fromEntries(params); | |
| console.log(url_params); | |
| return url_params; | |
| } | |
| """ | |
| def load_demo(url_params, request: gr.Request): | |
| logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") | |
| dropdown_update = gr.Dropdown(visible=True) | |
| if "model" in url_params: | |
| model = url_params["model"] | |
| if model in models: | |
| dropdown_update = gr.Dropdown(value=model, visible=True) | |
| state = default_conversation.copy() | |
| return state, dropdown_update | |
| def load_demo_refresh_model_list(request: gr.Request): | |
| logger.info(f"load_demo. ip: {request.client.host}") | |
| models = get_model_list() | |
| state = default_conversation.copy() | |
| dropdown_update = gr.Dropdown( | |
| choices=models, | |
| value=models[0] if len(models) > 0 else "" | |
| ) | |
| return state, dropdown_update | |
| def init_json(elem_list, dataset): | |
| json_data = [] | |
| for i, label in enumerate(ELEM_CLASSES[dataset]): | |
| num = int(elem_list[i]) | |
| json_data += [{"label": label, "box": []} for _ in range(num) if num > 0] | |
| return json_data | |
| def init_conv(request: gr.Request): | |
| logger.info(f"init_conversation. ip: {request.client.host}") | |
| state = default_conversation.copy() | |
| return (state, state.to_gradio_chatbot()) + (enable_btn,) * 3 | |
| def add_text(state, text, image, image_process_mode, request: gr.Request): | |
| logger.info(f"add_text. ip: {request.client.host}.") | |
| if image is not None: | |
| text = (text, image, image_process_mode) | |
| if len(state.get_images(return_pil=True)) > 0: | |
| state = default_conversation.copy() | |
| state.append_message(state.roles[0], text) | |
| state.append_message(state.roles[1], None) | |
| state.skip_next = False | |
| return (state, state.to_gradio_chatbot()) + (disable_btn,) * 3 | |
| def qb_add_text(state, title_num, decoration_num, subtitle_num, itemtitle_num, itemlogo_num, item_num, text_num, | |
| textbackground_num, object_num, frame_num, image, user_cons, image_process_mode, request: gr.Request): | |
| elem_list = [title_num, decoration_num, subtitle_num, itemtitle_num, itemlogo_num, \ | |
| item_num, text_num, textbackground_num, object_num, frame_num] | |
| json_data = init_json(elem_list, dataset='QB-Poster') | |
| if image is not None: | |
| resolution = list(image.size) | |
| else: | |
| try: | |
| resolution = list(state.get_images(return_pil=True)[-1].size) | |
| except: | |
| resolution = [1242, 1660] | |
| text = prompt_template.replace('<image>\n', '<image>').replace('\n', '\\n').format( | |
| N=len(json_data), | |
| resolution=resolution, | |
| domain_name="poster with xiaohonshu style", | |
| cons_data=user_cons, | |
| json_data=json.dumps(json_data) | |
| ) | |
| return add_text(state, text, image, image_process_mode, request) | |
| def cgl_add_text(state, text_num, underlay_num, embellishment_num, image, user_cons, image_process_mode, request: gr.Request): | |
| elem_list = [text_num, underlay_num, embellishment_num] | |
| json_data = init_json(elem_list, dataset='CGL') | |
| if image is not None: | |
| resolution = list(image.size) | |
| else: | |
| try: | |
| resolution = list(state.get_images(return_pil=True)[-1].size) | |
| except: | |
| resolution = [513, 750] | |
| text = prompt_template.replace('<image>\n', '<image>').replace('\n', '\\n').format( | |
| N=len(json_data), | |
| resolution=resolution, | |
| domain_name="commercial poster", | |
| cons_data=user_cons, | |
| json_data=json.dumps(json_data) | |
| ) | |
| return add_text(state, text, image, image_process_mode, request) | |
| def banners_add_text(state, header_num, preheader_num, postheader_num, body_text, disclaimer_num, button_num, | |
| callout_num, logo_num, image, user_cons, image_process_mode, request: gr.Request): | |
| elem_list = [header_num, preheader_num, postheader_num, body_text, disclaimer_num, button_num, callout_num, logo_num] | |
| json_data = init_json(elem_list, dataset='Ad Banners') | |
| if image is not None: | |
| resolution = list(image.size) | |
| else: | |
| try: | |
| resolution = list(state.get_images(return_pil=True)[-1].size) | |
| except: | |
| resolution = [1080, 1080] | |
| text = prompt_template.replace('<image>\n', '<image>').replace('\n', '\\n').format( | |
| N=len(json_data), | |
| resolution=resolution, | |
| domain_name="commercial banner", | |
| cons_data=user_cons, | |
| json_data=json.dumps(json_data) | |
| ) | |
| return add_text(state, text, image, image_process_mode, request) | |
| def http_bot(state, model_selector, temperature, top_p, max_new_tokens, repeat_times, request: gr.Request, progress = gr.Progress()): | |
| logger.info(f"http_bot. ip: {request.client.host}") | |
| start_tstamp = time.time() | |
| model_name = model_selector | |
| if state.skip_next: | |
| yield (state, state.to_gradio_chatbot(), None) + (no_change_btn,) * 3 | |
| return | |
| if len(state.messages) == state.offset + 2: | |
| template_name = "llava_v1" | |
| new_state = conv_templates[template_name].copy() | |
| new_state.append_message(new_state.roles[0], state.messages[-2][1]) | |
| new_state.append_message(new_state.roles[1], None) | |
| state = new_state | |
| # Query worker address | |
| controller_url = args.controller_url | |
| ret = requests.post(controller_url + "/get_worker_address", | |
| json={"model": model_name}) | |
| worker_addr = ret.json()["address"] | |
| logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") | |
| # No available worker | |
| if worker_addr == "": | |
| state.messages[-1][-1] = server_error_msg | |
| yield (state, state.to_gradio_chatbot(), None, disable_btn, disable_btn, disable_btn) | |
| return | |
| # Construct prompt | |
| prompt = state.get_prompt() | |
| if "xiaohonshu" in prompt: | |
| current_dataset = "QB-Poster" | |
| elif "commercial poster" in prompt: | |
| current_dataset = "CGL" | |
| elif "commercial banner" in prompt: | |
| current_dataset = "Ad Banners" | |
| all_images = state.get_images(return_pil=True) | |
| all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images] | |
| for image, hash in zip(all_images, all_image_hash): | |
| t = datetime.datetime.now() | |
| filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg") | |
| if not os.path.isfile(filename): | |
| os.makedirs(os.path.dirname(filename), exist_ok=True) | |
| image.save(filename) | |
| # Make requests | |
| pload = { | |
| "model": model_name, | |
| "prompt": prompt, | |
| "temperature": float(temperature), | |
| "top_p": float(top_p), | |
| "max_new_tokens": int(max_new_tokens), | |
| "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2, | |
| "images": f'List of {len(state.get_images())} images: {all_image_hash}', | |
| } | |
| logger.info(f"==== request ====\n{pload}") | |
| pload['images'] = state.get_images() | |
| boxmaps, all_responses = [], [] | |
| initial_json = re.findall(r'\[\{.*?\}\]', prompt)[0] | |
| elems_num = initial_json.count("label") | |
| total_length = elems_num * len('0.0000, 0.0000, 0.0000, 0.0000') + len(initial_json)\ | |
| + len('Sure! Here is the design results: ') | |
| for t in range(repeat_times): | |
| try: | |
| response = requests.post(worker_addr + "/worker_generate_stream", | |
| headers=headers, json=pload, stream=True, timeout=20) | |
| for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): | |
| if chunk: | |
| data = json.loads(chunk.decode()) | |
| if data["error_code"] == 0: | |
| output = data["text"][len(prompt):].strip() | |
| else: | |
| output = data["text"] + f" (error_code: {data['error_code']})" | |
| state.messages[-1][-1] = output | |
| yield (state, state.to_gradio_chatbot(), boxmaps) + (enable_btn,) * 3 | |
| return | |
| time.sleep(0.01) | |
| p = (len(output) + len(''.join(all_responses))) / (total_length * repeat_times) | |
| progress(p, desc=f'Generating the {t + 1}th output...') | |
| except requests.exceptions.RequestException as e: | |
| state.messages[-1][-1] = server_error_msg | |
| yield (state, state.to_gradio_chatbot(), boxmaps) + (enable_btn,) * 3 | |
| return | |
| all_responses.append(output) | |
| json_response = get_json_response(output) | |
| if json_response is not None: | |
| boxmaps.append(draw_boxmap(json_response, all_images[-1], CLS2COLOR[current_dataset])) | |
| state.messages[-1][-1] = "".join([f"Design Result {i}:\n" + all_responses[i] + "\n\n" for i in range(len(all_responses))]) | |
| yield (state, state.to_gradio_chatbot(), boxmaps) + (enable_btn,) * 3 | |
| finish_tstamp = time.time() | |
| logger.info(f"{output}") | |
| with open(get_conv_log_filename(), "a") as fout: | |
| data = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "type": "chat", | |
| "model": model_name, | |
| "start": round(start_tstamp, 4), | |
| "finish": round(start_tstamp, 4), | |
| "state": state.dict(), | |
| "images": all_image_hash, | |
| "ip": request.client.host, | |
| } | |
| fout.write(json.dumps(data) + "\n") | |
| title_markdown = (""" | |
| # PosterLLaVA: Constructing a Unified Multi-modal Layout Generator with LLM | |
| """) | |
| tos_markdown = (""" | |
| """) | |
| learn_more_markdown = (""" | |
| """) | |
| block_css = """ | |
| #buttons button { | |
| min-width: min(120px,100%); | |
| } | |
| """ | |
| def build_demo(embed_mode, cur_dir): | |
| imagebox_boxmap = gr.Gallery(label='Result(结果)', show_label=True, preview=False, columns=2, allow_preview=True, height=550) | |
| with gr.Blocks(title="PosterLLaVA", theme=gr.themes.Default(), css=block_css) as demo: | |
| state = gr.State() | |
| if not embed_mode: | |
| gr.Markdown(title_markdown) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| with gr.Row(elem_id="model_selector_row"): | |
| model_selector = gr.Dropdown( | |
| choices=models, | |
| value=models[0] if len(models) > 0 else "", | |
| interactive=True, | |
| show_label=False, | |
| container=False) | |
| imagebox = gr.Image(type="pil") | |
| image_process_mode = gr.Radio( | |
| ["Crop", "Resize", "Pad", "Default"], | |
| value="Default", | |
| label="Preprocess for non-square image", visible=False) | |
| with gr.Accordion("Parameters", open=True) as parameter_row: | |
| temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",) | |
| top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",) | |
| max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=1024, step=64, interactive=True, label="Max output tokens",) | |
| repeat_times = gr.Slider(minimum=1, maximum=64, value=4, step=1, interactive=True, label="Repeat times",) | |
| # cur_dir = os.path.dirname(os.path.abspath(__file__)) | |
| # gr.Examples(examples=[ | |
| # [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"], | |
| # [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"], | |
| # ], inputs=[imagebox, textbox]) | |
| with gr.Column(scale=8): | |
| imagebox_boxmap.render() | |
| with gr.Tabs() as tabs: | |
| with gr.Tab("QB-Poster"): | |
| with gr.Row(variant='compact'): | |
| object_num = gr.Checkbox(label='object', value=True) | |
| frame_num = gr.Checkbox(label='frame', value=True) | |
| title_num = gr.Checkbox(label='title', value=True) | |
| decoration_num = gr.Slider(minimum=0, maximum=10, value=2, step=1, label='decoration') | |
| subtitle_num = gr.Slider(minimum=0, maximum=10, value=1, step=1, label='subtitle') | |
| itemtitle_num = gr.Slider(minimum=0, maximum=10, value=0, step=1, label='itemtitle') | |
| itemlogo_num = gr.Slider(minimum=0, maximum=10, value=3, step=1, label='itemlogo') | |
| item_num = gr.Slider(minimum=0, maximum=10, value=3, step=1, label='item') | |
| text_num = gr.Slider(minimum=0, maximum=10, value=1, step=1, label='text') | |
| textbackground_num = gr.Slider(minimum=0, maximum=10, value=1, step=1, label='textbackground') | |
| qb_elem_list = [title_num, decoration_num, subtitle_num, itemtitle_num, itemlogo_num, | |
| item_num, text_num, textbackground_num, object_num, frame_num] | |
| with gr.Row(): | |
| qb_textbox = gr.Textbox(label='User Constraint', show_label=True, | |
| placeholder="Enter text and press ENTER", container=True) | |
| with gr.Column(scale=1, min_width=50): | |
| qb_submit_btn = gr.Button(value="Generate", variant="primary", interactive=False) | |
| with gr.Tab("CGL / PosterLayout"): | |
| with gr.Row(variant='compact'): | |
| cgl_text_num = gr.Slider(minimum=0, maximum=10, value=1, step=1, label='text') | |
| underlay_num = gr.Slider(minimum=0, maximum=10, value=1, step=1, label='underlay') | |
| embellishment_num = gr.Slider(minimum=0, maximum=10, value=1, step=1, label='embellishment') | |
| cgl_elem_list = [cgl_text_num, underlay_num, embellishment_num] | |
| with gr.Row(): | |
| cgl_textbox = gr.Textbox(label='User Constraint', show_label=True, | |
| placeholder="Enter text and press ENTER", container=True) | |
| with gr.Column(scale=1, min_width=50): | |
| cgl_submit_btn = gr.Button(value="Generate", variant="primary", interactive=False) | |
| with gr.Tab("Ad Banners"): | |
| with gr.Row(variant='compact'): | |
| header_num = gr.Slider(minimum=0, maximum=10, value=1, step=1, label='header') | |
| preheader_num = gr.Slider(minimum=0, maximum=10, value=1, step=1, label='pre-header') | |
| postheader_num = gr.Slider(minimum=0, maximum=10, value=1, step=1, label='post-header') | |
| body_text = gr.Slider(minimum=0, maximum=10, value=1, step=1, label='body text') | |
| disclaimer_num = gr.Slider(minimum=0, maximum=10, value=1, step=1, label='disclaimer / footnote') | |
| button_num = gr.Slider(minimum=0, maximum=10, value=1, step=1, label='button') | |
| callout_num = gr.Slider(minimum=0, maximum=10, value=1, step=1, label='callout') | |
| logo_num = gr.Slider(minimum=0, maximum=10, value=1, step=1, label='logo') | |
| banners_elem_list = [header_num, preheader_num, postheader_num, body_text, disclaimer_num, | |
| button_num, callout_num, logo_num] | |
| with gr.Row(): | |
| banners_textbox = gr.Textbox(label='User Constraint', show_label=True, | |
| placeholder="Enter your customized design requirements separated by ';' to control the sizes and positions of elements", container=True) | |
| with gr.Column(scale=1, min_width=50): | |
| banners_submit_btn = gr.Button(value="Generate", variant="primary", interactive=False) | |
| with gr.Accordion("Intermediate results", open=False): | |
| gr.Markdown("The layout generation process with LLM") | |
| chatbot = gr.Chatbot(elem_id="chatbot", label="LLM Conversations", height=550) | |
| if not embed_mode: | |
| gr.Markdown(tos_markdown) | |
| gr.Markdown(learn_more_markdown) | |
| url_params = gr.JSON(visible=False) | |
| # Register listeners | |
| btn_list = [qb_submit_btn, cgl_submit_btn, banners_submit_btn] | |
| imagebox.change( | |
| init_conv, | |
| None, | |
| [state, chatbot] + btn_list, | |
| queue=False | |
| ) | |
| qb_submit_btn.click( | |
| qb_add_text, | |
| [state, *qb_elem_list, imagebox, qb_textbox, image_process_mode], | |
| [state, chatbot] + btn_list, | |
| queue=False | |
| ).then( | |
| http_bot, | |
| [state, model_selector, temperature, top_p, max_output_tokens, repeat_times], | |
| [state, chatbot, imagebox_boxmap] + btn_list | |
| ) | |
| cgl_submit_btn.click( | |
| cgl_add_text, | |
| [state, *cgl_elem_list, imagebox, cgl_textbox, image_process_mode], | |
| [state, chatbot] + btn_list, | |
| queue=False | |
| ).then( | |
| http_bot, | |
| [state, model_selector, temperature, top_p, max_output_tokens, repeat_times], | |
| [state, chatbot, imagebox_boxmap] + btn_list | |
| ) | |
| banners_submit_btn.click( | |
| banners_add_text, | |
| [state, *banners_elem_list, imagebox, banners_textbox, image_process_mode], | |
| [state, chatbot] + btn_list, | |
| queue=False | |
| ).then( | |
| http_bot, | |
| [state, model_selector, temperature, top_p, max_output_tokens, repeat_times], | |
| [state, chatbot, imagebox_boxmap] + btn_list | |
| ) | |
| if args.model_list_mode == "once": | |
| demo.load( | |
| load_demo, | |
| [url_params], | |
| [state, model_selector], | |
| _js=get_window_url_params, | |
| queue=False | |
| ) | |
| elif args.model_list_mode == "reload": | |
| demo.load( | |
| load_demo_refresh_model_list, | |
| None, | |
| [state, model_selector], | |
| queue=False | |
| ) | |
| else: | |
| raise ValueError(f"Unknown model list mode: {args.model_list_mode}") | |
| return demo | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--host", type=str, default="0.0.0.0") | |
| parser.add_argument("--port", type=int) | |
| parser.add_argument("--controller-url", type=str, default="http://localhost:21001") | |
| parser.add_argument("--concurrency-count", type=int, default=10) | |
| parser.add_argument("--model-list-mode", type=str, default="once", | |
| choices=["once", "reload"]) | |
| parser.add_argument("--share", action="store_true") | |
| parser.add_argument("--embed", action="store_true") | |
| args = parser.parse_args() | |
| logger.info(f"args: {args}") | |
| models = get_model_list() | |
| logger.info(args) | |
| demo = build_demo(args.embed) | |
| demo.queue( | |
| concurrency_count=args.concurrency_count, | |
| api_open=False | |
| ).launch( | |
| server_name=args.host, | |
| server_port=args.port, | |
| share=args.share | |
| ) |