POET / app.py
xh365's picture
update saving to each iteration
a92b2e7
raw
history blame
28 kB
import gradio as gr
from gradio.themes.base import Base
import numpy as np
import random
import spaces
import torch
import re
import open_clip
from optim_utils import optimize_prompt
from utils import clean_response_gpt, setup_model, init_gpt_api, call_gpt_api, get_refine_msg, clean_cache, get_personalize_message, clean_refined_prompt_response_gpt
from utils import SCENARIOS, PROMPTS, IMAGES, OPTIONS, T2I_MODELS, INSTRUCTION, IMAGE_OPTIONS
import spaces #[uncomment to use ZeroGPU]
import transformers
import gspread
from googleapiclient.discovery import build
from googleapiclient.http import MediaFileUpload
from googleapiclient.errors import HttpError
from google.oauth2.service_account import Credentials
CLIP_MODEL = "ViT-H-14"
PRETRAINED_CLIP = "laion2b_s32b_b79k"
default_t2i_model = "black-forest-labs/FLUX.1-dev" # "black-forest-labs/FLUX.1-dev"
default_llm_model = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" # "meta-llama/Meta-Llama-3-8B-Instruct"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
NUM_IMAGES=4
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
clean_cache()
selected_pipe = setup_model(default_t2i_model, torch_dtype, device)
# clip_model, _, preprocess = open_clip.create_model_and_transforms(CLIP_MODEL, pretrained=PRETRAINED_CLIP, device=device)
llm_pipe = None
torch.cuda.empty_cache()
inverted_prompt = ""
VERBAL_MSG = "Please explain your rating of satisfaction in few words or sentences."
DEFAULT_SCENARIO = "Product advertisement"
METHODS = ["Baseline", "Experimental"]
MAX_ROUND = 5
counter1, counter2 = 1, 1
responses_memory = {}
assigned_scenarios = list(SCENARIOS.keys())[:2]
current_task1, current_task2 = METHODS # current task 1 (tab 1)
task1_success, task2_success = False, False
enable_submit1, enable_submit2 = False, False
scopes = ['https://www.googleapis.com/auth/spreadsheets', 'https://www.googleapis.com/auth/drive']
########################################################################################################
# Generating images with two methods
########################################################################################################
@spaces.GPU(duration=65)
def infer(
prompt,
negative_prompt="",
seed=42,
randomize_seed=True,
width=256,
height=256,
guidance_scale=5,
num_inference_steps=18,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
with torch.no_grad():
image = selected_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
return image
def call_gpt_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0.7, top_p=0.9):
seed = random.randint(0, MAX_SEED)
client = init_gpt_api()
messages = get_refine_msg(prompt, num_prompts)
outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens, temperature, top_p)
prompt_list = clean_response_gpt(outputs)
return prompt_list
@spaces.GPU(duration=100)
def invert_prompt(prompt, images, prompt_len=15, iter=1000, lr=0.1, batch_size=2):
text_params = {
"iter": iter,
"lr": lr,
"batch_size": batch_size,
"prompt_len": prompt_len,
"weight_decay": 0.1,
"prompt_bs": 1,
"loss_weight": 1.0,
"print_step": 100,
"clip_model": CLIP_MODEL,
"clip_pretrain": PRETRAINED_CLIP,
}
inverted_prompt = optimize_prompt(clip_model, preprocess, text_params, device, target_images=images, target_prompts=prompt)
# eval(prompt, learned_prompt, optimized_images, clip_model, preprocess)
# return learned_prompt
def personalize_prompt(prompt, history, feedback, like_image, dislike_image):
seed = random.randint(0, MAX_SEED)
client = init_gpt_api()
messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image)
outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9)
# prompt_list = clean_response_gpt(outputs)
# print(prompt_list)
return outputs
########################################################################################################
# Button-related functions
########################################################################################################
def reset_gallery():
return []
def display_error_message(msg, duration=5):
gr.Warning(msg, duration=duration)
def display_info_message(msg, duration=5):
gr.Info(msg, duration=duration)
def switch_tab(active_tab):
if active_tab == "Task A":
return gr.Tabs(selected="Task B")
else:
return gr.Tabs(selected="Task A")
def check_satisfaction(sim_radio, active_tab):
global enable_submit1, enable_submit2, counter1, counter2
method = current_task1 if active_tab == "Task A" else current_task2
enable_submit = enable_submit1 if method == METHODS[0] else enable_submit2
counter = counter1 if method == METHODS[0] else counter2
fully_satisfied_option = ["Satisfied", "Very Satisfied"] # The value to trigger submit
if_submit = sim_radio in fully_satisfied_option or enable_submit or counter > MAX_ROUND
return gr.update(interactive=if_submit)
def check_participant(participant):
if participant == "":
display_error_message("Please fill your participant id!")
return False
return True
def check_evaluation(sim_radio):
if not sim_radio :
display_error_message("❌ Please fill all evaluations before change image or submit.")
return False
return True
def select_image(like_radio, images_method):
if like_radio == IMAGE_OPTIONS[0]:
return images_method[0][0]
elif like_radio == IMAGE_OPTIONS[1]:
return images_method[1][0]
elif like_radio == IMAGE_OPTIONS[2]:
return images_method[2][0]
elif like_radio == IMAGE_OPTIONS[3]:
return images_method[3][0]
else:
return None
def set_user(participant):
global responses_memory, assigned_scenarios
responses_memory[participant] = {METHODS[0]:{}, METHODS[1]:{}}
# id = re.findall(r'\d+', participant)
# if len(id) == 0 or int(id[0]) % 2 == 0: # name invalid, assign first half scenarios
# assigned_scenarios = list(SCENARIOS.keys())[:2]
# else:
# assigned_scenarios = list(SCENARIOS.keys())[2:]
# return assigned_scenarios[0]
def assign_tasks(participant):
id = re.findall(r'\d+', participant)
if len(id) == 0 or int(id[0]) % 4 == 1 or int(id[0]) % 4 == 2:
return METHODS[1], METHODS[0]
else:
return METHODS[0], METHODS[1]
def display_scenario(participant, choice):
# reset intermittent storage when scenario change
global counter1, counter2, responses_memory, current_task1, current_task2, task1_success, task2_success, enable_submit1, enable_submit2
task1_success, task2_success = False, False
enable_submit1, enable_submit2 = False, False
counter1, counter2 = 1, 1
if check_participant(participant):
responses_memory[participant] = {METHODS[0]:{}, METHODS[1]:{}}
# [current_task1, current_task2] = random.sample(METHODS, 2)
current_task1, current_task2 = assign_tasks(participant)
if current_task1 == METHODS[0]:
initial_images1 = IMAGES[choice]["baseline"]
initial_images2 = IMAGES[choice]["ours"]
else:
initial_images1 = IMAGES[choice]["ours"]
initial_images2 = IMAGES[choice]["baseline"]
res = {
scenario_content: SCENARIOS.get(choice, ""),
prompt1: gr.update(value=PROMPTS.get(choice, ""), interactive=False),
prompt2: gr.update(value=PROMPTS.get(choice, ""), interactive=False),
images_method1: initial_images1,
images_method2: initial_images2,
like_image1: None,
dislike_image1: None,
like_image2: None,
dislike_image2: None,
history_images1: [],
history_images2: [],
example1.dataset: gr.update(samples=[], visible=False),
example2.dataset: gr.update(samples=[], visible=False),
next_btn1: gr.update(interactive=False),
next_btn2: gr.update(interactive=False),
redesign_btn1: gr.update(interactive=True),
redesign_btn2: gr.update(interactive=True),
submit_btn1: gr.update(interactive=False),
submit_btn2: gr.update(interactive=False),
}
return res
def generate_image(participant, scenario, prompt, active_tab, like_image, dislike_image):
if not check_participant(participant): return [], []
global current_task1, current_task2
method = current_task1 if active_tab == "Task A" else current_task2
history_prompts = [v["prompt"] for v in responses_memory[participant][method].values()]
feedback = [v["sim_radio"] for v in responses_memory[participant][method].values()]
personalized_prompt = personalize_prompt(prompt, history_prompts, feedback, like_image, dislike_image)
personalized_prompt = clean_refined_prompt_response_gpt(personalized_prompt)
print(f"Personalized prompt: {personalized_prompt}, {type(personalized_prompt)}")
if "I'm sorry, I can't assist with" in personalized_prompt:
print("error in gpt...")
personalized_prompt = prompt
gallery_images = []
if method == METHODS[0]:
for i in range(NUM_IMAGES):
img = infer(personalized_prompt)
gallery_images.append(img)
yield gallery_images
else:
refined_prompts = call_gpt_refine_prompt(personalized_prompt)
for i in range(NUM_IMAGES):
img = infer(refined_prompts[i])
gallery_images.append(img)
yield gallery_images
def save_response_to_sheet(participant, method, scenario, active_tab, round, like_image, dislike_image):
global responses_memory
gc = gspread.service_account(filename='credentials.json')
sheet = gc.open("DiverseGen-phase3").sheet1
entry = responses_memory[participant][method][round]
print(entry)
sheet.append_row([participant, scenario, f"{active_tab}, {method}", round, entry["prompt"], entry["sim_radio"], entry["response"], entry["satisfied_img"], entry["unsatisfied_img"]])
# save images in google drive
creds = Credentials.from_service_account_file('credentials.json',scopes=scopes)
save_image(creds, like_image, dislike_image, f"{participant}_{scenario}_{active_tab}_{method}_round{round}")
display_info_message("βœ… Your answer is saved!")
def redesign(participant, scenario, prompt, sim_radio, like_radio, dislike_radio, current_images, history_images, active_tab, like_image, dislike_image):
global counter1, counter2, responses_memory, current_task1, current_task2, enable_submit1, enable_submit2
method = current_task1 if active_tab == "Task A" else current_task2
if check_evaluation(sim_radio) and check_participant(participant):
counter = counter1 if method == METHODS[0] else counter2
enable_submit = enable_submit1 if method == METHODS[0] else enable_submit2
responses_memory[participant][method][counter] = {}
responses_memory[participant][method][counter]["prompt"] = prompt
responses_memory[participant][method][counter]["sim_radio"] = sim_radio
responses_memory[participant][method][counter]["response"] = ""
responses_memory[participant][method][counter]["satisfied_img"] = f"round {counter}, {like_radio}"
responses_memory[participant][method][counter]["unsatisfied_img"] = f"round {counter}, {dislike_radio}"
save_response_to_sheet(participant, method, scenario, active_tab, counter, like_image, dislike_image)
enable_submit = True if sim_radio in ["Satisfied", "Very Satisfied"] or enable_submit else False
history_prompts = [[v["prompt"]] for v in responses_memory[participant][method].values()]
if not history_images:
history_images = current_images
elif current_images:
history_images.extend(current_images)
current_images = []
examples_state = gr.update(samples=history_prompts, visible=True)
prompt_state = gr.update(interactive=True)
next_state = gr.update(visible=True, interactive=True)
redesign_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(interactive=True)
submit_state = gr.update(interactive=True) if counter >= MAX_ROUND or enable_submit else gr.update(interactive=False)
# update counter
if method == METHODS[0]:
counter1 += 1
enable_submit1 = enable_submit
else:
counter2 += 1
enable_submit2 = enable_submit
return None, None, None, current_images, history_images, examples_state, prompt_state, next_state, redesign_state, submit_state
else:
return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()}
def save_image(creds, like_image, dislike_image, name):
try:
service = build("drive", "v3", credentials=creds)
for image_path, suffix in zip([like_image, dislike_image], ["satisfied", "unsatisfied"]):
filename = f"{name}_{suffix}"
file_metadata = {"name": filename, "parents": ["1ru3-QbbzyVSk-1kBfVv4nhElFqYh3ITj"]}
media = MediaFileUpload(image_path, mimetype="image/png")
uploaded_file = service.files().create(body=file_metadata, media_body=media, fields="id").execute()
except HttpError as error:
print(f"An error occurred: {error}")
def save_response(participant, scenario, prompt, sim_radio, like_radio, dislike_radio, like_image, dislike_image, active_tab):
global current_task1, current_task2, scopes # not change
global task1_success, task2_success, counter1, counter2, enable_submit1, enable_submit2, responses_memory, assigned_scenarios # will change
method = current_task1 if active_tab == "Task A" else current_task2
if check_evaluation(sim_radio) and check_participant(participant):
counter = counter1 if method == METHODS[0] else counter2
responses_memory[participant][method][counter] = {}
responses_memory[participant][method][counter]["prompt"] = prompt
responses_memory[participant][method][counter]["sim_radio"] = sim_radio
responses_memory[participant][method][counter]["response"] = ""
responses_memory[participant][method][counter]["satisfied_img"] = f"round {counter}, {like_radio}"
responses_memory[participant][method][counter]["unsatisfied_img"] = f"round {counter}, {dislike_radio}"
try:
save_response_to_sheet(participant, method, scenario, active_tab, counter, like_image, dislike_image)
# reset global variables
if method == METHODS[0]:
counter1 = 1
enable_submit1 = False
else:
counter2 = 1
enable_submit2 = False
if active_tab == "Task A":
task1_success = True
else:
task2_success = True
# decide if change scenario
# if scenario == assigned_scenarios[0]:
# next_scenario = assigned_scenarios[1] if task1_success and task2_success else assigned_scenarios[0]
# else:
# if task1_success and task2_success:
# display_info_message("You have finished all scenarios, thank you!")
# next_scenario = assigned_scenarios[0]
# else:
# next_scenario = assigned_scenarios[1]
# reset buttons
prompt_state = gr.update(interactive=False)
next_state = gr.update(visible=False, interactive=False)
submit_state = gr.update(interactive=False)
redesign_state = gr.update(interactive=False)
tabs = switch_tab(active_tab)
return None, None, None, prompt_state, next_state, redesign_state, submit_state, tabs
except Exception as e:
display_error_message(f"❌ Error saving response: {str(e)}")
return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()}
else:
return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()}
########################################################################################################
# Interface
########################################################################################################
css="""
#col-container {
margin: 0 auto;
max-width: 700px;
}
#col-container2 {
margin: 0 auto;
max-width: 1000px;
}
#col-container3 {
margin: 0 0 auto auto;
max-width: 300px;
}
#button-container {
display: flex;
justify-content: center; /* Centers the buttons horizontally */
}
#compact-row {
width:100%;
max-width: 1000px;
margin: 0px auto;
}
"""
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(" # πŸ“Œ **PAI-GEN**")
with gr.Row():
participant = gr.Textbox(
label="πŸ§‘β€πŸ’Ό Participant ID", placeholder="Please enter you participant id"
)
scenario = gr.Dropdown(
choices=list(SCENARIOS.keys()),
value=None,
label="πŸ“Œ Scenario",
# interactive=False,
)
scenario_content = gr.Textbox(
label="πŸ“– Background",
interactive=False,
)
active_tab = gr.State("Task A")
instruction = gr.Markdown(INSTRUCTION)
with gr.Tabs() as tabs:
with gr.TabItem("Task A", id="Task A") as task1_tab:
task1_tab.select(lambda: "Task A", outputs=[active_tab])
with gr.Row(elem_id="compact-row"):
prompt1 = gr.Textbox(
label="🎨 Revise Prompt",
max_lines=5,
placeholder="Enter your prompt",
scale=4,
visible=True,
)
next_btn1 = gr.Button("Generate", variant="primary", scale=1, interactive=False, visible=False)
with gr.Row(elem_id="compact-row"):
with gr.Column(elem_id="col-container"):
images_method1 = gr.Gallery(label="Images", columns=[4], rows=[1], height=400, elem_id="gallery", format="png")
with gr.Column(elem_id="col-container3"):
like_image1 = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload', format="png", type="filepath")
dislike_image1 = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload', format="png", type="filepath")
with gr.Column(elem_id="col-container2"):
gr.Markdown("### πŸ“ Evaluation")
sim_radio1 = gr.Radio(
OPTIONS,
label="How would you rate your satisfaction with the generated images, based on your expectations for the specified scenario?",
type="value",
elem_classes=["gradio-radio"]
)
like_radio1 = gr.Radio(
IMAGE_OPTIONS,
label="Select your all-time favorite image that you fnd MOST satisfactory in this task. You may leave this section blank if you prefer the previous images.",
type="value",
elem_classes=["gradio-radio"]
)
dislike_radio1 = gr.Radio(
IMAGE_OPTIONS,
label="Select your all-time disliked image that you fnd LEAST satisfactory in this task. You may leave this section blank if you are more dislike previous images.",
type="value",
elem_classes=["gradio-radio"]
)
response1 = gr.Textbox(
label="Verbally describe key differences found in the image pair.",
max_lines=1,
interactive=False,
container=False,
value=VERBAL_MSG
)
with gr.Column(elem_id="col-container2"):
example1 = gr.Examples([['']], prompt1, label="Revised Prompt History", visible=False)
history_images1 = gr.Gallery(label="History Images", columns=[4], rows=[1], elem_id="gallery", format="png")
with gr.Row(elem_id="button-container"):
redesign_btn1 = gr.Button("🎨 Redesign", variant="primary", scale=0)
submit_btn1 = gr.Button("βœ… Submit", variant="primary", interactive=False, scale=0)
with gr.TabItem("Task B", id="Task B") as task2_tab:
task2_tab.select(lambda: "Task B", outputs=[active_tab])
with gr.Row(elem_id="compact-row"):
prompt2 = gr.Textbox(
label="🎨 Revise Prompt",
max_lines=5,
placeholder="Enter your prompt",
scale=4,
visible=True,
)
next_btn2 = gr.Button("Generate", variant="primary", scale=1, interactive=False, visible=False)
with gr.Row(elem_id="compact-row"):
with gr.Column(elem_id="col-container"):
images_method2 = gr.Gallery(label="Images", columns=[4], rows=[1], height=200, elem_id="gallery", format="png")
with gr.Column(elem_id="col-container3"):
like_image2 = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload', format="png", type="filepath")
dislike_image2 = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload', format="png", type="filepath")
with gr.Column(elem_id="col-container2"):
gr.Markdown("### πŸ“ Evaluation")
sim_radio2 = gr.Radio(
OPTIONS,
label="How would you rate your satisfaction with the generated images, based on your expectations for the specified scenario?",
type="value",
elem_classes=["gradio-radio"]
)
like_radio2 = gr.Radio(
IMAGE_OPTIONS,
label="Select your all-time favorite image that you fnd MOST satisfactory in this task. You may leave this section blank if you prefer the previous images.",
type="value",
elem_classes=["gradio-radio"]
)
dislike_radio2 = gr.Radio(
IMAGE_OPTIONS,
label="Select your all-time disliked image that you fnd LEAST satisfactory in this task. You may leave this section blank if you are more dislike previous images.",
type="value",
elem_classes=["gradio-radio"]
)
response2 = gr.Textbox(
label="Verbally describe key differences found in the image pair.",
max_lines=1,
interactive=False,
container=False,
value=VERBAL_MSG
)
with gr.Column(elem_id="col-container2"):
example2 = gr.Examples([['']], prompt2, label="Revised Prompt History", visible=False)
history_images2 = gr.Gallery(label="History Images", columns=[4], rows=[1], elem_id="gallery", format="png")
with gr.Row(elem_id="button-container"):
redesign_btn2 = gr.Button("🎨 Redesign", variant="primary", scale=0)
submit_btn2 = gr.Button("βœ… Submit", variant="primary", interactive=False, scale=0)
########################################################################################################
# Button Function Setup
########################################################################################################
# participant.change(fn=set_user, inputs=[participant], outputs=[scenario])
participant.change(fn=set_user, inputs=[participant])
scenario.change(display_scenario,
inputs=[participant, scenario],
outputs=[scenario_content, prompt1, prompt2, images_method1, images_method2, like_image1, dislike_image1, like_image2, dislike_image2, history_images1, history_images2, example1.dataset, example2.dataset, next_btn1, next_btn2, redesign_btn1, redesign_btn2, submit_btn1, submit_btn2])
# prompt1.change(fn=reset_gallery, inputs=[], outputs=[gallery_state1])
# prompt2.change(fn=reset_gallery, inputs=[], outputs=[gallery_state2])
next_btn1.click(fn=generate_image, inputs=[participant, scenario, prompt1, active_tab, like_image1, dislike_image1], outputs=[images_method1]).success(lambda: [gr.update(interactive=False),gr.update(interactive=False)], outputs=[next_btn1, prompt1])
next_btn2.click(fn=generate_image, inputs=[participant, scenario, prompt2, active_tab, like_image2, dislike_image2], outputs=[images_method2]).success(lambda: [gr.update(interactive=False),gr.update(interactive=False)], outputs=[next_btn2, prompt2])
sim_radio1.change(fn=check_satisfaction, inputs=[sim_radio1, active_tab], outputs=[submit_btn1])
sim_radio2.change(fn=check_satisfaction, inputs=[sim_radio2, active_tab], outputs=[submit_btn2])
dislike_radio1.select(fn=select_image, inputs=[dislike_radio1, images_method1], outputs=[dislike_image1])
like_radio1.select(fn=select_image, inputs=[like_radio1, images_method1], outputs=[like_image1])
dislike_radio2.select(fn=select_image, inputs=[dislike_radio2, images_method2], outputs=[dislike_image2])
like_radio2.select(fn=select_image, inputs=[like_radio2, images_method2], outputs=[like_image2])
redesign_btn1.click(
fn=redesign,
inputs=[participant, scenario, prompt1, sim_radio1, like_radio1, dislike_radio1, images_method1, history_images1, active_tab, like_image1, dislike_image1],
outputs=[sim_radio1, dislike_radio1, like_radio1, images_method1, history_images1, example1.dataset, prompt1, next_btn1, redesign_btn1, submit_btn1]
)
redesign_btn2.click(
fn=redesign,
inputs=[participant, scenario, prompt2, sim_radio2, like_radio2, dislike_radio2, images_method2, history_images2, active_tab, like_image2, dislike_image2],
outputs=[sim_radio2, dislike_radio2, like_radio2, images_method2, history_images2, example2.dataset, prompt2, next_btn2, redesign_btn2, submit_btn2]
)
submit_btn1.click(fn=save_response,
inputs=[participant, scenario, prompt1, sim_radio1, like_radio1, dislike_radio1, like_image1, dislike_image1, active_tab],
outputs=[sim_radio1, dislike_radio1, like_radio1, prompt1, next_btn1, redesign_btn1, submit_btn1, tabs])
submit_btn2.click(fn=save_response,
inputs=[participant, scenario, prompt2, sim_radio2, like_radio2, dislike_radio2, like_image2, dislike_image2, active_tab],
outputs=[sim_radio2, dislike_radio2, like_radio2, prompt2, next_btn2, redesign_btn2, submit_btn2, tabs])
if __name__ == "__main__":
demo.launch()