import gradio as gr from gradio.themes.base import Base import numpy as np import random import spaces import torch import re import open_clip from optim_utils import optimize_prompt from utils import clean_response_gpt, setup_model, init_gpt_api, call_gpt_api, get_refine_msg, clean_cache from utils import SCENARIOS, PROMPTS, IMAGES, OPTIONS, T2I_MODELS, INSTRUCTION, IMAGE_OPTIONS import spaces #[uncomment to use ZeroGPU] import transformers import gspread import asyncio from datetime import datetime CLIP_MODEL = "ViT-H-14" PRETRAINED_CLIP = "laion2b_s32b_b79k" default_t2i_model = "black-forest-labs/FLUX.1-dev" # "black-forest-labs/FLUX.1-dev" default_llm_model = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" # "meta-llama/Meta-Llama-3-8B-Instruct" MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1024 NUM_IMAGES=4 device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 clean_cache() selected_pipe = setup_model(default_t2i_model, torch_dtype, device) # clip_model, _, preprocess = open_clip.create_model_and_transforms(CLIP_MODEL, pretrained=PRETRAINED_CLIP, device=device) llm_pipe = None torch.cuda.empty_cache() inverted_prompt = "" VERBAL_MSG = "Please verbally describe key differences found in the image pair." DEFAULT_SCENARIO = "Product advertisement" METHODS = ["Method 1", "Method 2"] MAX_ROUND = 5 # intermittent memory counter1, counter2 = 1, 1 responses_memory = {} assigned_scenarios = list(SCENARIOS.keys())[:2] current_task1, current_task2 = METHODS # current task 1 (tab 1) task1_success, task2_success = False, False ######################################################################################################## # Generating images with two methods ######################################################################################################## @spaces.GPU(duration=65) def infer( prompt, negative_prompt="", seed=42, randomize_seed=True, width=256, height=256, guidance_scale=5, num_inference_steps=18, progress=gr.Progress(track_tqdm=True), ): if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator().manual_seed(seed) with torch.no_grad(): image = selected_pipe( prompt=prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=width, height=height, generator=generator, ).images[0] return image async def infer_async(prompt): return infer(prompt) # generate a batch of images in parallel async def generate_batch(prompts): tasks = [infer_async(p) for p in prompts] images = await asyncio.gather(*tasks) # Run all in parallel return images @spaces.GPU def call_llm_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0.7, top_p=0.9): print(f"loading {default_llm_model}") global llm_pipe if not llm_pipe: llm_pipe = transformers.pipeline("text-generation", model=default_llm_model, model_kwargs={"torch_dtype": torch_dtype}, device_map="auto") messages = get_refine_msg(prmpt, num_prompts) terminators = [ llm_pipe.tokenizer.eos_token_id, llm_pipe.tokenizer.convert_tokens_to_ids("<|eot_id|>") ] outputs = llm_pipe( messages, max_new_tokens=max_tokens, eos_token_id=terminators, do_sample=True, temperature=temperature, top_p=top_p, ) prompt_list = clean_response_gpt(outputs[0]["generated_text"][-1]["content"]) return prompt_list def call_gpt_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0.7, top_p=0.9): seed = random.randint(0, MAX_SEED) client = init_gpt_api() messages = get_refine_msg(prompt, num_prompts) outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens, temperature, top_p) prompt_list = clean_response_gpt(outputs) return prompt_list def refine_prompt(gallery_state, prompt): modified_prompts = call_gpt_refine_prompt(prompt) return modified_prompts # eval(prompt, inverted_prompt, gallery_state, clip_model, preprocess) @spaces.GPU(duration=100) def invert_prompt(prompt, images, prompt_len=15, iter=1000, lr=0.1, batch_size=2): text_params = { "iter": iter, "lr": lr, "batch_size": batch_size, "prompt_len": prompt_len, "weight_decay": 0.1, "prompt_bs": 1, "loss_weight": 1.0, "print_step": 100, "clip_model": CLIP_MODEL, "clip_pretrain": PRETRAINED_CLIP, } inverted_prompt = optimize_prompt(clip_model, preprocess, text_params, device, target_images=images, target_prompts=prompt) # eval(prompt, learned_prompt, optimized_images, clip_model, preprocess) # return learned_prompt def eval(prompt, optimized_prompt, optimized_images, clip_model, preprocess): torch.cuda.empty_cache() tokenizer = open_clip.get_tokenizer(CLIP_MODEL) images = [preprocess(i).unsqueeze(0) for i in optimized_images] images = torch.concatenate(images).to(device) with torch.no_grad(): image_feat = clip_model.encode_image(images) text_feat = clip_model.encode_text(tokenizer([prompt]).to(device)) optimized_text_feat = clip_model.encode_text(tokenizer([optimized_prompt]).to(device)) image_feat /= image_feat.norm(dim=-1, keepdim=True) text_feat /= text_feat.norm(dim=-1, keepdim=True) optimized_text_feat /= optimized_text_feat.norm(dim=-1, keepdim=True) similarity = text_feat.cpu().numpy() @ image_feat.cpu().numpy().T similarity_optimized = optimized_text_feat.cpu().numpy() @ image_feat.cpu().numpy().T ######################################################################################################## # Button-related functions ######################################################################################################## def reset_gallery(): return [] def display_error_message(msg, duration=5): gr.Warning(msg, duration=duration) def display_info_message(msg, duration=5): gr.Info(msg, duration=duration) def switch_tab(active_tab): print("switching tab") if active_tab == "Task A": return gr.Tabs(selected="Task B") else: return gr.Tabs(selected="Task A") def set_user(participant): global responses_memory responses_memory[participant] = {METHODS[0]:{}, METHODS[1]:{}} id = re.findall(r'\d+', participant) if len(id) == 0 or int(id[0]) % 2 == 0: # name invalid, assign first half scenarios assigned_scenarios = list(SCENARIOS.keys())[:2] else: assigned_scenarios = list(SCENARIOS.keys())[2:] return assigned_scenarios[0] def display_scenario(participant, choice): # reset intermittent storage when scenario change global counter1, counter2, responses_memory, current_task1, current_task2, task1_success, task2_success task1_success, task2_success = False, False counter1, counter2 = 1, 1 if check_participant(participant): responses_memory[participant] = {METHODS[0]:{}, METHODS[1]:{}} [current_task1, current_task2] = random.sample(METHODS, 2) if current_task1 == METHODS[0]: initial_images1 = IMAGES[choice]["baseline"] initial_images2 = IMAGES[choice]["ours"] else: initial_images1 = IMAGES[choice]["ours"] initial_images2 = IMAGES[choice]["baseline"] res = { scenario_content: SCENARIOS.get(choice, ""), prompt: PROMPTS.get(choice, ""), prompt1: "", prompt2: "", images_method1: initial_images1, images_method2: initial_images2, gallery_state1: initial_images1, gallery_state2: initial_images2, sim_radio1: None, sim_radio2: None, response1: VERBAL_MSG, response2: VERBAL_MSG, next_btn1: gr.update(interactive=False), next_btn2: gr.update(interactive=False), redesign_btn1: gr.update(interactive=True), redesign_btn2: gr.update(interactive=True), submit_btn1: gr.update(interactive=False), submit_btn2: gr.update(interactive=False), } return res def generate_image(participant, scenario, prompt, gallery_state, active_tab): if not check_participant(participant): return [], [] global current_task1, current_task2 method = current_task1 if active_tab == "Task A" else current_task2 if method == METHODS[0]: for i in range(NUM_IMAGES): img = infer(prompt) gallery_state.append(img) yield gallery_state else: refined_prompts = refine_prompt(gallery_state, prompt) for i in range(NUM_IMAGES): img = infer(refined_prompts[i]) gallery_state.append(img) yield gallery_state def check_satisfaction(sim_radio, active_tab): global counter1, counter2, current_task1, current_task2 method = current_task1 if active_tab == "Task A" else current_task2 counter = counter1 if method == METHODS[0] else counter2 fully_satisfied_option = ["Satisfied", "Very Satisfied"] # The value to trigger submit enable_submit = sim_radio in fully_satisfied_option or counter >= MAX_ROUND return gr.update(interactive=enable_submit), gr.update(interactive=(not enable_submit)) def check_participant(participant): if participant == "": display_error_message("Please fill your participant id!") return False return True def check_evaluation(sim_radio, response): if not sim_radio : display_error_message("❌ Please fill all evaluations before change image or submit.") return False return True def select_dislike(like_radio, images_method): if like_radio == IMAGE_OPTIONS[0]: return images_method[0] elif like_radio == IMAGE_OPTIONS[1]: return images_method[1] elif like_radio == IMAGE_OPTIONS[2]: return images_method[2] elif like_radio == IMAGE_OPTIONS[3]: return images_method[3] else: return None def redesign(participant, scenario, prompt, sim_radio, response, images_method, active_tab): global counter1, counter2, responses_memory, current_task1, current_task2 method = current_task1 if active_tab == "Task A" else current_task2 if check_evaluation(sim_radio, response) and check_participant(participant): if method == METHODS[0]: counter1 += 1 counter = counter1 else: counter2 += 1 counter = counter2 responses_memory[participant][method][counter-1] = {} responses_memory[participant][method][counter-1]["prompt"] = prompt responses_memory[participant][method][counter-1]["sim_radio"] = sim_radio responses_memory[participant][method][counter-1]["response"] = response prompt_state = gr.update(visible=True) next_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(visible=True, interactive=True) redesign_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(interactive=True) submit_state = gr.update(interactive=True) if counter >= MAX_ROUND else gr.update(interactive=False) return [], None, VERBAL_MSG, prompt_state, next_state, redesign_state, submit_state else: return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()} def show_message(selected_option): if selected_option: return "Click \"Redesign\" and revise your prompt to create images that may more closely match your expectations." return "" def save_response(participant, scenario, prompt, sim_radio, response, images_method, active_tab): global current_task1, current_task2, counter1, counter2, responses_memory, task1_success, task2_success, assigned_scenarios method = current_task1 if active_tab == "Task A" else current_task2 if check_evaluation(sim_radio, response) and check_participant(participant): counter = counter1 if method == METHODS[0] else counter2 # image_paths = [save_image(img, "method", i) for i, img in enumerate(images_method)] responses_memory[participant][method][counter] = {} responses_memory[participant][method][counter]["prompt"] = prompt responses_memory[participant][method][counter]["sim_radio"] = sim_radio responses_memory[participant][method][counter]["response"] = response prompt_state = gr.update(visible=False) next_state = gr.update(visible=False, interactive=False) submit_state = gr.update(interactive=False) redesign_state = gr.update(interactive=False) try: gc = gspread.service_account(filename='credentials.json') sheet = gc.open("DiverseGen-phase3").sheet1 for i, entry in responses_memory[participant][method].items(): sheet.append_row([participant, scenario, method, i, entry["prompt"], entry["sim_radio"], entry["response"]]) display_info_message("✅ Your answer is saved!") # reset counter and update success indicator if method == METHODS[0]: counter1 = 1 else: counter2 = 1 if active_tab == "Task A": task1_success = True else: task2_success = True tabs = switch_tab(active_tab) next_scenario = assigned_scenarios[1] if task1_success and task2_success else assigned_scenarios[0] return [], [], None, VERBAL_MSG, prompt_state, next_state, redesign_state, submit_state, tabs, next_scenario except Exception as e: display_error_message(f"❌ Error saving response: {str(e)}") return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()} else: return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()} ######################################################################################################## # Interface ######################################################################################################## css=""" #col-container { margin: 0 auto; max-width: 700px; } #col-container2 { margin: 0 auto; max-width: 1000px; } #col-container3 { margin: 0 auto; max-width: 300px; } #button-container { display: flex; justify-content: center; /* Centers the buttons horizontally */ } #compact-row { width:100%; max-width: 1000px; margin: 0px auto; } """ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(" # 📌 **Diverse Text-to-Image Generation**") with gr.Row(): participant = gr.Textbox( label="🧑‍💼 Participant ID", placeholder="Please enter you participant id" ) scenario = gr.Dropdown( choices=list(SCENARIOS.keys()), # value=DEFAULT_SCENARIO, value=None, label="📌 Scenario", interactive=False, ) scenario_content = gr.Textbox( label="📖 Background", interactive=False, # value=SCENARIOS[DEFAULT_SCENARIO] ) prompt = gr.Textbox( label="🎨 Prompt", max_lines=1, # value=PROMPTS[DEFAULT_SCENARIO], interactive=False ) active_tab = gr.State("Task A") instruction = gr.Markdown(INSTRUCTION) with gr.Tabs() as tabs: with gr.TabItem("Task A", id="Task A") as task1_tab: task1_tab.select(lambda: "Task A", outputs=[active_tab]) with gr.Column(elem_id="col-container"): # gr.Markdown("### Step 2: This is the prompt to generate images, you may modify the prompt after first round evaluation") with gr.Row(): prompt1 = gr.Textbox( label="🎨 Revise Prompt", max_lines=1, placeholder="Enter your prompt", # value=PROMPTS[DEFAULT_SCENARIO], scale=4, visible=False ) next_btn1 = gr.Button("Generate", variant="primary", scale=1, interactive=False, visible=False) with gr.Row(elem_id="compact-row"): with gr.Column(elem_id="col-container"): gallery_state1 = gr.State([]) images_method1 = gr.Gallery(show_label=False, columns=[4], rows=[1], height=420, elem_id="gallery") with gr.Column(elem_id="col-container3"): like_image1 = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload') dislike_image1 = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload') with gr.Column(elem_id="col-container2"): gr.Markdown("### 📝 Evaluation") sim_radio1 = gr.Radio( OPTIONS, label="How would you rate your satisfaction with the generated images, based on your expectations for the specified scenario?", type="value", elem_classes=["gradio-radio"] ) like_radio1 = gr.Radio( IMAGE_OPTIONS, label="Select the image you are most satisfied.", type="value", elem_classes=["gradio-radio"] ) dislike_radio1 = gr.Radio( IMAGE_OPTIONS, label="Select the image you are most unsatisfied.", type="value", elem_classes=["gradio-radio"] ) response1 = gr.Textbox( label="Verbally describe key differences found in the image pair.", max_lines=1, interactive=False, container=False, value=VERBAL_MSG ) with gr.Row(elem_id="button-container"): redesign_btn1 = gr.Button("🎨 Redesign", variant="primary", scale=0) submit_btn1 = gr.Button("✅ Submit", variant="primary", interactive=False, scale=0) with gr.TabItem("Task B", id="Task B") as task2_tab: task2_tab.select(lambda: "Task B", outputs=[active_tab]) with gr.Column(elem_id="col-container"): # gr.Markdown("### Step 2: This is the prompt to generate images, you may modify the prompt after first round evaluation") with gr.Row(): prompt2 = gr.Textbox( label="🎨 Revise Prompt", max_lines=1, placeholder="Enter your prompt", # value=PROMPTS[DEFAULT_SCENARIO], scale=4, visible=False ) next_btn2 = gr.Button("Generate", variant="primary", scale=1, interactive=False, visible=False) with gr.Row(elem_id="compact-row"): with gr.Column(elem_id="col-container"): gallery_state2 = gr.State(IMAGES[DEFAULT_SCENARIO]["ours"]) images_method2 = gr.Gallery(height=420, show_label=False, columns=[4], rows=[1], elem_id="gallery") with gr.Column(elem_id="col-container3"): like_image2 = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload') dislike_image2 = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload') with gr.Column(elem_id="col-container2"): gr.Markdown("### 📝 Evaluation") sim_radio2 = gr.Radio( OPTIONS, label="How would you rate your satisfaction with the generated images, based on your expectations for the specified scenario?", type="value", elem_classes=["gradio-radio"] ) like_radio2 = gr.Radio( IMAGE_OPTIONS, label="Select the image you are most satisfied.", type="value", elem_classes=["gradio-radio"] ) dislike_radio2 = gr.Radio( IMAGE_OPTIONS, label="Select the image you are most unsatisfied.", type="value", elem_classes=["gradio-radio"] ) response2 = gr.Textbox( label="Verbally describe key differences found in the image pair.", max_lines=1, interactive=False, container=False, value=VERBAL_MSG ) with gr.Row(elem_id="button-container"): redesign_btn2 = gr.Button("🎨 Redesign", variant="primary", scale=0) submit_btn2 = gr.Button("✅ Submit", variant="primary", interactive=False, scale=0) ######################################################################################################## # Button Function Setup ######################################################################################################## participant.change(fn=set_user, inputs=[participant], outputs=[scenario]) scenario.change(display_scenario, inputs=[participant, scenario], outputs=[scenario_content, prompt, prompt1, prompt2, images_method1, images_method2, gallery_state1, gallery_state2, sim_radio1, sim_radio2, response1, response2, next_btn1, next_btn2, redesign_btn1, redesign_btn2, submit_btn1, submit_btn2]) prompt1.change(fn=reset_gallery, inputs=[], outputs=[gallery_state1]) prompt2.change(fn=reset_gallery, inputs=[], outputs=[gallery_state2]) next_btn1.click(fn=generate_image, inputs=[participant, scenario, prompt1, gallery_state1, active_tab], outputs=[images_method1]) next_btn2.click(fn=generate_image, inputs=[participant, scenario, prompt2, gallery_state2, active_tab], outputs=[images_method2]) sim_radio1.change(fn=check_satisfaction, inputs=[sim_radio1, active_tab], outputs=[submit_btn1, redesign_btn1]) sim_radio2.change(fn=check_satisfaction, inputs=[sim_radio2, active_tab], outputs=[submit_btn2, redesign_btn2]) dislike_radio1.select(fn=select_dislike, inputs=[dislike_radio1, gallery_state1], outputs=[dislike_image1]) like_radio1.select(fn=select_dislike, inputs=[like_radio1, gallery_state1], outputs=[like_image1]) dislike_radio2.select(fn=select_dislike, inputs=[dislike_radio2, gallery_state2], outputs=[dislike_image2]) like_radio2.select(fn=select_dislike, inputs=[like_radio2, gallery_state2], outputs=[like_image2]) redesign_btn1.click( fn=redesign, inputs=[participant, scenario, prompt1, sim_radio1, response1, images_method1, active_tab], outputs=[gallery_state1, sim_radio1, response1, prompt1, next_btn1, redesign_btn1, submit_btn1] ) redesign_btn2.click( fn=redesign, inputs=[participant, scenario, prompt2, sim_radio2, response2, images_method2, active_tab], outputs=[gallery_state2, sim_radio2, response2, prompt2, next_btn2, redesign_btn2, submit_btn2] ) submit_btn1.click(fn=save_response, inputs=[participant, scenario, prompt1, sim_radio1, response1, images_method1, active_tab], outputs=[images_method1, gallery_state1, sim_radio1, prompt1, response1, next_btn1, redesign_btn1, submit_btn1, tabs, scenario]) submit_btn2.click(fn=save_response, inputs=[participant, scenario, prompt2, sim_radio2, response2, images_method2, active_tab], outputs=[images_method2, gallery_state2, sim_radio2, prompt2, response2, next_btn2, redesign_btn2, submit_btn2, tabs, scenario]) if __name__ == "__main__": demo.launch()