|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import random |
|
|
import spaces |
|
|
import torch |
|
|
import re |
|
|
import transformers |
|
|
import open_clip |
|
|
|
|
|
from optim_utils import optimize_prompt |
|
|
from utils import ( |
|
|
clean_response_gpt, setup_model, init_gpt_api, call_gpt_api, |
|
|
get_refine_msg, clean_cache, get_personalize_message, get_personalized_simplified, |
|
|
clean_refined_prompt_response_gpt, IMAGES, OPTIONS, T2I_MODELS, |
|
|
INSTRUCTION, IMAGE_OPTIONS, PROMPTS, SCENARIOS |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CLIP_MODEL = "ViT-H-14" |
|
|
PRETRAINED_CLIP = "laion2b_s32b_b79k" |
|
|
default_t2i_model = "black-forest-labs/FLUX.1-dev" |
|
|
default_llm_model = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" |
|
|
MAX_SEED = np.iinfo(np.int32).max |
|
|
MAX_IMAGE_SIZE = 1024 |
|
|
NUM_IMAGES = 4 |
|
|
MAX_ROUND = 5 |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
clean_cache() |
|
|
|
|
|
selected_pipe = setup_model(default_t2i_model, torch_dtype, device) |
|
|
clip_model, _, preprocess = open_clip.create_model_and_transforms(CLIP_MODEL, pretrained=PRETRAINED_CLIP, device=device) |
|
|
llm_pipe = None |
|
|
inverted_prompt = "" |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
METHOD = "Experimental" |
|
|
counter = 1 |
|
|
enable_submit = False |
|
|
redesign_flag = False |
|
|
responses_memory = {METHOD: {}} |
|
|
example_data = [ |
|
|
[ |
|
|
PROMPTS["Tourist promotion"], |
|
|
IMAGES["Tourist promotion"]["ours"] |
|
|
], |
|
|
[ |
|
|
PROMPTS["Fictional character generation"], |
|
|
IMAGES["Fictional character generation"]["ours"] |
|
|
], |
|
|
[ |
|
|
PROMPTS["Interior Design"], |
|
|
IMAGES["Interior Design"]["ours"] |
|
|
], |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU(duration=65) |
|
|
def infer( |
|
|
prompt, |
|
|
negative_prompt="", |
|
|
seed=42, |
|
|
randomize_seed=True, |
|
|
width=256, |
|
|
height=256, |
|
|
guidance_scale=5, |
|
|
num_inference_steps=18, |
|
|
progress=gr.Progress(track_tqdm=True), |
|
|
): |
|
|
if randomize_seed: |
|
|
seed = random.randint(0, MAX_SEED) |
|
|
|
|
|
generator = torch.Generator().manual_seed(seed) |
|
|
with torch.no_grad(): |
|
|
image = selected_pipe( |
|
|
prompt=prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
guidance_scale=guidance_scale, |
|
|
num_inference_steps=num_inference_steps, |
|
|
width=width, |
|
|
height=height, |
|
|
generator=generator, |
|
|
).images[0] |
|
|
|
|
|
return image |
|
|
|
|
|
def call_gpt_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0.7, top_p=0.9): |
|
|
seed = random.randint(0, MAX_SEED) |
|
|
client = init_gpt_api() |
|
|
messages = get_refine_msg(prompt, num_prompts) |
|
|
outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens, temperature, top_p) |
|
|
prompt_list = clean_response_gpt(outputs) |
|
|
return prompt_list |
|
|
|
|
|
def personalize_prompt(prompt, history, feedback, like_image, dislike_image): |
|
|
seed = random.randint(0, MAX_SEED) |
|
|
client = init_gpt_api() |
|
|
|
|
|
messages = get_personalized_simplified(prompt, like_image, dislike_image) |
|
|
outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9) |
|
|
return outputs |
|
|
|
|
|
@spaces.GPU(duration=100) |
|
|
def invert_prompt(prompt, images, prompt_len=15, iter=500, lr=0.1, batch_size=2): |
|
|
global inverted_prompt |
|
|
text_params = { |
|
|
"iter": iter, |
|
|
"lr": lr, |
|
|
"batch_size": batch_size, |
|
|
"prompt_len": prompt_len, |
|
|
"weight_decay": 0.1, |
|
|
"prompt_bs": 1, |
|
|
"loss_weight": 1.0, |
|
|
"print_step": 100, |
|
|
"clip_model": CLIP_MODEL, |
|
|
"clip_pretrain": PRETRAINED_CLIP, |
|
|
} |
|
|
inverted_prompt = optimize_prompt(clip_model, preprocess, text_params, device, target_images=images, target_prompts=prompt) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
current_generated_images = [] |
|
|
|
|
|
def reset_gallery(): |
|
|
return [] |
|
|
|
|
|
def display_error_message(msg, duration=5): |
|
|
gr.Warning(msg, duration=duration) |
|
|
|
|
|
def display_info_message(msg, duration=5): |
|
|
gr.Info(msg, duration=duration) |
|
|
|
|
|
def check_evaluation(sim_radio, like_image, dislike_image): |
|
|
if not sim_radio or not like_image or not dislike_image: |
|
|
display_error_message("β Please fill all evaluations before changing image or submitting.") |
|
|
return False |
|
|
return True |
|
|
|
|
|
def generate_image(prompt, like_image, dislike_image): |
|
|
global responses_memory, current_generated_images |
|
|
history_prompts = [v["prompt"] for v in responses_memory[METHOD].values()] |
|
|
feedback = [v["sim_radio"] for v in responses_memory[METHOD].values()] |
|
|
print(feedback, like_image, dislike_image) |
|
|
if like_image and dislike_image and feedback: |
|
|
personalized = personalize_prompt(prompt, history_prompts, feedback, like_image, dislike_image) |
|
|
else: |
|
|
personalized = prompt |
|
|
gallery_images = [] |
|
|
current_generated_images = [] |
|
|
refined_prompts = call_gpt_refine_prompt(personalized) |
|
|
for i in range(NUM_IMAGES): |
|
|
img = infer(refined_prompts[i]) |
|
|
gallery_images.append(img) |
|
|
current_generated_images.append(img) |
|
|
yield gallery_images |
|
|
|
|
|
def on_gallery_select(evt: gr.SelectData): |
|
|
"""Handle gallery image selection and return the selected image""" |
|
|
global current_generated_images |
|
|
if current_generated_images and evt.index < len(current_generated_images): |
|
|
return current_generated_images[evt.index] |
|
|
return None |
|
|
|
|
|
def handle_like_drag(selected_image): |
|
|
"""Handle setting an image as liked""" |
|
|
return selected_image |
|
|
|
|
|
def handle_dislike_drag(selected_image): |
|
|
"""Handle setting an image as disliked""" |
|
|
return selected_image |
|
|
|
|
|
def redesign(prompt, sim_radio, current_images, history_images, like_image, dislike_image): |
|
|
global counter, responses_memory, redesign_flag |
|
|
|
|
|
if check_evaluation(sim_radio, like_image, dislike_image): |
|
|
responses_memory[METHOD][counter] = { |
|
|
"prompt": prompt, |
|
|
"sim_radio": sim_radio, |
|
|
"response": "", |
|
|
"satisfied_img": f"round {counter}, liked image", |
|
|
"unsatisfied_img": f"round {counter}, disliked image", |
|
|
} |
|
|
|
|
|
history_prompts = [[v["prompt"]] for v in responses_memory[METHOD].values()] |
|
|
|
|
|
|
|
|
if not history_images: |
|
|
history_images = current_images.copy() if current_images else [] |
|
|
elif current_images: |
|
|
history_images.extend(current_images) |
|
|
|
|
|
current_images = [] |
|
|
|
|
|
examples_state = gr.update(samples=history_prompts, visible=True) |
|
|
prompt_state = gr.update(interactive=True) |
|
|
next_state = gr.update(visible=True, interactive=True) |
|
|
redesign_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(interactive=True) |
|
|
|
|
|
counter += 1 |
|
|
redesign_flag = True |
|
|
|
|
|
display_info_message(f"β
Round {counter-1} feedback saved! You can continue redesigning or restart.") |
|
|
|
|
|
return None, current_images, history_images, examples_state, prompt_state, next_state, redesign_state |
|
|
else: |
|
|
return gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip() |
|
|
|
|
|
def save_response(prompt, sim_radio, like_image, dislike_image): |
|
|
global counter, responses_memory, redesign_flag, current_generated_images |
|
|
|
|
|
|
|
|
responses_memory[METHOD] = {} |
|
|
counter = 1 |
|
|
redesign_flag = False |
|
|
current_generated_images = [] |
|
|
|
|
|
|
|
|
prompt_state = gr.update(value="", interactive=True) |
|
|
next_state = gr.update(visible=True, interactive=True) |
|
|
redesign_state = gr.update(interactive=False) |
|
|
submit_state = gr.update(interactive=False) |
|
|
sim_radio_state = gr.update(value=None) |
|
|
like_image_state = gr.update(value=None) |
|
|
dislike_image_state = gr.update(value=None) |
|
|
gallery_state = [] |
|
|
history_gallery_state = [] |
|
|
examples_state = gr.update(samples=[['']], visible=True) |
|
|
|
|
|
display_info_message("π Session restarted! You can begin with a new prompt.") |
|
|
|
|
|
return (sim_radio_state, prompt_state, next_state, redesign_state, |
|
|
like_image_state, dislike_image_state, gallery_state, history_gallery_state, examples_state) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
css = """ |
|
|
#col-container { |
|
|
margin: 0 auto; |
|
|
max-width: 700px; |
|
|
} |
|
|
#col-container2 { |
|
|
margin: 0 auto; |
|
|
max-width: 1000px; |
|
|
} |
|
|
#col-container3 { |
|
|
margin: 0 0 auto auto; |
|
|
max-width: 300px; |
|
|
} |
|
|
#button-container { |
|
|
display: flex; |
|
|
justify-content: center; |
|
|
gap: 10px; |
|
|
} |
|
|
#compact-compact-row { |
|
|
width:100%; |
|
|
max-width: 800px; |
|
|
margin: 0px auto; |
|
|
} |
|
|
#compact-row { |
|
|
width:100%; |
|
|
max-width: 1000px; |
|
|
margin: 0px auto; |
|
|
} |
|
|
.header-section { |
|
|
text-align: center; |
|
|
margin-bottom: 2rem; |
|
|
} |
|
|
.abstract-text { |
|
|
text-align: justify; |
|
|
line-height: 1.5; |
|
|
margin: 0rem 0; |
|
|
padding: 0 0.5rem; |
|
|
background-color: rgba(0, 0, 0, 0.05); |
|
|
border-radius: 8px; |
|
|
border-left: 4px solid #3498db; |
|
|
} |
|
|
.paper-link { |
|
|
display: inline-block; |
|
|
margin: 0rem 0; |
|
|
padding: 0rem 0rem; |
|
|
background-color: #3498db; |
|
|
color: white; |
|
|
text-decoration: none; |
|
|
border-radius: 5px; |
|
|
font-weight: 500; |
|
|
} |
|
|
.paper-link:hover { |
|
|
background-color: #2980b9; |
|
|
text-decoration: none; |
|
|
} |
|
|
.authors-section { |
|
|
text-align: center; |
|
|
margin: 0 0; |
|
|
font-style: italic; |
|
|
color: #666; |
|
|
} |
|
|
.authors-title { |
|
|
font-weight: bold; |
|
|
margin-bottom: 0rem; |
|
|
color: #333; |
|
|
} |
|
|
.logo-container { |
|
|
text-align: center; |
|
|
margin: 0.5rem 0 1rem 0; |
|
|
} |
|
|
.logo-container img { |
|
|
height: 60px; |
|
|
width: auto; |
|
|
max-width: 150px; |
|
|
display: inline-block; |
|
|
} |
|
|
.instruction-box { |
|
|
background: linear-gradient(135deg, #e8f4fd 0%, #f0f8ff 100%); |
|
|
border: 2px solid #3498db; |
|
|
border-radius: 12px; |
|
|
padding: 20px; |
|
|
margin: 15px 0; |
|
|
color: #2c3e50; |
|
|
} |
|
|
.instruction-title { |
|
|
font-size: 1.2em; |
|
|
font-weight: bold; |
|
|
margin-bottom: 15px; |
|
|
color: #2c3e50; |
|
|
display: flex; |
|
|
align-items: center; |
|
|
gap: 8px; |
|
|
} |
|
|
.step-list { |
|
|
list-style: none; |
|
|
padding: 0; |
|
|
margin: 0; |
|
|
} |
|
|
.step-item { |
|
|
background: rgba(52, 152, 219, 0.1); |
|
|
border-radius: 8px; |
|
|
padding: 12px 16px; |
|
|
margin: 8px 0; |
|
|
border-left: 4px solid #3498db; |
|
|
} |
|
|
.step-number { |
|
|
font-weight: bold; |
|
|
color: #3498db; |
|
|
margin-right: 8px; |
|
|
} |
|
|
.personalization-header { |
|
|
background: linear-gradient(135deg, #ff6b6b, #ee5a24); |
|
|
color: white; |
|
|
padding: 15px; |
|
|
border-radius: 10px 10px 0 0; |
|
|
margin: -10px -10px 15px -10px; |
|
|
text-align: center; |
|
|
font-weight: bold; |
|
|
font-size: 1.1em; |
|
|
} |
|
|
""" |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo: |
|
|
|
|
|
selected_image = gr.State(None) |
|
|
|
|
|
with gr.Column(elem_id="col-container", elem_classes=["header-section"]): |
|
|
gr.HTML('<div class="logo-container"><img src="https://huggingface.co/spaces/PAI-GEN/POET/resolve/main/images/icon.png" alt="POET Logo"></div>') |
|
|
gr.Markdown("### Supporting Prompting Creativity with Automated Expansion of Text-to-Image Generation") |
|
|
|
|
|
gr.HTML(""" |
|
|
<div style="text-align: center;"> |
|
|
<a href="https://arxiv.org/pdf/2504.13392" target="_blank" class="paper-link"> |
|
|
π Read the Full Paper |
|
|
</a> |
|
|
</div> |
|
|
""") |
|
|
gr.Markdown(""" |
|
|
<div class="abstract-text"> |
|
|
<strong>Abstract:</strong> Given that creative end-users often operate in diverse, context-specific ways that are often unpredictable, more variation and personalization are necessary. We introduce POET, a real-time interactive tool that (1) automatically discovers dimensions of homogeneity in text-to-image generative models, (2) expands these dimensions to diversify the output space of generated images, and (3) learns from user feedback to personalize expansions. Focusing on visual creativity, POET offers a first glimpse of how interaction techniques of future text-to-image generation tools may support and align with more pluralistic values and the needs of end-users during the ideation stages of their work. |
|
|
</div> |
|
|
""", elem_classes=["abstract-text"]) |
|
|
|
|
|
gr.Markdown(""" |
|
|
<div class="authors-section"> |
|
|
<a href="https://scholar.google.com/citations?user=HXED4kIAAAAJ&hl=en">Evans Han</a>, |
|
|
<a href="https://www.aliceqian.com/">Alice Qian Zhang</a>, |
|
|
<a href="https://haiyizhu.com/">Haiyi Zhu</a>, |
|
|
<a href="https://www.andrew.cmu.edu/user/hongs/">Hong Shen</a>, |
|
|
<a href="https://pliang279.github.io/">Paul Pu Liang</a>, |
|
|
<a href="https://janeon.github.io/">Jane Hsieh</a> |
|
|
</div> |
|
|
""", elem_classes=["authors-section"]) |
|
|
|
|
|
|
|
|
with gr.Tab(""): |
|
|
with gr.Row(elem_id="compact-row"): |
|
|
with gr.Column(elem_id="col-container"): |
|
|
with gr.Row(): |
|
|
prompt = gr.Textbox( |
|
|
label="π¨ Prompt", |
|
|
max_lines=5, |
|
|
placeholder="Enter your prompt", |
|
|
visible=True, |
|
|
) |
|
|
with gr.Column(elem_id="col-container3"): |
|
|
next_btn = gr.Button("Generate", variant="primary", scale=1) |
|
|
|
|
|
with gr.Row(elem_id="compact-row"): |
|
|
with gr.Column(elem_id="col-container"): |
|
|
images_method = gr.Gallery( |
|
|
label="Generated Images (Click to select, then set to Like/Dislike image)", |
|
|
columns=[4], |
|
|
rows=[1], |
|
|
height=400, |
|
|
interactive=False, |
|
|
elem_id="gallery", |
|
|
format="png" |
|
|
) |
|
|
|
|
|
with gr.Column(elem_id="col-container3"): |
|
|
like_btn = gr.Button("π Set as Liked (Optional for personalization)", size="sm", variant="secondary") |
|
|
like_image = gr.Image( |
|
|
label="Satisfied Image", |
|
|
width=150, |
|
|
height=150, |
|
|
interactive=False, |
|
|
format="png", |
|
|
type="filepath" |
|
|
) |
|
|
dislike_btn = gr.Button("π Set as Disliked (Optional for personalization)", size="sm", variant="secondary") |
|
|
dislike_image = gr.Image( |
|
|
label="Unsatisfied Image", |
|
|
width=150, |
|
|
height=150, |
|
|
interactive=False, |
|
|
format="png", |
|
|
type="filepath" |
|
|
) |
|
|
|
|
|
with gr.Accordion("π― Advanced: Personalized Image Redesign", open=False, elem_id="col-container2"): |
|
|
gr.HTML(""" |
|
|
<div class="instruction-box"> |
|
|
<div class="instruction-title"> |
|
|
π How to Use Personalized Redesign |
|
|
</div> |
|
|
<div class="step-list"> |
|
|
<div class="step-item"> |
|
|
<span class="step-number">1.</span> |
|
|
<strong>Rate Your Satisfaction:</strong> Provide a satisfaction score for the current generated images |
|
|
</div> |
|
|
<div class="step-item"> |
|
|
<span class="step-number">2.</span> |
|
|
<strong>Select Preferences:</strong> Choose your most liked and disliked images |
|
|
</div> |
|
|
<div class="step-item"> |
|
|
<span class="step-number">3.</span> |
|
|
<strong>Save & Iterate:</strong> Click "Save Personalized Data" before redesgining your prompt and clicking "Generate" |
|
|
</div> |
|
|
<div class="step-item"> |
|
|
<span class="step-number">4.</span> |
|
|
<strong>Restart Anytime:</strong> Use the "Restart" button to begin a fresh session |
|
|
</div> |
|
|
</div> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
with gr.Column(elem_id="col-container2"): |
|
|
gr.Markdown("### π Rate Current Generation") |
|
|
with gr.Row(): |
|
|
sim_radio = gr.Radio( |
|
|
OPTIONS, |
|
|
label="How satisfied are you with the current generated images?", |
|
|
type="value", |
|
|
show_label=True, |
|
|
container=True, |
|
|
scale=1 |
|
|
) |
|
|
|
|
|
with gr.Row(elem_id="button-container"): |
|
|
with gr.Column(scale=1): |
|
|
redesign_btn = gr.Button("πΎ Save Personalization Data", variant="primary", size="lg") |
|
|
with gr.Column(scale=1): |
|
|
submit_btn = gr.Button("π Restart Session", variant="secondary", size="lg") |
|
|
|
|
|
|
|
|
with gr.Column(elem_id="col-container2"): |
|
|
example = gr.Examples([['']], prompt, label="π Prompt History", visible=True) |
|
|
history_images = gr.Gallery( |
|
|
label="ποΈ Generation History", |
|
|
columns=[4], |
|
|
rows=[1], |
|
|
elem_id="gallery", |
|
|
format="png", |
|
|
interactive=False, |
|
|
) |
|
|
|
|
|
with gr.Column(elem_id="col-container2"): |
|
|
gr.Markdown("### π Examples") |
|
|
ex1 = gr.Image(label="Image 1", width=200, height=200, format="png", type="filepath", visible=False) |
|
|
ex2 = gr.Image(label="Image 2", width=200, height=200, format="png", type="filepath", visible=False) |
|
|
ex3 = gr.Image(label="Image 3", width=200, height=200, format="png", type="filepath", visible=False) |
|
|
ex4 = gr.Image(label="Image 4", width=200, height=200, format="png", type="filepath", visible=False) |
|
|
|
|
|
gr.Examples( |
|
|
examples=[[ex[0], ex[1][0], ex[1][1], ex[1][2], ex[1][3]] for ex in example_data], |
|
|
inputs=[prompt, ex1, ex2, ex3, ex4] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
images_method.select( |
|
|
fn=on_gallery_select, |
|
|
inputs=[], |
|
|
outputs=[selected_image] |
|
|
) |
|
|
|
|
|
|
|
|
like_btn.click( |
|
|
fn=handle_like_drag, |
|
|
inputs=[selected_image], |
|
|
outputs=[like_image] |
|
|
) |
|
|
|
|
|
dislike_btn.click( |
|
|
fn=handle_dislike_drag, |
|
|
inputs=[selected_image], |
|
|
outputs=[dislike_image] |
|
|
) |
|
|
|
|
|
next_btn.click( |
|
|
fn=generate_image, |
|
|
inputs=[prompt, like_image, dislike_image], |
|
|
outputs=[images_method] |
|
|
).success(lambda: [gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)], |
|
|
outputs=[next_btn, prompt, redesign_btn, submit_btn]) |
|
|
|
|
|
redesign_btn.click( |
|
|
fn=redesign, |
|
|
inputs=[prompt, sim_radio, images_method, history_images, like_image, dislike_image], |
|
|
outputs=[sim_radio, images_method, history_images, example.dataset, prompt, next_btn, redesign_btn] |
|
|
) |
|
|
|
|
|
submit_btn.click( |
|
|
fn=save_response, |
|
|
inputs=[prompt, sim_radio, like_image, dislike_image], |
|
|
outputs=[sim_radio, prompt, next_btn, redesign_btn, like_image, dislike_image, images_method, history_images, example.dataset] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |