xh365 commited on
Commit
185470f
·
verified ·
1 Parent(s): bdd3f68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +184 -462
app.py CHANGED
@@ -1,60 +1,53 @@
 
1
  import gradio as gr
2
- from gradio.themes.base import Base
3
  import numpy as np
4
  import random
5
  import spaces
6
  import torch
7
  import re
8
- import open_clip
9
- from optim_utils import optimize_prompt
10
- from utils import clean_response_gpt, setup_model, init_gpt_api, call_gpt_api, get_refine_msg, clean_cache, get_personalize_message, clean_refined_prompt_response_gpt
11
- from utils import SCENARIOS, PROMPTS, IMAGES, OPTIONS, T2I_MODELS, INSTRUCTION, IMAGE_OPTIONS
12
- import spaces #[uncomment to use ZeroGPU]
13
  import transformers
14
- import gspread
15
- from googleapiclient.discovery import build
16
- from googleapiclient.http import MediaFileUpload
17
- from googleapiclient.errors import HttpError
18
- from google.oauth2.service_account import Credentials
19
-
20
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  CLIP_MODEL = "ViT-H-14"
22
  PRETRAINED_CLIP = "laion2b_s32b_b79k"
23
- default_t2i_model = "black-forest-labs/FLUX.1-dev" # "black-forest-labs/FLUX.1-dev"
24
- default_llm_model = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" # "meta-llama/Meta-Llama-3-8B-Instruct"
25
  MAX_SEED = np.iinfo(np.int32).max
26
  MAX_IMAGE_SIZE = 1024
27
- NUM_IMAGES=4
 
28
 
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
31
- clean_cache()
32
 
33
  selected_pipe = setup_model(default_t2i_model, torch_dtype, device)
34
- # clip_model, _, preprocess = open_clip.create_model_and_transforms(CLIP_MODEL, pretrained=PRETRAINED_CLIP, device=device)
35
  llm_pipe = None
36
  torch.cuda.empty_cache()
37
  inverted_prompt = ""
38
 
39
  VERBAL_MSG = "Please explain your rating of satisfaction in few words or sentences."
40
- DEFAULT_SCENARIO = "Product advertisement"
41
- METHODS = ["Baseline", "Experimental"]
42
- MAX_ROUND = 5
43
-
44
- counter1, counter2 = 1, 1
45
- responses_memory = {}
46
- assigned_scenarios = list(SCENARIOS.keys())[:2]
47
- current_task1, current_task2 = METHODS # current task 1 (tab 1)
48
- task1_success, task2_success = False, False
49
- enable_submit1, enable_submit2 = False, False
50
- scopes = ['https://www.googleapis.com/auth/spreadsheets', 'https://www.googleapis.com/auth/drive']
51
-
52
-
53
- ########################################################################################################
54
- # Generating images with two methods
55
- ########################################################################################################
56
 
 
 
 
 
57
 
 
 
 
58
  @spaces.GPU(duration=65)
59
  def infer(
60
  prompt,
@@ -92,38 +85,16 @@ def call_gpt_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0
92
  prompt_list = clean_response_gpt(outputs)
93
  return prompt_list
94
 
95
- @spaces.GPU(duration=100)
96
- def invert_prompt(prompt, images, prompt_len=15, iter=1000, lr=0.1, batch_size=2):
97
- text_params = {
98
- "iter": iter,
99
- "lr": lr,
100
- "batch_size": batch_size,
101
- "prompt_len": prompt_len,
102
- "weight_decay": 0.1,
103
- "prompt_bs": 1,
104
- "loss_weight": 1.0,
105
- "print_step": 100,
106
- "clip_model": CLIP_MODEL,
107
- "clip_pretrain": PRETRAINED_CLIP,
108
- }
109
- inverted_prompt = optimize_prompt(clip_model, preprocess, text_params, device, target_images=images, target_prompts=prompt)
110
-
111
- # eval(prompt, learned_prompt, optimized_images, clip_model, preprocess)
112
- # return learned_prompt
113
-
114
  def personalize_prompt(prompt, history, feedback, like_image, dislike_image):
115
  seed = random.randint(0, MAX_SEED)
116
  client = init_gpt_api()
117
  messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image)
118
  outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9)
119
- # prompt_list = clean_response_gpt(outputs)
120
- # print(prompt_list)
121
  return outputs
122
 
123
- ########################################################################################################
124
- # Button-related functions
125
- ########################################################################################################
126
-
127
  def reset_gallery():
128
  return []
129
 
@@ -133,35 +104,12 @@ def display_error_message(msg, duration=5):
133
  def display_info_message(msg, duration=5):
134
  gr.Info(msg, duration=duration)
135
 
136
- def switch_tab(active_tab):
137
- if active_tab == "Task A":
138
- return gr.Tabs(selected="Task B")
139
- else:
140
- return gr.Tabs(selected="Task A")
141
-
142
- def check_satisfaction(sim_radio, active_tab):
143
- global enable_submit1, enable_submit2, counter1, counter2
144
- method = current_task1 if active_tab == "Task A" else current_task2
145
- enable_submit = enable_submit1 if method == METHODS[0] else enable_submit2
146
- counter = counter1 if method == METHODS[0] else counter2
147
-
148
- fully_satisfied_option = ["Satisfied", "Very Satisfied"] # The value to trigger submit
149
- if_submit = sim_radio in fully_satisfied_option or enable_submit or counter > MAX_ROUND
150
  return gr.update(interactive=if_submit)
151
 
152
- def check_participant(participant):
153
- if participant == "":
154
- display_error_message("Please fill your participant id!")
155
- return False
156
- return True
157
-
158
- def check_evaluation(sim_radio):
159
- if not sim_radio :
160
- display_error_message("❌ Please fill all evaluations before change image or submit.")
161
- return False
162
-
163
- return True
164
-
165
  def select_image(like_radio, images_method):
166
  if like_radio == IMAGE_OPTIONS[0]:
167
  return images_method[0][0]
@@ -174,249 +122,113 @@ def select_image(like_radio, images_method):
174
  else:
175
  return None
176
 
177
- def set_user(participant):
178
- global responses_memory, assigned_scenarios
179
-
180
- responses_memory[participant] = {METHODS[0]:{}, METHODS[1]:{}}
181
-
182
- # id = re.findall(r'\d+', participant)
183
- # if len(id) == 0 or int(id[0]) % 2 == 0: # name invalid, assign first half scenarios
184
- # assigned_scenarios = list(SCENARIOS.keys())[:2]
185
- # else:
186
- # assigned_scenarios = list(SCENARIOS.keys())[2:]
187
- # return assigned_scenarios[0]
188
-
189
- def assign_tasks(participant):
190
- id = re.findall(r'\d+', participant)
191
- if len(id) == 0 or int(id[0]) % 4 == 1 or int(id[0]) % 4 == 2:
192
- return METHODS[1], METHODS[0]
193
- else:
194
- return METHODS[0], METHODS[1]
195
-
196
- def display_scenario(participant, choice):
197
- # reset intermittent storage when scenario change
198
- global counter1, counter2, responses_memory, current_task1, current_task2, task1_success, task2_success, enable_submit1, enable_submit2
199
-
200
- task1_success, task2_success = False, False
201
- enable_submit1, enable_submit2 = False, False
202
- counter1, counter2 = 1, 1
203
-
204
- if check_participant(participant):
205
- responses_memory[participant] = {METHODS[0]:{}, METHODS[1]:{}}
206
-
207
- # [current_task1, current_task2] = random.sample(METHODS, 2)
208
- current_task1, current_task2 = assign_tasks(participant)
209
-
210
- if current_task1 == METHODS[0]:
211
- initial_images1 = IMAGES[choice]["baseline"]
212
- initial_images2 = IMAGES[choice]["ours"]
213
- else:
214
- initial_images1 = IMAGES[choice]["ours"]
215
- initial_images2 = IMAGES[choice]["baseline"]
216
-
217
- res = {
218
- scenario_content: SCENARIOS.get(choice, ""),
219
- prompt1: gr.update(value=PROMPTS.get(choice, ""), interactive=False),
220
- prompt2: gr.update(value=PROMPTS.get(choice, ""), interactive=False),
221
- images_method1: initial_images1,
222
- images_method2: initial_images2,
223
- like_image1: None,
224
- dislike_image1: None,
225
- like_image2: None,
226
- dislike_image2: None,
227
- history_images1: [],
228
- history_images2: [],
229
- example1.dataset: gr.update(samples=[], visible=False),
230
- example2.dataset: gr.update(samples=[], visible=False),
231
- next_btn1: gr.update(interactive=False),
232
- next_btn2: gr.update(interactive=False),
233
- redesign_btn1: gr.update(interactive=True),
234
- redesign_btn2: gr.update(interactive=True),
235
- submit_btn1: gr.update(interactive=False),
236
- submit_btn2: gr.update(interactive=False),
237
- }
238
- return res
239
-
240
- def generate_image(participant, scenario, prompt, active_tab, like_image, dislike_image):
241
- if not check_participant(participant): return [], []
242
- global current_task1, current_task2
243
- method = current_task1 if active_tab == "Task A" else current_task2
244
-
245
- history_prompts = [v["prompt"] for v in responses_memory[participant][method].values()]
246
- feedback = [v["sim_radio"] for v in responses_memory[participant][method].values()]
247
-
248
- personalized_prompt = personalize_prompt(prompt, history_prompts, feedback, like_image, dislike_image)
249
-
250
- personalized_prompt = clean_refined_prompt_response_gpt(personalized_prompt)
251
- print(f"Personalized prompt: {personalized_prompt}, {type(personalized_prompt)}")
252
-
253
- if "I'm sorry, I can't assist with" in personalized_prompt:
254
- print("error in gpt...")
255
- personalized_prompt = prompt
256
-
257
- gallery_images = []
258
- if method == METHODS[0]:
259
- for i in range(NUM_IMAGES):
260
- img = infer(personalized_prompt)
261
- gallery_images.append(img)
262
- yield gallery_images
263
- else:
264
- refined_prompts = call_gpt_refine_prompt(personalized_prompt)
265
- for i in range(NUM_IMAGES):
266
- img = infer(refined_prompts[i])
267
- gallery_images.append(img)
268
- yield gallery_images
269
 
270
- def save_response_to_sheet(participant, method, scenario, active_tab, round, like_image, dislike_image):
 
 
 
271
  global responses_memory
272
- gc = gspread.service_account(filename='credentials.json')
273
- sheet = gc.open("DiverseGen-phase3").sheet1
274
-
275
- entry = responses_memory[participant][method][round]
276
- print(entry)
277
- sheet.append_row([participant, scenario, f"{active_tab}, {method}", round, entry["prompt"], entry["sim_radio"], entry["response"], entry["satisfied_img"], entry["unsatisfied_img"]])
278
-
279
- # save images in google drive
280
- creds = Credentials.from_service_account_file('credentials.json',scopes=scopes)
281
- save_image(creds, like_image, dislike_image, f"{participant}_{scenario}_{active_tab}_{method}_round{round}")
282
-
283
- display_info_message("✅ Your answer is saved!")
284
-
285
- def redesign(participant, scenario, prompt, sim_radio, like_radio, dislike_radio, current_images, history_images, active_tab, like_image, dislike_image):
286
- global counter1, counter2, responses_memory, current_task1, current_task2, enable_submit1, enable_submit2
287
- method = current_task1 if active_tab == "Task A" else current_task2
288
-
289
- if check_evaluation(sim_radio) and check_participant(participant):
290
- counter = counter1 if method == METHODS[0] else counter2
291
- enable_submit = enable_submit1 if method == METHODS[0] else enable_submit2
292
-
293
- responses_memory[participant][method][counter] = {}
294
- responses_memory[participant][method][counter]["prompt"] = prompt
295
- responses_memory[participant][method][counter]["sim_radio"] = sim_radio
296
- responses_memory[participant][method][counter]["response"] = ""
297
- responses_memory[participant][method][counter]["satisfied_img"] = f"round {counter}, {like_radio}"
298
- responses_memory[participant][method][counter]["unsatisfied_img"] = f"round {counter}, {dislike_radio}"
299
-
300
- save_response_to_sheet(participant, method, scenario, active_tab, counter, like_image, dislike_image)
301
 
302
  enable_submit = True if sim_radio in ["Satisfied", "Very Satisfied"] or enable_submit else False
303
 
304
- history_prompts = [[v["prompt"]] for v in responses_memory[participant][method].values()]
305
- if not history_images:
306
  history_images = current_images
307
  elif current_images:
308
  history_images.extend(current_images)
309
  current_images = []
310
-
311
  examples_state = gr.update(samples=history_prompts, visible=True)
312
  prompt_state = gr.update(interactive=True)
313
  next_state = gr.update(visible=True, interactive=True)
314
  redesign_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(interactive=True)
315
  submit_state = gr.update(interactive=True) if counter >= MAX_ROUND or enable_submit else gr.update(interactive=False)
316
 
317
- # update counter
318
- if method == METHODS[0]:
319
- counter1 += 1
320
- enable_submit1 = enable_submit
321
- else:
322
- counter2 += 1
323
- enable_submit2 = enable_submit
324
 
325
  return None, None, None, current_images, history_images, examples_state, prompt_state, next_state, redesign_state, submit_state
326
  else:
327
- return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()}
328
-
329
- def save_image(creds, like_image, dislike_image, name):
330
- try:
331
- service = build("drive", "v3", credentials=creds)
332
- for image_path, suffix in zip([like_image, dislike_image], ["satisfied", "unsatisfied"]):
333
- filename = f"{name}_{suffix}"
334
- file_metadata = {"name": filename, "parents": ["1ru3-QbbzyVSk-1kBfVv4nhElFqYh3ITj"]}
335
- media = MediaFileUpload(image_path, mimetype="image/png")
336
- uploaded_file = service.files().create(body=file_metadata, media_body=media, fields="id").execute()
337
-
338
- except HttpError as error:
339
- print(f"An error occurred: {error}")
340
-
341
- def save_response(participant, scenario, prompt, sim_radio, like_radio, dislike_radio, like_image, dislike_image, active_tab):
342
- global current_task1, current_task2, scopes # not change
343
- global task1_success, task2_success, counter1, counter2, enable_submit1, enable_submit2, responses_memory, assigned_scenarios # will change
344
-
345
- method = current_task1 if active_tab == "Task A" else current_task2
346
- if check_evaluation(sim_radio) and check_participant(participant):
347
- counter = counter1 if method == METHODS[0] else counter2
348
-
349
- responses_memory[participant][method][counter] = {}
350
- responses_memory[participant][method][counter]["prompt"] = prompt
351
- responses_memory[participant][method][counter]["sim_radio"] = sim_radio
352
- responses_memory[participant][method][counter]["response"] = ""
353
- responses_memory[participant][method][counter]["satisfied_img"] = f"round {counter}, {like_radio}"
354
- responses_memory[participant][method][counter]["unsatisfied_img"] = f"round {counter}, {dislike_radio}"
355
-
356
- try:
357
- save_response_to_sheet(participant, method, scenario, active_tab, counter, like_image, dislike_image)
358
-
359
- # reset global variables
360
- if method == METHODS[0]:
361
- counter1 = 1
362
- enable_submit1 = False
363
- else:
364
- counter2 = 1
365
- enable_submit2 = False
366
- if active_tab == "Task A":
367
- task1_success = True
368
- else:
369
- task2_success = True
370
-
371
- # decide if change scenario
372
- # if scenario == assigned_scenarios[0]:
373
- # next_scenario = assigned_scenarios[1] if task1_success and task2_success else assigned_scenarios[0]
374
- # else:
375
- # if task1_success and task2_success:
376
- # display_info_message("You have finished all scenarios, thank you!")
377
- # next_scenario = assigned_scenarios[0]
378
- # else:
379
- # next_scenario = assigned_scenarios[1]
380
-
381
- # reset buttons
382
- prompt_state = gr.update(interactive=False)
383
- next_state = gr.update(visible=False, interactive=False)
384
- submit_state = gr.update(interactive=False)
385
- redesign_state = gr.update(interactive=False)
386
- tabs = switch_tab(active_tab)
387
-
388
- return None, None, None, prompt_state, next_state, redesign_state, submit_state, tabs
389
-
390
- except Exception as e:
391
- display_error_message(f"❌ Error saving response: {str(e)}")
392
- return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()}
393
  else:
394
- return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()}
395
-
396
 
397
- ########################################################################################################
398
- # Interface
399
- ########################################################################################################
400
 
401
- css="""
402
  #col-container {
403
  margin: 0 auto;
404
  max-width: 700px;
405
  }
406
-
407
  #col-container2 {
408
  margin: 0 auto;
409
  max-width: 1000px;
410
  }
411
-
412
  #col-container3 {
413
  margin: 0 0 auto auto;
414
  max-width: 300px;
415
  }
416
-
417
  #button-container {
418
  display: flex;
419
- justify-content: center; /* Centers the buttons horizontally */
420
  }
421
  #compact-row {
422
  width:100%;
@@ -427,180 +239,90 @@ css="""
427
 
428
  with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo:
429
  with gr.Column(elem_id="col-container"):
430
- gr.Markdown(" # 📌 **PAI-GEN**")
 
431
 
432
- with gr.Row():
433
- participant = gr.Textbox(
434
- label="🧑‍💼 Participant ID", placeholder="Please enter you participant id"
 
 
 
 
 
435
  )
436
- scenario = gr.Dropdown(
437
- choices=list(SCENARIOS.keys()),
438
- value=None,
439
- label="📌 Scenario",
440
- # interactive=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
  )
442
- scenario_content = gr.Textbox(
443
- label="📖 Background",
444
- interactive=False,
445
- )
446
- active_tab = gr.State("Task A")
447
- instruction = gr.Markdown(INSTRUCTION)
448
 
449
- with gr.Tabs() as tabs:
450
- with gr.TabItem("Task A", id="Task A") as task1_tab:
451
- task1_tab.select(lambda: "Task A", outputs=[active_tab])
452
- with gr.Row(elem_id="compact-row"):
453
- prompt1 = gr.Textbox(
454
- label="🎨 Revise Prompt",
455
- max_lines=5,
456
- placeholder="Enter your prompt",
457
- scale=4,
458
- visible=True,
459
- )
460
- next_btn1 = gr.Button("Generate", variant="primary", scale=1, interactive=False, visible=False)
461
-
462
- with gr.Row(elem_id="compact-row"):
463
- with gr.Column(elem_id="col-container"):
464
- images_method1 = gr.Gallery(label="Images", columns=[4], rows=[1], height=400, elem_id="gallery", format="png")
465
-
466
- with gr.Column(elem_id="col-container3"):
467
- like_image1 = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload', format="png", type="filepath")
468
- dislike_image1 = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload', format="png", type="filepath")
469
- with gr.Column(elem_id="col-container2"):
470
- gr.Markdown("### 📝 Evaluation")
471
- sim_radio1 = gr.Radio(
472
- OPTIONS,
473
- label="How would you rate your satisfaction with the generated images, based on your expectations for the specified scenario?",
474
- type="value",
475
- elem_classes=["gradio-radio"]
476
- )
477
- like_radio1 = gr.Radio(
478
- IMAGE_OPTIONS,
479
- label="Select your all-time favorite image that you fnd MOST satisfactory in this task. You may leave this section blank if you prefer the previous images.",
480
- type="value",
481
- elem_classes=["gradio-radio"]
482
- )
483
- dislike_radio1 = gr.Radio(
484
- IMAGE_OPTIONS,
485
- label="Select your all-time disliked image that you fnd LEAST satisfactory in this task. You may leave this section blank if you are more dislike previous images.",
486
- type="value",
487
- elem_classes=["gradio-radio"]
488
- )
489
-
490
- response1 = gr.Textbox(
491
- label="Verbally describe key differences found in the image pair.",
492
- max_lines=1,
493
- interactive=False,
494
- container=False,
495
- value=VERBAL_MSG
496
- )
497
-
498
- with gr.Column(elem_id="col-container2"):
499
- example1 = gr.Examples([['']], prompt1, label="Revised Prompt History", visible=False)
500
- history_images1 = gr.Gallery(label="History Images", columns=[4], rows=[1], elem_id="gallery", format="png")
501
-
502
- with gr.Row(elem_id="button-container"):
503
- redesign_btn1 = gr.Button("🎨 Redesign", variant="primary", scale=0)
504
- submit_btn1 = gr.Button("✅ Submit", variant="primary", interactive=False, scale=0)
505
-
506
-
507
- with gr.TabItem("Task B", id="Task B") as task2_tab:
508
- task2_tab.select(lambda: "Task B", outputs=[active_tab])
509
- with gr.Row(elem_id="compact-row"):
510
- prompt2 = gr.Textbox(
511
- label="🎨 Revise Prompt",
512
- max_lines=5,
513
- placeholder="Enter your prompt",
514
- scale=4,
515
- visible=True,
516
- )
517
- next_btn2 = gr.Button("Generate", variant="primary", scale=1, interactive=False, visible=False)
518
-
519
- with gr.Row(elem_id="compact-row"):
520
- with gr.Column(elem_id="col-container"):
521
- images_method2 = gr.Gallery(label="Images", columns=[4], rows=[1], height=200, elem_id="gallery", format="png")
522
-
523
- with gr.Column(elem_id="col-container3"):
524
- like_image2 = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload', format="png", type="filepath")
525
- dislike_image2 = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload', format="png", type="filepath")
526
-
527
- with gr.Column(elem_id="col-container2"):
528
- gr.Markdown("### 📝 Evaluation")
529
- sim_radio2 = gr.Radio(
530
- OPTIONS,
531
- label="How would you rate your satisfaction with the generated images, based on your expectations for the specified scenario?",
532
- type="value",
533
- elem_classes=["gradio-radio"]
534
- )
535
- like_radio2 = gr.Radio(
536
- IMAGE_OPTIONS,
537
- label="Select your all-time favorite image that you fnd MOST satisfactory in this task. You may leave this section blank if you prefer the previous images.",
538
- type="value",
539
- elem_classes=["gradio-radio"]
540
- )
541
- dislike_radio2 = gr.Radio(
542
- IMAGE_OPTIONS,
543
- label="Select your all-time disliked image that you fnd LEAST satisfactory in this task. You may leave this section blank if you are more dislike previous images.",
544
- type="value",
545
- elem_classes=["gradio-radio"]
546
- )
547
-
548
- response2 = gr.Textbox(
549
- label="Verbally describe key differences found in the image pair.",
550
- max_lines=1,
551
- interactive=False,
552
- container=False,
553
- value=VERBAL_MSG
554
- )
555
-
556
- with gr.Column(elem_id="col-container2"):
557
- example2 = gr.Examples([['']], prompt2, label="Revised Prompt History", visible=False)
558
- history_images2 = gr.Gallery(label="History Images", columns=[4], rows=[1], elem_id="gallery", format="png")
559
-
560
- with gr.Row(elem_id="button-container"):
561
- redesign_btn2 = gr.Button("🎨 Redesign", variant="primary", scale=0)
562
- submit_btn2 = gr.Button("✅ Submit", variant="primary", interactive=False, scale=0)
563
-
564
-
565
- ########################################################################################################
566
- # Button Function Setup
567
- ########################################################################################################
568
-
569
- # participant.change(fn=set_user, inputs=[participant], outputs=[scenario])
570
- participant.change(fn=set_user, inputs=[participant])
571
- scenario.change(display_scenario,
572
- inputs=[participant, scenario],
573
- outputs=[scenario_content, prompt1, prompt2, images_method1, images_method2, like_image1, dislike_image1, like_image2, dislike_image2, history_images1, history_images2, example1.dataset, example2.dataset, next_btn1, next_btn2, redesign_btn1, redesign_btn2, submit_btn1, submit_btn2])
574
-
575
- # prompt1.change(fn=reset_gallery, inputs=[], outputs=[gallery_state1])
576
- # prompt2.change(fn=reset_gallery, inputs=[], outputs=[gallery_state2])
577
- next_btn1.click(fn=generate_image, inputs=[participant, scenario, prompt1, active_tab, like_image1, dislike_image1], outputs=[images_method1]).success(lambda: [gr.update(interactive=False),gr.update(interactive=False)], outputs=[next_btn1, prompt1])
578
- next_btn2.click(fn=generate_image, inputs=[participant, scenario, prompt2, active_tab, like_image2, dislike_image2], outputs=[images_method2]).success(lambda: [gr.update(interactive=False),gr.update(interactive=False)], outputs=[next_btn2, prompt2])
579
- sim_radio1.change(fn=check_satisfaction, inputs=[sim_radio1, active_tab], outputs=[submit_btn1])
580
- sim_radio2.change(fn=check_satisfaction, inputs=[sim_radio2, active_tab], outputs=[submit_btn2])
581
- dislike_radio1.select(fn=select_image, inputs=[dislike_radio1, images_method1], outputs=[dislike_image1])
582
- like_radio1.select(fn=select_image, inputs=[like_radio1, images_method1], outputs=[like_image1])
583
- dislike_radio2.select(fn=select_image, inputs=[dislike_radio2, images_method2], outputs=[dislike_image2])
584
- like_radio2.select(fn=select_image, inputs=[like_radio2, images_method2], outputs=[like_image2])
585
-
586
- redesign_btn1.click(
587
- fn=redesign,
588
- inputs=[participant, scenario, prompt1, sim_radio1, like_radio1, dislike_radio1, images_method1, history_images1, active_tab, like_image1, dislike_image1],
589
- outputs=[sim_radio1, dislike_radio1, like_radio1, images_method1, history_images1, example1.dataset, prompt1, next_btn1, redesign_btn1, submit_btn1]
590
- )
591
- redesign_btn2.click(
592
- fn=redesign,
593
- inputs=[participant, scenario, prompt2, sim_radio2, like_radio2, dislike_radio2, images_method2, history_images2, active_tab, like_image2, dislike_image2],
594
- outputs=[sim_radio2, dislike_radio2, like_radio2, images_method2, history_images2, example2.dataset, prompt2, next_btn2, redesign_btn2, submit_btn2]
595
  )
596
- submit_btn1.click(fn=save_response,
597
- inputs=[participant, scenario, prompt1, sim_radio1, like_radio1, dislike_radio1, like_image1, dislike_image1, active_tab],
598
- outputs=[sim_radio1, dislike_radio1, like_radio1, prompt1, next_btn1, redesign_btn1, submit_btn1, tabs])
599
-
600
- submit_btn2.click(fn=save_response,
601
- inputs=[participant, scenario, prompt2, sim_radio2, like_radio2, dislike_radio2, like_image2, dislike_image2, active_tab],
602
- outputs=[sim_radio2, dislike_radio2, like_radio2, prompt2, next_btn2, redesign_btn2, submit_btn2, tabs])
603
 
 
 
 
 
 
604
 
605
  if __name__ == "__main__":
606
- demo.launch()
 
1
+
2
  import gradio as gr
 
3
  import numpy as np
4
  import random
5
  import spaces
6
  import torch
7
  import re
 
 
 
 
 
8
  import transformers
 
 
 
 
 
 
9
 
10
+ # Optional: keep these utilities if your pipeline depends on them
11
+ from optim_utils import optimize_prompt
12
+ from utils import (
13
+ clean_response_gpt, setup_model, init_gpt_api, call_gpt_api,
14
+ get_refine_msg, clean_cache, get_personalize_message,
15
+ clean_refined_prompt_response_gpt, IMAGES, OPTIONS, T2I_MODELS,
16
+ INSTRUCTION, IMAGE_OPTIONS, PROMPTS, SCENARIOS # some may be unused after simplification
17
+ )
18
+
19
+ # =========================
20
+ # Constants / Defaults
21
+ # =========================
22
  CLIP_MODEL = "ViT-H-14"
23
  PRETRAINED_CLIP = "laion2b_s32b_b79k"
24
+ default_t2i_model = "black-forest-labs/FLUX.1-dev"
25
+ default_llm_model = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
26
  MAX_SEED = np.iinfo(np.int32).max
27
  MAX_IMAGE_SIZE = 1024
28
+ NUM_IMAGES = 4
29
+ MAX_ROUND = 5
30
 
31
  device = "cuda" if torch.cuda.is_available() else "cpu"
32
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
33
+ clean_cache()
34
 
35
  selected_pipe = setup_model(default_t2i_model, torch_dtype, device)
 
36
  llm_pipe = None
37
  torch.cuda.empty_cache()
38
  inverted_prompt = ""
39
 
40
  VERBAL_MSG = "Please explain your rating of satisfaction in few words or sentences."
41
+ METHOD = "Experimental" # keep ONLY experimental
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ # Global states for a single-task, single-method flow
44
+ counter = 1
45
+ enable_submit = False
46
+ responses_memory = {METHOD: {}}
47
 
48
+ # =========================
49
+ # Image Generation Helpers
50
+ # =========================
51
  @spaces.GPU(duration=65)
52
  def infer(
53
  prompt,
 
85
  prompt_list = clean_response_gpt(outputs)
86
  return prompt_list
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def personalize_prompt(prompt, history, feedback, like_image, dislike_image):
89
  seed = random.randint(0, MAX_SEED)
90
  client = init_gpt_api()
91
  messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image)
92
  outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9)
 
 
93
  return outputs
94
 
95
+ # =========================
96
+ # UI Helper Functions
97
+ # =========================
 
98
  def reset_gallery():
99
  return []
100
 
 
104
  def display_info_message(msg, duration=5):
105
  gr.Info(msg, duration=duration)
106
 
107
+ def check_satisfaction(sim_radio):
108
+ global enable_submit, counter
109
+ fully_satisfied_option = ["Satisfied", "Very Satisfied"]
110
+ if_submit = (sim_radio in fully_satisfied_option) or enable_submit or (counter > MAX_ROUND)
 
 
 
 
 
 
 
 
 
 
111
  return gr.update(interactive=if_submit)
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  def select_image(like_radio, images_method):
114
  if like_radio == IMAGE_OPTIONS[0]:
115
  return images_method[0][0]
 
122
  else:
123
  return None
124
 
125
+ def check_evaluation(sim_radio):
126
+ if not sim_radio:
127
+ display_error_message("❌ Please fill all evaluations before changing image or submitting.")
128
+ return False
129
+ return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
+ # =========================
132
+ # Core Actions (single method)
133
+ # =========================
134
+ def generate_image(prompt, like_image, dislike_image):
135
  global responses_memory
136
+ history_prompts = [v["prompt"] for v in responses_memory[METHOD].values()]
137
+ feedback = [v["sim_radio"] for v in responses_memory[METHOD].values()]
138
+
139
+ personalized = personalize_prompt(prompt, history_prompts, feedback, like_image, dislike_image)
140
+ personalized = clean_refined_prompt_response_gpt(personalized)
141
+ if "I'm sorry, I can't assist with" in personalized:
142
+ personalized = prompt
143
+
144
+ gallery_images = []
145
+ # Experimental method refines prompts first
146
+ refined_prompts = call_gpt_refine_prompt(personalized)
147
+ for i in range(NUM_IMAGES):
148
+ img = infer(refined_prompts[i])
149
+ gallery_images.append(img)
150
+ yield gallery_images
151
+
152
+ def redesign(prompt, sim_radio, like_radio, dislike_radio, current_images, history_images, like_image, dislike_image):
153
+ global counter, enable_submit, responses_memory
154
+ if check_evaluation(sim_radio):
155
+ responses_memory[METHOD][counter] = {
156
+ "prompt": prompt,
157
+ "sim_radio": sim_radio,
158
+ "response": "",
159
+ "satisfied_img": f"round {counter}, {like_radio}",
160
+ "unsatisfied_img": f"round {counter}, {dislike_radio}",
161
+ }
 
 
 
162
 
163
  enable_submit = True if sim_radio in ["Satisfied", "Very Satisfied"] or enable_submit else False
164
 
165
+ history_prompts = [[v["prompt"]] for v in responses_memory[METHOD].values()]
166
+ if not history_images:
167
  history_images = current_images
168
  elif current_images:
169
  history_images.extend(current_images)
170
  current_images = []
171
+
172
  examples_state = gr.update(samples=history_prompts, visible=True)
173
  prompt_state = gr.update(interactive=True)
174
  next_state = gr.update(visible=True, interactive=True)
175
  redesign_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(interactive=True)
176
  submit_state = gr.update(interactive=True) if counter >= MAX_ROUND or enable_submit else gr.update(interactive=False)
177
 
178
+ counter += 1
 
 
 
 
 
 
179
 
180
  return None, None, None, current_images, history_images, examples_state, prompt_state, next_state, redesign_state, submit_state
181
  else:
182
+ return {submit_btn: gr.skip()}
183
+
184
+ def save_response(prompt, sim_radio, like_radio, dislike_radio, like_image, dislike_image):
185
+ global counter, enable_submit, responses_memory
186
+
187
+ if check_evaluation(sim_radio):
188
+ # Save the final round entry
189
+ responses_memory[METHOD][counter] = {
190
+ "prompt": prompt,
191
+ "sim_radio": sim_radio,
192
+ "response": "",
193
+ "satisfied_img": f"round {counter}, {like_radio}",
194
+ "unsatisfied_img": f"round {counter}, {dislike_radio}",
195
+ }
196
+
197
+ # Reset states
198
+ counter = 1
199
+ enable_submit = False
200
+
201
+ # Reset buttons
202
+ prompt_state = gr.update(interactive=False)
203
+ next_state = gr.update(visible=False, interactive=False)
204
+ submit_state = gr.update(interactive=False)
205
+ redesign_state = gr.update(interactive=False)
206
+
207
+ display_info_message(" Your answer is saved!")
208
+ return None, None, None, prompt_state, next_state, redesign_state, submit_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  else:
210
+ return {submit_btn: gr.skip()}
 
211
 
212
+ # =========================
213
+ # Interface (single tab, no participant/scenario/background)
214
+ # =========================
215
 
216
+ css = """
217
  #col-container {
218
  margin: 0 auto;
219
  max-width: 700px;
220
  }
 
221
  #col-container2 {
222
  margin: 0 auto;
223
  max-width: 1000px;
224
  }
 
225
  #col-container3 {
226
  margin: 0 0 auto auto;
227
  max-width: 300px;
228
  }
 
229
  #button-container {
230
  display: flex;
231
+ justify-content: center;
232
  }
233
  #compact-row {
234
  width:100%;
 
239
 
240
  with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo:
241
  with gr.Column(elem_id="col-container"):
242
+ gr.Markdown("# 📌 **PAI-GEN — Experimental Only**")
243
+ instruction = gr.Markdown(INSTRUCTION)
244
 
245
+ with gr.Tab("Task"):
246
+ with gr.Row(elem_id="compact-row"):
247
+ prompt = gr.Textbox(
248
+ label="🎨 Revise Prompt",
249
+ max_lines=5,
250
+ placeholder="Enter your prompt",
251
+ scale=4,
252
+ visible=True,
253
  )
254
+ next_btn = gr.Button("Generate", variant="primary", scale=1, interactive=False, visible=False)
255
+
256
+ with gr.Row(elem_id="compact-row"):
257
+ with gr.Column(elem_id="col-container"):
258
+ images_method = gr.Gallery(label="Images", columns=[4], rows=[1], height=400, elem_id="gallery", format="png")
259
+
260
+ with gr.Column(elem_id="col-container3"):
261
+ like_image = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload', format="png", type="filepath")
262
+ dislike_image = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload', format="png", type="filepath")
263
+
264
+ with gr.Column(elem_id="col-container2"):
265
+ gr.Markdown("### 📝 Evaluation")
266
+ sim_radio = gr.Radio(
267
+ OPTIONS,
268
+ label="How would you rate your satisfaction with the generated images?",
269
+ type="value",
270
+ elem_classes=["gradio-radio"]
271
+ )
272
+ like_radio = gr.Radio(
273
+ IMAGE_OPTIONS,
274
+ label="Select your all-time favorite image (optional).",
275
+ type="value",
276
+ elem_classes=["gradio-radio"]
277
+ )
278
+ dislike_radio = gr.Radio(
279
+ IMAGE_OPTIONS,
280
+ label="Select your all-time least satisfactory image (optional).",
281
+ type="value",
282
+ elem_classes=["gradio-radio"]
283
  )
 
 
 
 
 
 
284
 
285
+ response = gr.Textbox(
286
+ label="Briefly explain your rating.",
287
+ max_lines=1,
288
+ interactive=False,
289
+ container=False,
290
+ value=VERBAL_MSG
291
+ )
292
+
293
+ with gr.Column(elem_id="col-container2"):
294
+ example = gr.Examples([['']], prompt, label="Revised Prompt History", visible=False)
295
+ history_images = gr.Gallery(label="History Images", columns=[4], rows=[1], elem_id="gallery", format="png")
296
+
297
+ with gr.Row(elem_id="button-container"):
298
+ redesign_btn = gr.Button("🎨 Redesign", variant="primary", scale=0)
299
+ submit_btn = gr.Button("✅ Submit", variant="primary", interactive=False, scale=0)
300
+
301
+ # =========================
302
+ # Wiring
303
+ # =========================
304
+ sim_radio.change(fn=check_satisfaction, inputs=[sim_radio], outputs=[submit_btn])
305
+
306
+ dislike_radio.select(fn=select_image, inputs=[dislike_radio, images_method], outputs=[dislike_image])
307
+ like_radio.select(fn=select_image, inputs=[like_radio, images_method], outputs=[like_image])
308
+
309
+ next_btn.click(
310
+ fn=generate_image,
311
+ inputs=[prompt, like_image, dislike_image],
312
+ outputs=[images_method]
313
+ ).success(lambda: [gr.update(interactive=False), gr.update(interactive=False)], outputs=[next_btn, prompt])
314
+
315
+ redesign_btn.click(
316
+ fn=redesign,
317
+ inputs=[prompt, sim_radio, like_radio, dislike_radio, images_method, history_images, like_image, dislike_image],
318
+ outputs=[sim_radio, dislike_radio, like_radio, images_method, history_images, example.dataset, prompt, next_btn, redesign_btn, submit_btn]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  )
 
 
 
 
 
 
 
320
 
321
+ submit_btn.click(
322
+ fn=save_response,
323
+ inputs=[prompt, sim_radio, like_radio, dislike_radio, like_image, dislike_image],
324
+ outputs=[sim_radio, dislike_radio, like_radio, prompt, next_btn, redesign_btn, submit_btn]
325
+ )
326
 
327
  if __name__ == "__main__":
328
+ demo.launch()