POET / app.py
xh365's picture
update instruction wording
4db812e
raw
history blame
25.6 kB
import gradio as gr
from gradio.themes.base import Base
import numpy as np
import random
import spaces
import torch
import re
import open_clip
from optim_utils import optimize_prompt
from utils import clean_response_gpt, setup_model, init_gpt_api, call_gpt_api, get_refine_msg, clean_cache, get_personalize_message, clean_refined_prompt_response_gpt
from utils import SCENARIOS, PROMPTS, IMAGES, OPTIONS, T2I_MODELS, INSTRUCTION, IMAGE_OPTIONS
import spaces #[uncomment to use ZeroGPU]
import transformers
import gspread
CLIP_MODEL = "ViT-H-14"
PRETRAINED_CLIP = "laion2b_s32b_b79k"
default_t2i_model = "black-forest-labs/FLUX.1-dev" # "black-forest-labs/FLUX.1-dev"
default_llm_model = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" # "meta-llama/Meta-Llama-3-8B-Instruct"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
NUM_IMAGES=4
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
clean_cache()
selected_pipe = setup_model(default_t2i_model, torch_dtype, device)
# clip_model, _, preprocess = open_clip.create_model_and_transforms(CLIP_MODEL, pretrained=PRETRAINED_CLIP, device=device)
llm_pipe = None
torch.cuda.empty_cache()
inverted_prompt = ""
VERBAL_MSG = "Please briefly, verbally describe the key differences found in the image pair using a few words or sentences."
DEFAULT_SCENARIO = "Product advertisement"
METHODS = ["Baseline", "Experimental"]
MAX_ROUND = 5
# intermittent memory
counter1, counter2 = 1, 1
responses_memory = {}
assigned_scenarios = list(SCENARIOS.keys())[:2]
current_task1, current_task2 = METHODS # current task 1 (tab 1)
task1_success, task2_success = False, False
enable_submit1, enable_submit2 = False, False
########################################################################################################
# Generating images with two methods
########################################################################################################
@spaces.GPU(duration=65)
def infer(
prompt,
negative_prompt="",
seed=42,
randomize_seed=True,
width=256,
height=256,
guidance_scale=5,
num_inference_steps=18,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
with torch.no_grad():
image = selected_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
return image
def call_gpt_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0.7, top_p=0.9):
seed = random.randint(0, MAX_SEED)
client = init_gpt_api()
messages = get_refine_msg(prompt, num_prompts)
outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens, temperature, top_p)
prompt_list = clean_response_gpt(outputs)
return prompt_list
@spaces.GPU(duration=100)
def invert_prompt(prompt, images, prompt_len=15, iter=1000, lr=0.1, batch_size=2):
text_params = {
"iter": iter,
"lr": lr,
"batch_size": batch_size,
"prompt_len": prompt_len,
"weight_decay": 0.1,
"prompt_bs": 1,
"loss_weight": 1.0,
"print_step": 100,
"clip_model": CLIP_MODEL,
"clip_pretrain": PRETRAINED_CLIP,
}
inverted_prompt = optimize_prompt(clip_model, preprocess, text_params, device, target_images=images, target_prompts=prompt)
# eval(prompt, learned_prompt, optimized_images, clip_model, preprocess)
# return learned_prompt
def personalize_prompt(prompt, history, feedback, like_image, dislike_image):
seed = random.randint(0, MAX_SEED)
client = init_gpt_api()
messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image)
outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9)
# prompt_list = clean_response_gpt(outputs)
# print(prompt_list)
return outputs
########################################################################################################
# Button-related functions
########################################################################################################
def reset_gallery():
return []
def display_error_message(msg, duration=5):
gr.Warning(msg, duration=duration)
def display_info_message(msg, duration=5):
gr.Info(msg, duration=duration)
def switch_tab(active_tab):
if active_tab == "Task A":
return gr.Tabs(selected="Task B")
else:
return gr.Tabs(selected="Task A")
def check_satisfaction(sim_radio, active_tab):
global enable_submit1, enable_submit2, counter1, counter2
method = current_task1 if active_tab == "Task A" else current_task2
enable_submit = enable_submit1 if method == METHODS[0] else enable_submit2
counter = counter1 if method == METHODS[0] else counter2
fully_satisfied_option = ["Satisfied", "Very Satisfied"] # The value to trigger submit
if_submit = sim_radio in fully_satisfied_option or enable_submit or counter > MAX_ROUND
return gr.update(interactive=if_submit)
def check_participant(participant):
if participant == "":
display_error_message("Please fill your participant id!")
return False
return True
def check_evaluation(sim_radio):
if not sim_radio :
display_error_message("❌ Please fill all evaluations before change image or submit.")
return False
return True
def select_image(like_radio, images_method):
if like_radio == IMAGE_OPTIONS[0]:
return images_method[0][0]
elif like_radio == IMAGE_OPTIONS[1]:
return images_method[1][0]
elif like_radio == IMAGE_OPTIONS[2]:
return images_method[2][0]
elif like_radio == IMAGE_OPTIONS[3]:
return images_method[3][0]
else:
return None
def set_user(participant):
global responses_memory, assigned_scenarios
responses_memory[participant] = {METHODS[0]:{}, METHODS[1]:{}}
# id = re.findall(r'\d+', participant)
# if len(id) == 0 or int(id[0]) % 2 == 0: # name invalid, assign first half scenarios
# assigned_scenarios = list(SCENARIOS.keys())[:2]
# else:
# assigned_scenarios = list(SCENARIOS.keys())[2:]
# return assigned_scenarios[0]
def display_scenario(participant, choice):
# reset intermittent storage when scenario change
global counter1, counter2, responses_memory, current_task1, current_task2, task1_success, task2_success, enable_submit1, enable_submit2
task1_success, task2_success = False, False
enable_submit1, enable_submit2 = False, False
counter1, counter2 = 1, 1
if check_participant(participant):
responses_memory[participant] = {METHODS[0]:{}, METHODS[1]:{}}
[current_task1, current_task2] = random.sample(METHODS, 2)
if current_task1 == METHODS[0]:
initial_images1 = IMAGES[choice]["baseline"]
initial_images2 = IMAGES[choice]["ours"]
else:
initial_images1 = IMAGES[choice]["ours"]
initial_images2 = IMAGES[choice]["baseline"]
res = {
scenario_content: SCENARIOS.get(choice, ""),
prompt1: gr.update(value=PROMPTS.get(choice, ""), interactive=False),
prompt2: gr.update(value=PROMPTS.get(choice, ""), interactive=False),
images_method1: initial_images1,
images_method2: initial_images2,
like_image1: None,
dislike_image1: None,
like_image2: None,
dislike_image2: None,
history_images1: [],
history_images2: [],
example1.dataset: gr.update(samples=[], visible=False),
example2.dataset: gr.update(samples=[], visible=False),
next_btn1: gr.update(interactive=False),
next_btn2: gr.update(interactive=False),
redesign_btn1: gr.update(interactive=True),
redesign_btn2: gr.update(interactive=True),
submit_btn1: gr.update(interactive=False),
submit_btn2: gr.update(interactive=False),
}
return res
def generate_image(participant, scenario, prompt, active_tab, like_image, dislike_image):
if not check_participant(participant): return [], []
global current_task1, current_task2
method = current_task1 if active_tab == "Task A" else current_task2
history_prompts = [v["prompt"] for v in responses_memory[participant][method].values()]
feedback = [v["sim_radio"] for v in responses_memory[participant][method].values()]
personalized_prompt = personalize_prompt(prompt, history_prompts, feedback, like_image, dislike_image)
personalized_prompt = clean_refined_prompt_response_gpt(personalized_prompt)
print(f"Personalized prompt: {personalized_prompt}, {type(personalized_prompt)}")
if "I'm sorry, I can't assist with" in personalized_prompt:
print("error in gpt...")
personalized_prompt = prompt
gallery_images = []
if method == METHODS[0]:
for i in range(NUM_IMAGES):
img = infer(personalized_prompt)
gallery_images.append(img)
yield gallery_images
else:
refined_prompts = call_gpt_refine_prompt(personalized_prompt)
for i in range(NUM_IMAGES):
img = infer(refined_prompts[i])
gallery_images.append(img)
yield gallery_images
def redesign(participant, scenario, prompt, sim_radio, like_radio, dislike_radio, current_images, history_images, active_tab):
global counter1, counter2, responses_memory, current_task1, current_task2, enable_submit1, enable_submit2
method = current_task1 if active_tab == "Task A" else current_task2
if check_evaluation(sim_radio) and check_participant(participant):
counter = counter1 if method == METHODS[0] else counter2
enable_submit = enable_submit1 if method == METHODS[0] else enable_submit2
responses_memory[participant][method][counter] = {}
responses_memory[participant][method][counter]["prompt"] = prompt
responses_memory[participant][method][counter]["sim_radio"] = sim_radio
responses_memory[participant][method][counter]["response"] = ""
responses_memory[participant][method][counter]["satisfied_img"] = f"round {counter}, {like_radio}"
responses_memory[participant][method][counter]["unsatisfied_img"] = f"round {counter}, {dislike_radio}"
enable_submit = True if sim_radio in ["Satisfied", "Very Satisfied"] or enable_submit else False
history_prompts = [[v["prompt"]] for v in responses_memory[participant][method].values()]
if not history_images:
history_images = current_images
elif current_images:
history_images.extend(current_images)
current_images = []
examples_state = gr.update(samples=history_prompts, visible=True)
prompt_state = gr.update(interactive=True)
next_state = gr.update(visible=True, interactive=True)
redesign_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(interactive=True)
submit_state = gr.update(interactive=True) if counter >= MAX_ROUND or enable_submit else gr.update(interactive=False)
# update counter
if method == METHODS[0]:
counter1 += 1
enable_submit1 = enable_submit
else:
counter2 += 1
enable_submit2 = enable_submit
return None, None, None, current_images, history_images, examples_state, prompt_state, next_state, redesign_state, submit_state
else:
return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()}
def save_response(participant, scenario, prompt, sim_radio, like_radio, dislike_radio, active_tab):
global current_task1, current_task2 # not change
global task1_success, task2_success, counter1, counter2, enable_submit1, enable_submit2, responses_memory, assigned_scenarios # will change
method = current_task1 if active_tab == "Task A" else current_task2
if check_evaluation(sim_radio) and check_participant(participant):
counter = counter1 if method == METHODS[0] else counter2
responses_memory[participant][method][counter] = {}
responses_memory[participant][method][counter]["prompt"] = prompt
responses_memory[participant][method][counter]["sim_radio"] = sim_radio
responses_memory[participant][method][counter]["response"] = ""
responses_memory[participant][method][counter]["satisfied_img"] = f"round {counter}, {like_radio}"
responses_memory[participant][method][counter]["unsatisfied_img"] = f"round {counter}, {dislike_radio}"
try:
gc = gspread.service_account(filename='credentials.json')
sheet = gc.open("DiverseGen-phase3").sheet1
for i, entry in responses_memory[participant][method].items():
sheet.append_row([participant, scenario, f"{active_tab}, {method}", i, entry["prompt"], entry["sim_radio"], entry["response"], entry["satisfied_img"], entry["unsatisfied_img"]])
display_info_message("βœ… Your answer is saved!")
# reset global variables
if method == METHODS[0]:
counter1 = 1
enable_submit1 = False
else:
counter2 = 1
enable_submit2 = False
if active_tab == "Task A":
task1_success = True
else:
task2_success = True
# decide if change scenario
# if scenario == assigned_scenarios[0]:
# next_scenario = assigned_scenarios[1] if task1_success and task2_success else assigned_scenarios[0]
# else:
# if task1_success and task2_success:
# display_info_message("You have finished all scenarios, thank you!")
# next_scenario = assigned_scenarios[0]
# else:
# next_scenario = assigned_scenarios[1]
# reset buttons
prompt_state = gr.update(interactive=False)
next_state = gr.update(visible=False, interactive=False)
submit_state = gr.update(interactive=False)
redesign_state = gr.update(interactive=False)
tabs = switch_tab(active_tab)
return None, None, None, prompt_state, next_state, redesign_state, submit_state, tabs
except Exception as e:
display_error_message(f"❌ Error saving response: {str(e)}")
return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()}
else:
return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()}
########################################################################################################
# Interface
########################################################################################################
css="""
#col-container {
margin: 0 auto;
max-width: 700px;
}
#col-container2 {
margin: 0 auto;
max-width: 1000px;
}
#col-container3 {
margin: 0 0 auto auto;
max-width: 300px;
}
#button-container {
display: flex;
justify-content: center; /* Centers the buttons horizontally */
}
#compact-row {
width:100%;
max-width: 1000px;
margin: 0px auto;
}
"""
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(" # πŸ“Œ **PAI-GEN**")
with gr.Row():
participant = gr.Textbox(
label="πŸ§‘β€πŸ’Ό Participant ID", placeholder="Please enter you participant id"
)
scenario = gr.Dropdown(
choices=list(SCENARIOS.keys()),
value=None,
label="πŸ“Œ Scenario",
# interactive=False,
)
scenario_content = gr.Textbox(
label="πŸ“– Background",
interactive=False,
)
active_tab = gr.State("Task A")
instruction = gr.Markdown(INSTRUCTION)
with gr.Tabs() as tabs:
with gr.TabItem("Task A", id="Task A") as task1_tab:
task1_tab.select(lambda: "Task A", outputs=[active_tab])
with gr.Row(elem_id="compact-row"):
prompt1 = gr.Textbox(
label="🎨 Revise Prompt",
max_lines=5,
placeholder="Enter your prompt",
scale=4,
visible=True,
)
next_btn1 = gr.Button("Generate", variant="primary", scale=1, interactive=False, visible=False)
with gr.Row(elem_id="compact-row"):
example1 = gr.Examples([['']], prompt1, label="Revised Prompt History", visible=False)
with gr.Row(elem_id="compact-row"):
with gr.Column(elem_id="col-container"):
images_method1 = gr.Gallery(label="Images", columns=[4], rows=[1], height=200, elem_id="gallery")
history_images1 = gr.Gallery(label="History Images", columns=[4], rows=[1], elem_id="gallery")
with gr.Column(elem_id="col-container3"):
like_image1 = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload', type="pil")
dislike_image1 = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload', type="pil")
with gr.Column(elem_id="col-container2"):
gr.Markdown("### πŸ“ Evaluation")
sim_radio1 = gr.Radio(
OPTIONS,
label="How would you rate your satisfaction with the generated images, based on your expectations for the specified scenario?",
type="value",
elem_classes=["gradio-radio"]
)
like_radio1 = gr.Radio(
IMAGE_OPTIONS,
label="Select the image that you find MOST satisfactory. You may leave this section blank if you prefer the previous images.",
type="value",
elem_classes=["gradio-radio"]
)
dislike_radio1 = gr.Radio(
IMAGE_OPTIONS,
label="Please choose the image that you find LEAST satisfactory. You may leave this section blank if you are more dislike previous images.",
type="value",
elem_classes=["gradio-radio"]
)
response1 = gr.Textbox(
label="Verbally describe key differences found in the image pair.",
max_lines=1,
interactive=False,
container=False,
value=VERBAL_MSG
)
with gr.Row(elem_id="button-container"):
redesign_btn1 = gr.Button("🎨 Redesign", variant="primary", scale=0)
submit_btn1 = gr.Button("βœ… Submit", variant="primary", interactive=False, scale=0)
with gr.TabItem("Task B", id="Task B") as task2_tab:
task2_tab.select(lambda: "Task B", outputs=[active_tab])
with gr.Row(elem_id="compact-row"):
prompt2 = gr.Textbox(
label="🎨 Revise Prompt",
max_lines=5,
placeholder="Enter your prompt",
scale=4,
visible=True,
)
next_btn2 = gr.Button("Generate", variant="primary", scale=1, interactive=False, visible=False)
with gr.Row(elem_id="compact-row"):
example2 = gr.Examples([['']], prompt2, label="Revised Prompt History", visible=False)
with gr.Row(elem_id="compact-row"):
with gr.Column(elem_id="col-container"):
images_method2 = gr.Gallery(label="Images", columns=[4], rows=[1], height=200, elem_id="gallery")
history_images2 = gr.Gallery(label="History Images", columns=[4], rows=[1], elem_id="gallery")
with gr.Column(elem_id="col-container3"):
like_image2 = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload', type="pil")
dislike_image2 = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload', type="pil")
with gr.Column(elem_id="col-container2"):
gr.Markdown("### πŸ“ Evaluation")
sim_radio2 = gr.Radio(
OPTIONS,
label="How would you rate your satisfaction with the generated images, based on your expectations for the specified scenario?",
type="value",
elem_classes=["gradio-radio"]
)
like_radio2 = gr.Radio(
IMAGE_OPTIONS,
label="Select the image that you find MOST satisfactory. You may leave this section blank if you prefer the previous images.",
type="value",
elem_classes=["gradio-radio"]
)
dislike_radio2 = gr.Radio(
IMAGE_OPTIONS,
label="Please choose the image that you find LEAST satisfactory. You may leave this section blank if you are more dislike previous images.",
type="value",
elem_classes=["gradio-radio"]
)
response2 = gr.Textbox(
label="Verbally describe key differences found in the image pair.",
max_lines=1,
interactive=False,
container=False,
value=VERBAL_MSG
)
with gr.Row(elem_id="button-container"):
redesign_btn2 = gr.Button("🎨 Redesign", variant="primary", scale=0)
submit_btn2 = gr.Button("βœ… Submit", variant="primary", interactive=False, scale=0)
########################################################################################################
# Button Function Setup
########################################################################################################
# participant.change(fn=set_user, inputs=[participant], outputs=[scenario])
participant.change(fn=set_user, inputs=[participant])
scenario.change(display_scenario,
inputs=[participant, scenario],
outputs=[scenario_content, prompt1, prompt2, images_method1, images_method2, like_image1, dislike_image1, like_image2, dislike_image2, history_images1, history_images2, example1.dataset, example2.dataset, next_btn1, next_btn2, redesign_btn1, redesign_btn2, submit_btn1, submit_btn2])
# prompt1.change(fn=reset_gallery, inputs=[], outputs=[gallery_state1])
# prompt2.change(fn=reset_gallery, inputs=[], outputs=[gallery_state2])
next_btn1.click(fn=generate_image, inputs=[participant, scenario, prompt1, active_tab, like_image1, dislike_image1], outputs=[images_method1])
next_btn2.click(fn=generate_image, inputs=[participant, scenario, prompt2, active_tab, like_image2, dislike_image2], outputs=[images_method2])
sim_radio1.change(fn=check_satisfaction, inputs=[sim_radio1, active_tab], outputs=[submit_btn1])
sim_radio2.change(fn=check_satisfaction, inputs=[sim_radio2, active_tab], outputs=[submit_btn2])
dislike_radio1.select(fn=select_image, inputs=[dislike_radio1, images_method1], outputs=[dislike_image1])
like_radio1.select(fn=select_image, inputs=[like_radio1, images_method1], outputs=[like_image1])
dislike_radio2.select(fn=select_image, inputs=[dislike_radio2, images_method2], outputs=[dislike_image2])
like_radio2.select(fn=select_image, inputs=[like_radio2, images_method2], outputs=[like_image2])
redesign_btn1.click(
fn=redesign,
inputs=[participant, scenario, prompt1, sim_radio1, like_radio1, dislike_radio1, images_method1, history_images1, active_tab],
outputs=[sim_radio1, dislike_radio1, like_radio1, images_method1, history_images1, example1.dataset, prompt1, next_btn1, redesign_btn1, submit_btn1]
)
redesign_btn2.click(
fn=redesign,
inputs=[participant, scenario, prompt2, sim_radio2, like_radio2, dislike_radio2, images_method2, history_images2, active_tab],
outputs=[sim_radio2, dislike_radio2, like_radio2, images_method2, history_images2, example2.dataset, prompt2, next_btn2, redesign_btn2, submit_btn2]
)
submit_btn1.click(fn=save_response,
inputs=[participant, scenario, prompt1, sim_radio1, like_radio1, dislike_radio1, active_tab],
outputs=[sim_radio1, dislike_radio1, like_radio1, prompt1, next_btn1, redesign_btn1, submit_btn1, tabs])
submit_btn2.click(fn=save_response,
inputs=[participant, scenario, prompt2, sim_radio2, like_radio2, dislike_radio2, active_tab],
outputs=[sim_radio2, dislike_radio2, like_radio2, prompt2, next_btn2, redesign_btn2, submit_btn2, tabs])
if __name__ == "__main__":
demo.launch()