import gradio as gr import numpy as np import random import spaces import torch import re import transformers import open_clip # from Pilot-Phase3.optim_utils import optimize_prompt # from Pi from optim_utils import optimize_prompt from utils import ( clean_response_gpt, setup_model, init_gpt_api, call_gpt_api, get_refine_msg, clean_cache, get_personalize_message, clean_refined_prompt_response_gpt, IMAGES, OPTIONS, T2I_MODELS, INSTRUCTION, IMAGE_OPTIONS, PROMPTS, SCENARIOS ) # ========================= # Constants / Defaults # ========================= CLIP_MODEL = "ViT-H-14" PRETRAINED_CLIP = "laion2b_s32b_b79k" default_t2i_model = "black-forest-labs/FLUX.1-dev" default_llm_model = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1024 NUM_IMAGES = 4 MAX_ROUND = 5 device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 clean_cache() selected_pipe = setup_model(default_t2i_model, torch_dtype, device) clip_model, _, preprocess = open_clip.create_model_and_transforms(CLIP_MODEL, pretrained=PRETRAINED_CLIP, device=device) llm_pipe = None inverted_prompt = "" torch.cuda.empty_cache() METHOD = "Experimental" counter = 1 enable_submit = False responses_memory = {METHOD: {}} example_data = [ [ PROMPTS["Tourist promotion"], IMAGES["Tourist promotion"]["ours"] ], [ PROMPTS["Fictional character generation"], IMAGES["Fictional character generation"]["ours"] ], [ PROMPTS["Interior Design"], IMAGES["Interior Design"]["ours"] ], ] # ========================= # Image Generation Helpers # ========================= @spaces.GPU(duration=65) def infer( prompt, negative_prompt="", seed=42, randomize_seed=True, width=256, height=256, guidance_scale=5, num_inference_steps=18, progress=gr.Progress(track_tqdm=True), ): if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator().manual_seed(seed) with torch.no_grad(): image = selected_pipe( prompt=prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=width, height=height, generator=generator, ).images[0] return image def call_gpt_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0.7, top_p=0.9): seed = random.randint(0, MAX_SEED) client = init_gpt_api() messages = get_refine_msg(prompt, num_prompts) outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens, temperature, top_p) prompt_list = clean_response_gpt(outputs) return prompt_list def personalize_prompt(prompt, history, feedback, like_image, dislike_image): seed = random.randint(0, MAX_SEED) client = init_gpt_api() messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image) outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9) return outputs @spaces.GPU(duration=100) def invert_prompt(prompt, images, prompt_len=15, iter=500, lr=0.1, batch_size=2): global inverted_prompt text_params = { "iter": iter, "lr": lr, "batch_size": batch_size, "prompt_len": prompt_len, "weight_decay": 0.1, "prompt_bs": 1, "loss_weight": 1.0, "print_step": 100, "clip_model": CLIP_MODEL, "clip_pretrain": PRETRAINED_CLIP, } inverted_prompt = optimize_prompt(clip_model, preprocess, text_params, device, target_images=images, target_prompts=prompt) # eval(prompt, learned_prompt, optimized_images, clip_model, preprocess) # return learned_prompt # ========================= # UI Helper Functions # ========================= def reset_gallery(): return [] def display_error_message(msg, duration=5): gr.Warning(msg, duration=duration) def display_info_message(msg, duration=5): gr.Info(msg, duration=duration) def check_satisfaction(sim_radio): global enable_submit, counter fully_satisfied_option = ["Satisfied", "Very Satisfied"] if_submit = (sim_radio in fully_satisfied_option) or enable_submit or (counter > MAX_ROUND) return gr.update(interactive=if_submit) def select_image(like_radio, images_method): if like_radio == IMAGE_OPTIONS[0]: return images_method[0][0] elif like_radio == IMAGE_OPTIONS[1]: return images_method[1][0] elif like_radio == IMAGE_OPTIONS[2]: return images_method[2][0] elif like_radio == IMAGE_OPTIONS[3]: return images_method[3][0] else: return None def check_evaluation(sim_radio): if not sim_radio: display_error_message("❌ Please fill all evaluations before changing image or submitting.") return False return True def generate_image(prompt, like_image, dislike_image): global responses_memory history_prompts = [v["prompt"] for v in responses_memory[METHOD].values()] feedback = [v["sim_radio"] for v in responses_memory[METHOD].values()] personalized = prompt # personalized = personalize_prompt(prompt, history_prompts, feedback, like_image, dislike_image) # personalized = clean_refined_prompt_response_gpt(personalized) # if "I'm sorry, I can't assist with" in personalized: # personalized = prompt gallery_images = [] refined_prompts = call_gpt_refine_prompt(personalized) for i in range(NUM_IMAGES): img = infer(refined_prompts[i]) gallery_images.append(img) yield gallery_images def redesign(prompt, sim_radio, like_radio, dislike_radio, current_images, history_images, like_image, dislike_image): global counter, enable_submit, responses_memory if check_evaluation(sim_radio): responses_memory[METHOD][counter] = { "prompt": prompt, "sim_radio": sim_radio, "response": "", "satisfied_img": f"round {counter}, {like_radio}", "unsatisfied_img": f"round {counter}, {dislike_radio}", } enable_submit = True if sim_radio in ["Satisfied", "Very Satisfied"] or enable_submit else False history_prompts = [[v["prompt"]] for v in responses_memory[METHOD].values()] if not history_images: history_images = current_images elif current_images: history_images.extend(current_images) current_images = [] examples_state = gr.update(samples=history_prompts, visible=True) prompt_state = gr.update(interactive=True) next_state = gr.update(visible=True, interactive=True) redesign_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(interactive=True) submit_state = gr.update(interactive=True) if counter >= MAX_ROUND or enable_submit else gr.update(interactive=False) counter += 1 return None, None, None, current_images, history_images, examples_state, prompt_state, next_state, redesign_state, submit_state else: return {submit_btn: gr.skip()} def save_response(prompt, sim_radio, like_radio, dislike_radio, like_image, dislike_image): global counter, enable_submit, responses_memory if check_evaluation(sim_radio): # Save the final round entry responses_memory[METHOD][counter] = { "prompt": prompt, "sim_radio": sim_radio, "response": "", "satisfied_img": f"round {counter}, {like_radio}", "unsatisfied_img": f"round {counter}, {dislike_radio}", } # Reset states counter = 1 enable_submit = False # Reset buttons prompt_state = gr.update(interactive=False) next_state = gr.update(visible=False, interactive=False) submit_state = gr.update(interactive=False) redesign_state = gr.update(interactive=False) display_info_message("✅ Your answer is saved!") return None, None, None, prompt_state, next_state, redesign_state, submit_state else: return {submit_btn: gr.skip()} # ========================= # Interface (single tab, no participant/scenario/background) # ========================= css = """ #col-container { margin: 0 auto; max-width: 700px; } #col-container2 { margin: 0 auto; max-width: 1000px; } #col-container3 { margin: 0 0 auto auto; max-width: 300px; } #button-container { display: flex; justify-content: center; } #compact-compact-row { width:100%; max-width: 800px; margin: 0px auto; } #compact-row { width:100%; max-width: 1000px; margin: 0px auto; } .header-section { text-align: center; margin-bottom: 2rem; } .abstract-text { text-align: justify; line-height: 1.6; margin: 0.5rem 0; padding: 0.5rem; background-color: rgba(0, 0, 0, 0.05); border-radius: 8px; border-left: 4px solid #3498db; } .paper-link { display: inline-block; margin: 0rem 0; padding: 0rem 0rem; background-color: #3498db; color: white; text-decoration: none; border-radius: 5px; font-weight: 500; } .paper-link:hover { background-color: #2980b9; text-decoration: none; } .authors-section { text-align: center; margin: 0 0; font-style: italic; color: #666; } .authors-title { font-weight: bold; margin-bottom: 0rem; color: #333; } """ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo: with gr.Column(elem_id="col-container", elem_classes=["header-section"]): gr.Markdown("# 📌 **POET**") gr.HTML('
') gr.Markdown("### Supporting Prompting Creativity with Automated Expansion of Text-to-Image Generation") # Abstract: State-of-the-art visual generative AI tools hold immense potential to assist users in the early ideation stages of creative tasks — offering the ability to generate (rather than search for) novel and unprecedented (instead of existing) images of considerable quality that also adhere to boundless combinations of user specifications. However, many large-scale text-to-image systems are designed for broad applicability, yielding conventional output that may limit creative exploration. They also employ interaction methods that may be difficult for beginners. # gr.Markdown("""
Abstract: Given that creative end-users often operate in diverse, context-specific ways that are often unpredictable, more variation and personalization are necessary. We introduce POET, a real-time interactive tool that (1) automatically discovers dimensions of homogeneity in text-to-image generative models, (2) expands these dimensions to diversify the output space of generated images, and (3) learns from user feedback to personalize expansions. Focusing on visual creativity, POET offers a first glimpse of how interaction techniques of future text-to-image generation tools may support and align with more pluralistic values and the needs of end-users during the ideation stages of their work.
""", elem_classes=["abstract-text"]) # Paper Link gr.HTML("""
📄 Read the Full Paper .
""") # Authors gr.Markdown("""
Evans Han, Alice Qian Zhang, Haiyi Zhu, Hong Shen, Paul Pu Liang, Jane Hsieh
""", elem_classes=["authors-section"]) # gr.Markdown("---") with gr.Tab(""): with gr.Row(elem_id="compact-row"): with gr.Column(elem_id="col-container"): with gr.Row(): prompt = gr.Textbox( label="🎨 Prompt", max_lines=5, placeholder="Enter your prompt", visible=True, ) with gr.Column(elem_id="col-container3"): next_btn = gr.Button("Generate", variant="primary", scale=1) with gr.Row(elem_id="compact-row"): with gr.Column(elem_id="col-container"): images_method = gr.Gallery(label="Images", columns=[4], rows=[1], height=400, elem_id="gallery", format="png") with gr.Column(elem_id="col-container3"): like_image = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload', format="png", type="filepath", visible=False) dislike_image = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload', format="png", type="filepath", visible=False) with gr.Column(elem_id="col-container2", visible=False): gr.Markdown("### 📝 Evaluation") sim_radio = gr.Radio( OPTIONS, label="How would you rate your satisfaction with the generated images?", type="value", elem_classes=["gradio-radio"] ) like_radio = gr.Radio( IMAGE_OPTIONS, label="Select your all-time favorite image (optional).", type="value", elem_classes=["gradio-radio"] ) dislike_radio = gr.Radio( IMAGE_OPTIONS, label="Select your all-time least satisfactory image (optional).", type="value", elem_classes=["gradio-radio"] ) with gr.Column(elem_id="col-container2", visible=False): example = gr.Examples([['']], prompt, label="Revised Prompt History", visible=False) history_images = gr.Gallery(label="History Images", columns=[4], rows=[1], elem_id="gallery", format="png") with gr.Row(elem_id="button-container"): redesign_btn = gr.Button("🎨 Redesign", variant="primary", scale=0) submit_btn = gr.Button("✅ Submit", variant="primary", interactive=False, scale=0) with gr.Column(elem_id="col-container2"): gr.Markdown("### 🌟 Examples") ex1 = gr.Image(label="Image 1", width=200, height=200, sources='upload', format="png", type="filepath", visible=False) ex2 = gr.Image(label="Image 2", width=200, height=200, sources='upload', format="png", type="filepath", visible=False) ex3 = gr.Image(label="Image 3", width=200, height=200, sources='upload', format="png", type="filepath", visible=False) ex4 = gr.Image(label="Image 4", width=200, height=200, sources='upload', format="png", type="filepath", visible=False) gr.Examples( examples=[[ex[0], ex[1][0], ex[1][1], ex[1][2], ex[1][3]] for ex in example_data], inputs=[prompt, ex1, ex2, ex3, ex4] ) # ========================= # Wiring # ========================= sim_radio.change(fn=check_satisfaction, inputs=[sim_radio], outputs=[submit_btn]) dislike_radio.select(fn=select_image, inputs=[dislike_radio, images_method], outputs=[dislike_image]) like_radio.select(fn=select_image, inputs=[like_radio, images_method], outputs=[like_image]) next_btn.click( fn=generate_image, inputs=[prompt, like_image, dislike_image], outputs=[images_method] ).success(lambda: [gr.update(interactive=True), gr.update(interactive=True)], outputs=[next_btn, prompt]) redesign_btn.click( fn=redesign, inputs=[prompt, sim_radio, like_radio, dislike_radio, images_method, history_images, like_image, dislike_image], outputs=[sim_radio, dislike_radio, like_radio, images_method, history_images, example.dataset, prompt, next_btn, redesign_btn, submit_btn] ) submit_btn.click( fn=save_response, inputs=[prompt, sim_radio, like_radio, dislike_radio, like_image, dislike_image], outputs=[sim_radio, dislike_radio, like_radio, prompt, next_btn, redesign_btn, submit_btn] ) if __name__ == "__main__": demo.launch()