import gradio as gr from gradio.themes.base import Base import numpy as np import random import spaces import torch import re import open_clip from optim_utils import optimize_prompt from utils import clean_response_gpt, setup_model, init_gpt_api, call_gpt_api, get_refine_msg, clean_cache, get_personalize_message, clean_refined_prompt_response_gpt from utils import SCENARIOS, PROMPTS, IMAGES, OPTIONS, T2I_MODELS, INSTRUCTION, IMAGE_OPTIONS import spaces #[uncomment to use ZeroGPU] import transformers import gspread CLIP_MODEL = "ViT-H-14" PRETRAINED_CLIP = "laion2b_s32b_b79k" default_t2i_model = "black-forest-labs/FLUX.1-dev" # "black-forest-labs/FLUX.1-dev" default_llm_model = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" # "meta-llama/Meta-Llama-3-8B-Instruct" MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1024 NUM_IMAGES=4 device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 clean_cache() selected_pipe = setup_model(default_t2i_model, torch_dtype, device) # clip_model, _, preprocess = open_clip.create_model_and_transforms(CLIP_MODEL, pretrained=PRETRAINED_CLIP, device=device) llm_pipe = None torch.cuda.empty_cache() inverted_prompt = "" VERBAL_MSG = "Please verbally describe why you are satisfied or unsatisfied at the generated images." DEFAULT_SCENARIO = "Product advertisement" METHODS = ["Method 1", "Method 2"] MAX_ROUND = 5 # intermittent memory counter1, counter2 = 1, 1 responses_memory = {} assigned_scenarios = list(SCENARIOS.keys())[:2] current_task1, current_task2 = METHODS # current task 1 (tab 1) task1_success, task2_success = False, False ######################################################################################################## # Generating images with two methods ######################################################################################################## @spaces.GPU(duration=65) def infer( prompt, negative_prompt="", seed=42, randomize_seed=True, width=256, height=256, guidance_scale=5, num_inference_steps=18, progress=gr.Progress(track_tqdm=True), ): if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator().manual_seed(seed) with torch.no_grad(): image = selected_pipe( prompt=prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=width, height=height, generator=generator, ).images[0] return image def call_gpt_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0.7, top_p=0.9): seed = random.randint(0, MAX_SEED) client = init_gpt_api() messages = get_refine_msg(prompt, num_prompts) outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens, temperature, top_p) prompt_list = clean_response_gpt(outputs) return prompt_list @spaces.GPU(duration=100) def invert_prompt(prompt, images, prompt_len=15, iter=1000, lr=0.1, batch_size=2): text_params = { "iter": iter, "lr": lr, "batch_size": batch_size, "prompt_len": prompt_len, "weight_decay": 0.1, "prompt_bs": 1, "loss_weight": 1.0, "print_step": 100, "clip_model": CLIP_MODEL, "clip_pretrain": PRETRAINED_CLIP, } inverted_prompt = optimize_prompt(clip_model, preprocess, text_params, device, target_images=images, target_prompts=prompt) # eval(prompt, learned_prompt, optimized_images, clip_model, preprocess) # return learned_prompt def personalize_prompt(prompt, history, feedback, like_image, dislike_image): seed = random.randint(0, MAX_SEED) client = init_gpt_api() messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image) outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9) # prompt_list = clean_response_gpt(outputs) # print(prompt_list) return outputs ######################################################################################################## # Button-related functions ######################################################################################################## def reset_gallery(): return [] def display_error_message(msg, duration=5): gr.Warning(msg, duration=duration) def display_info_message(msg, duration=5): gr.Info(msg, duration=duration) def switch_tab(active_tab): print("switching tab") if active_tab == "Task A": return gr.Tabs(selected="Task B") else: return gr.Tabs(selected="Task A") def check_satisfaction(sim_radio, active_tab): global counter1, counter2, current_task1, current_task2 method = current_task1 if active_tab == "Task A" else current_task2 counter = counter1 if method == METHODS[0] else counter2 fully_satisfied_option = ["Satisfied", "Very Satisfied"] # The value to trigger submit enable_submit = sim_radio in fully_satisfied_option or counter >= MAX_ROUND return gr.update(interactive=enable_submit), gr.update(interactive=(not enable_submit)) def check_participant(participant): if participant == "": display_error_message("Please fill your participant id!") return False return True def check_evaluation(sim_radio): if not sim_radio : display_error_message("❌ Please fill all evaluations before change image or submit.") return False return True def select_image(like_radio, images_method): if like_radio == IMAGE_OPTIONS[0]: return images_method[0][0] elif like_radio == IMAGE_OPTIONS[1]: return images_method[1][0] elif like_radio == IMAGE_OPTIONS[2]: return images_method[2][0] elif like_radio == IMAGE_OPTIONS[3]: return images_method[3][0] else: return None def set_user(participant): global responses_memory, assigned_scenarios responses_memory[participant] = {METHODS[0]:{}, METHODS[1]:{}} id = re.findall(r'\d+', participant) if len(id) == 0 or int(id[0]) % 2 == 0: # name invalid, assign first half scenarios assigned_scenarios = list(SCENARIOS.keys())[:2] else: assigned_scenarios = list(SCENARIOS.keys())[2:] return assigned_scenarios[0] def display_scenario(participant, choice): # reset intermittent storage when scenario change global counter1, counter2, responses_memory, current_task1, current_task2, task1_success, task2_success task1_success, task2_success = False, False counter1, counter2 = 1, 1 if check_participant(participant): responses_memory[participant] = {METHODS[0]:{}, METHODS[1]:{}} [current_task1, current_task2] = random.sample(METHODS, 2) if current_task1 == METHODS[0]: initial_images1 = IMAGES[choice]["baseline"] initial_images2 = IMAGES[choice]["ours"] else: initial_images1 = IMAGES[choice]["ours"] initial_images2 = IMAGES[choice]["baseline"] res = { scenario_content: SCENARIOS.get(choice, ""), prompt1: gr.update(value=PROMPTS.get(choice, ""), interactive=False), prompt2: gr.update(value=PROMPTS.get(choice, ""), interactive=False), images_method1: initial_images1, images_method2: initial_images2, like_image1: None, dislike_image1: None, like_image2: None, dislike_image2: None, history_images1: [], history_images2: [], next_btn1: gr.update(interactive=False), next_btn2: gr.update(interactive=False), redesign_btn1: gr.update(interactive=True), redesign_btn2: gr.update(interactive=True), submit_btn1: gr.update(interactive=False), submit_btn2: gr.update(interactive=False), } return res def generate_image(participant, scenario, prompt, active_tab, like_image, dislike_image): if not check_participant(participant): return [], [] global current_task1, current_task2 method = current_task1 if active_tab == "Task A" else current_task2 history_prompts = [v["prompt"] for v in responses_memory[participant][method].values()] feedback = [v["sim_radio"] for v in responses_memory[participant][method].values()] personalized_prompt = personalize_prompt(prompt, history_prompts, feedback, like_image, dislike_image) personalized_prompt = clean_refined_prompt_response_gpt(personalized_prompt) print(f"Personalized prompt: {personalized_prompt}, {type(personalized_prompt)}") if "I'm sorry, I can't assist with" in personalized_prompt: print("error in gpt...") personalized_prompt = prompt gallery_images = [] if method == METHODS[0]: for i in range(NUM_IMAGES): img = infer(personalized_prompt) gallery_images.append(img) yield gallery_images else: refined_prompts = call_gpt_refine_prompt(personalized_prompt) for i in range(NUM_IMAGES): img = infer(refined_prompts[i]) gallery_images.append(img) yield gallery_images def redesign(participant, scenario, prompt, sim_radio, current_images, history_images, active_tab): global counter1, counter2, responses_memory, current_task1, current_task2 method = current_task1 if active_tab == "Task A" else current_task2 if check_evaluation(sim_radio) and check_participant(participant): if method == METHODS[0]: counter1 += 1 counter = counter1 else: counter2 += 1 counter = counter2 responses_memory[participant][method][counter-1] = {} responses_memory[participant][method][counter-1]["prompt"] = prompt responses_memory[participant][method][counter-1]["sim_radio"] = sim_radio # responses_memory[participant][method][counter-1]["response"] = response history_prompts = [[v["prompt"]] for v in responses_memory[participant][method].values()] if not history_images: history_images = current_images elif current_images: history_images.extend(current_images) current_images = [] examples_state = gr.update(samples=history_prompts, visible=True) prompt_state = gr.update(interactive=True) next_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(visible=True, interactive=True) redesign_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(interactive=True) submit_state = gr.update(interactive=True) if counter >= MAX_ROUND else gr.update(interactive=False) return None, None, None, current_images, history_images, examples_state, prompt_state, next_state, redesign_state, submit_state else: return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()} def save_response(participant, scenario, prompt, sim_radio, active_tab): global current_task1, current_task2 # not change global task1_success, task2_success, counter1, counter2, responses_memory, assigned_scenarios # will change method = current_task1 if active_tab == "Task A" else current_task2 if check_evaluation(sim_radio) and check_participant(participant): counter = counter1 if method == METHODS[0] else counter2 responses_memory[participant][method][counter] = {} responses_memory[participant][method][counter]["prompt"] = prompt responses_memory[participant][method][counter]["sim_radio"] = sim_radio # responses_memory[participant][method][counter]["response"] = response try: gc = gspread.service_account(filename='credentials.json') sheet = gc.open("DiverseGen-phase3").sheet1 for i, entry in responses_memory[participant][method].items(): sheet.append_row([participant, scenario, method, i, entry["prompt"], entry["sim_radio"]]) display_info_message("✅ Your answer is saved!") # reset global variables if method == METHODS[0]: counter1 = 1 else: counter2 = 1 if active_tab == "Task A": task1_success = True else: task2_success = True # decide if change scenario next_scenario = assigned_scenarios[1] if task1_success and task2_success else assigned_scenarios[0] # update buttons example_state = gr.update(samples=[], visible=False) prompt_state = gr.update(interactive=False) next_state = gr.update(visible=False, interactive=False) submit_state = gr.update(interactive=False) redesign_state = gr.update(interactive=False) tabs = switch_tab(active_tab) return None, None, None, None, None, [], [], example_state, prompt_state, next_state, redesign_state, submit_state, next_scenario, tabs except Exception as e: display_error_message(f"❌ Error saving response: {str(e)}") return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()} else: return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()} ######################################################################################################## # Interface ######################################################################################################## css=""" #col-container { margin: 0 auto; max-width: 700px; } #col-container2 { margin: 0 auto; max-width: 1000px; } #col-container3 { margin: 0 0 auto auto; max-width: 300px; } #button-container { display: flex; justify-content: center; /* Centers the buttons horizontally */ } #compact-row { width:100%; max-width: 1000px; margin: 0px auto; } """ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(" # 📌 **Diverse Text-to-Image Generation**") with gr.Row(): participant = gr.Textbox( label="🧑‍💼 Participant ID", placeholder="Please enter you participant id" ) scenario = gr.Dropdown( choices=list(SCENARIOS.keys()), value=None, label="📌 Scenario", interactive=False, ) scenario_content = gr.Textbox( label="📖 Background", interactive=False, ) active_tab = gr.State("Task A") instruction = gr.Markdown(INSTRUCTION) with gr.Tabs() as tabs: with gr.TabItem("Task A", id="Task A") as task1_tab: task1_tab.select(lambda: "Task A", outputs=[active_tab]) with gr.Row(elem_id="compact-row"): prompt1 = gr.Textbox( label="🎨 Revise Prompt", max_lines=5, placeholder="Enter your prompt", scale=4, visible=True, ) next_btn1 = gr.Button("Generate", variant="primary", scale=1, interactive=False, visible=False) with gr.Row(elem_id="compact-row"): example1 = gr.Examples([['']], prompt1, label="Revised Prompt History", visible=False) with gr.Row(elem_id="compact-row"): with gr.Column(elem_id="col-container"): images_method1 = gr.Gallery(label="Images", columns=[4], rows=[1], height=200, elem_id="gallery") history_images1 = gr.Gallery(label="History Images", columns=[4], rows=[1], elem_id="gallery") with gr.Column(elem_id="col-container3"): like_image1 = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload', type="pil") dislike_image1 = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload', type="pil") with gr.Column(elem_id="col-container2"): gr.Markdown("### 📝 Evaluation") sim_radio1 = gr.Radio( OPTIONS, label="How would you rate your satisfaction with the generated images, based on your expectations for the specified scenario?", type="value", elem_classes=["gradio-radio"] ) like_radio1 = gr.Radio( IMAGE_OPTIONS, label="Select the image that you find MOST satisfactory. You may leave this section blank if you prefer the previous images.", type="value", elem_classes=["gradio-radio"] ) dislike_radio1 = gr.Radio( IMAGE_OPTIONS, label="Please choose the image that you find LEAST satisfactory. You may leave this section blank if you are more dislike previous images.", type="value", elem_classes=["gradio-radio"] ) response1 = gr.Textbox( label="Verbally describe key differences found in the image pair.", max_lines=1, interactive=False, container=False, value=VERBAL_MSG ) with gr.Row(elem_id="button-container"): redesign_btn1 = gr.Button("🎨 Redesign", variant="primary", scale=0) submit_btn1 = gr.Button("✅ Submit", variant="primary", interactive=False, scale=0) with gr.TabItem("Task B", id="Task B") as task2_tab: task2_tab.select(lambda: "Task B", outputs=[active_tab]) with gr.Row(elem_id="compact-row"): prompt2 = gr.Textbox( label="🎨 Revise Prompt", max_lines=5, placeholder="Enter your prompt", scale=4, visible=True, ) next_btn2 = gr.Button("Generate", variant="primary", scale=1, interactive=False, visible=False) with gr.Row(elem_id="compact-row"): example2 = gr.Examples([['']], prompt2, label="Revised Prompt History", visible=False) with gr.Row(elem_id="compact-row"): with gr.Column(elem_id="col-container"): images_method2 = gr.Gallery(label="Images", columns=[4], rows=[1], height=200, elem_id="gallery") history_images2 = gr.Gallery(label="History Images", columns=[4], rows=[1], elem_id="gallery") with gr.Column(elem_id="col-container3"): like_image2 = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload', type="pil") dislike_image2 = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload', type="pil") with gr.Column(elem_id="col-container2"): gr.Markdown("### 📝 Evaluation") sim_radio2 = gr.Radio( OPTIONS, label="How would you rate your satisfaction with the generated images, based on your expectations for the specified scenario?", type="value", elem_classes=["gradio-radio"] ) like_radio2 = gr.Radio( IMAGE_OPTIONS, label="Select the image that you find MOST satisfactory. You may leave this section blank if you prefer the previous images.", type="value", elem_classes=["gradio-radio"] ) dislike_radio2 = gr.Radio( IMAGE_OPTIONS, label="Please choose the image that you find LEAST satisfactory. You may leave this section blank if you are more dislike previous images.", type="value", elem_classes=["gradio-radio"] ) response2 = gr.Textbox( label="Verbally describe key differences found in the image pair.", max_lines=1, interactive=False, container=False, value=VERBAL_MSG ) with gr.Row(elem_id="button-container"): redesign_btn2 = gr.Button("🎨 Redesign", variant="primary", scale=0) submit_btn2 = gr.Button("✅ Submit", variant="primary", interactive=False, scale=0) ######################################################################################################## # Button Function Setup ######################################################################################################## participant.change(fn=set_user, inputs=[participant], outputs=[scenario]) scenario.change(display_scenario, inputs=[participant, scenario], outputs=[scenario_content, prompt1, prompt2, images_method1, images_method2, like_image1, dislike_image1, like_image2, dislike_image2, history_images1, history_images2, next_btn1, next_btn2, redesign_btn1, redesign_btn2, submit_btn1, submit_btn2]) # prompt1.change(fn=reset_gallery, inputs=[], outputs=[gallery_state1]) # prompt2.change(fn=reset_gallery, inputs=[], outputs=[gallery_state2]) next_btn1.click(fn=generate_image, inputs=[participant, scenario, prompt1, active_tab, like_image1, dislike_image1], outputs=[images_method1]) next_btn2.click(fn=generate_image, inputs=[participant, scenario, prompt2, active_tab, like_image2, dislike_image2], outputs=[images_method2]) sim_radio1.change(fn=check_satisfaction, inputs=[sim_radio1, active_tab], outputs=[submit_btn1, redesign_btn1]) sim_radio2.change(fn=check_satisfaction, inputs=[sim_radio2, active_tab], outputs=[submit_btn2, redesign_btn2]) dislike_radio1.select(fn=select_image, inputs=[dislike_radio1, images_method1], outputs=[dislike_image1]) like_radio1.select(fn=select_image, inputs=[like_radio1, images_method1], outputs=[like_image1]) dislike_radio2.select(fn=select_image, inputs=[dislike_radio2, images_method2], outputs=[dislike_image2]) like_radio2.select(fn=select_image, inputs=[like_radio2, images_method2], outputs=[like_image2]) redesign_btn1.click( fn=redesign, inputs=[participant, scenario, prompt1, sim_radio1, images_method1, history_images1, active_tab], outputs=[sim_radio1, dislike_radio1, like_radio1, images_method1, history_images1, example1.dataset, prompt1, next_btn1, redesign_btn1, submit_btn1] ) redesign_btn2.click( fn=redesign, inputs=[participant, scenario, prompt2, sim_radio2, images_method2, history_images2, active_tab], outputs=[sim_radio2, dislike_radio2, like_radio2, images_method2, history_images2, example2.dataset, prompt2, next_btn2, redesign_btn2, submit_btn2] ) submit_btn1.click(fn=save_response, inputs=[participant, scenario, prompt1, sim_radio1, active_tab], outputs=[sim_radio1, dislike_radio1, like_radio1, like_image1, dislike_image1, images_method1, history_images1, example1.dataset, prompt1, next_btn1, redesign_btn1, submit_btn1, scenario, tabs]) submit_btn2.click(fn=save_response, inputs=[participant, scenario, prompt2, sim_radio2, active_tab], outputs=[sim_radio2, dislike_radio2, like_radio2, like_image2, dislike_image2, images_method2, history_images2, example2.dataset, prompt2, next_btn2, redesign_btn2, submit_btn2, scenario, tabs]) if __name__ == "__main__": demo.launch()