POET / utils.py
xh365's picture
update refine policy
64dd181
raw
history blame
12.2 kB
import re
from diffusers import DiffusionPipeline, FluxPipeline
from live_preview_helpers import FLUXPipelineWithIntermediateOutputs
import torch
import os
from openai import OpenAI
import subprocess
import spaces #[uncomment to use ZeroGPU]
import base64
from io import BytesIO
T2I_MODELS = {
"Stable Diffusion v2.1": "stabilityai/stable-diffusion-2-1",
"SDXL-Turbo": "stabilityai/sdxl-turbo",
"Stable Diffusion v3.5-medium": "stabilityai/stable-diffusion-3.5-medium", # Default
"Flux.1-dev": "black-forest-labs/FLUX.1-dev",
}
SCENARIOS = {
"Product advertisement": "You are designing an advertising campaign for a new line of coffee machines. To ensure the campaign resonates with a wider audience, you use generative models to create marketing images that showcase a variety of users interacting with the product.",
"Tourist promotion": "You are creating a travel campaign to attract a variety of visitors to a specific destination. To make the promotional materials more engaging, you use generative models to design posters that highlight a broader array of experiences.",
"Fictional character generation": "You are creating a superhero video game that’s fun and relatable to a range of users. You decide to use generative models to help visualize a new character.",
"Interior Design": "You are helping design the furniture layout for a model one-bedroom rental apartment. To make the apartment appealing to different potential tenants, you try to visualize different furniture placements before setting everything up.",
}
PROMPTS = {
"Product advertisement": "Design an advertisement image showcasing a range of users operating coffee machines.",
"Tourist promotion": "Design a promotional poster to attract a variety of visitors to a tourist destination.",
"Fictional character generation": "Design a video game superhero character that is relatable. ",
"Interior Design": "Design an apartment that’s appealing to potential tenants.",
}
IMAGES = {
"Product advertisement": {"baseline": ["images/scenario1_base1.png","images/scenario1_base2.png","images/scenario1_base3.png","images/scenario1_base4.png"],
"ours": ["images/scenario1_ours1.png","images/scenario1_ours2.png","images/scenario1_ours3.png","images/scenario1_ours4.png"]},
"Tourist promotion": {"baseline": ["images/scenario2_base1.png","images/scenario2_base2.png","images/scenario2_base3.png","images/scenario2_base4.png"],
"ours": ["images/scenario2_ours1.png","images/scenario2_ours2.png","images/scenario2_ours3.png","images/scenario2_ours4.png"]},
"Fictional character generation": {"baseline": ["images/scenario3_base1.png","images/scenario3_base2.png","images/scenario3_base3.png","images/scenario3_base4.png"],
"ours": ["images/scenario3_ours1.png","images/scenario3_ours2.png","images/scenario3_ours3.png","images/scenario3_ours4.png"]},
"Interior Design": {"baseline": ["images/scenario4_base1.png","images/scenario4_base2.png","images/scenario3_base4.png","images/scenario4_base4.png"],
"ours": ["images/scenario4_ours1.png","images/scenario4_ours2.png","images/scenario4_ours3.png","images/scenario4_ours4.png"]},
}
OPTIONS = ["Very Unsatisfied", "Unsatisfied", "Slightly Unsatisfied", "Neutral", "Slightly Satisfied", "Satisfied", "Very Satisfied"]
IMAGE_OPTIONS = ["First Image", "Second Image", "Third Image", "Fourth Image"]
INSTRUCTION = "📌 **Instruction**: Now, we want to understand your satisfaction with the images generated. <br /> 📌 Step 1: You will start from evaluating the following images based on the given prompt. <br /> 📌 Step 2: Then please modify the prompt according to your expectations for the given scenario background, and answer the evaluation question **until you are satisfied** with at least one of the images generated below. If you are not satisfied with the generated images, you can repeatedly modify the prompts for at most **5 times**."
def clean_cache():
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
if torch.cuda.is_available():
torch.cuda.empty_cache()
def setup_model(t2i_model_repo, torch_dtype, device):
if t2i_model_repo == "stabilityai/sdxl-turbo" or t2i_model_repo == "stabilityai/stable-diffusion-3.5-medium" or t2i_model_repo == "stabilityai/stable-diffusion-2-1":
pipe = DiffusionPipeline.from_pretrained(t2i_model_repo, torch_dtype=torch_dtype).to(device)
elif t2i_model_repo == "black-forest-labs/FLUX.1-dev":
# pipe = FluxPipeline.from_pretrained(t2i_model_repo, torch_dtype=torch_dtype).to(device)
pipe = FLUXPipelineWithIntermediateOutputs.from_pretrained(t2i_model_repo, torch_dtype=torch_dtype).to(device)
torch.cuda.empty_cache()
return pipe
def init_gpt_api():
return OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
def call_gpt_api(messages, client, model, seed, max_tokens, temperature, top_p):
completion = client.chat.completions.create(
model=model,
messages=messages,
seed=seed,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
return completion.choices[0].message.content
def clean_response_gpt(res: str):
prompts = re.findall(r'\d+\.\s"?(.*?)"?(?=\n|$)', res)
return prompts
def get_refine_msg(prompt, num_prompts):
messages = [{"role": "system", "content": f"You are a helpful, respectful and precise assistant. You will be asked to generate {num_prompts} refined prompts. Only respond with those refined prompts"}]
message = f"""Given a prompt, modify the prompt for me to explore variations in subject attributes, actions, and contextual details, while retaining the semantic consistency of the original description.
Follow the following refinement instruction:
1. Subject: refine broad terms into specific subsets, focusing on but not restricted on ethinity, gender, age of human.
2. Object: modify the brand, color of object(s) only if it's not specified in the prompt.
3. Setting: add details to the background environment, such as change of temporal or spatial details (e.g., day to night, indoor to outdoor).
4. Action: add more details to the action or specify the object or goal of the action.
For example, given this prompt: a person is drinking a coffee in a coffee shop, the refined prompts could be:
'an elderly woman is drinking a coffee in a coffee shop' (subject adjective)
'an asian young woman is drinking a coffee in a coffee shop' (subject adjective)
'a young woman is drinking a hot coffee with her left hand in a coffee shop' (action details)
'a woman is drinking a coffee in an outdoor coffee shop in the garden' (setting details)
If there is no human in the sentence, you do not need to add person intentionally.
If you use adjectives, they should be visual. So don't use something like 'interesting'. Please also vary the number of modifications but do not change the number of subjects/objects that have been specified in the prompt. Remember not to change the predefined concepts that have been specified in the prompt. e.g. don't change a boy to several boys.
Can you give me {num_prompts} modified prompts for the prompt '{prompt}' please."""
messages.append({"role": "user", "content": f"{message}"})
return messages
def encode_image(image):
buffered = BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def get_personalize_message(prompt, history_prompts, history_feedback, like_image, dislike_image):
messages = [{"role": "system", "content": f"You will act as a prompt optimization assistant that helps refine an original prompt based on user feedback over multiple rounds of image generation. The goal is to dynamically adjust the prompt to better align with user preferences while preserving the original intent."}]
message = f"""The process consists of a maximum of 5 rounds.
Users start with an initial prompt and generate 4 images per round. After reviewing the images, users will modify the prompt based on their preferences. Then we will generate new images based on the modified prompt.
Users will rate the generated images on a scale from ["Very Unsatisfied", "Unsatisfied", "Slightly Unsatisfied", "Neutral", "Slightly Satisfied", "Satisfied", "Very Satisfied"], indicating how satisfied they are with the results.
Your task is to analyze the sequence of modified prompts and corresponding ratings to refine the prompt dynamically, ensuring improved results in the next rounds. For each new round, you should:
Incorporate the user's modifications: Use the latest user-revised prompt as a reference but retain essential details from previous rounds if they contributed positively.
Analyze user ratings:
If the rating is high ("Satisfied", "Very Satisfied") → Maintain key aspects of the most recent prompt since it aligns well with user preferences.
If the rating is medium ("Slightly Unsatisfied", "Neutral", "Slightly Satisfied") → Adjust minor details that could improve alignment with the user’s preferences, considering the changes from previous rounds.
If the rating is low ("Very Unsatisfied", "Unsatisfied") → Identify aspects that might be causing dissatisfaction (e.g., unwanted elements, style mismatch) and rework the prompt while keeping the user’s core intent intact.
Refine the prompt intelligently and ensure the following:
- The updated prompt reflects user feedback without unnecessary repetition.
- Unwanted elements (if any) from previous rounds are removed.
- Preferred elements are retained and enhanced.
- The modifications remain subtle but progressive to ensure smooth refinement over multiple rounds.
- Maintain coherence: Avoid drastic changes that might deviate from the original intent unless the user explicitly requests them.
Now given the following revised prompts and ratings from user\n:
"""
for his_prompt, feedback in zip(history_prompts, history_feedback):
message += f"Revised prompt: {his_prompt}; Rating: {feedback}\n"
message += f"\nWe also provide the user's preferred image during this process as the first image provided and the disliked image as the second image\n"
message += "Now, please optimize current prompt and only output the modified prompt: '{prompt}'"""
messages.append({
"role": "user",
"content": [
{"type": "text", "text": f"{message}"},
],
})
if like_image:
like_image_base64 = encode_image(like_image)
messages[-1]["content"].append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{like_image_base64}",
},
})
if dislike_image:
dislike_image_base64 = encode_image(dislike_image)
messages[-1]["content"].append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{dislike_image_base64}",
},
})
print(messages)
return messages
@spaces.GPU
def call_llm_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0.7, top_p=0.9):
print(f"loading {default_llm_model}")
global llm_pipe
if not llm_pipe:
llm_pipe = transformers.pipeline("text-generation", model=default_llm_model, model_kwargs={"torch_dtype": torch_dtype}, device_map="auto")
messages = get_refine_msg(prmpt, num_prompts)
terminators = [
llm_pipe.tokenizer.eos_token_id,
llm_pipe.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
outputs = llm_pipe(
messages,
max_new_tokens=max_tokens,
eos_token_id=terminators,
do_sample=True,
temperature=temperature,
top_p=top_p,
)
prompt_list = clean_response_gpt(outputs[0]["generated_text"][-1]["content"])
return prompt_list