POET / app.py
xh365's picture
update icon
b18de3e
raw
history blame
16.8 kB
import gradio as gr
import numpy as np
import random
import spaces
import torch
import re
import transformers
import open_clip
# from Pilot-Phase3.optim_utils import optimize_prompt
# from Pi
from optim_utils import optimize_prompt
from utils import (
clean_response_gpt, setup_model, init_gpt_api, call_gpt_api,
get_refine_msg, clean_cache, get_personalize_message,
clean_refined_prompt_response_gpt, IMAGES, OPTIONS, T2I_MODELS,
INSTRUCTION, IMAGE_OPTIONS, PROMPTS, SCENARIOS
)
# =========================
# Constants / Defaults
# =========================
CLIP_MODEL = "ViT-H-14"
PRETRAINED_CLIP = "laion2b_s32b_b79k"
default_t2i_model = "black-forest-labs/FLUX.1-dev"
default_llm_model = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
NUM_IMAGES = 4
MAX_ROUND = 5
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
clean_cache()
selected_pipe = setup_model(default_t2i_model, torch_dtype, device)
clip_model, _, preprocess = open_clip.create_model_and_transforms(CLIP_MODEL, pretrained=PRETRAINED_CLIP, device=device)
llm_pipe = None
inverted_prompt = ""
torch.cuda.empty_cache()
METHOD = "Experimental"
counter = 1
enable_submit = False
responses_memory = {METHOD: {}}
example_data = [
[
PROMPTS["Tourist promotion"],
IMAGES["Tourist promotion"]["ours"]
],
[
PROMPTS["Fictional character generation"],
IMAGES["Fictional character generation"]["ours"]
],
[
PROMPTS["Interior Design"],
IMAGES["Interior Design"]["ours"]
],
]
# =========================
# Image Generation Helpers
# =========================
@spaces.GPU(duration=65)
def infer(
prompt,
negative_prompt="",
seed=42,
randomize_seed=True,
width=256,
height=256,
guidance_scale=5,
num_inference_steps=18,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
with torch.no_grad():
image = selected_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
return image
def call_gpt_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0.7, top_p=0.9):
seed = random.randint(0, MAX_SEED)
client = init_gpt_api()
messages = get_refine_msg(prompt, num_prompts)
outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens, temperature, top_p)
prompt_list = clean_response_gpt(outputs)
return prompt_list
def personalize_prompt(prompt, history, feedback, like_image, dislike_image):
seed = random.randint(0, MAX_SEED)
client = init_gpt_api()
messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image)
outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9)
return outputs
@spaces.GPU(duration=100)
def invert_prompt(prompt, images, prompt_len=15, iter=500, lr=0.1, batch_size=2):
global inverted_prompt
text_params = {
"iter": iter,
"lr": lr,
"batch_size": batch_size,
"prompt_len": prompt_len,
"weight_decay": 0.1,
"prompt_bs": 1,
"loss_weight": 1.0,
"print_step": 100,
"clip_model": CLIP_MODEL,
"clip_pretrain": PRETRAINED_CLIP,
}
inverted_prompt = optimize_prompt(clip_model, preprocess, text_params, device, target_images=images, target_prompts=prompt)
# eval(prompt, learned_prompt, optimized_images, clip_model, preprocess)
# return learned_prompt
# =========================
# UI Helper Functions
# =========================
def reset_gallery():
return []
def display_error_message(msg, duration=5):
gr.Warning(msg, duration=duration)
def display_info_message(msg, duration=5):
gr.Info(msg, duration=duration)
def check_satisfaction(sim_radio):
global enable_submit, counter
fully_satisfied_option = ["Satisfied", "Very Satisfied"]
if_submit = (sim_radio in fully_satisfied_option) or enable_submit or (counter > MAX_ROUND)
return gr.update(interactive=if_submit)
def select_image(like_radio, images_method):
if like_radio == IMAGE_OPTIONS[0]:
return images_method[0][0]
elif like_radio == IMAGE_OPTIONS[1]:
return images_method[1][0]
elif like_radio == IMAGE_OPTIONS[2]:
return images_method[2][0]
elif like_radio == IMAGE_OPTIONS[3]:
return images_method[3][0]
else:
return None
def check_evaluation(sim_radio):
if not sim_radio:
display_error_message("❌ Please fill all evaluations before changing image or submitting.")
return False
return True
def generate_image(prompt, like_image, dislike_image):
global responses_memory
history_prompts = [v["prompt"] for v in responses_memory[METHOD].values()]
feedback = [v["sim_radio"] for v in responses_memory[METHOD].values()]
personalized = prompt
# personalized = personalize_prompt(prompt, history_prompts, feedback, like_image, dislike_image)
# personalized = clean_refined_prompt_response_gpt(personalized)
# if "I'm sorry, I can't assist with" in personalized:
# personalized = prompt
gallery_images = []
refined_prompts = call_gpt_refine_prompt(personalized)
for i in range(NUM_IMAGES):
img = infer(refined_prompts[i])
gallery_images.append(img)
yield gallery_images
def redesign(prompt, sim_radio, like_radio, dislike_radio, current_images, history_images, like_image, dislike_image):
global counter, enable_submit, responses_memory
if check_evaluation(sim_radio):
responses_memory[METHOD][counter] = {
"prompt": prompt,
"sim_radio": sim_radio,
"response": "",
"satisfied_img": f"round {counter}, {like_radio}",
"unsatisfied_img": f"round {counter}, {dislike_radio}",
}
enable_submit = True if sim_radio in ["Satisfied", "Very Satisfied"] or enable_submit else False
history_prompts = [[v["prompt"]] for v in responses_memory[METHOD].values()]
if not history_images:
history_images = current_images
elif current_images:
history_images.extend(current_images)
current_images = []
examples_state = gr.update(samples=history_prompts, visible=True)
prompt_state = gr.update(interactive=True)
next_state = gr.update(visible=True, interactive=True)
redesign_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(interactive=True)
submit_state = gr.update(interactive=True) if counter >= MAX_ROUND or enable_submit else gr.update(interactive=False)
counter += 1
return None, None, None, current_images, history_images, examples_state, prompt_state, next_state, redesign_state, submit_state
else:
return {submit_btn: gr.skip()}
def save_response(prompt, sim_radio, like_radio, dislike_radio, like_image, dislike_image):
global counter, enable_submit, responses_memory
if check_evaluation(sim_radio):
# Save the final round entry
responses_memory[METHOD][counter] = {
"prompt": prompt,
"sim_radio": sim_radio,
"response": "",
"satisfied_img": f"round {counter}, {like_radio}",
"unsatisfied_img": f"round {counter}, {dislike_radio}",
}
# Reset states
counter = 1
enable_submit = False
# Reset buttons
prompt_state = gr.update(interactive=False)
next_state = gr.update(visible=False, interactive=False)
submit_state = gr.update(interactive=False)
redesign_state = gr.update(interactive=False)
display_info_message("βœ… Your answer is saved!")
return None, None, None, prompt_state, next_state, redesign_state, submit_state
else:
return {submit_btn: gr.skip()}
# =========================
# Interface (single tab, no participant/scenario/background)
# =========================
css = """
#col-container {
margin: 0 auto;
max-width: 700px;
}
#col-container2 {
margin: 0 auto;
max-width: 1000px;
}
#col-container3 {
margin: 0 0 auto auto;
max-width: 300px;
}
#button-container {
display: flex;
justify-content: center;
}
#compact-compact-row {
width:100%;
max-width: 800px;
margin: 0px auto;
}
#compact-row {
width:100%;
max-width: 1000px;
margin: 0px auto;
}
.header-section {
text-align: center;
margin-bottom: 2rem;
}
.abstract-text {
text-align: justify;
line-height: 1.6;
margin: 0.5rem 0;
padding: 0.5rem;
background-color: rgba(0, 0, 0, 0.05);
border-radius: 8px;
border-left: 4px solid #3498db;
}
.paper-link {
display: inline-block;
margin: 0rem 0;
padding: 0rem 0rem;
background-color: #3498db;
color: white;
text-decoration: none;
border-radius: 5px;
font-weight: 500;
}
.paper-link:hover {
background-color: #2980b9;
text-decoration: none;
}
.authors-section {
text-align: center;
margin: 0 0;
font-style: italic;
color: #666;
}
.authors-title {
font-weight: bold;
margin-bottom: 0rem;
color: #333;
}
"""
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo:
with gr.Column(elem_id="col-container", elem_classes=["header-section"]):
gr.Markdown("# πŸ“Œ **POET**")
gr.HTML('<div><img src="images/icon.png" width="200"></div>')
gr.Markdown("### Supporting Prompting Creativity with Automated Expansion of Text-to-Image Generation")
# <strong>Abstract:</strong> State-of-the-art visual generative AI tools hold immense potential to assist users in the early ideation stages of creative tasks β€” offering the ability to generate (rather than search for) novel and unprecedented (instead of existing) images of considerable quality that also adhere to boundless combinations of user specifications. However, many large-scale text-to-image systems are designed for broad applicability, yielding conventional output that may limit creative exploration. They also employ interaction methods that may be difficult for beginners. #
gr.Markdown("""
<div class="abstract-text">
<strong>Abstract:</strong> Given that creative end-users often operate in diverse, context-specific ways that are often unpredictable, more variation and personalization are necessary. We introduce POET, a real-time interactive tool that (1) automatically discovers dimensions of homogeneity in text-to-image generative models, (2) expands these dimensions to diversify the output space of generated images, and (3) learns from user feedback to personalize expansions. Focusing on visual creativity, POET offers a first glimpse of how interaction techniques of future text-to-image generation tools may support and align with more pluralistic values and the needs of end-users during the ideation stages of their work.
</div>
""", elem_classes=["abstract-text"])
# Paper Link
gr.HTML("""
<div style="text-align: center;">
<a href="https://arxiv.org/pdf/2504.13392" target="_blank" class="paper-link">
πŸ“„ Read the Full Paper .
</a>
</div>
""")
# Authors
gr.Markdown("""
<div class="authors-section">
<a href="https://scholar.google.com/citations?user=HXED4kIAAAAJ&hl=en">Evans Han</a>, <a href"https://www.aliceqian.com/">Alice Qian Zhang</a>,
<a href="https://haiyizhu.com/">Haiyi Zhu</a>, <a href="https://www.andrew.cmu.edu/user/hongs/">Hong Shen</a>,
<a href="https://pliang279.github.io/">Paul Pu Liang</a>, <a href="https://janeon.github.io/">Jane Hsieh</a>
</div>
""", elem_classes=["authors-section"])
# gr.Markdown("---")
with gr.Tab(""):
with gr.Row(elem_id="compact-row"):
with gr.Column(elem_id="col-container"):
with gr.Row():
prompt = gr.Textbox(
label="🎨 Prompt",
max_lines=5,
placeholder="Enter your prompt",
visible=True,
)
with gr.Column(elem_id="col-container3"):
next_btn = gr.Button("Generate", variant="primary", scale=1)
with gr.Row(elem_id="compact-row"):
with gr.Column(elem_id="col-container"):
images_method = gr.Gallery(label="Images", columns=[4], rows=[1], height=400, elem_id="gallery", format="png")
with gr.Column(elem_id="col-container3"):
like_image = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)
dislike_image = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)
with gr.Column(elem_id="col-container2", visible=False):
gr.Markdown("### πŸ“ Evaluation")
sim_radio = gr.Radio(
OPTIONS,
label="How would you rate your satisfaction with the generated images?",
type="value",
elem_classes=["gradio-radio"]
)
like_radio = gr.Radio(
IMAGE_OPTIONS,
label="Select your all-time favorite image (optional).",
type="value",
elem_classes=["gradio-radio"]
)
dislike_radio = gr.Radio(
IMAGE_OPTIONS,
label="Select your all-time least satisfactory image (optional).",
type="value",
elem_classes=["gradio-radio"]
)
with gr.Column(elem_id="col-container2", visible=False):
example = gr.Examples([['']], prompt, label="Revised Prompt History", visible=False)
history_images = gr.Gallery(label="History Images", columns=[4], rows=[1], elem_id="gallery", format="png")
with gr.Row(elem_id="button-container"):
redesign_btn = gr.Button("🎨 Redesign", variant="primary", scale=0)
submit_btn = gr.Button("βœ… Submit", variant="primary", interactive=False, scale=0)
with gr.Column(elem_id="col-container2"):
gr.Markdown("### 🌟 Examples")
ex1 = gr.Image(label="Image 1", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)
ex2 = gr.Image(label="Image 2", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)
ex3 = gr.Image(label="Image 3", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)
ex4 = gr.Image(label="Image 4", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)
gr.Examples(
examples=[[ex[0], ex[1][0], ex[1][1], ex[1][2], ex[1][3]] for ex in example_data],
inputs=[prompt, ex1, ex2, ex3, ex4]
)
# =========================
# Wiring
# =========================
sim_radio.change(fn=check_satisfaction, inputs=[sim_radio], outputs=[submit_btn])
dislike_radio.select(fn=select_image, inputs=[dislike_radio, images_method], outputs=[dislike_image])
like_radio.select(fn=select_image, inputs=[like_radio, images_method], outputs=[like_image])
next_btn.click(
fn=generate_image,
inputs=[prompt, like_image, dislike_image],
outputs=[images_method]
).success(lambda: [gr.update(interactive=True), gr.update(interactive=True)], outputs=[next_btn, prompt])
redesign_btn.click(
fn=redesign,
inputs=[prompt, sim_radio, like_radio, dislike_radio, images_method, history_images, like_image, dislike_image],
outputs=[sim_radio, dislike_radio, like_radio, images_method, history_images, example.dataset, prompt, next_btn, redesign_btn, submit_btn]
)
submit_btn.click(
fn=save_response,
inputs=[prompt, sim_radio, like_radio, dislike_radio, like_image, dislike_image],
outputs=[sim_radio, dislike_radio, like_radio, prompt, next_btn, redesign_btn, submit_btn]
)
if __name__ == "__main__":
demo.launch()