update policy
Browse files- __pycache__/live_preview_helpers.cpython-310.pyc +0 -0
- __pycache__/optim_utils.cpython-310.pyc +0 -0
- __pycache__/utils.cpython-310.pyc +0 -0
- app.py +15 -4
- utils.py +41 -33
__pycache__/live_preview_helpers.cpython-310.pyc
CHANGED
|
Binary files a/__pycache__/live_preview_helpers.cpython-310.pyc and b/__pycache__/live_preview_helpers.cpython-310.pyc differ
|
|
|
__pycache__/optim_utils.cpython-310.pyc
CHANGED
|
Binary files a/__pycache__/optim_utils.cpython-310.pyc and b/__pycache__/optim_utils.cpython-310.pyc differ
|
|
|
__pycache__/utils.cpython-310.pyc
CHANGED
|
Binary files a/__pycache__/utils.cpython-310.pyc and b/__pycache__/utils.cpython-310.pyc differ
|
|
|
app.py
CHANGED
|
@@ -7,7 +7,7 @@ import torch
|
|
| 7 |
import re
|
| 8 |
import open_clip
|
| 9 |
from optim_utils import optimize_prompt
|
| 10 |
-
from utils import clean_response_gpt, setup_model, init_gpt_api, call_gpt_api, get_refine_msg, clean_cache, get_personalize_message
|
| 11 |
from utils import SCENARIOS, PROMPTS, IMAGES, OPTIONS, T2I_MODELS, INSTRUCTION, IMAGE_OPTIONS
|
| 12 |
import spaces #[uncomment to use ZeroGPU]
|
| 13 |
import transformers
|
|
@@ -108,7 +108,6 @@ def personalize_prompt(prompt, history, feedback, like_image, dislike_image):
|
|
| 108 |
client = init_gpt_api()
|
| 109 |
messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image)
|
| 110 |
outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9)
|
| 111 |
-
print(outputs)
|
| 112 |
# prompt_list = clean_response_gpt(outputs)
|
| 113 |
# print(prompt_list)
|
| 114 |
return outputs
|
|
@@ -203,6 +202,10 @@ def display_scenario(participant, choice):
|
|
| 203 |
prompt2: gr.update(value=PROMPTS.get(choice, ""), interactive=False),
|
| 204 |
images_method1: initial_images1,
|
| 205 |
images_method2: initial_images2,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
history_images1: [],
|
| 207 |
history_images2: [],
|
| 208 |
next_btn1: gr.update(interactive=False),
|
|
@@ -221,9 +224,16 @@ def generate_image(participant, scenario, prompt, active_tab, like_image, dislik
|
|
| 221 |
|
| 222 |
history_prompts = [v["prompt"] for v in responses_memory[participant][method].values()]
|
| 223 |
feedback = [v["sim_radio"] for v in responses_memory[participant][method].values()]
|
| 224 |
-
|
| 225 |
personalized_prompt = personalize_prompt(prompt, history_prompts, feedback, like_image, dislike_image)
|
| 226 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
gallery_images = []
|
| 228 |
if method == METHODS[0]:
|
| 229 |
for i in range(NUM_IMAGES):
|
|
@@ -491,7 +501,8 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "
|
|
| 491 |
participant.change(fn=set_user, inputs=[participant], outputs=[scenario])
|
| 492 |
scenario.change(display_scenario,
|
| 493 |
inputs=[participant, scenario],
|
| 494 |
-
outputs=[scenario_content, prompt1, prompt2, images_method1, images_method2, history_images1, history_images2, next_btn1, next_btn2, redesign_btn1, redesign_btn2, submit_btn1, submit_btn2])
|
|
|
|
| 495 |
# prompt1.change(fn=reset_gallery, inputs=[], outputs=[gallery_state1])
|
| 496 |
# prompt2.change(fn=reset_gallery, inputs=[], outputs=[gallery_state2])
|
| 497 |
next_btn1.click(fn=generate_image, inputs=[participant, scenario, prompt1, active_tab, like_image1, dislike_image1], outputs=[images_method1])
|
|
|
|
| 7 |
import re
|
| 8 |
import open_clip
|
| 9 |
from optim_utils import optimize_prompt
|
| 10 |
+
from utils import clean_response_gpt, setup_model, init_gpt_api, call_gpt_api, get_refine_msg, clean_cache, get_personalize_message, clean_refined_prompt_response_gpt
|
| 11 |
from utils import SCENARIOS, PROMPTS, IMAGES, OPTIONS, T2I_MODELS, INSTRUCTION, IMAGE_OPTIONS
|
| 12 |
import spaces #[uncomment to use ZeroGPU]
|
| 13 |
import transformers
|
|
|
|
| 108 |
client = init_gpt_api()
|
| 109 |
messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image)
|
| 110 |
outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9)
|
|
|
|
| 111 |
# prompt_list = clean_response_gpt(outputs)
|
| 112 |
# print(prompt_list)
|
| 113 |
return outputs
|
|
|
|
| 202 |
prompt2: gr.update(value=PROMPTS.get(choice, ""), interactive=False),
|
| 203 |
images_method1: initial_images1,
|
| 204 |
images_method2: initial_images2,
|
| 205 |
+
like_image1: None,
|
| 206 |
+
dislike_image1: None,
|
| 207 |
+
like_image2: None,
|
| 208 |
+
dislike_image2: None,
|
| 209 |
history_images1: [],
|
| 210 |
history_images2: [],
|
| 211 |
next_btn1: gr.update(interactive=False),
|
|
|
|
| 224 |
|
| 225 |
history_prompts = [v["prompt"] for v in responses_memory[participant][method].values()]
|
| 226 |
feedback = [v["sim_radio"] for v in responses_memory[participant][method].values()]
|
| 227 |
+
|
| 228 |
personalized_prompt = personalize_prompt(prompt, history_prompts, feedback, like_image, dislike_image)
|
| 229 |
|
| 230 |
+
personalized_prompt = clean_refined_prompt_response_gpt(personalized_prompt)
|
| 231 |
+
print(f"Personalized prompt: {personalized_prompt}, {type(personalized_prompt)}")
|
| 232 |
+
|
| 233 |
+
if "I'm sorry, I can't assist with" in personalized_prompt:
|
| 234 |
+
print("error in gpt...")
|
| 235 |
+
personalized_prompt = prompt
|
| 236 |
+
|
| 237 |
gallery_images = []
|
| 238 |
if method == METHODS[0]:
|
| 239 |
for i in range(NUM_IMAGES):
|
|
|
|
| 501 |
participant.change(fn=set_user, inputs=[participant], outputs=[scenario])
|
| 502 |
scenario.change(display_scenario,
|
| 503 |
inputs=[participant, scenario],
|
| 504 |
+
outputs=[scenario_content, prompt1, prompt2, images_method1, images_method2, like_image1, dislike_image1, like_image2, dislike_image2, history_images1, history_images2, next_btn1, next_btn2, redesign_btn1, redesign_btn2, submit_btn1, submit_btn2])
|
| 505 |
+
|
| 506 |
# prompt1.change(fn=reset_gallery, inputs=[], outputs=[gallery_state1])
|
| 507 |
# prompt2.change(fn=reset_gallery, inputs=[], outputs=[gallery_state2])
|
| 508 |
next_btn1.click(fn=generate_image, inputs=[participant, scenario, prompt1, active_tab, like_image1, dislike_image1], outputs=[images_method1])
|
utils.py
CHANGED
|
@@ -78,6 +78,16 @@ def clean_response_gpt(res: str):
|
|
| 78 |
return prompts
|
| 79 |
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
def get_refine_msg(prompt, num_prompts):
|
| 82 |
messages = [{"role": "system", "content": f"You are a helpful, respectful and precise assistant. You will be asked to generate {num_prompts} refined prompts. Only respond with those refined prompts"}]
|
| 83 |
|
|
@@ -108,37 +118,37 @@ def encode_image(image):
|
|
| 108 |
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 109 |
|
| 110 |
def get_personalize_message(prompt, history_prompts, history_feedback, like_image, dislike_image):
|
| 111 |
-
messages = [
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
-
|
| 128 |
-
-
|
| 129 |
-
-
|
| 130 |
-
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
"""
|
| 135 |
-
|
|
|
|
| 136 |
for his_prompt, feedback in zip(history_prompts, history_feedback):
|
| 137 |
-
message += f"
|
|
|
|
|
|
|
| 138 |
|
| 139 |
-
message += f"\nWe also provide the user's preferred image during this process as the first image provided and the disliked image as the second image\n"
|
| 140 |
-
message += "Now, please optimize current prompt and only output the modified prompt: '{prompt}'"""
|
| 141 |
-
|
| 142 |
messages.append({
|
| 143 |
"role": "user",
|
| 144 |
"content": [
|
|
@@ -150,7 +160,7 @@ def get_personalize_message(prompt, history_prompts, history_feedback, like_imag
|
|
| 150 |
messages[-1]["content"].append({
|
| 151 |
"type": "image_url",
|
| 152 |
"image_url": {
|
| 153 |
-
"url": f"data:image/
|
| 154 |
},
|
| 155 |
})
|
| 156 |
if dislike_image:
|
|
@@ -158,11 +168,9 @@ def get_personalize_message(prompt, history_prompts, history_feedback, like_imag
|
|
| 158 |
messages[-1]["content"].append({
|
| 159 |
"type": "image_url",
|
| 160 |
"image_url": {
|
| 161 |
-
"url": f"data:image/
|
| 162 |
},
|
| 163 |
})
|
| 164 |
-
|
| 165 |
-
print(messages)
|
| 166 |
|
| 167 |
return messages
|
| 168 |
|
|
|
|
| 78 |
return prompts
|
| 79 |
|
| 80 |
|
| 81 |
+
def clean_refined_prompt_response_gpt(res: str):
|
| 82 |
+
# Using regex to extract the refined prompt
|
| 83 |
+
match = re.search(r"\*\*Refined Prompt:\*\*\n\n(.+)", res, re.DOTALL)
|
| 84 |
+
if match:
|
| 85 |
+
refined_prompt = match.group(1).strip()
|
| 86 |
+
else:
|
| 87 |
+
refined_prompt = res.strip() # Fallback: Use full text if no match found
|
| 88 |
+
return refined_prompt
|
| 89 |
+
|
| 90 |
+
|
| 91 |
def get_refine_msg(prompt, num_prompts):
|
| 92 |
messages = [{"role": "system", "content": f"You are a helpful, respectful and precise assistant. You will be asked to generate {num_prompts} refined prompts. Only respond with those refined prompts"}]
|
| 93 |
|
|
|
|
| 118 |
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 119 |
|
| 120 |
def get_personalize_message(prompt, history_prompts, history_feedback, like_image, dislike_image):
|
| 121 |
+
messages = [
|
| 122 |
+
{"role": "system", "content": f"You are a prompt refinement assistant. Your task is to improve a user’s text prompt based on their prompt revision history, satisfaction ratings, and preferences inferred from selected images. Your goal is to refine the prompt while maintaining its original meaning, improving clarity, specificity, and alignment with user preferences."}
|
| 123 |
+
]
|
| 124 |
+
|
| 125 |
+
message = f"""The refinement should preserve the core meaning of the current prompt while improving its clarity, specificity, and style based on user feedback.
|
| 126 |
+
|
| 127 |
+
### **Input Format:**
|
| 128 |
+
1. **Prompt History**: A list of previously revised prompts and their corresponding satisfaction ratings.
|
| 129 |
+
2. **Rating Scale**: Very Unsatisfied, Unsatisfied, Slightly Unsatisfied, Neutral, Slightly Satisfied, Satisfied, Very Satisfied
|
| 130 |
+
3. **User-Selected Image Preferences**:
|
| 131 |
+
- **Preferred Image**: The image the user found most satisfactory.
|
| 132 |
+
- **Disliked Image**: The image the user found least satisfactory.
|
| 133 |
+
*Note: These images are for reference only and should be used to infer stylistic preferences rather than directly modifying prompt content.*
|
| 134 |
+
4. **Current Prompt**: The latest prompt from the user, which requires refinement.
|
| 135 |
+
|
| 136 |
+
### **Refinement Guidelines:**
|
| 137 |
+
- Identify and retain/expand patterns/elements in past revisions and correlate them with satisfaction ratings.
|
| 138 |
+
- Avoid or adjust features that led to lower ratings.
|
| 139 |
+
- Improve clarity, specificity, and descriptive quality while ensuring the prompt remains faithful to its current prompt's meaning.
|
| 140 |
+
- The preferred image reflects desirable attributes; the disliked image indicates elements to avoid. Use these for reference but **do not describe them.**
|
| 141 |
+
- Output only the refined prompt, no explanations, disclaimers, or formatting.
|
| 142 |
+
|
| 143 |
+
The first provided image is the user's preferred image, and the second is the disliked image.
|
| 144 |
+
Now, refine the following current prompt based on the given user history and preferences:\n"""
|
| 145 |
+
|
| 146 |
+
message += "Prompt History\n"
|
| 147 |
for his_prompt, feedback in zip(history_prompts, history_feedback):
|
| 148 |
+
message += f"{his_prompt}: {feedback}\n"
|
| 149 |
+
|
| 150 |
+
message += f"Current Prompt: '{prompt}'\n Refined Prompt:"
|
| 151 |
|
|
|
|
|
|
|
|
|
|
| 152 |
messages.append({
|
| 153 |
"role": "user",
|
| 154 |
"content": [
|
|
|
|
| 160 |
messages[-1]["content"].append({
|
| 161 |
"type": "image_url",
|
| 162 |
"image_url": {
|
| 163 |
+
"url": f"data:image/png;base64,{like_image_base64}",
|
| 164 |
},
|
| 165 |
})
|
| 166 |
if dislike_image:
|
|
|
|
| 168 |
messages[-1]["content"].append({
|
| 169 |
"type": "image_url",
|
| 170 |
"image_url": {
|
| 171 |
+
"url": f"data:image/png;base64,{dislike_image_base64}",
|
| 172 |
},
|
| 173 |
})
|
|
|
|
|
|
|
| 174 |
|
| 175 |
return messages
|
| 176 |
|