File size: 16,792 Bytes
185470f
4c26858
9cf98ec
 
 
 
 
 
70f494f
4c26858
b18de3e
 
185470f
 
 
 
 
b18de3e
185470f
 
 
 
 
9cf98ec
 
13ebcb9
185470f
9cf98ec
 
185470f
 
4c26858
9cf98ec
 
185470f
9cf98ec
 
70f494f
9cf98ec
 
70f494f
9cf98ec
b18de3e
185470f
 
 
a65a087
 
13ebcb9
a65a087
 
 
13ebcb9
a65a087
 
 
13ebcb9
a65a087
 
 
 
185470f
 
 
9cf98ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64dd181
 
 
 
 
 
9cf98ec
70f494f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185470f
 
 
9cf98ec
 
 
 
 
 
 
 
 
185470f
 
 
 
73570ff
64dd181
 
 
 
 
 
 
 
 
 
 
 
 
185470f
 
 
 
 
a92b2e7
185470f
a92b2e7
185470f
 
a65a087
 
 
 
 
185470f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a92b2e7
73570ff
 
185470f
 
64dd181
 
 
 
185470f
64dd181
 
d827699
9cf98ec
d827699
9cf98ec
185470f
a8549bf
64dd181
9cf98ec
185470f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cf98ec
185470f
9cf98ec
185470f
 
 
9cf98ec
185470f
9cf98ec
 
 
 
 
 
 
 
de22521
64dd181
de22521
 
9cf98ec
 
185470f
9cf98ec
70f494f
 
 
 
 
195aa1e
 
2e7f19e
 
195aa1e
70f494f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cf98ec
 
 
70f494f
a65a087
b18de3e
 
70f494f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b18de3e
 
 
70f494f
 
 
 
9cf98ec
a65a087
185470f
70f494f
 
 
 
 
 
 
 
 
 
185470f
 
 
 
 
 
13ebcb9
 
185470f
13ebcb9
185470f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cf98ec
 
13ebcb9
185470f
 
 
 
 
 
 
a65a087
 
 
 
 
 
 
 
 
 
 
70f494f
185470f
 
 
 
 
 
 
 
 
 
 
 
13ebcb9
185470f
 
 
 
 
9cf98ec
 
185470f
 
 
 
 
9cf98ec
 
185470f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435

import gradio as gr
import numpy as np
import random
import spaces
import torch
import re
import transformers
import open_clip

# from Pilot-Phase3.optim_utils import optimize_prompt
# from Pi
from optim_utils import optimize_prompt
from utils import (
    clean_response_gpt, setup_model, init_gpt_api, call_gpt_api,
    get_refine_msg, clean_cache, get_personalize_message,
    clean_refined_prompt_response_gpt, IMAGES, OPTIONS, T2I_MODELS,
    INSTRUCTION, IMAGE_OPTIONS, PROMPTS, SCENARIOS  
)

# =========================
# Constants / Defaults
# =========================
CLIP_MODEL = "ViT-H-14"
PRETRAINED_CLIP = "laion2b_s32b_b79k"
default_t2i_model = "black-forest-labs/FLUX.1-dev"
default_llm_model = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
NUM_IMAGES = 4
MAX_ROUND = 5

device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
clean_cache()

selected_pipe = setup_model(default_t2i_model, torch_dtype, device)
clip_model, _, preprocess = open_clip.create_model_and_transforms(CLIP_MODEL, pretrained=PRETRAINED_CLIP, device=device)
llm_pipe = None
inverted_prompt = ""
torch.cuda.empty_cache()

METHOD = "Experimental" 
counter = 1
enable_submit = False
responses_memory = {METHOD: {}}
example_data = [
    [
        PROMPTS["Tourist promotion"],
        IMAGES["Tourist promotion"]["ours"]
    ],
    [
        PROMPTS["Fictional character generation"],
        IMAGES["Fictional character generation"]["ours"]
    ],
    [
        PROMPTS["Interior Design"],
        IMAGES["Interior Design"]["ours"]
    ],
]

# =========================
# Image Generation Helpers
# =========================
@spaces.GPU(duration=65)
def infer(
    prompt,
    negative_prompt="",
    seed=42,
    randomize_seed=True,
    width=256,
    height=256,
    guidance_scale=5,
    num_inference_steps=18,
    progress=gr.Progress(track_tqdm=True),
):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)

    generator = torch.Generator().manual_seed(seed)
    with torch.no_grad():
        image = selected_pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            width=width,
            height=height,
            generator=generator,
        ).images[0]

    return image

def call_gpt_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0.7, top_p=0.9):
    seed = random.randint(0, MAX_SEED)
    client = init_gpt_api()
    messages = get_refine_msg(prompt, num_prompts)
    outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens, temperature, top_p)
    prompt_list = clean_response_gpt(outputs)
    return prompt_list

def personalize_prompt(prompt, history, feedback, like_image, dislike_image):
    seed = random.randint(0, MAX_SEED)
    client = init_gpt_api()
    messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image)
    outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9)
    return outputs

@spaces.GPU(duration=100)
def invert_prompt(prompt, images, prompt_len=15, iter=500, lr=0.1, batch_size=2):
    global inverted_prompt
    text_params = {
        "iter": iter,
        "lr": lr,
        "batch_size": batch_size,
        "prompt_len": prompt_len,
        "weight_decay": 0.1,
        "prompt_bs": 1,
        "loss_weight": 1.0,
        "print_step": 100,
        "clip_model": CLIP_MODEL,
        "clip_pretrain": PRETRAINED_CLIP,
    }
    inverted_prompt = optimize_prompt(clip_model, preprocess, text_params, device, target_images=images, target_prompts=prompt)

    # eval(prompt, learned_prompt, optimized_images, clip_model, preprocess)
    # return learned_prompt

# =========================
# UI Helper Functions
# =========================
def reset_gallery():
    return []

def display_error_message(msg, duration=5):
    gr.Warning(msg, duration=duration)

def display_info_message(msg, duration=5):
    gr.Info(msg, duration=duration)

def check_satisfaction(sim_radio):
    global enable_submit, counter
    fully_satisfied_option = ["Satisfied", "Very Satisfied"]
    if_submit = (sim_radio in fully_satisfied_option) or enable_submit or (counter > MAX_ROUND)
    return gr.update(interactive=if_submit)

def select_image(like_radio, images_method):
    if like_radio == IMAGE_OPTIONS[0]:
        return images_method[0][0]
    elif like_radio == IMAGE_OPTIONS[1]:
        return images_method[1][0]
    elif like_radio == IMAGE_OPTIONS[2]:
        return images_method[2][0]
    elif like_radio == IMAGE_OPTIONS[3]:
        return images_method[3][0]
    else:
        return None

def check_evaluation(sim_radio):
    if not sim_radio:
        display_error_message("❌ Please fill all evaluations before changing image or submitting.")
        return False
    return True

def generate_image(prompt, like_image, dislike_image):
    global responses_memory
    history_prompts = [v["prompt"] for v in responses_memory[METHOD].values()]
    feedback = [v["sim_radio"] for v in responses_memory[METHOD].values()]
    personalized = prompt
    # personalized = personalize_prompt(prompt, history_prompts, feedback, like_image, dislike_image)
    # personalized = clean_refined_prompt_response_gpt(personalized)
    # if "I'm sorry, I can't assist with" in personalized:
    #     personalized = prompt
    gallery_images = []
    refined_prompts = call_gpt_refine_prompt(personalized)
    for i in range(NUM_IMAGES):
        img = infer(refined_prompts[i])
        gallery_images.append(img)
        yield gallery_images

def redesign(prompt, sim_radio, like_radio, dislike_radio, current_images, history_images, like_image, dislike_image):
    global counter, enable_submit, responses_memory
    if check_evaluation(sim_radio):
        responses_memory[METHOD][counter] = {
            "prompt": prompt,
            "sim_radio": sim_radio,
            "response": "",
            "satisfied_img": f"round {counter}, {like_radio}",
            "unsatisfied_img": f"round {counter}, {dislike_radio}",
        }

        enable_submit = True if sim_radio in ["Satisfied", "Very Satisfied"] or enable_submit else False

        history_prompts = [[v["prompt"]] for v in responses_memory[METHOD].values()]
        if not history_images:
            history_images = current_images
        elif current_images:
            history_images.extend(current_images)
        current_images = []

        examples_state = gr.update(samples=history_prompts, visible=True)
        prompt_state = gr.update(interactive=True)
        next_state = gr.update(visible=True, interactive=True)
        redesign_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(interactive=True)
        submit_state = gr.update(interactive=True) if counter >= MAX_ROUND or enable_submit else gr.update(interactive=False)

        counter += 1

        return None, None, None, current_images, history_images, examples_state, prompt_state, next_state, redesign_state, submit_state
    else:
        return {submit_btn: gr.skip()}

def save_response(prompt, sim_radio, like_radio, dislike_radio, like_image, dislike_image):
    global counter, enable_submit, responses_memory

    if check_evaluation(sim_radio):
        # Save the final round entry
        responses_memory[METHOD][counter] = {
            "prompt": prompt,
            "sim_radio": sim_radio,
            "response": "",
            "satisfied_img": f"round {counter}, {like_radio}",
            "unsatisfied_img": f"round {counter}, {dislike_radio}",
        }

        # Reset states
        counter = 1
        enable_submit = False

        # Reset buttons
        prompt_state = gr.update(interactive=False)
        next_state = gr.update(visible=False, interactive=False)
        submit_state = gr.update(interactive=False)
        redesign_state = gr.update(interactive=False)

        display_info_message("βœ… Your answer is saved!")
        return None, None, None, prompt_state, next_state, redesign_state, submit_state
    else:
        return {submit_btn: gr.skip()}

# =========================
# Interface (single tab, no participant/scenario/background)
# =========================

css = """
#col-container {
    margin: 0 auto;
    max-width: 700px;
}
#col-container2 {
    margin: 0 auto;
    max-width: 1000px;
}
#col-container3 {
    margin: 0 0 auto auto;
    max-width: 300px;
}
#button-container {
    display: flex;
    justify-content: center;
}
#compact-compact-row {
    width:100%;
    max-width: 800px;
    margin: 0px auto;
}
#compact-row {
    width:100%;
    max-width: 1000px;
    margin: 0px auto;
}
.header-section {
    text-align: center;
    margin-bottom: 2rem;
}
.abstract-text {
    text-align: justify;
    line-height: 1.6;
    margin: 0.5rem 0;
    padding: 0.5rem;
    background-color: rgba(0, 0, 0, 0.05);
    border-radius: 8px;
    border-left: 4px solid #3498db;
}
.paper-link {
    display: inline-block;
    margin: 0rem 0;
    padding: 0rem 0rem;
    background-color: #3498db;
    color: white;
    text-decoration: none;
    border-radius: 5px;
    font-weight: 500;
}
.paper-link:hover {
    background-color: #2980b9;
    text-decoration: none;
}
.authors-section {
    text-align: center;
    margin: 0 0;
    font-style: italic;
    color: #666;
}
.authors-title {
    font-weight: bold;
    margin-bottom: 0rem;
    color: #333;
}
"""

with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo:
    with gr.Column(elem_id="col-container", elem_classes=["header-section"]):
        gr.Markdown("# πŸ“Œ **POET**")
        gr.HTML('<div><img src="images/icon.png" width="200"></div>')
        gr.Markdown("### Supporting Prompting Creativity with Automated Expansion of Text-to-Image Generation")
        
        # <strong>Abstract:</strong> State-of-the-art visual generative AI tools hold immense potential to assist users in the early ideation stages of creative tasks β€” offering the ability to generate (rather than search for) novel and unprecedented (instead of existing) images of considerable quality that also adhere to boundless combinations of user specifications. However, many large-scale text-to-image systems are designed for broad applicability, yielding conventional output that may limit creative exploration. They also employ interaction methods that may be difficult for beginners.        # 
        gr.Markdown("""
        <div class="abstract-text">
        <strong>Abstract:</strong> Given that creative end-users often operate in diverse, context-specific ways that are often unpredictable, more variation and personalization are necessary. We introduce POET, a real-time interactive tool that (1) automatically discovers dimensions of homogeneity in text-to-image generative models, (2) expands these dimensions to diversify the output space of generated images, and (3) learns from user feedback to personalize expansions. Focusing on visual creativity, POET offers a first glimpse of how interaction techniques of future text-to-image generation tools may support and align with more pluralistic values and the needs of end-users during the ideation stages of their work.
        </div>
        """, elem_classes=["abstract-text"])
        
        # Paper Link
        gr.HTML("""
        <div style="text-align: center;">
            <a href="https://arxiv.org/pdf/2504.13392" target="_blank" class="paper-link">
                πŸ“„ Read the Full Paper .
            </a>
        </div>
        """)
        
        # Authors
        gr.Markdown("""
        <div class="authors-section">
            <a href="https://scholar.google.com/citations?user=HXED4kIAAAAJ&hl=en">Evans Han</a>, <a href"https://www.aliceqian.com/">Alice Qian Zhang</a>, 
            <a href="https://haiyizhu.com/">Haiyi Zhu</a>, <a href="https://www.andrew.cmu.edu/user/hongs/">Hong Shen</a>, 
            <a href="https://pliang279.github.io/">Paul Pu Liang</a>, <a href="https://janeon.github.io/">Jane Hsieh</a>
        </div>
        """, elem_classes=["authors-section"])
        
        # gr.Markdown("---")

    with gr.Tab(""):
        with gr.Row(elem_id="compact-row"):
            with gr.Column(elem_id="col-container"):
                with gr.Row():
                    prompt = gr.Textbox(
                        label="🎨 Prompt",
                        max_lines=5,
                        placeholder="Enter your prompt",
                        visible=True,
                    )
            with gr.Column(elem_id="col-container3"):
                next_btn = gr.Button("Generate", variant="primary", scale=1)

        with gr.Row(elem_id="compact-row"):
            with gr.Column(elem_id="col-container"):
                images_method = gr.Gallery(label="Images", columns=[4], rows=[1], height=400, elem_id="gallery", format="png")

            with gr.Column(elem_id="col-container3"):
                like_image = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)
                dislike_image = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)

        with gr.Column(elem_id="col-container2", visible=False):
            gr.Markdown("### πŸ“ Evaluation")
            sim_radio = gr.Radio(
                OPTIONS,
                label="How would you rate your satisfaction with the generated images?",
                type="value",
                elem_classes=["gradio-radio"]
            )
            like_radio = gr.Radio(
                IMAGE_OPTIONS,
                label="Select your all-time favorite image (optional).",
                type="value",
                elem_classes=["gradio-radio"]
            )
            dislike_radio = gr.Radio(
                IMAGE_OPTIONS,
                label="Select your all-time least satisfactory image (optional).",
                type="value",
                elem_classes=["gradio-radio"]
            )

        with gr.Column(elem_id="col-container2", visible=False):
            example = gr.Examples([['']], prompt, label="Revised Prompt History", visible=False)
            history_images = gr.Gallery(label="History Images", columns=[4], rows=[1], elem_id="gallery", format="png")

            with gr.Row(elem_id="button-container"):
                redesign_btn = gr.Button("🎨 Redesign", variant="primary", scale=0)
                submit_btn = gr.Button("βœ… Submit", variant="primary", interactive=False, scale=0)

        with gr.Column(elem_id="col-container2"):
            gr.Markdown("### 🌟 Examples")
            ex1 = gr.Image(label="Image 1", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)
            ex2 = gr.Image(label="Image 2", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)
            ex3 = gr.Image(label="Image 3", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)
            ex4 = gr.Image(label="Image 4", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)

            gr.Examples(
                examples=[[ex[0], ex[1][0], ex[1][1], ex[1][2], ex[1][3]] for ex in example_data],
                inputs=[prompt, ex1, ex2, ex3, ex4]
            )

# =========================
# Wiring
# =========================
    sim_radio.change(fn=check_satisfaction, inputs=[sim_radio], outputs=[submit_btn])

    dislike_radio.select(fn=select_image, inputs=[dislike_radio, images_method], outputs=[dislike_image])
    like_radio.select(fn=select_image, inputs=[like_radio, images_method], outputs=[like_image])

    next_btn.click(
        fn=generate_image,
        inputs=[prompt, like_image, dislike_image],
        outputs=[images_method]
    ).success(lambda: [gr.update(interactive=True), gr.update(interactive=True)], outputs=[next_btn, prompt])

    redesign_btn.click(
        fn=redesign,
        inputs=[prompt, sim_radio, like_radio, dislike_radio, images_method, history_images, like_image, dislike_image],
        outputs=[sim_radio, dislike_radio, like_radio, images_method, history_images, example.dataset, prompt, next_btn, redesign_btn, submit_btn]
    )

    submit_btn.click(
        fn=save_response,
        inputs=[prompt, sim_radio, like_radio, dislike_radio, like_image, dislike_image],
        outputs=[sim_radio, dislike_radio, like_radio, prompt, next_btn, redesign_btn, submit_btn]
    )

if __name__ == "__main__":
    demo.launch()