import gradio as gr import numpy as np import random import spaces import torch import re import transformers # Optional: keep these utilities if your pipeline depends on them from optim_utils import optimize_prompt from utils import ( clean_response_gpt, setup_model, init_gpt_api, call_gpt_api, get_refine_msg, clean_cache, get_personalize_message, clean_refined_prompt_response_gpt, IMAGES, OPTIONS, T2I_MODELS, INSTRUCTION, IMAGE_OPTIONS, PROMPTS, SCENARIOS # some may be unused after simplification ) # ========================= # Constants / Defaults # ========================= CLIP_MODEL = "ViT-H-14" PRETRAINED_CLIP = "laion2b_s32b_b79k" default_t2i_model = "black-forest-labs/FLUX.1-schnell" # "black-forest-labs/FLUX.1-dev" default_llm_model = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1024 NUM_IMAGES = 4 MAX_ROUND = 5 device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 clean_cache() selected_pipe = setup_model(default_t2i_model, torch_dtype, device) llm_pipe = None torch.cuda.empty_cache() inverted_prompt = "" METHOD = "Experimental" # keep ONLY experimental # Global states for a single-task, single-method flow counter = 1 enable_submit = False responses_memory = {METHOD: {}} example_data = [ [ "A futuristic city skyline at sunset", IMAGES["Tourist promotion"]["ours"] ], [ "A fantasy castle in the clouds", IMAGES["Fictional character generation"]["ours"] ], [ "A robot painting a portrait in a studio", IMAGES["Interior Design"]["ours"] ], ] print(example_data) # ========================= # Image Generation Helpers # ========================= @spaces.GPU(duration=65) def infer( prompt, negative_prompt="", seed=42, randomize_seed=True, width=256, height=256, guidance_scale=5, num_inference_steps=18, progress=gr.Progress(track_tqdm=True), ): if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator().manual_seed(seed) with torch.no_grad(): image = selected_pipe( prompt=prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=width, height=height, generator=generator, ).images[0] return image def call_gpt_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0.7, top_p=0.9): seed = random.randint(0, MAX_SEED) client = init_gpt_api() messages = get_refine_msg(prompt, num_prompts) outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens, temperature, top_p) prompt_list = clean_response_gpt(outputs) return prompt_list def personalize_prompt(prompt, history, feedback, like_image, dislike_image): seed = random.randint(0, MAX_SEED) client = init_gpt_api() print(like_image, dislike_image) messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image) outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9) return outputs # ========================= # UI Helper Functions # ========================= def reset_gallery(): return [] def display_error_message(msg, duration=5): gr.Warning(msg, duration=duration) def display_info_message(msg, duration=5): gr.Info(msg, duration=duration) def check_satisfaction(sim_radio): global enable_submit, counter fully_satisfied_option = ["Satisfied", "Very Satisfied"] if_submit = (sim_radio in fully_satisfied_option) or enable_submit or (counter > MAX_ROUND) return gr.update(interactive=if_submit) def select_image(like_radio, images_method): if like_radio == IMAGE_OPTIONS[0]: return images_method[0][0] elif like_radio == IMAGE_OPTIONS[1]: return images_method[1][0] elif like_radio == IMAGE_OPTIONS[2]: return images_method[2][0] elif like_radio == IMAGE_OPTIONS[3]: return images_method[3][0] else: return None def check_evaluation(sim_radio): if not sim_radio: display_error_message("❌ Please fill all evaluations before changing image or submitting.") return False return True def generate_image(prompt, like_image, dislike_image): global responses_memory history_prompts = [v["prompt"] for v in responses_memory[METHOD].values()] feedback = [v["sim_radio"] for v in responses_memory[METHOD].values()] personalized = prompt # personalized = personalize_prompt(prompt, history_prompts, feedback, like_image, dislike_image) # personalized = clean_refined_prompt_response_gpt(personalized) # if "I'm sorry, I can't assist with" in personalized: # personalized = prompt gallery_images = [] refined_prompts = call_gpt_refine_prompt(personalized) for i in range(NUM_IMAGES): img = infer(refined_prompts[i]) gallery_images.append(img) yield gallery_images def redesign(prompt, sim_radio, like_radio, dislike_radio, current_images, history_images, like_image, dislike_image): global counter, enable_submit, responses_memory if check_evaluation(sim_radio): responses_memory[METHOD][counter] = { "prompt": prompt, "sim_radio": sim_radio, "response": "", "satisfied_img": f"round {counter}, {like_radio}", "unsatisfied_img": f"round {counter}, {dislike_radio}", } enable_submit = True if sim_radio in ["Satisfied", "Very Satisfied"] or enable_submit else False history_prompts = [[v["prompt"]] for v in responses_memory[METHOD].values()] if not history_images: history_images = current_images elif current_images: history_images.extend(current_images) current_images = [] examples_state = gr.update(samples=history_prompts, visible=True) prompt_state = gr.update(interactive=True) next_state = gr.update(visible=True, interactive=True) redesign_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(interactive=True) submit_state = gr.update(interactive=True) if counter >= MAX_ROUND or enable_submit else gr.update(interactive=False) counter += 1 return None, None, None, current_images, history_images, examples_state, prompt_state, next_state, redesign_state, submit_state else: return {submit_btn: gr.skip()} def save_response(prompt, sim_radio, like_radio, dislike_radio, like_image, dislike_image): global counter, enable_submit, responses_memory if check_evaluation(sim_radio): # Save the final round entry responses_memory[METHOD][counter] = { "prompt": prompt, "sim_radio": sim_radio, "response": "", "satisfied_img": f"round {counter}, {like_radio}", "unsatisfied_img": f"round {counter}, {dislike_radio}", } # Reset states counter = 1 enable_submit = False # Reset buttons prompt_state = gr.update(interactive=False) next_state = gr.update(visible=False, interactive=False) submit_state = gr.update(interactive=False) redesign_state = gr.update(interactive=False) display_info_message("✅ Your answer is saved!") return None, None, None, prompt_state, next_state, redesign_state, submit_state else: return {submit_btn: gr.skip()} # ========================= # Interface (single tab, no participant/scenario/background) # ========================= css = """ #col-container { margin: 0 auto; max-width: 700px; } #col-container2 { margin: 0 auto; max-width: 1000px; } #col-container3 { margin: 0 0 auto auto; max-width: 300px; } #button-container { display: flex; justify-content: center; } #compact-row { width:100%; max-width: 1000px; margin: 0px auto; } """ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown("# 📌 **POET**") instruction = gr.Markdown(" Supporting Prompting Creativity and Personalization with Automated Expansion of Text-to-Image Generation") with gr.Tab(""): with gr.Row(elem_id="compact-row"): prompt = gr.Textbox( label="🎨 Revise Prompt", max_lines=5, placeholder="Enter your prompt", scale=3, visible=True, ) next_btn = gr.Button("Generate", variant="primary", scale=1) with gr.Row(elem_id="compact-row"): with gr.Column(elem_id="col-container"): images_method = gr.Gallery(label="Images", columns=[4], rows=[1], height=400, elem_id="gallery", format="png") with gr.Column(elem_id="col-container3"): like_image = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload', format="png", type="filepath") dislike_image = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload', format="png", type="filepath") with gr.Column(elem_id="col-container2"): gr.Markdown("### 📝 Evaluation") sim_radio = gr.Radio( OPTIONS, label="How would you rate your satisfaction with the generated images?", type="value", elem_classes=["gradio-radio"] ) like_radio = gr.Radio( IMAGE_OPTIONS, label="Select your all-time favorite image (optional).", type="value", elem_classes=["gradio-radio"] ) dislike_radio = gr.Radio( IMAGE_OPTIONS, label="Select your all-time least satisfactory image (optional).", type="value", elem_classes=["gradio-radio"] ) with gr.Column(elem_id="col-container2"): example = gr.Examples([['']], prompt, label="Revised Prompt History", visible=False) history_images = gr.Gallery(label="History Images", columns=[4], rows=[1], elem_id="gallery", format="png") with gr.Row(elem_id="button-container"): redesign_btn = gr.Button("🎨 Redesign", variant="primary", scale=0) submit_btn = gr.Button("✅ Submit", variant="primary", interactive=False, scale=0) with gr.Column(elem_id="col-container2"): gr.Markdown("### 🌟 Examples") ex1 = gr.Image(label="Image 1", width=200, height=200, sources='upload', format="png", type="filepath", visible=False) ex2 = gr.Image(label="Image 2", width=200, height=200, sources='upload', format="png", type="filepath", visible=False) ex3 = gr.Image(label="Image 3", width=200, height=200, sources='upload', format="png", type="filepath", visible=False) ex4 = gr.Image(label="Image 4", width=200, height=200, sources='upload', format="png", type="filepath", visible=False) gr.Examples( examples=[[ex[0], ex[1][0], ex[1][1], ex[1][2], ex[1][3]] for ex in example_data], inputs=[prompt, ex1, ex2, ex3, ex4] ) # ========================= # Wiring # ========================= sim_radio.change(fn=check_satisfaction, inputs=[sim_radio], outputs=[submit_btn]) dislike_radio.select(fn=select_image, inputs=[dislike_radio, images_method], outputs=[dislike_image]) like_radio.select(fn=select_image, inputs=[like_radio, images_method], outputs=[like_image]) next_btn.click( fn=generate_image, inputs=[prompt, like_image, dislike_image], outputs=[images_method] ).success(lambda: [gr.update(interactive=False), gr.update(interactive=False)], outputs=[next_btn, prompt]) redesign_btn.click( fn=redesign, inputs=[prompt, sim_radio, like_radio, dislike_radio, images_method, history_images, like_image, dislike_image], outputs=[sim_radio, dislike_radio, like_radio, images_method, history_images, example.dataset, prompt, next_btn, redesign_btn, submit_btn] ) submit_btn.click( fn=save_response, inputs=[prompt, sim_radio, like_radio, dislike_radio, like_image, dislike_image], outputs=[sim_radio, dislike_radio, like_radio, prompt, next_btn, redesign_btn, submit_btn] ) if __name__ == "__main__": demo.launch()