ginipick commited on
Commit
7b6c687
·
verified ·
1 Parent(s): ce178fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -60
app.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  import os
2
  import random
3
  import time
@@ -51,17 +56,19 @@ class Preprocessor:
51
 
52
  if hasattr(self.model, 'device'):
53
  if self.model.device.type != device:
54
- print(f"Moving preprocessor model to {device}")
55
- try:
56
- self.model.to(device)
57
- except Exception as e:
58
- print(f"Error moving preprocessor model to {device}: {e}")
 
59
  else:
60
  print("Warning: Preprocessor model has no .device attribute. Attempting to move to correct device.")
61
  try:
62
  self.model.to(device)
63
  except Exception as e:
64
  print(f"Error attempting to move preprocessor model without .device attribute: {e}")
 
65
 
66
  return self.model(image, **kwargs)
67
 
@@ -166,6 +173,7 @@ def get_prompt(prompt, additional_prompt):
166
  return ", ".join(filter(None, prompt_parts))
167
 
168
 
 
169
  style_list = [
170
  {"name": "None",
171
  "prompt": ""
@@ -215,6 +223,7 @@ style_list = [
215
  {"name": "Matrix",
216
  "prompt": "Futuristic cyberpunk interior,neon accent lighting,holographic plants,sleek black surfaces,advanced gaming setup,transparent screens,Blade Runner inspired decor,high-tech minimalist furniture"
217
  },
 
218
  {"name": "Industrial Loft",
219
  "prompt": "Industrial loft interior,exposed brick walls,metal finishes,high ceilings with exposed pipes,concrete floors,vintage factory lights,open floor plan"
220
  },
@@ -238,8 +247,7 @@ style_list = [
238
  },
239
  {"name": "Penthouse",
240
  "prompt": "Luxury penthouse interior,floor-to-ceiling windows,city skyline views,modern furniture,high-end appliances,marble countertops,designer lighting fixtures"
241
- }
242
- ]
243
  styles = {k["name"]: (k["prompt"]) for k in style_list}
244
  STYLE_NAMES = list(styles.keys())
245
 
@@ -248,6 +256,7 @@ def apply_style(style_name):
248
  return styles.get(style_name, "")
249
 
250
 
 
251
  css = """
252
  /* Global Styles */
253
  :root {
@@ -436,6 +445,7 @@ css = """
436
  }
437
  """
438
 
 
439
  def load_examples():
440
  examples = []
441
  for i in range(1, 5):
@@ -451,11 +461,13 @@ def load_examples():
451
 
452
  example_images = load_examples()
453
 
 
454
  def select_example(index):
455
  if 0 <= index < len(example_images):
456
  return example_images[index]
457
  return None
458
 
 
459
  with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
460
  gr.Markdown("<h1>✨ Dream of IKEA ✨</h1>")
461
  gr.Markdown("<h3>Transform your space with AI-powered interior design</h3>")
@@ -470,16 +482,22 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
470
  mirror_webcam=True,
471
  type="pil",
472
  elem_id="input-image",
473
- value=example_images[0] if example_images else None
474
  )
475
 
 
476
  with gr.Row(elem_classes="example-images"):
 
477
  example_buttons = []
478
  for i in range(len(example_images)):
479
  if example_images[i]:
480
  btn = gr.Button(f"Example {i+1}", elem_classes="example-thumb")
481
  example_buttons.append(btn)
482
- btn.click(fn=lambda idx=i: select_example(idx), outputs=image)
 
 
 
 
483
 
484
  with gr.Column(scale=1, min_width=300):
485
  result = gr.Image(
@@ -490,6 +508,7 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
490
  elem_id="output-image"
491
  )
492
 
 
493
  with gr.Row():
494
  with gr.Column(scale=2):
495
  prompt = gr.Textbox(
@@ -501,6 +520,7 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
501
  run_button = gr.Button(value="🚀 Generate Design", size="lg")
502
  use_ai_button = gr.Button(value="♻️ Use Result as New Input", size="lg")
503
 
 
504
  gr.Markdown("<h2>Design Style Selection</h2>")
505
  with gr.Tabs():
506
  with gr.TabItem("Modern Styles"):
@@ -536,6 +556,7 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
536
  elem_id="global-styles"
537
  )
538
 
 
539
  with gr.Accordion("⚙️ Advanced Options", open=False):
540
  with gr.Row():
541
  with gr.Column(scale=1):
@@ -605,17 +626,7 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
605
  elem_classes="helper-text"
606
  )
607
 
608
- # 화면 아래에 Discord 배지를 삽입
609
- gr.Markdown(
610
- """
611
- <p style="text-align: center;">
612
- <a href="https://discord.gg/openfreeai" target="_blank">
613
- <img src="https://img.shields.io/static/v1?label=Discord&message=Join%20our%20community&color=%230000ff&labelColor=%23800080&logo=discord&logoColor=white&style=for-the-badge" alt="badge">
614
- </a>
615
- </p>
616
- """
617
- )
618
-
619
  def update_style_selection(modern_value, classic_value, global_value):
620
  if modern_value is not None:
621
  return modern_value
@@ -626,6 +637,7 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
626
  else:
627
  return "None"
628
 
 
629
  style_selection = gr.State("None")
630
 
631
  def clear_other_tabs(active_tab, value):
@@ -637,6 +649,7 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
637
  return None, None, value
638
  return None, None, None
639
 
 
640
  style_selection_modern.change(
641
  fn=lambda x: clear_other_tabs("modern", x),
642
  inputs=[style_selection_modern],
@@ -655,6 +668,7 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
655
  outputs=[style_selection_modern, style_selection_classic, style_selection_global]
656
  )
657
 
 
658
  def get_active_style(modern, classic, global_style):
659
  if modern is not None and modern != "":
660
  return modern
@@ -664,11 +678,13 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
664
  return global_style
665
  return "None"
666
 
 
667
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
668
  if randomize_seed:
669
  seed = random.randint(0, MAX_SEED)
670
  return seed
671
 
 
672
  def get_config_inputs():
673
  return [
674
  image,
@@ -687,6 +703,7 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
687
  randomize_seed,
688
  ]
689
 
 
690
  @gr.on(
691
  triggers=[image.upload, prompt.submit, run_button.click],
692
  inputs=get_config_inputs(),
@@ -698,20 +715,18 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
698
  num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale,
699
  seed, randomize_seed, progress=gr.Progress(track_tqdm=True)
700
  ):
 
701
  active_style = get_active_style(style_modern, style_classic, style_global)
 
 
702
  processed_seed = randomize_seed_fn(seed, randomize_seed)
703
  print(f"Using processed seed: {processed_seed}")
 
 
 
704
  return process_image(
705
- image,
706
- active_style,
707
- prompt,
708
- a_prompt,
709
- n_prompt,
710
- num_images,
711
- image_resolution,
712
- preprocess_resolution,
713
- num_steps,
714
- guidance_scale,
715
  processed_seed
716
  )
717
 
@@ -726,24 +741,27 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
726
  a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution,
727
  num_steps, guidance_scale, seed, randomize_seed, progress=gr.Progress(track_tqdm=True)
728
  ):
 
729
  yield previous_result, gr.update()
 
 
730
  active_style = get_active_style(style_modern, style_classic, style_global)
 
 
731
  processed_seed = randomize_seed_fn(seed, randomize_seed)
 
 
 
732
  new_result = process_image(
733
- previous_result,
734
- active_style,
735
- prompt,
736
- a_prompt,
737
- n_prompt,
738
- num_images,
739
- image_resolution,
740
- preprocess_resolution,
741
- num_steps,
742
- guidance_scale,
743
  processed_seed
744
  )
 
745
  yield previous_result, new_result
746
 
 
747
  @gr.on(
748
  triggers=[image.upload, use_ai_button.click, run_button.click],
749
  inputs=None,
@@ -753,6 +771,7 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
753
  def turn_buttons_off():
754
  return gr.update(interactive=False, value="Processing..."), gr.update(interactive=False)
755
 
 
756
  @gr.on(
757
  triggers=[result.change],
758
  inputs=None,
@@ -762,24 +781,49 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
762
  def turn_buttons_on():
763
  return gr.update(interactive=True, value="♻️ Use Result as New Input"), gr.update(interactive=True, value="🚀 Generate Design")
764
 
 
 
 
 
765
  def process_image(
766
- image,
767
- style_selection,
768
- prompt,
769
- a_prompt,
770
- n_prompt,
771
- num_images,
772
- image_resolution,
773
- preprocess_resolution,
774
- num_steps,
775
- guidance_scale,
776
- seed,
777
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
778
  current_seed = seed
779
  generator = torch.cuda.manual_seed(current_seed) if torch.cuda.is_available() else torch.manual_seed(current_seed)
780
 
781
  if preprocessor.name != "NormalBae":
782
- preprocessor.load("NormalBae")
 
783
  preprocessor.model.to("cuda" if torch.cuda.is_available() else "cpu")
784
 
785
  control_image = preprocessor(
@@ -788,13 +832,10 @@ def process_image(
788
  detect_resolution=preprocess_resolution,
789
  )
790
 
 
791
  if style_selection and style_selection != "None":
792
  style_prompt = apply_style(style_selection)
793
- prompt_parts = [
794
- f"Photo from Pinterest of {prompt}" if prompt else None,
795
- style_prompt if style_prompt else None,
796
- a_prompt if a_prompt else None
797
- ]
798
  full_prompt = ", ".join(filter(None, prompt_parts))
799
  else:
800
  full_prompt = get_prompt(prompt, a_prompt)
@@ -817,6 +858,7 @@ def process_image(
817
  image=control_image,
818
  ).images[0]
819
 
 
820
  try:
821
  timestamp = int(time.time())
822
  results_path = f"{timestamp}_output.jpg"
@@ -844,9 +886,8 @@ def process_image(
844
  return initial_result
845
 
846
 
847
- prod = False
848
- port = 8080
849
  if prod:
850
  demo.queue(max_size=20).launch(server_name="localhost", server_port=port)
851
  else:
852
- demo.queue().launch(share=True, show_api=False)
 
1
+ # Configuration
2
+ prod = False
3
+ port = 8080
4
+ show_options = True # Changed to True for better visibility
5
+
6
  import os
7
  import random
8
  import time
 
56
 
57
  if hasattr(self.model, 'device'):
58
  if self.model.device.type != device:
59
+ print(f"Moving preprocessor model to {device}")
60
+ try:
61
+ self.model.to(device)
62
+ except Exception as e:
63
+ print(f"Error moving preprocessor model to {device}: {e}")
64
+ pass
65
  else:
66
  print("Warning: Preprocessor model has no .device attribute. Attempting to move to correct device.")
67
  try:
68
  self.model.to(device)
69
  except Exception as e:
70
  print(f"Error attempting to move preprocessor model without .device attribute: {e}")
71
+ pass
72
 
73
  return self.model(image, **kwargs)
74
 
 
173
  return ", ".join(filter(None, prompt_parts))
174
 
175
 
176
+ # Enhanced style list with more diverse options
177
  style_list = [
178
  {"name": "None",
179
  "prompt": ""
 
223
  {"name": "Matrix",
224
  "prompt": "Futuristic cyberpunk interior,neon accent lighting,holographic plants,sleek black surfaces,advanced gaming setup,transparent screens,Blade Runner inspired decor,high-tech minimalist furniture"
225
  },
226
+ # New added styles
227
  {"name": "Industrial Loft",
228
  "prompt": "Industrial loft interior,exposed brick walls,metal finishes,high ceilings with exposed pipes,concrete floors,vintage factory lights,open floor plan"
229
  },
 
247
  },
248
  {"name": "Penthouse",
249
  "prompt": "Luxury penthouse interior,floor-to-ceiling windows,city skyline views,modern furniture,high-end appliances,marble countertops,designer lighting fixtures"
250
+ }]
 
251
  styles = {k["name"]: (k["prompt"]) for k in style_list}
252
  STYLE_NAMES = list(styles.keys())
253
 
 
256
  return styles.get(style_name, "")
257
 
258
 
259
+ # Enhanced CSS for Gradio UI
260
  css = """
261
  /* Global Styles */
262
  :root {
 
445
  }
446
  """
447
 
448
+ # Load example images
449
  def load_examples():
450
  examples = []
451
  for i in range(1, 5):
 
461
 
462
  example_images = load_examples()
463
 
464
+ # Function to select example image
465
  def select_example(index):
466
  if 0 <= index < len(example_images):
467
  return example_images[index]
468
  return None
469
 
470
+ # Gradio Interface Definition
471
  with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
472
  gr.Markdown("<h1>✨ Dream of IKEA ✨</h1>")
473
  gr.Markdown("<h3>Transform your space with AI-powered interior design</h3>")
 
482
  mirror_webcam=True,
483
  type="pil",
484
  elem_id="input-image",
485
+ value=example_images[0] if example_images else None # Set default image to in1.jpg
486
  )
487
 
488
+ # Example images section with buttons instead of images with style
489
  with gr.Row(elem_classes="example-images"):
490
+ # Create example buttons
491
  example_buttons = []
492
  for i in range(len(example_images)):
493
  if example_images[i]:
494
  btn = gr.Button(f"Example {i+1}", elem_classes="example-thumb")
495
  example_buttons.append(btn)
496
+ # Add click event for each example button
497
+ btn.click(
498
+ fn=lambda idx=i: select_example(idx),
499
+ outputs=image
500
+ )
501
 
502
  with gr.Column(scale=1, min_width=300):
503
  result = gr.Image(
 
508
  elem_id="output-image"
509
  )
510
 
511
+ # Design input section
512
  with gr.Row():
513
  with gr.Column(scale=2):
514
  prompt = gr.Textbox(
 
520
  run_button = gr.Button(value="🚀 Generate Design", size="lg")
521
  use_ai_button = gr.Button(value="♻️ Use Result as New Input", size="lg")
522
 
523
+ # Grouped style selection with categories
524
  gr.Markdown("<h2>Design Style Selection</h2>")
525
  with gr.Tabs():
526
  with gr.TabItem("Modern Styles"):
 
556
  elem_id="global-styles"
557
  )
558
 
559
+ # Advanced options - now with a clearer separator and improved layout
560
  with gr.Accordion("⚙️ Advanced Options", open=False):
561
  with gr.Row():
562
  with gr.Column(scale=1):
 
626
  elem_classes="helper-text"
627
  )
628
 
629
+ # Function to handle style selection changes across tabs
 
 
 
 
 
 
 
 
 
 
630
  def update_style_selection(modern_value, classic_value, global_value):
631
  if modern_value is not None:
632
  return modern_value
 
637
  else:
638
  return "None"
639
 
640
+ # Style synchronization
641
  style_selection = gr.State("None")
642
 
643
  def clear_other_tabs(active_tab, value):
 
649
  return None, None, value
650
  return None, None, None
651
 
652
+ # Connect the tab radios to update each other
653
  style_selection_modern.change(
654
  fn=lambda x: clear_other_tabs("modern", x),
655
  inputs=[style_selection_modern],
 
668
  outputs=[style_selection_modern, style_selection_classic, style_selection_global]
669
  )
670
 
671
+ # Combine all style selections into one for processing
672
  def get_active_style(modern, classic, global_style):
673
  if modern is not None and modern != "":
674
  return modern
 
678
  return global_style
679
  return "None"
680
 
681
+ # Randomize seed function
682
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
683
  if randomize_seed:
684
  seed = random.randint(0, MAX_SEED)
685
  return seed
686
 
687
+ # Configuration list for inputs - using function to get active style
688
  def get_config_inputs():
689
  return [
690
  image,
 
703
  randomize_seed,
704
  ]
705
 
706
+ # Gradio Event Handling Functions
707
  @gr.on(
708
  triggers=[image.upload, prompt.submit, run_button.click],
709
  inputs=get_config_inputs(),
 
715
  num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale,
716
  seed, randomize_seed, progress=gr.Progress(track_tqdm=True)
717
  ):
718
+ # Get the active style
719
  active_style = get_active_style(style_modern, style_classic, style_global)
720
+
721
+ # Apply seed randomization
722
  processed_seed = randomize_seed_fn(seed, randomize_seed)
723
  print(f"Using processed seed: {processed_seed}")
724
+ print(f"Active style: {active_style}")
725
+
726
+ # Call the core processing function
727
  return process_image(
728
+ image, active_style, prompt, a_prompt, n_prompt, num_images,
729
+ image_resolution, preprocess_resolution, num_steps, guidance_scale,
 
 
 
 
 
 
 
 
730
  processed_seed
731
  )
732
 
 
741
  a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution,
742
  num_steps, guidance_scale, seed, randomize_seed, progress=gr.Progress(track_tqdm=True)
743
  ):
744
+ # First, yield the previous result to update the input image immediately
745
  yield previous_result, gr.update()
746
+
747
+ # Get active style
748
  active_style = get_active_style(style_modern, style_classic, style_global)
749
+
750
+ # Apply seed randomization
751
  processed_seed = randomize_seed_fn(seed, randomize_seed)
752
+ print(f"Using processed seed: {processed_seed}")
753
+
754
+ # Then, process the new input image
755
  new_result = process_image(
756
+ previous_result, active_style, prompt, a_prompt,
757
+ n_prompt, num_images, image_resolution,
758
+ preprocess_resolution, num_steps, guidance_scale,
 
 
 
 
 
 
 
759
  processed_seed
760
  )
761
+ # Finally, yield the new result
762
  yield previous_result, new_result
763
 
764
+ # Turn off buttons when processing
765
  @gr.on(
766
  triggers=[image.upload, use_ai_button.click, run_button.click],
767
  inputs=None,
 
771
  def turn_buttons_off():
772
  return gr.update(interactive=False, value="Processing..."), gr.update(interactive=False)
773
 
774
+ # Turn on buttons when processing is complete
775
  @gr.on(
776
  triggers=[result.change],
777
  inputs=None,
 
781
  def turn_buttons_on():
782
  return gr.update(interactive=True, value="♻️ Use Result as New Input"), gr.update(interactive=True, value="🚀 Generate Design")
783
 
784
+
785
+ # Core Image Processing Function
786
+ @spaces.GPU(duration=12)
787
+ @torch.inference_mode()
788
  def process_image(
789
+ image,
790
+ style_selection,
791
+ prompt,
792
+ a_prompt,
793
+ n_prompt,
794
+ num_images,
795
+ image_resolution,
796
+ preprocess_resolution,
797
+ num_steps,
798
+ guidance_scale,
799
+ seed,
800
  ):
801
+ """
802
+ Processes an input image to generate a new image based on style and prompts.
803
+
804
+ Args:
805
+ image: Input PIL Image.
806
+ style_selection: Name of the design style to apply.
807
+ prompt: Custom design prompt.
808
+ a_prompt: Additional positive prompt.
809
+ n_prompt: Negative prompt.
810
+ num_images: Number of images to generate (currently only 1 supported by pipeline).
811
+ image_resolution: Resolution for the output image.
812
+ preprocess_resolution: Resolution for the preprocessor.
813
+ num_steps: Number of inference steps.
814
+ guidance_scale: Guidance scale for the diffusion process.
815
+ seed: Random seed for reproducibility.
816
+
817
+ Returns:
818
+ A PIL Image of the generated result.
819
+ """
820
+ # Use the seed passed from the event handler
821
  current_seed = seed
822
  generator = torch.cuda.manual_seed(current_seed) if torch.cuda.is_available() else torch.manual_seed(current_seed)
823
 
824
  if preprocessor.name != "NormalBae":
825
+ preprocessor.load("NormalBae")
826
+
827
  preprocessor.model.to("cuda" if torch.cuda.is_available() else "cpu")
828
 
829
  control_image = preprocessor(
 
832
  detect_resolution=preprocess_resolution,
833
  )
834
 
835
+ # Construct the full prompt
836
  if style_selection and style_selection != "None":
837
  style_prompt = apply_style(style_selection)
838
+ prompt_parts = [f"Photo from Pinterest of {prompt}" if prompt else None, style_prompt if style_prompt else None, a_prompt if a_prompt else None]
 
 
 
 
839
  full_prompt = ", ".join(filter(None, prompt_parts))
840
  else:
841
  full_prompt = get_prompt(prompt, a_prompt)
 
858
  image=control_image,
859
  ).images[0]
860
 
861
+ # Save and upload results (optional)
862
  try:
863
  timestamp = int(time.time())
864
  results_path = f"{timestamp}_output.jpg"
 
886
  return initial_result
887
 
888
 
889
+ # Launch the Gradio app
 
890
  if prod:
891
  demo.queue(max_size=20).launch(server_name="localhost", server_port=port)
892
  else:
893
+ demo.queue().launch(share=True, show_api=False)