File size: 12,740 Bytes
185470f
4c26858
9cf98ec
 
 
 
 
 
4c26858
185470f
 
 
 
 
 
 
 
 
 
 
 
9cf98ec
 
a65a087
185470f
9cf98ec
 
185470f
 
4c26858
9cf98ec
 
185470f
9cf98ec
 
 
 
 
 
185470f
9cf98ec
185470f
 
 
 
9cf98ec
a65a087
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185470f
 
 
9cf98ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64dd181
 
 
a65a087
64dd181
 
 
9cf98ec
185470f
 
 
9cf98ec
 
 
 
 
 
 
 
 
185470f
 
 
 
73570ff
64dd181
 
 
 
 
 
 
 
 
 
 
 
 
185470f
 
 
 
 
a92b2e7
185470f
a92b2e7
185470f
 
a65a087
 
 
 
 
185470f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a92b2e7
73570ff
 
185470f
 
64dd181
 
 
 
185470f
64dd181
 
d827699
9cf98ec
d827699
9cf98ec
185470f
a8549bf
64dd181
9cf98ec
185470f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cf98ec
185470f
9cf98ec
185470f
 
 
9cf98ec
185470f
9cf98ec
 
 
 
 
 
 
 
de22521
64dd181
de22521
 
9cf98ec
 
185470f
9cf98ec
195aa1e
 
2e7f19e
 
195aa1e
9cf98ec
 
 
 
a65a087
 
9cf98ec
a65a087
185470f
 
 
 
 
a65a087
185470f
9cf98ec
a65a087
185470f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cf98ec
 
185470f
 
 
 
 
 
 
 
a65a087
 
 
 
 
 
 
 
 
 
 
185470f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cf98ec
 
185470f
 
 
 
 
9cf98ec
 
185470f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343

import gradio as gr
import numpy as np
import random
import spaces
import torch
import re
import transformers

# Optional: keep these utilities if your pipeline depends on them
from optim_utils import optimize_prompt
from utils import (
    clean_response_gpt, setup_model, init_gpt_api, call_gpt_api,
    get_refine_msg, clean_cache, get_personalize_message,
    clean_refined_prompt_response_gpt, IMAGES, OPTIONS, T2I_MODELS,
    INSTRUCTION, IMAGE_OPTIONS, PROMPTS, SCENARIOS  # some may be unused after simplification
)

# =========================
# Constants / Defaults
# =========================
CLIP_MODEL = "ViT-H-14"
PRETRAINED_CLIP = "laion2b_s32b_b79k"
default_t2i_model = "black-forest-labs/FLUX.1-schnell" # "black-forest-labs/FLUX.1-dev"
default_llm_model = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
NUM_IMAGES = 4
MAX_ROUND = 5

device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
clean_cache()

selected_pipe = setup_model(default_t2i_model, torch_dtype, device)
llm_pipe = None
torch.cuda.empty_cache()
inverted_prompt = ""

METHOD = "Experimental"  # keep ONLY experimental

# Global states for a single-task, single-method flow
counter = 1
enable_submit = False
responses_memory = {METHOD: {}}

example_data = [
    [
        "A futuristic city skyline at sunset",
        IMAGES["Tourist promotion"]["ours"]
    ],
    [
        "A fantasy castle in the clouds",
        IMAGES["Fictional character generation"]["ours"]
    ],
    [
        "A robot painting a portrait in a studio",
        IMAGES["Interior Design"]["ours"]
    ],
]
print(example_data)

# =========================
# Image Generation Helpers
# =========================
@spaces.GPU(duration=65)
def infer(
    prompt,
    negative_prompt="",
    seed=42,
    randomize_seed=True,
    width=256,
    height=256,
    guidance_scale=5,
    num_inference_steps=18,
    progress=gr.Progress(track_tqdm=True),
):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)

    generator = torch.Generator().manual_seed(seed)
    with torch.no_grad():
        image = selected_pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            width=width,
            height=height,
            generator=generator,
        ).images[0]

    return image

def call_gpt_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0.7, top_p=0.9):
    seed = random.randint(0, MAX_SEED)
    client = init_gpt_api()
    messages = get_refine_msg(prompt, num_prompts)
    outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens, temperature, top_p)
    prompt_list = clean_response_gpt(outputs)
    return prompt_list

def personalize_prompt(prompt, history, feedback, like_image, dislike_image):
    seed = random.randint(0, MAX_SEED)
    client = init_gpt_api()
    print(like_image, dislike_image)
    messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image)
    outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9)
    return outputs

# =========================
# UI Helper Functions
# =========================
def reset_gallery():
    return []

def display_error_message(msg, duration=5):
    gr.Warning(msg, duration=duration)

def display_info_message(msg, duration=5):
    gr.Info(msg, duration=duration)

def check_satisfaction(sim_radio):
    global enable_submit, counter
    fully_satisfied_option = ["Satisfied", "Very Satisfied"]
    if_submit = (sim_radio in fully_satisfied_option) or enable_submit or (counter > MAX_ROUND)
    return gr.update(interactive=if_submit)

def select_image(like_radio, images_method):
    if like_radio == IMAGE_OPTIONS[0]:
        return images_method[0][0]
    elif like_radio == IMAGE_OPTIONS[1]:
        return images_method[1][0]
    elif like_radio == IMAGE_OPTIONS[2]:
        return images_method[2][0]
    elif like_radio == IMAGE_OPTIONS[3]:
        return images_method[3][0]
    else:
        return None

def check_evaluation(sim_radio):
    if not sim_radio:
        display_error_message("❌ Please fill all evaluations before changing image or submitting.")
        return False
    return True

def generate_image(prompt, like_image, dislike_image):
    global responses_memory
    history_prompts = [v["prompt"] for v in responses_memory[METHOD].values()]
    feedback = [v["sim_radio"] for v in responses_memory[METHOD].values()]
    personalized = prompt
    # personalized = personalize_prompt(prompt, history_prompts, feedback, like_image, dislike_image)
    # personalized = clean_refined_prompt_response_gpt(personalized)
    # if "I'm sorry, I can't assist with" in personalized:
    #     personalized = prompt
    gallery_images = []
    refined_prompts = call_gpt_refine_prompt(personalized)
    for i in range(NUM_IMAGES):
        img = infer(refined_prompts[i])
        gallery_images.append(img)
        yield gallery_images

def redesign(prompt, sim_radio, like_radio, dislike_radio, current_images, history_images, like_image, dislike_image):
    global counter, enable_submit, responses_memory
    if check_evaluation(sim_radio):
        responses_memory[METHOD][counter] = {
            "prompt": prompt,
            "sim_radio": sim_radio,
            "response": "",
            "satisfied_img": f"round {counter}, {like_radio}",
            "unsatisfied_img": f"round {counter}, {dislike_radio}",
        }

        enable_submit = True if sim_radio in ["Satisfied", "Very Satisfied"] or enable_submit else False

        history_prompts = [[v["prompt"]] for v in responses_memory[METHOD].values()]
        if not history_images:
            history_images = current_images
        elif current_images:
            history_images.extend(current_images)
        current_images = []

        examples_state = gr.update(samples=history_prompts, visible=True)
        prompt_state = gr.update(interactive=True)
        next_state = gr.update(visible=True, interactive=True)
        redesign_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(interactive=True)
        submit_state = gr.update(interactive=True) if counter >= MAX_ROUND or enable_submit else gr.update(interactive=False)

        counter += 1

        return None, None, None, current_images, history_images, examples_state, prompt_state, next_state, redesign_state, submit_state
    else:
        return {submit_btn: gr.skip()}

def save_response(prompt, sim_radio, like_radio, dislike_radio, like_image, dislike_image):
    global counter, enable_submit, responses_memory

    if check_evaluation(sim_radio):
        # Save the final round entry
        responses_memory[METHOD][counter] = {
            "prompt": prompt,
            "sim_radio": sim_radio,
            "response": "",
            "satisfied_img": f"round {counter}, {like_radio}",
            "unsatisfied_img": f"round {counter}, {dislike_radio}",
        }

        # Reset states
        counter = 1
        enable_submit = False

        # Reset buttons
        prompt_state = gr.update(interactive=False)
        next_state = gr.update(visible=False, interactive=False)
        submit_state = gr.update(interactive=False)
        redesign_state = gr.update(interactive=False)

        display_info_message("βœ… Your answer is saved!")
        return None, None, None, prompt_state, next_state, redesign_state, submit_state
    else:
        return {submit_btn: gr.skip()}

# =========================
# Interface (single tab, no participant/scenario/background)
# =========================

css = """
#col-container {
    margin: 0 auto;
    max-width: 700px;
}
#col-container2 {
    margin: 0 auto;
    max-width: 1000px;
}
#col-container3 {
    margin: 0 0 auto auto;
    max-width: 300px;
}
#button-container {
    display: flex;
    justify-content: center;
}
#compact-row {
    width:100%;
    max-width: 1000px;
    margin: 0px auto;
}
"""

with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("# πŸ“Œ **POET**")
        instruction = gr.Markdown(" Supporting Prompting Creativity and Personalization with Automated Expansion of Text-to-Image Generation")

    with gr.Tab(""):
        with gr.Row(elem_id="compact-row"):
            prompt = gr.Textbox(
                label="🎨 Revise Prompt",
                max_lines=5,
                placeholder="Enter your prompt",
                scale=3,
                visible=True,
            )
            next_btn = gr.Button("Generate", variant="primary", scale=1)

        with gr.Row(elem_id="compact-row"):
            with gr.Column(elem_id="col-container"):
                images_method = gr.Gallery(label="Images", columns=[4], rows=[1], height=400, elem_id="gallery", format="png")

            with gr.Column(elem_id="col-container3"):
                like_image = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload', format="png", type="filepath")
                dislike_image = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload', format="png", type="filepath")

        with gr.Column(elem_id="col-container2"):
            gr.Markdown("### πŸ“ Evaluation")
            sim_radio = gr.Radio(
                OPTIONS,
                label="How would you rate your satisfaction with the generated images?",
                type="value",
                elem_classes=["gradio-radio"]
            )
            like_radio = gr.Radio(
                IMAGE_OPTIONS,
                label="Select your all-time favorite image (optional).",
                type="value",
                elem_classes=["gradio-radio"]
            )
            dislike_radio = gr.Radio(
                IMAGE_OPTIONS,
                label="Select your all-time least satisfactory image (optional).",
                type="value",
                elem_classes=["gradio-radio"]
            )

        with gr.Column(elem_id="col-container2"):
            example = gr.Examples([['']], prompt, label="Revised Prompt History", visible=False)
            history_images = gr.Gallery(label="History Images", columns=[4], rows=[1], elem_id="gallery", format="png")

            with gr.Row(elem_id="button-container"):
                redesign_btn = gr.Button("🎨 Redesign", variant="primary", scale=0)
                submit_btn = gr.Button("βœ… Submit", variant="primary", interactive=False, scale=0)

        with gr.Column(elem_id="col-container2"):
            gr.Markdown("### 🌟 Examples")
            ex1 = gr.Image(label="Image 1", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)
            ex2 = gr.Image(label="Image 2", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)
            ex3 = gr.Image(label="Image 3", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)
            ex4 = gr.Image(label="Image 4", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)

            gr.Examples(
                examples=[[ex[0], ex[1][0], ex[1][1], ex[1][2], ex[1][3]] for ex in example_data],
                inputs=[prompt, ex1, ex2, ex3, ex4]
            )
# =========================
# Wiring
# =========================
    sim_radio.change(fn=check_satisfaction, inputs=[sim_radio], outputs=[submit_btn])

    dislike_radio.select(fn=select_image, inputs=[dislike_radio, images_method], outputs=[dislike_image])
    like_radio.select(fn=select_image, inputs=[like_radio, images_method], outputs=[like_image])

    next_btn.click(
        fn=generate_image,
        inputs=[prompt, like_image, dislike_image],
        outputs=[images_method]
    ).success(lambda: [gr.update(interactive=False), gr.update(interactive=False)], outputs=[next_btn, prompt])

    redesign_btn.click(
        fn=redesign,
        inputs=[prompt, sim_radio, like_radio, dislike_radio, images_method, history_images, like_image, dislike_image],
        outputs=[sim_radio, dislike_radio, like_radio, images_method, history_images, example.dataset, prompt, next_btn, redesign_btn, submit_btn]
    )

    submit_btn.click(
        fn=save_response,
        inputs=[prompt, sim_radio, like_radio, dislike_radio, like_image, dislike_image],
        outputs=[sim_radio, dislike_radio, like_radio, prompt, next_btn, redesign_btn, submit_btn]
    )

if __name__ == "__main__":
    demo.launch()