POET / app.py
xh365's picture
center aligned
2e7f19e
raw
history blame
24.6 kB
import gradio as gr
from gradio.themes.base import Base
import numpy as np
import random
import spaces
import torch
import re
import open_clip
from optim_utils import optimize_prompt
from utils import clean_response_gpt, setup_model, init_gpt_api, call_gpt_api, get_refine_msg, clean_cache
from utils import SCENARIOS, PROMPTS, IMAGES, OPTIONS, T2I_MODELS, INSTRUCTION, IMAGE_OPTIONS
import spaces #[uncomment to use ZeroGPU]
import transformers
import gspread
import asyncio
from datetime import datetime
CLIP_MODEL = "ViT-H-14"
PRETRAINED_CLIP = "laion2b_s32b_b79k"
default_t2i_model = "black-forest-labs/FLUX.1-dev" # "black-forest-labs/FLUX.1-dev"
default_llm_model = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" # "meta-llama/Meta-Llama-3-8B-Instruct"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
NUM_IMAGES=4
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
clean_cache()
selected_pipe = setup_model(default_t2i_model, torch_dtype, device)
# clip_model, _, preprocess = open_clip.create_model_and_transforms(CLIP_MODEL, pretrained=PRETRAINED_CLIP, device=device)
llm_pipe = None
torch.cuda.empty_cache()
inverted_prompt = ""
VERBAL_MSG = "Please verbally describe key differences found in the image pair."
DEFAULT_SCENARIO = "Product advertisement"
METHODS = ["Method 1", "Method 2"]
MAX_ROUND = 5
# intermittent memory
counter1, counter2 = 1, 1
responses_memory = {}
assigned_scenarios = list(SCENARIOS.keys())[:2]
current_task1, current_task2 = METHODS # current task 1 (tab 1)
task1_success, task2_success = False, False
########################################################################################################
# Generating images with two methods
########################################################################################################
@spaces.GPU(duration=65)
def infer(
prompt,
negative_prompt="",
seed=42,
randomize_seed=True,
width=256,
height=256,
guidance_scale=5,
num_inference_steps=18,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
with torch.no_grad():
image = selected_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
return image
async def infer_async(prompt):
return infer(prompt)
# generate a batch of images in parallel
async def generate_batch(prompts):
tasks = [infer_async(p) for p in prompts]
images = await asyncio.gather(*tasks) # Run all in parallel
return images
@spaces.GPU
def call_llm_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0.7, top_p=0.9):
print(f"loading {default_llm_model}")
global llm_pipe
if not llm_pipe:
llm_pipe = transformers.pipeline("text-generation", model=default_llm_model, model_kwargs={"torch_dtype": torch_dtype}, device_map="auto")
messages = get_refine_msg(prmpt, num_prompts)
terminators = [
llm_pipe.tokenizer.eos_token_id,
llm_pipe.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
outputs = llm_pipe(
messages,
max_new_tokens=max_tokens,
eos_token_id=terminators,
do_sample=True,
temperature=temperature,
top_p=top_p,
)
prompt_list = clean_response_gpt(outputs[0]["generated_text"][-1]["content"])
return prompt_list
def call_gpt_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0.7, top_p=0.9):
seed = random.randint(0, MAX_SEED)
client = init_gpt_api()
messages = get_refine_msg(prompt, num_prompts)
outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens, temperature, top_p)
prompt_list = clean_response_gpt(outputs)
return prompt_list
def refine_prompt(gallery_state, prompt):
modified_prompts = call_gpt_refine_prompt(prompt)
return modified_prompts
# eval(prompt, inverted_prompt, gallery_state, clip_model, preprocess)
@spaces.GPU(duration=100)
def invert_prompt(prompt, images, prompt_len=15, iter=1000, lr=0.1, batch_size=2):
text_params = {
"iter": iter,
"lr": lr,
"batch_size": batch_size,
"prompt_len": prompt_len,
"weight_decay": 0.1,
"prompt_bs": 1,
"loss_weight": 1.0,
"print_step": 100,
"clip_model": CLIP_MODEL,
"clip_pretrain": PRETRAINED_CLIP,
}
inverted_prompt = optimize_prompt(clip_model, preprocess, text_params, device, target_images=images, target_prompts=prompt)
# eval(prompt, learned_prompt, optimized_images, clip_model, preprocess)
# return learned_prompt
def eval(prompt, optimized_prompt, optimized_images, clip_model, preprocess):
torch.cuda.empty_cache()
tokenizer = open_clip.get_tokenizer(CLIP_MODEL)
images = [preprocess(i).unsqueeze(0) for i in optimized_images]
images = torch.concatenate(images).to(device)
with torch.no_grad():
image_feat = clip_model.encode_image(images)
text_feat = clip_model.encode_text(tokenizer([prompt]).to(device))
optimized_text_feat = clip_model.encode_text(tokenizer([optimized_prompt]).to(device))
image_feat /= image_feat.norm(dim=-1, keepdim=True)
text_feat /= text_feat.norm(dim=-1, keepdim=True)
optimized_text_feat /= optimized_text_feat.norm(dim=-1, keepdim=True)
similarity = text_feat.cpu().numpy() @ image_feat.cpu().numpy().T
similarity_optimized = optimized_text_feat.cpu().numpy() @ image_feat.cpu().numpy().T
########################################################################################################
# Button-related functions
########################################################################################################
def reset_gallery():
return []
def display_error_message(msg, duration=5):
gr.Warning(msg, duration=duration)
def display_info_message(msg, duration=5):
gr.Info(msg, duration=duration)
def switch_tab(active_tab):
print("switching tab")
if active_tab == "Task A":
return gr.Tabs(selected="Task B")
else:
return gr.Tabs(selected="Task A")
def set_user(participant):
global responses_memory
responses_memory[participant] = {METHODS[0]:{}, METHODS[1]:{}}
id = re.findall(r'\d+', participant)
if len(id) == 0 or int(id[0]) % 2 == 0: # name invalid, assign first half scenarios
assigned_scenarios = list(SCENARIOS.keys())[:2]
else:
assigned_scenarios = list(SCENARIOS.keys())[2:]
return assigned_scenarios[0]
def display_scenario(participant, choice):
# reset intermittent storage when scenario change
global counter1, counter2, responses_memory, current_task1, current_task2, task1_success, task2_success
task1_success, task2_success = False, False
counter1, counter2 = 1, 1
if check_participant(participant):
responses_memory[participant] = {METHODS[0]:{}, METHODS[1]:{}}
[current_task1, current_task2] = random.sample(METHODS, 2)
if current_task1 == METHODS[0]:
initial_images1 = IMAGES[choice]["baseline"]
initial_images2 = IMAGES[choice]["ours"]
else:
initial_images1 = IMAGES[choice]["ours"]
initial_images2 = IMAGES[choice]["baseline"]
res = {
scenario_content: SCENARIOS.get(choice, ""),
prompt: PROMPTS.get(choice, ""),
prompt1: "",
prompt2: "",
images_method1: initial_images1,
images_method2: initial_images2,
gallery_state1: initial_images1,
gallery_state2: initial_images2,
sim_radio1: None,
sim_radio2: None,
response1: VERBAL_MSG,
response2: VERBAL_MSG,
next_btn1: gr.update(interactive=False),
next_btn2: gr.update(interactive=False),
redesign_btn1: gr.update(interactive=True),
redesign_btn2: gr.update(interactive=True),
submit_btn1: gr.update(interactive=False),
submit_btn2: gr.update(interactive=False),
}
return res
def generate_image(participant, scenario, prompt, gallery_state, active_tab):
if not check_participant(participant): return [], []
global current_task1, current_task2
method = current_task1 if active_tab == "Task A" else current_task2
if method == METHODS[0]:
for i in range(NUM_IMAGES):
img = infer(prompt)
gallery_state.append(img)
yield gallery_state
else:
refined_prompts = refine_prompt(gallery_state, prompt)
for i in range(NUM_IMAGES):
img = infer(refined_prompts[i])
gallery_state.append(img)
yield gallery_state
def check_satisfaction(sim_radio, active_tab):
global counter1, counter2, current_task1, current_task2
method = current_task1 if active_tab == "Task A" else current_task2
counter = counter1 if method == METHODS[0] else counter2
fully_satisfied_option = ["Satisfied", "Very Satisfied"] # The value to trigger submit
enable_submit = sim_radio in fully_satisfied_option or counter >= MAX_ROUND
return gr.update(interactive=enable_submit), gr.update(interactive=(not enable_submit))
def check_participant(participant):
if participant == "":
display_error_message("Please fill your participant id!")
return False
return True
def check_evaluation(sim_radio, response):
if not sim_radio :
display_error_message("❌ Please fill all evaluations before change image or submit.")
return False
return True
def select_dislike(like_radio, images_method):
if like_radio == IMAGE_OPTIONS[0]:
return images_method[0]
elif like_radio == IMAGE_OPTIONS[1]:
return images_method[1]
elif like_radio == IMAGE_OPTIONS[2]:
return images_method[2]
elif like_radio == IMAGE_OPTIONS[3]:
return images_method[3]
else:
return None
def redesign(participant, scenario, prompt, sim_radio, response, images_method, active_tab):
global counter1, counter2, responses_memory, current_task1, current_task2
method = current_task1 if active_tab == "Task A" else current_task2
if check_evaluation(sim_radio, response) and check_participant(participant):
if method == METHODS[0]:
counter1 += 1
counter = counter1
else:
counter2 += 1
counter = counter2
responses_memory[participant][method][counter-1] = {}
responses_memory[participant][method][counter-1]["prompt"] = prompt
responses_memory[participant][method][counter-1]["sim_radio"] = sim_radio
responses_memory[participant][method][counter-1]["response"] = response
prompt_state = gr.update(visible=True)
next_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(visible=True, interactive=True)
redesign_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(interactive=True)
submit_state = gr.update(interactive=True) if counter >= MAX_ROUND else gr.update(interactive=False)
return [], None, VERBAL_MSG, prompt_state, next_state, redesign_state, submit_state
else:
return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()}
def show_message(selected_option):
if selected_option:
return "Click \"Redesign\" and revise your prompt to create images that may more closely match your expectations."
return ""
def save_response(participant, scenario, prompt, sim_radio, response, images_method, active_tab):
global current_task1, current_task2, counter1, counter2, responses_memory, task1_success, task2_success, assigned_scenarios
method = current_task1 if active_tab == "Task A" else current_task2
if check_evaluation(sim_radio, response) and check_participant(participant):
counter = counter1 if method == METHODS[0] else counter2
# image_paths = [save_image(img, "method", i) for i, img in enumerate(images_method)]
responses_memory[participant][method][counter] = {}
responses_memory[participant][method][counter]["prompt"] = prompt
responses_memory[participant][method][counter]["sim_radio"] = sim_radio
responses_memory[participant][method][counter]["response"] = response
prompt_state = gr.update(visible=False)
next_state = gr.update(visible=False, interactive=False)
submit_state = gr.update(interactive=False)
redesign_state = gr.update(interactive=False)
try:
gc = gspread.service_account(filename='credentials.json')
sheet = gc.open("DiverseGen-phase3").sheet1
for i, entry in responses_memory[participant][method].items():
sheet.append_row([participant, scenario, method, i, entry["prompt"], entry["sim_radio"], entry["response"]])
display_info_message("βœ… Your answer is saved!")
# reset counter and update success indicator
if method == METHODS[0]:
counter1 = 1
else:
counter2 = 1
if active_tab == "Task A":
task1_success = True
else:
task2_success = True
tabs = switch_tab(active_tab)
next_scenario = assigned_scenarios[1] if task1_success and task2_success else assigned_scenarios[0]
return [], [], None, VERBAL_MSG, prompt_state, next_state, redesign_state, submit_state, tabs, next_scenario
except Exception as e:
display_error_message(f"❌ Error saving response: {str(e)}")
return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()}
else:
return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()}
########################################################################################################
# Interface
########################################################################################################
css="""
#col-container {
margin: 0 auto;
max-width: 700px;
}
#col-container2 {
margin: 0 auto;
max-width: 1000px;
}
#col-container3 {
margin: 0 auto;
max-width: 300px;
}
#button-container {
display: flex;
justify-content: center; /* Centers the buttons horizontally */
}
#compact-row {
width:100%;
max-width: 1000px;
margin: 0px auto;
}
"""
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(" # πŸ“Œ **Diverse Text-to-Image Generation**")
with gr.Row():
participant = gr.Textbox(
label="πŸ§‘β€πŸ’Ό Participant ID", placeholder="Please enter you participant id"
)
scenario = gr.Dropdown(
choices=list(SCENARIOS.keys()),
# value=DEFAULT_SCENARIO,
value=None,
label="πŸ“Œ Scenario",
interactive=False,
)
scenario_content = gr.Textbox(
label="πŸ“– Background",
interactive=False,
# value=SCENARIOS[DEFAULT_SCENARIO]
)
prompt = gr.Textbox(
label="🎨 Prompt",
max_lines=1,
# value=PROMPTS[DEFAULT_SCENARIO],
interactive=False
)
active_tab = gr.State("Task A")
instruction = gr.Markdown(INSTRUCTION)
with gr.Tabs() as tabs:
with gr.TabItem("Task A", id="Task A") as task1_tab:
task1_tab.select(lambda: "Task A", outputs=[active_tab])
with gr.Column(elem_id="col-container"):
# gr.Markdown("### Step 2: This is the prompt to generate images, you may modify the prompt after first round evaluation")
with gr.Row():
prompt1 = gr.Textbox(
label="🎨 Revise Prompt",
max_lines=1,
placeholder="Enter your prompt",
# value=PROMPTS[DEFAULT_SCENARIO],
scale=4,
visible=False
)
next_btn1 = gr.Button("Generate", variant="primary", scale=1, interactive=False, visible=False)
with gr.Row(elem_id="compact-row"):
with gr.Column(elem_id="col-container"):
gallery_state1 = gr.State([])
images_method1 = gr.Gallery(show_label=False, columns=[4], rows=[1], height=420, elem_id="gallery")
with gr.Column(elem_id="col-container3"):
like_image1 = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload')
dislike_image1 = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload')
with gr.Column(elem_id="col-container2"):
gr.Markdown("### πŸ“ Evaluation")
sim_radio1 = gr.Radio(
OPTIONS,
label="How would you rate your satisfaction with the generated images, based on your expectations for the specified scenario?",
type="value",
elem_classes=["gradio-radio"]
)
like_radio1 = gr.Radio(
IMAGE_OPTIONS,
label="Select the image you are most satisfied.",
type="value",
elem_classes=["gradio-radio"]
)
dislike_radio1 = gr.Radio(
IMAGE_OPTIONS,
label="Select the image you are most unsatisfied.",
type="value",
elem_classes=["gradio-radio"]
)
response1 = gr.Textbox(
label="Verbally describe key differences found in the image pair.",
max_lines=1,
interactive=False,
container=False,
value=VERBAL_MSG
)
with gr.Row(elem_id="button-container"):
redesign_btn1 = gr.Button("🎨 Redesign", variant="primary", scale=0)
submit_btn1 = gr.Button("βœ… Submit", variant="primary", interactive=False, scale=0)
with gr.TabItem("Task B", id="Task B") as task2_tab:
task2_tab.select(lambda: "Task B", outputs=[active_tab])
with gr.Column(elem_id="col-container"):
# gr.Markdown("### Step 2: This is the prompt to generate images, you may modify the prompt after first round evaluation")
with gr.Row():
prompt2 = gr.Textbox(
label="🎨 Revise Prompt",
max_lines=1,
placeholder="Enter your prompt",
# value=PROMPTS[DEFAULT_SCENARIO],
scale=4,
visible=False
)
next_btn2 = gr.Button("Generate", variant="primary", scale=1, interactive=False, visible=False)
with gr.Row(elem_id="compact-row"):
with gr.Column(elem_id="col-container"):
gallery_state2 = gr.State(IMAGES[DEFAULT_SCENARIO]["ours"])
images_method2 = gr.Gallery(height=420, show_label=False, columns=[4], rows=[1], elem_id="gallery")
with gr.Column(elem_id="col-container3"):
like_image2 = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload')
dislike_image2 = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload')
with gr.Column(elem_id="col-container2"):
gr.Markdown("### πŸ“ Evaluation")
sim_radio2 = gr.Radio(
OPTIONS,
label="How would you rate your satisfaction with the generated images, based on your expectations for the specified scenario?",
type="value",
elem_classes=["gradio-radio"]
)
like_radio2 = gr.Radio(
IMAGE_OPTIONS,
label="Select the image you are most satisfied.",
type="value",
elem_classes=["gradio-radio"]
)
dislike_radio2 = gr.Radio(
IMAGE_OPTIONS,
label="Select the image you are most unsatisfied.",
type="value",
elem_classes=["gradio-radio"]
)
response2 = gr.Textbox(
label="Verbally describe key differences found in the image pair.",
max_lines=1,
interactive=False,
container=False,
value=VERBAL_MSG
)
with gr.Row(elem_id="button-container"):
redesign_btn2 = gr.Button("🎨 Redesign", variant="primary", scale=0)
submit_btn2 = gr.Button("βœ… Submit", variant="primary", interactive=False, scale=0)
########################################################################################################
# Button Function Setup
########################################################################################################
participant.change(fn=set_user, inputs=[participant], outputs=[scenario])
scenario.change(display_scenario, inputs=[participant, scenario], outputs=[scenario_content, prompt, prompt1, prompt2, images_method1, images_method2, gallery_state1, gallery_state2, sim_radio1, sim_radio2, response1, response2, next_btn1, next_btn2, redesign_btn1, redesign_btn2, submit_btn1, submit_btn2])
prompt1.change(fn=reset_gallery, inputs=[], outputs=[gallery_state1])
prompt2.change(fn=reset_gallery, inputs=[], outputs=[gallery_state2])
next_btn1.click(fn=generate_image, inputs=[participant, scenario, prompt1, gallery_state1, active_tab], outputs=[images_method1])
next_btn2.click(fn=generate_image, inputs=[participant, scenario, prompt2, gallery_state2, active_tab], outputs=[images_method2])
sim_radio1.change(fn=check_satisfaction, inputs=[sim_radio1, active_tab], outputs=[submit_btn1, redesign_btn1])
sim_radio2.change(fn=check_satisfaction, inputs=[sim_radio2, active_tab], outputs=[submit_btn2, redesign_btn2])
dislike_radio1.select(fn=select_dislike, inputs=[dislike_radio1, gallery_state1], outputs=[dislike_image1])
like_radio1.select(fn=select_dislike, inputs=[like_radio1, gallery_state1], outputs=[like_image1])
dislike_radio2.select(fn=select_dislike, inputs=[dislike_radio2, gallery_state2], outputs=[dislike_image2])
like_radio2.select(fn=select_dislike, inputs=[like_radio2, gallery_state2], outputs=[like_image2])
redesign_btn1.click(
fn=redesign,
inputs=[participant, scenario, prompt1, sim_radio1, response1, images_method1, active_tab],
outputs=[gallery_state1, sim_radio1, response1, prompt1, next_btn1, redesign_btn1, submit_btn1]
)
redesign_btn2.click(
fn=redesign,
inputs=[participant, scenario, prompt2, sim_radio2, response2, images_method2, active_tab],
outputs=[gallery_state2, sim_radio2, response2, prompt2, next_btn2, redesign_btn2, submit_btn2]
)
submit_btn1.click(fn=save_response,
inputs=[participant, scenario, prompt1, sim_radio1, response1, images_method1, active_tab],
outputs=[images_method1, gallery_state1, sim_radio1, prompt1, response1, next_btn1, redesign_btn1, submit_btn1, tabs, scenario])
submit_btn2.click(fn=save_response,
inputs=[participant, scenario, prompt2, sim_radio2, response2, images_method2, active_tab],
outputs=[images_method2, gallery_state2, sim_radio2, prompt2, response2, next_btn2, redesign_btn2, submit_btn2, tabs, scenario])
if __name__ == "__main__":
demo.launch()