POET / app.py
xh365's picture
update policy
1a23e90
raw
history blame
24.1 kB
import gradio as gr
from gradio.themes.base import Base
import numpy as np
import random
import spaces
import torch
import re
import open_clip
from optim_utils import optimize_prompt
from utils import clean_response_gpt, setup_model, init_gpt_api, call_gpt_api, get_refine_msg, clean_cache, get_personalize_message, clean_refined_prompt_response_gpt
from utils import SCENARIOS, PROMPTS, IMAGES, OPTIONS, T2I_MODELS, INSTRUCTION, IMAGE_OPTIONS
import spaces #[uncomment to use ZeroGPU]
import transformers
import gspread
CLIP_MODEL = "ViT-H-14"
PRETRAINED_CLIP = "laion2b_s32b_b79k"
default_t2i_model = "black-forest-labs/FLUX.1-dev" # "black-forest-labs/FLUX.1-dev"
default_llm_model = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" # "meta-llama/Meta-Llama-3-8B-Instruct"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
NUM_IMAGES=4
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
clean_cache()
selected_pipe = setup_model(default_t2i_model, torch_dtype, device)
# clip_model, _, preprocess = open_clip.create_model_and_transforms(CLIP_MODEL, pretrained=PRETRAINED_CLIP, device=device)
llm_pipe = None
torch.cuda.empty_cache()
inverted_prompt = ""
VERBAL_MSG = "Please verbally describe why you are satisfied or unsatisfied at the generated images."
DEFAULT_SCENARIO = "Product advertisement"
METHODS = ["Method 1", "Method 2"]
MAX_ROUND = 5
# intermittent memory
counter1, counter2 = 1, 1
responses_memory = {}
assigned_scenarios = list(SCENARIOS.keys())[:2]
current_task1, current_task2 = METHODS # current task 1 (tab 1)
task1_success, task2_success = False, False
########################################################################################################
# Generating images with two methods
########################################################################################################
@spaces.GPU(duration=65)
def infer(
prompt,
negative_prompt="",
seed=42,
randomize_seed=True,
width=256,
height=256,
guidance_scale=5,
num_inference_steps=18,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
with torch.no_grad():
image = selected_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
return image
def call_gpt_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0.7, top_p=0.9):
seed = random.randint(0, MAX_SEED)
client = init_gpt_api()
messages = get_refine_msg(prompt, num_prompts)
outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens, temperature, top_p)
prompt_list = clean_response_gpt(outputs)
return prompt_list
@spaces.GPU(duration=100)
def invert_prompt(prompt, images, prompt_len=15, iter=1000, lr=0.1, batch_size=2):
text_params = {
"iter": iter,
"lr": lr,
"batch_size": batch_size,
"prompt_len": prompt_len,
"weight_decay": 0.1,
"prompt_bs": 1,
"loss_weight": 1.0,
"print_step": 100,
"clip_model": CLIP_MODEL,
"clip_pretrain": PRETRAINED_CLIP,
}
inverted_prompt = optimize_prompt(clip_model, preprocess, text_params, device, target_images=images, target_prompts=prompt)
# eval(prompt, learned_prompt, optimized_images, clip_model, preprocess)
# return learned_prompt
def personalize_prompt(prompt, history, feedback, like_image, dislike_image):
seed = random.randint(0, MAX_SEED)
client = init_gpt_api()
messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image)
outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9)
# prompt_list = clean_response_gpt(outputs)
# print(prompt_list)
return outputs
########################################################################################################
# Button-related functions
########################################################################################################
def reset_gallery():
return []
def display_error_message(msg, duration=5):
gr.Warning(msg, duration=duration)
def display_info_message(msg, duration=5):
gr.Info(msg, duration=duration)
def switch_tab(active_tab):
print("switching tab")
if active_tab == "Task A":
return gr.Tabs(selected="Task B")
else:
return gr.Tabs(selected="Task A")
def check_satisfaction(sim_radio, active_tab):
global counter1, counter2, current_task1, current_task2
method = current_task1 if active_tab == "Task A" else current_task2
counter = counter1 if method == METHODS[0] else counter2
fully_satisfied_option = ["Satisfied", "Very Satisfied"] # The value to trigger submit
enable_submit = sim_radio in fully_satisfied_option or counter >= MAX_ROUND
return gr.update(interactive=enable_submit), gr.update(interactive=(not enable_submit))
def check_participant(participant):
if participant == "":
display_error_message("Please fill your participant id!")
return False
return True
def check_evaluation(sim_radio):
if not sim_radio :
display_error_message("❌ Please fill all evaluations before change image or submit.")
return False
return True
def select_image(like_radio, images_method):
if like_radio == IMAGE_OPTIONS[0]:
return images_method[0][0]
elif like_radio == IMAGE_OPTIONS[1]:
return images_method[1][0]
elif like_radio == IMAGE_OPTIONS[2]:
return images_method[2][0]
elif like_radio == IMAGE_OPTIONS[3]:
return images_method[3][0]
else:
return None
def set_user(participant):
global responses_memory, assigned_scenarios
responses_memory[participant] = {METHODS[0]:{}, METHODS[1]:{}}
id = re.findall(r'\d+', participant)
if len(id) == 0 or int(id[0]) % 2 == 0: # name invalid, assign first half scenarios
assigned_scenarios = list(SCENARIOS.keys())[:2]
else:
assigned_scenarios = list(SCENARIOS.keys())[2:]
return assigned_scenarios[0]
def display_scenario(participant, choice):
# reset intermittent storage when scenario change
global counter1, counter2, responses_memory, current_task1, current_task2, task1_success, task2_success
task1_success, task2_success = False, False
counter1, counter2 = 1, 1
if check_participant(participant):
responses_memory[participant] = {METHODS[0]:{}, METHODS[1]:{}}
[current_task1, current_task2] = random.sample(METHODS, 2)
if current_task1 == METHODS[0]:
initial_images1 = IMAGES[choice]["baseline"]
initial_images2 = IMAGES[choice]["ours"]
else:
initial_images1 = IMAGES[choice]["ours"]
initial_images2 = IMAGES[choice]["baseline"]
res = {
scenario_content: SCENARIOS.get(choice, ""),
prompt1: gr.update(value=PROMPTS.get(choice, ""), interactive=False),
prompt2: gr.update(value=PROMPTS.get(choice, ""), interactive=False),
images_method1: initial_images1,
images_method2: initial_images2,
like_image1: None,
dislike_image1: None,
like_image2: None,
dislike_image2: None,
history_images1: [],
history_images2: [],
next_btn1: gr.update(interactive=False),
next_btn2: gr.update(interactive=False),
redesign_btn1: gr.update(interactive=True),
redesign_btn2: gr.update(interactive=True),
submit_btn1: gr.update(interactive=False),
submit_btn2: gr.update(interactive=False),
}
return res
def generate_image(participant, scenario, prompt, active_tab, like_image, dislike_image):
if not check_participant(participant): return [], []
global current_task1, current_task2
method = current_task1 if active_tab == "Task A" else current_task2
history_prompts = [v["prompt"] for v in responses_memory[participant][method].values()]
feedback = [v["sim_radio"] for v in responses_memory[participant][method].values()]
personalized_prompt = personalize_prompt(prompt, history_prompts, feedback, like_image, dislike_image)
personalized_prompt = clean_refined_prompt_response_gpt(personalized_prompt)
print(f"Personalized prompt: {personalized_prompt}, {type(personalized_prompt)}")
if "I'm sorry, I can't assist with" in personalized_prompt:
print("error in gpt...")
personalized_prompt = prompt
gallery_images = []
if method == METHODS[0]:
for i in range(NUM_IMAGES):
img = infer(personalized_prompt)
gallery_images.append(img)
yield gallery_images
else:
refined_prompts = call_gpt_refine_prompt(personalized_prompt)
for i in range(NUM_IMAGES):
img = infer(refined_prompts[i])
gallery_images.append(img)
yield gallery_images
def redesign(participant, scenario, prompt, sim_radio, current_images, history_images, active_tab):
global counter1, counter2, responses_memory, current_task1, current_task2
method = current_task1 if active_tab == "Task A" else current_task2
if check_evaluation(sim_radio) and check_participant(participant):
if method == METHODS[0]:
counter1 += 1
counter = counter1
else:
counter2 += 1
counter = counter2
responses_memory[participant][method][counter-1] = {}
responses_memory[participant][method][counter-1]["prompt"] = prompt
responses_memory[participant][method][counter-1]["sim_radio"] = sim_radio
# responses_memory[participant][method][counter-1]["response"] = response
history_prompts = [[v["prompt"]] for v in responses_memory[participant][method].values()]
if not history_images:
history_images = current_images
elif current_images:
history_images.extend(current_images)
current_images = []
examples_state = gr.update(samples=history_prompts, visible=True)
prompt_state = gr.update(interactive=True)
next_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(visible=True, interactive=True)
redesign_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(interactive=True)
submit_state = gr.update(interactive=True) if counter >= MAX_ROUND else gr.update(interactive=False)
return None, None, None, current_images, history_images, examples_state, prompt_state, next_state, redesign_state, submit_state
else:
return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()}
def save_response(participant, scenario, prompt, sim_radio, active_tab):
global current_task1, current_task2 # not change
global task1_success, task2_success, counter1, counter2, responses_memory, assigned_scenarios # will change
method = current_task1 if active_tab == "Task A" else current_task2
if check_evaluation(sim_radio) and check_participant(participant):
counter = counter1 if method == METHODS[0] else counter2
responses_memory[participant][method][counter] = {}
responses_memory[participant][method][counter]["prompt"] = prompt
responses_memory[participant][method][counter]["sim_radio"] = sim_radio
# responses_memory[participant][method][counter]["response"] = response
try:
gc = gspread.service_account(filename='credentials.json')
sheet = gc.open("DiverseGen-phase3").sheet1
for i, entry in responses_memory[participant][method].items():
sheet.append_row([participant, scenario, method, i, entry["prompt"], entry["sim_radio"]])
display_info_message("βœ… Your answer is saved!")
# reset global variables
if method == METHODS[0]:
counter1 = 1
else:
counter2 = 1
if active_tab == "Task A":
task1_success = True
else:
task2_success = True
# decide if change scenario
next_scenario = assigned_scenarios[1] if task1_success and task2_success else assigned_scenarios[0]
# update buttons
example_state = gr.update(samples=[], visible=False)
prompt_state = gr.update(interactive=False)
next_state = gr.update(visible=False, interactive=False)
submit_state = gr.update(interactive=False)
redesign_state = gr.update(interactive=False)
tabs = switch_tab(active_tab)
return None, None, None, None, None, [], [], example_state, prompt_state, next_state, redesign_state, submit_state, next_scenario, tabs
except Exception as e:
display_error_message(f"❌ Error saving response: {str(e)}")
return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()}
else:
return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()}
########################################################################################################
# Interface
########################################################################################################
css="""
#col-container {
margin: 0 auto;
max-width: 700px;
}
#col-container2 {
margin: 0 auto;
max-width: 1000px;
}
#col-container3 {
margin: 0 0 auto auto;
max-width: 300px;
}
#button-container {
display: flex;
justify-content: center; /* Centers the buttons horizontally */
}
#compact-row {
width:100%;
max-width: 1000px;
margin: 0px auto;
}
"""
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(" # πŸ“Œ **Diverse Text-to-Image Generation**")
with gr.Row():
participant = gr.Textbox(
label="πŸ§‘β€πŸ’Ό Participant ID", placeholder="Please enter you participant id"
)
scenario = gr.Dropdown(
choices=list(SCENARIOS.keys()),
value=None,
label="πŸ“Œ Scenario",
interactive=False,
)
scenario_content = gr.Textbox(
label="πŸ“– Background",
interactive=False,
)
active_tab = gr.State("Task A")
instruction = gr.Markdown(INSTRUCTION)
with gr.Tabs() as tabs:
with gr.TabItem("Task A", id="Task A") as task1_tab:
task1_tab.select(lambda: "Task A", outputs=[active_tab])
with gr.Row(elem_id="compact-row"):
prompt1 = gr.Textbox(
label="🎨 Revise Prompt",
max_lines=5,
placeholder="Enter your prompt",
scale=4,
visible=True,
)
next_btn1 = gr.Button("Generate", variant="primary", scale=1, interactive=False, visible=False)
with gr.Row(elem_id="compact-row"):
example1 = gr.Examples([['']], prompt1, label="Revised Prompt History", visible=False)
with gr.Row(elem_id="compact-row"):
with gr.Column(elem_id="col-container"):
images_method1 = gr.Gallery(label="Images", columns=[4], rows=[1], height=200, elem_id="gallery")
history_images1 = gr.Gallery(label="History Images", columns=[4], rows=[1], elem_id="gallery")
with gr.Column(elem_id="col-container3"):
like_image1 = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload', type="pil")
dislike_image1 = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload', type="pil")
with gr.Column(elem_id="col-container2"):
gr.Markdown("### πŸ“ Evaluation")
sim_radio1 = gr.Radio(
OPTIONS,
label="How would you rate your satisfaction with the generated images, based on your expectations for the specified scenario?",
type="value",
elem_classes=["gradio-radio"]
)
like_radio1 = gr.Radio(
IMAGE_OPTIONS,
label="Select the image that you find MOST satisfactory. You may leave this section blank if you prefer the previous images.",
type="value",
elem_classes=["gradio-radio"]
)
dislike_radio1 = gr.Radio(
IMAGE_OPTIONS,
label="Please choose the image that you find LEAST satisfactory. You may leave this section blank if you are more dislike previous images.",
type="value",
elem_classes=["gradio-radio"]
)
response1 = gr.Textbox(
label="Verbally describe key differences found in the image pair.",
max_lines=1,
interactive=False,
container=False,
value=VERBAL_MSG
)
with gr.Row(elem_id="button-container"):
redesign_btn1 = gr.Button("🎨 Redesign", variant="primary", scale=0)
submit_btn1 = gr.Button("βœ… Submit", variant="primary", interactive=False, scale=0)
with gr.TabItem("Task B", id="Task B") as task2_tab:
task2_tab.select(lambda: "Task B", outputs=[active_tab])
with gr.Row(elem_id="compact-row"):
prompt2 = gr.Textbox(
label="🎨 Revise Prompt",
max_lines=5,
placeholder="Enter your prompt",
scale=4,
visible=True,
)
next_btn2 = gr.Button("Generate", variant="primary", scale=1, interactive=False, visible=False)
with gr.Row(elem_id="compact-row"):
example2 = gr.Examples([['']], prompt2, label="Revised Prompt History", visible=False)
with gr.Row(elem_id="compact-row"):
with gr.Column(elem_id="col-container"):
images_method2 = gr.Gallery(label="Images", columns=[4], rows=[1], height=200, elem_id="gallery")
history_images2 = gr.Gallery(label="History Images", columns=[4], rows=[1], elem_id="gallery")
with gr.Column(elem_id="col-container3"):
like_image2 = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload', type="pil")
dislike_image2 = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload', type="pil")
with gr.Column(elem_id="col-container2"):
gr.Markdown("### πŸ“ Evaluation")
sim_radio2 = gr.Radio(
OPTIONS,
label="How would you rate your satisfaction with the generated images, based on your expectations for the specified scenario?",
type="value",
elem_classes=["gradio-radio"]
)
like_radio2 = gr.Radio(
IMAGE_OPTIONS,
label="Select the image that you find MOST satisfactory. You may leave this section blank if you prefer the previous images.",
type="value",
elem_classes=["gradio-radio"]
)
dislike_radio2 = gr.Radio(
IMAGE_OPTIONS,
label="Please choose the image that you find LEAST satisfactory. You may leave this section blank if you are more dislike previous images.",
type="value",
elem_classes=["gradio-radio"]
)
response2 = gr.Textbox(
label="Verbally describe key differences found in the image pair.",
max_lines=1,
interactive=False,
container=False,
value=VERBAL_MSG
)
with gr.Row(elem_id="button-container"):
redesign_btn2 = gr.Button("🎨 Redesign", variant="primary", scale=0)
submit_btn2 = gr.Button("βœ… Submit", variant="primary", interactive=False, scale=0)
########################################################################################################
# Button Function Setup
########################################################################################################
participant.change(fn=set_user, inputs=[participant], outputs=[scenario])
scenario.change(display_scenario,
inputs=[participant, scenario],
outputs=[scenario_content, prompt1, prompt2, images_method1, images_method2, like_image1, dislike_image1, like_image2, dislike_image2, history_images1, history_images2, next_btn1, next_btn2, redesign_btn1, redesign_btn2, submit_btn1, submit_btn2])
# prompt1.change(fn=reset_gallery, inputs=[], outputs=[gallery_state1])
# prompt2.change(fn=reset_gallery, inputs=[], outputs=[gallery_state2])
next_btn1.click(fn=generate_image, inputs=[participant, scenario, prompt1, active_tab, like_image1, dislike_image1], outputs=[images_method1])
next_btn2.click(fn=generate_image, inputs=[participant, scenario, prompt2, active_tab, like_image2, dislike_image2], outputs=[images_method2])
sim_radio1.change(fn=check_satisfaction, inputs=[sim_radio1, active_tab], outputs=[submit_btn1, redesign_btn1])
sim_radio2.change(fn=check_satisfaction, inputs=[sim_radio2, active_tab], outputs=[submit_btn2, redesign_btn2])
dislike_radio1.select(fn=select_image, inputs=[dislike_radio1, images_method1], outputs=[dislike_image1])
like_radio1.select(fn=select_image, inputs=[like_radio1, images_method1], outputs=[like_image1])
dislike_radio2.select(fn=select_image, inputs=[dislike_radio2, images_method2], outputs=[dislike_image2])
like_radio2.select(fn=select_image, inputs=[like_radio2, images_method2], outputs=[like_image2])
redesign_btn1.click(
fn=redesign,
inputs=[participant, scenario, prompt1, sim_radio1, images_method1, history_images1, active_tab],
outputs=[sim_radio1, dislike_radio1, like_radio1, images_method1, history_images1, example1.dataset, prompt1, next_btn1, redesign_btn1, submit_btn1]
)
redesign_btn2.click(
fn=redesign,
inputs=[participant, scenario, prompt2, sim_radio2, images_method2, history_images2, active_tab],
outputs=[sim_radio2, dislike_radio2, like_radio2, images_method2, history_images2, example2.dataset, prompt2, next_btn2, redesign_btn2, submit_btn2]
)
submit_btn1.click(fn=save_response,
inputs=[participant, scenario, prompt1, sim_radio1, active_tab],
outputs=[sim_radio1, dislike_radio1, like_radio1, like_image1, dislike_image1, images_method1, history_images1, example1.dataset, prompt1, next_btn1, redesign_btn1, submit_btn1, scenario, tabs])
submit_btn2.click(fn=save_response,
inputs=[participant, scenario, prompt2, sim_radio2, active_tab],
outputs=[sim_radio2, dislike_radio2, like_radio2, like_image2, dislike_image2, images_method2, history_images2, example2.dataset, prompt2, next_btn2, redesign_btn2, submit_btn2, scenario, tabs])
if __name__ == "__main__":
demo.launch()