Files changed (4) hide show
  1. README.md +1 -2
  2. app.py +22 -43
  3. arguments.py +10 -8
  4. requirements.txt +9 -16
README.md CHANGED
@@ -3,9 +3,8 @@ title: ReNO
3
  emoji: 🦌
4
  colorFrom: pink
5
  colorTo: indigo
6
- python_version: '3.10'
7
  sdk: gradio
8
- sdk_version: 6.14.0
9
  app_file: app.py
10
  pinned: false
11
  short_description: Reward-based Noise Optimization for 1-step t2i models
 
3
  emoji: 🦌
4
  colorFrom: pink
5
  colorTo: indigo
 
6
  sdk: gradio
7
+ sdk_version: 5.1.0
8
  app_file: app.py
9
  pinned: false
10
  short_description: Reward-based Noise Optimization for 1-step t2i models
app.py CHANGED
@@ -78,36 +78,25 @@ def setup_model(loaded_model_setup, prompt, model, seed, num_iterations, enable_
78
  args.save_all_images = True
79
 
80
  if enable_hps is True:
81
- args.enable_hps = True
82
  args.hps_weighting = hps_w
83
- else:
84
- args.enable_hps = False
85
 
86
  if enable_imagereward is True:
87
- args.enable_imagereward = True
88
  args.imagereward_weighting = imgrw_w
89
- else:
90
- args.enable_imagereward = False
91
 
92
  if enable_pickscore is True:
93
- args.enable_pickscore = True
94
  args.pickscore_weighting = pcks_w
95
- else:
96
- args.enable_pickscore = False
97
 
98
  if enable_clip is True:
99
- args.enable_clip = True
100
  args.clip_weighting = clip_w
101
- else:
102
- args.enable_clip = False
103
 
104
  if model == "flux":
105
  args.cpu_offloading = True
106
  args.enable_multi_apply = True
107
  args.multi_step_model = "flux"
108
-
109
- if model == "hyper-sd":
110
- args.cpu_offloading = True
111
 
112
  # Check if args are the same as the loaded_model_setup except for the prompt
113
  if loaded_model_setup and hasattr(loaded_model_setup[0], '__dict__'):
@@ -275,12 +264,7 @@ def combined_function(gallery_state, loaded_model_setup, prompt, chosen_model, s
275
 
276
  # Create Gradio interface
277
  title="# ReNO: Enhancing One-step Text-to-Image Models through Reward-based Noise Optimization"
278
- description = "Enter a prompt to generate an image using ReNO. The method enhances text-to-image generation by optimizing \
279
- the initial noise using reward models as detailed in the paper. The demo uses a lower learning rate (2.5) compared to the paper's default (5.0) \
280
- for smoother trajectories - if you are looking for more drastic changes, you can increase this value. You can also \
281
- adjust the reward weights to e.g. prioritize either prompt following (increase ImageReward) or aesthetic quality \
282
- (increase HPS/PickScore) based on your preferences.\n\nThe first time you load this demo, it will take a bit \
283
- to download and initialize the required model. Once loaded, each optimization run takes about 25-60 seconds."
284
 
285
  css="""
286
  #model-status-id{
@@ -315,28 +299,28 @@ with gr.Blocks(css=css, analytics_enabled=False) as demo:
315
  with gr.Column():
316
  prompt = gr.Textbox(label="Prompt")
317
  with gr.Row():
318
- chosen_model = gr.Dropdown(["sd-turbo", "sdxl-turbo", "pixart", "hyper-sd", "flux"], label="Model", value="sdxl-turbo")
319
  seed = gr.Number(label="seed", value=0)
320
 
321
  model_status = gr.Textbox(label="model status", visible=True, elem_id="model-status-id")
322
 
323
  with gr.Row():
324
- n_iter = gr.Slider(minimum=10, maximum=100, step=10, value=50, label="Number of Iterations")
325
- learning_rate = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, value=2.5, label="Learning Rate")
326
 
327
  with gr.Accordion("Advanced Settings", open=True):
328
  with gr.Column():
329
  with gr.Row():
330
- enable_hps = gr.Checkbox(label="HPS ON", value=True, scale=1)
331
  hps_w = gr.Slider(label="HPS weight", step=0.1, minimum=0.0, maximum=10.0, value=5.0, interactive=False, scale=3)
332
  with gr.Row():
333
- enable_imagereward = gr.Checkbox(label="ImageReward ON", value=True, scale=1)
334
  imgrw_w = gr.Slider(label="ImageReward weight", step=0.1, minimum=0, maximum=5.0, value=1.0, interactive=False, scale=3)
335
  with gr.Row():
336
- enable_pickscore = gr.Checkbox(label="PickScore ON", value=True, scale=1)
337
- pcks_w = gr.Slider(label="PickScore weight", step=0.01, minimum=0, maximum=0.5, value=0.05, interactive=False, scale=3)
338
  with gr.Row():
339
- enable_clip = gr.Checkbox(label="CLIP ON", value=True, scale=1)
340
  clip_w = gr.Slider(label="CLIP weight", step=0.01, minimum=0, maximum=0.1, value=0.01, interactive=False, scale=3)
341
 
342
  submit_btn = gr.Button("Submit")
@@ -344,11 +328,11 @@ with gr.Blocks(css=css, analytics_enabled=False) as demo:
344
  gr.Examples(
345
  examples = [
346
  "A red dog and a green cat",
 
347
  "A toaster riding a bike",
348
- "A blue scooter is parked near a curb in front of a green vintage car",
349
  "A curious, orange fox and a fluffy, white rabbit, playing together in a lush, green meadow filled with yellow dandelions",
350
- "An orange chair to the right of a black airplane",
351
- "A brain riding a rocketship towards the moon",
352
  ],
353
  inputs = [prompt]
354
  )
@@ -368,29 +352,25 @@ with gr.Blocks(css=css, analytics_enabled=False) as demo:
368
  fn = allow_weighting,
369
  inputs = [enable_hps],
370
  outputs = [hps_w],
371
- queue = False,
372
- api_visibility="private"
373
  )
374
  enable_imagereward.change(
375
  fn = allow_weighting,
376
  inputs = [enable_imagereward],
377
  outputs = [imgrw_w],
378
- queue = False,
379
- api_visibility="private"
380
  )
381
  enable_pickscore.change(
382
  fn = allow_weighting,
383
  inputs = [enable_pickscore],
384
  outputs = [pcks_w],
385
- queue = False,
386
- api_visibility="private"
387
  )
388
  enable_clip.change(
389
  fn = allow_weighting,
390
  inputs = [enable_clip],
391
  outputs = [clip_w],
392
- queue = False,
393
- api_visibility="private"
394
  )
395
 
396
  submit_btn.click(
@@ -402,8 +382,7 @@ with gr.Blocks(css=css, analytics_enabled=False) as demo:
402
  ],
403
  outputs = [
404
  gallery_state, output_image, status, iter_gallery, loaded_model_setup, model_status # Ensure `model_status` is included in the outputs
405
- ],
406
- api_visibility="private"
407
  )
408
 
409
  """
@@ -427,4 +406,4 @@ with gr.Blocks(css=css, analytics_enabled=False) as demo:
427
  """
428
 
429
  # Launch the app
430
- demo.queue().launch(show_error=True)
 
78
  args.save_all_images = True
79
 
80
  if enable_hps is True:
81
+ args.disable_hps = False
82
  args.hps_weighting = hps_w
 
 
83
 
84
  if enable_imagereward is True:
85
+ args.disable_imagereward = False
86
  args.imagereward_weighting = imgrw_w
 
 
87
 
88
  if enable_pickscore is True:
89
+ args.disable_pickscore = False
90
  args.pickscore_weighting = pcks_w
 
 
91
 
92
  if enable_clip is True:
93
+ args.disable_clip = False
94
  args.clip_weighting = clip_w
 
 
95
 
96
  if model == "flux":
97
  args.cpu_offloading = True
98
  args.enable_multi_apply = True
99
  args.multi_step_model = "flux"
 
 
 
100
 
101
  # Check if args are the same as the loaded_model_setup except for the prompt
102
  if loaded_model_setup and hasattr(loaded_model_setup[0], '__dict__'):
 
264
 
265
  # Create Gradio interface
266
  title="# ReNO: Enhancing One-step Text-to-Image Models through Reward-based Noise Optimization"
267
+ description="Enter a prompt to generate an image using ReNO. Adjust the model and parameters as needed."
 
 
 
 
 
268
 
269
  css="""
270
  #model-status-id{
 
299
  with gr.Column():
300
  prompt = gr.Textbox(label="Prompt")
301
  with gr.Row():
302
+ chosen_model = gr.Dropdown(["sd-turbo", "sdxl-turbo", "pixart", "hyper-sd", "flux"], label="Model", value="sd-turbo")
303
  seed = gr.Number(label="seed", value=0)
304
 
305
  model_status = gr.Textbox(label="model status", visible=True, elem_id="model-status-id")
306
 
307
  with gr.Row():
308
+ n_iter = gr.Slider(minimum=10, maximum=100, step=10, value=10, label="Number of Iterations")
309
+ learning_rate = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, value=5.0, label="Learning Rate")
310
 
311
  with gr.Accordion("Advanced Settings", open=True):
312
  with gr.Column():
313
  with gr.Row():
314
+ enable_hps = gr.Checkbox(label="HPS ON", value=False, scale=1)
315
  hps_w = gr.Slider(label="HPS weight", step=0.1, minimum=0.0, maximum=10.0, value=5.0, interactive=False, scale=3)
316
  with gr.Row():
317
+ enable_imagereward = gr.Checkbox(label="ImageReward ON", value=False, scale=1)
318
  imgrw_w = gr.Slider(label="ImageReward weight", step=0.1, minimum=0, maximum=5.0, value=1.0, interactive=False, scale=3)
319
  with gr.Row():
320
+ enable_pickscore = gr.Checkbox(label="PickScore ON", value=False, scale=1)
321
+ pcks_w = gr.Slider(label="PickScore weight", step=0.01, minimum=0, maximum=5.0, value=0.05, interactive=False, scale=3)
322
  with gr.Row():
323
+ enable_clip = gr.Checkbox(label="CLIP ON", value=False, scale=1)
324
  clip_w = gr.Slider(label="CLIP weight", step=0.01, minimum=0, maximum=0.1, value=0.01, interactive=False, scale=3)
325
 
326
  submit_btn = gr.Button("Submit")
 
328
  gr.Examples(
329
  examples = [
330
  "A red dog and a green cat",
331
+ "A pink elephant and a grey cow",
332
  "A toaster riding a bike",
333
+ "Dwayne Johnson depicted as a philosopher king in an academic painting by Greg Rutkowski",
334
  "A curious, orange fox and a fluffy, white rabbit, playing together in a lush, green meadow filled with yellow dandelions",
335
+ "An epic oil painting: a red portal infront of a cityscape, a solitary figure, and a colorful sky over snowy mountains"
 
336
  ],
337
  inputs = [prompt]
338
  )
 
352
  fn = allow_weighting,
353
  inputs = [enable_hps],
354
  outputs = [hps_w],
355
+ queue = False
 
356
  )
357
  enable_imagereward.change(
358
  fn = allow_weighting,
359
  inputs = [enable_imagereward],
360
  outputs = [imgrw_w],
361
+ queue = False
 
362
  )
363
  enable_pickscore.change(
364
  fn = allow_weighting,
365
  inputs = [enable_pickscore],
366
  outputs = [pcks_w],
367
+ queue = False
 
368
  )
369
  enable_clip.change(
370
  fn = allow_weighting,
371
  inputs = [enable_clip],
372
  outputs = [clip_w],
373
+ queue = False
 
374
  )
375
 
376
  submit_btn.click(
 
382
  ],
383
  outputs = [
384
  gallery_state, output_image, status, iter_gallery, loaded_model_setup, model_status # Ensure `model_status` is included in the outputs
385
+ ]
 
386
  )
387
 
388
  """
 
406
  """
407
 
408
  # Launch the app
409
+ demo.queue().launch(show_error=True, show_api=False)
arguments.py CHANGED
@@ -39,15 +39,16 @@ def parse_args():
39
 
40
  # reward losses
41
  parser.add_argument(
42
- "--enable_hps", default=False, action="store_true",
43
  )
44
  parser.add_argument(
45
  "--hps_weighting", type=float, help="Weighting for HPS", default=5.0
46
  )
47
  parser.add_argument(
48
- "--enable_imagereward",
49
- default=False,
50
- action="store_true",
 
51
  )
52
  parser.add_argument(
53
  "--imagereward_weighting",
@@ -56,15 +57,16 @@ def parse_args():
56
  default=1.0,
57
  )
58
  parser.add_argument(
59
- "--enable_clip", default=False, action="store_true"
60
  )
61
  parser.add_argument(
62
  "--clip_weighting", type=float, help="Weighting for CLIP", default=0.01
63
  )
64
  parser.add_argument(
65
- "--enable_pickscore",
66
- default=False,
67
- action="store_true",
 
68
  )
69
  parser.add_argument(
70
  "--pickscore_weighting",
 
39
 
40
  # reward losses
41
  parser.add_argument(
42
+ "--disable_hps", default=True, action="store_false", dest="enable_hps"
43
  )
44
  parser.add_argument(
45
  "--hps_weighting", type=float, help="Weighting for HPS", default=5.0
46
  )
47
  parser.add_argument(
48
+ "--disable_imagereward",
49
+ default=True,
50
+ action="store_false",
51
+ dest="enable_imagereward",
52
  )
53
  parser.add_argument(
54
  "--imagereward_weighting",
 
57
  default=1.0,
58
  )
59
  parser.add_argument(
60
+ "--disable_clip", default=True, action="store_false", dest="enable_clip"
61
  )
62
  parser.add_argument(
63
  "--clip_weighting", type=float, help="Weighting for CLIP", default=0.01
64
  )
65
  parser.add_argument(
66
+ "--disable_pickscore",
67
+ default=True,
68
+ action="store_false",
69
+ dest="enable_pickscore",
70
  )
71
  parser.add_argument(
72
  "--pickscore_weighting",
requirements.txt CHANGED
@@ -1,22 +1,15 @@
1
- torch==2.5.1
2
- torchvision==0.20.1
3
- pytorch-lightning==2.2.0
4
- datasets==2.18.0
5
-
6
- transformers==4.55.4
7
- diffusers==0.35.1
8
- accelerate==1.8.1
9
- huggingface_hub==0.34.4
10
- safetensors>=0.4.3
11
-
12
  hpsv2==1.2
13
  hpsv2x==1.2.0
14
  image-reward==1.5
15
- open-clip-torch==2.24.0
16
  blobfile
17
  openai-clip
 
18
  optimum
19
- xformers
20
- hf-xet==1.1.8
21
-
22
- setuptools>=68
 
1
+ torch==2.3
2
+ torchvision==0.18.0
3
+ pytorch-lightning==2.2
4
+ datasets==2.18
5
+ transformers==4.38.2
6
+ diffusers
 
 
 
 
 
7
  hpsv2==1.2
8
  hpsv2x==1.2.0
9
  image-reward==1.5
10
+ open-clip-torch==2.24
11
  blobfile
12
  openai-clip
13
+ setuptools==60.2
14
  optimum
15
+ xformers