ginipick commited on
Commit
ce178fe
·
verified ·
1 Parent(s): be831e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -101
app.py CHANGED
@@ -1,8 +1,3 @@
1
- # Configuration
2
- prod = False
3
- port = 8080
4
- show_options = True # Changed to True for better visibility
5
-
6
  import os
7
  import random
8
  import time
@@ -56,19 +51,17 @@ class Preprocessor:
56
 
57
  if hasattr(self.model, 'device'):
58
  if self.model.device.type != device:
59
- print(f"Moving preprocessor model to {device}")
60
- try:
61
- self.model.to(device)
62
- except Exception as e:
63
- print(f"Error moving preprocessor model to {device}: {e}")
64
- pass
65
  else:
66
  print("Warning: Preprocessor model has no .device attribute. Attempting to move to correct device.")
67
  try:
68
  self.model.to(device)
69
  except Exception as e:
70
  print(f"Error attempting to move preprocessor model without .device attribute: {e}")
71
- pass
72
 
73
  return self.model(image, **kwargs)
74
 
@@ -173,7 +166,6 @@ def get_prompt(prompt, additional_prompt):
173
  return ", ".join(filter(None, prompt_parts))
174
 
175
 
176
- # Enhanced style list with more diverse options
177
  style_list = [
178
  {"name": "None",
179
  "prompt": ""
@@ -223,7 +215,6 @@ style_list = [
223
  {"name": "Matrix",
224
  "prompt": "Futuristic cyberpunk interior,neon accent lighting,holographic plants,sleek black surfaces,advanced gaming setup,transparent screens,Blade Runner inspired decor,high-tech minimalist furniture"
225
  },
226
- # New added styles
227
  {"name": "Industrial Loft",
228
  "prompt": "Industrial loft interior,exposed brick walls,metal finishes,high ceilings with exposed pipes,concrete floors,vintage factory lights,open floor plan"
229
  },
@@ -247,7 +238,8 @@ style_list = [
247
  },
248
  {"name": "Penthouse",
249
  "prompt": "Luxury penthouse interior,floor-to-ceiling windows,city skyline views,modern furniture,high-end appliances,marble countertops,designer lighting fixtures"
250
- }]
 
251
  styles = {k["name"]: (k["prompt"]) for k in style_list}
252
  STYLE_NAMES = list(styles.keys())
253
 
@@ -256,7 +248,6 @@ def apply_style(style_name):
256
  return styles.get(style_name, "")
257
 
258
 
259
- # Enhanced CSS for Gradio UI
260
  css = """
261
  /* Global Styles */
262
  :root {
@@ -445,7 +436,6 @@ css = """
445
  }
446
  """
447
 
448
- # Load example images
449
  def load_examples():
450
  examples = []
451
  for i in range(1, 5):
@@ -461,13 +451,11 @@ def load_examples():
461
 
462
  example_images = load_examples()
463
 
464
- # Function to select example image
465
  def select_example(index):
466
  if 0 <= index < len(example_images):
467
  return example_images[index]
468
  return None
469
 
470
- # Gradio Interface Definition
471
  with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
472
  gr.Markdown("<h1>✨ Dream of IKEA ✨</h1>")
473
  gr.Markdown("<h3>Transform your space with AI-powered interior design</h3>")
@@ -482,22 +470,16 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
482
  mirror_webcam=True,
483
  type="pil",
484
  elem_id="input-image",
485
- value=example_images[0] if example_images else None # Set default image to in1.jpg
486
  )
487
 
488
- # Example images section with buttons instead of images with style
489
  with gr.Row(elem_classes="example-images"):
490
- # Create example buttons
491
  example_buttons = []
492
  for i in range(len(example_images)):
493
  if example_images[i]:
494
  btn = gr.Button(f"Example {i+1}", elem_classes="example-thumb")
495
  example_buttons.append(btn)
496
- # Add click event for each example button
497
- btn.click(
498
- fn=lambda idx=i: select_example(idx),
499
- outputs=image
500
- )
501
 
502
  with gr.Column(scale=1, min_width=300):
503
  result = gr.Image(
@@ -508,7 +490,6 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
508
  elem_id="output-image"
509
  )
510
 
511
- # Design input section
512
  with gr.Row():
513
  with gr.Column(scale=2):
514
  prompt = gr.Textbox(
@@ -520,7 +501,6 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
520
  run_button = gr.Button(value="🚀 Generate Design", size="lg")
521
  use_ai_button = gr.Button(value="♻️ Use Result as New Input", size="lg")
522
 
523
- # Grouped style selection with categories
524
  gr.Markdown("<h2>Design Style Selection</h2>")
525
  with gr.Tabs():
526
  with gr.TabItem("Modern Styles"):
@@ -556,7 +536,6 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
556
  elem_id="global-styles"
557
  )
558
 
559
- # Advanced options - now with a clearer separator and improved layout
560
  with gr.Accordion("⚙️ Advanced Options", open=False):
561
  with gr.Row():
562
  with gr.Column(scale=1):
@@ -626,7 +605,17 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
626
  elem_classes="helper-text"
627
  )
628
 
629
- # Function to handle style selection changes across tabs
 
 
 
 
 
 
 
 
 
 
630
  def update_style_selection(modern_value, classic_value, global_value):
631
  if modern_value is not None:
632
  return modern_value
@@ -637,7 +626,6 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
637
  else:
638
  return "None"
639
 
640
- # Style synchronization
641
  style_selection = gr.State("None")
642
 
643
  def clear_other_tabs(active_tab, value):
@@ -649,7 +637,6 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
649
  return None, None, value
650
  return None, None, None
651
 
652
- # Connect the tab radios to update each other
653
  style_selection_modern.change(
654
  fn=lambda x: clear_other_tabs("modern", x),
655
  inputs=[style_selection_modern],
@@ -668,7 +655,6 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
668
  outputs=[style_selection_modern, style_selection_classic, style_selection_global]
669
  )
670
 
671
- # Combine all style selections into one for processing
672
  def get_active_style(modern, classic, global_style):
673
  if modern is not None and modern != "":
674
  return modern
@@ -678,13 +664,11 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
678
  return global_style
679
  return "None"
680
 
681
- # Randomize seed function
682
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
683
  if randomize_seed:
684
  seed = random.randint(0, MAX_SEED)
685
  return seed
686
 
687
- # Configuration list for inputs - using function to get active style
688
  def get_config_inputs():
689
  return [
690
  image,
@@ -703,7 +687,6 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
703
  randomize_seed,
704
  ]
705
 
706
- # Gradio Event Handling Functions
707
  @gr.on(
708
  triggers=[image.upload, prompt.submit, run_button.click],
709
  inputs=get_config_inputs(),
@@ -715,18 +698,20 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
715
  num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale,
716
  seed, randomize_seed, progress=gr.Progress(track_tqdm=True)
717
  ):
718
- # Get the active style
719
  active_style = get_active_style(style_modern, style_classic, style_global)
720
-
721
- # Apply seed randomization
722
  processed_seed = randomize_seed_fn(seed, randomize_seed)
723
  print(f"Using processed seed: {processed_seed}")
724
- print(f"Active style: {active_style}")
725
-
726
- # Call the core processing function
727
  return process_image(
728
- image, active_style, prompt, a_prompt, n_prompt, num_images,
729
- image_resolution, preprocess_resolution, num_steps, guidance_scale,
 
 
 
 
 
 
 
 
730
  processed_seed
731
  )
732
 
@@ -741,27 +726,24 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
741
  a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution,
742
  num_steps, guidance_scale, seed, randomize_seed, progress=gr.Progress(track_tqdm=True)
743
  ):
744
- # First, yield the previous result to update the input image immediately
745
  yield previous_result, gr.update()
746
-
747
- # Get active style
748
  active_style = get_active_style(style_modern, style_classic, style_global)
749
-
750
- # Apply seed randomization
751
  processed_seed = randomize_seed_fn(seed, randomize_seed)
752
- print(f"Using processed seed: {processed_seed}")
753
-
754
- # Then, process the new input image
755
  new_result = process_image(
756
- previous_result, active_style, prompt, a_prompt,
757
- n_prompt, num_images, image_resolution,
758
- preprocess_resolution, num_steps, guidance_scale,
 
 
 
 
 
 
 
759
  processed_seed
760
  )
761
- # Finally, yield the new result
762
  yield previous_result, new_result
763
 
764
- # Turn off buttons when processing
765
  @gr.on(
766
  triggers=[image.upload, use_ai_button.click, run_button.click],
767
  inputs=None,
@@ -771,7 +753,6 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
771
  def turn_buttons_off():
772
  return gr.update(interactive=False, value="Processing..."), gr.update(interactive=False)
773
 
774
- # Turn on buttons when processing is complete
775
  @gr.on(
776
  triggers=[result.change],
777
  inputs=None,
@@ -781,49 +762,24 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
781
  def turn_buttons_on():
782
  return gr.update(interactive=True, value="♻️ Use Result as New Input"), gr.update(interactive=True, value="🚀 Generate Design")
783
 
784
-
785
- # Core Image Processing Function
786
- @spaces.GPU(duration=12)
787
- @torch.inference_mode()
788
  def process_image(
789
- image,
790
- style_selection,
791
- prompt,
792
- a_prompt,
793
- n_prompt,
794
- num_images,
795
- image_resolution,
796
- preprocess_resolution,
797
- num_steps,
798
- guidance_scale,
799
- seed,
800
  ):
801
- """
802
- Processes an input image to generate a new image based on style and prompts.
803
-
804
- Args:
805
- image: Input PIL Image.
806
- style_selection: Name of the design style to apply.
807
- prompt: Custom design prompt.
808
- a_prompt: Additional positive prompt.
809
- n_prompt: Negative prompt.
810
- num_images: Number of images to generate (currently only 1 supported by pipeline).
811
- image_resolution: Resolution for the output image.
812
- preprocess_resolution: Resolution for the preprocessor.
813
- num_steps: Number of inference steps.
814
- guidance_scale: Guidance scale for the diffusion process.
815
- seed: Random seed for reproducibility.
816
-
817
- Returns:
818
- A PIL Image of the generated result.
819
- """
820
- # Use the seed passed from the event handler
821
  current_seed = seed
822
  generator = torch.cuda.manual_seed(current_seed) if torch.cuda.is_available() else torch.manual_seed(current_seed)
823
 
824
  if preprocessor.name != "NormalBae":
825
- preprocessor.load("NormalBae")
826
-
827
  preprocessor.model.to("cuda" if torch.cuda.is_available() else "cpu")
828
 
829
  control_image = preprocessor(
@@ -832,10 +788,13 @@ def process_image(
832
  detect_resolution=preprocess_resolution,
833
  )
834
 
835
- # Construct the full prompt
836
  if style_selection and style_selection != "None":
837
  style_prompt = apply_style(style_selection)
838
- prompt_parts = [f"Photo from Pinterest of {prompt}" if prompt else None, style_prompt if style_prompt else None, a_prompt if a_prompt else None]
 
 
 
 
839
  full_prompt = ", ".join(filter(None, prompt_parts))
840
  else:
841
  full_prompt = get_prompt(prompt, a_prompt)
@@ -858,7 +817,6 @@ def process_image(
858
  image=control_image,
859
  ).images[0]
860
 
861
- # Save and upload results (optional)
862
  try:
863
  timestamp = int(time.time())
864
  results_path = f"{timestamp}_output.jpg"
@@ -886,8 +844,9 @@ def process_image(
886
  return initial_result
887
 
888
 
889
- # Launch the Gradio app
 
890
  if prod:
891
  demo.queue(max_size=20).launch(server_name="localhost", server_port=port)
892
  else:
893
- demo.queue().launch(share=True, show_api=False)
 
 
 
 
 
 
1
  import os
2
  import random
3
  import time
 
51
 
52
  if hasattr(self.model, 'device'):
53
  if self.model.device.type != device:
54
+ print(f"Moving preprocessor model to {device}")
55
+ try:
56
+ self.model.to(device)
57
+ except Exception as e:
58
+ print(f"Error moving preprocessor model to {device}: {e}")
 
59
  else:
60
  print("Warning: Preprocessor model has no .device attribute. Attempting to move to correct device.")
61
  try:
62
  self.model.to(device)
63
  except Exception as e:
64
  print(f"Error attempting to move preprocessor model without .device attribute: {e}")
 
65
 
66
  return self.model(image, **kwargs)
67
 
 
166
  return ", ".join(filter(None, prompt_parts))
167
 
168
 
 
169
  style_list = [
170
  {"name": "None",
171
  "prompt": ""
 
215
  {"name": "Matrix",
216
  "prompt": "Futuristic cyberpunk interior,neon accent lighting,holographic plants,sleek black surfaces,advanced gaming setup,transparent screens,Blade Runner inspired decor,high-tech minimalist furniture"
217
  },
 
218
  {"name": "Industrial Loft",
219
  "prompt": "Industrial loft interior,exposed brick walls,metal finishes,high ceilings with exposed pipes,concrete floors,vintage factory lights,open floor plan"
220
  },
 
238
  },
239
  {"name": "Penthouse",
240
  "prompt": "Luxury penthouse interior,floor-to-ceiling windows,city skyline views,modern furniture,high-end appliances,marble countertops,designer lighting fixtures"
241
+ }
242
+ ]
243
  styles = {k["name"]: (k["prompt"]) for k in style_list}
244
  STYLE_NAMES = list(styles.keys())
245
 
 
248
  return styles.get(style_name, "")
249
 
250
 
 
251
  css = """
252
  /* Global Styles */
253
  :root {
 
436
  }
437
  """
438
 
 
439
  def load_examples():
440
  examples = []
441
  for i in range(1, 5):
 
451
 
452
  example_images = load_examples()
453
 
 
454
  def select_example(index):
455
  if 0 <= index < len(example_images):
456
  return example_images[index]
457
  return None
458
 
 
459
  with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
460
  gr.Markdown("<h1>✨ Dream of IKEA ✨</h1>")
461
  gr.Markdown("<h3>Transform your space with AI-powered interior design</h3>")
 
470
  mirror_webcam=True,
471
  type="pil",
472
  elem_id="input-image",
473
+ value=example_images[0] if example_images else None
474
  )
475
 
 
476
  with gr.Row(elem_classes="example-images"):
 
477
  example_buttons = []
478
  for i in range(len(example_images)):
479
  if example_images[i]:
480
  btn = gr.Button(f"Example {i+1}", elem_classes="example-thumb")
481
  example_buttons.append(btn)
482
+ btn.click(fn=lambda idx=i: select_example(idx), outputs=image)
 
 
 
 
483
 
484
  with gr.Column(scale=1, min_width=300):
485
  result = gr.Image(
 
490
  elem_id="output-image"
491
  )
492
 
 
493
  with gr.Row():
494
  with gr.Column(scale=2):
495
  prompt = gr.Textbox(
 
501
  run_button = gr.Button(value="🚀 Generate Design", size="lg")
502
  use_ai_button = gr.Button(value="♻️ Use Result as New Input", size="lg")
503
 
 
504
  gr.Markdown("<h2>Design Style Selection</h2>")
505
  with gr.Tabs():
506
  with gr.TabItem("Modern Styles"):
 
536
  elem_id="global-styles"
537
  )
538
 
 
539
  with gr.Accordion("⚙️ Advanced Options", open=False):
540
  with gr.Row():
541
  with gr.Column(scale=1):
 
605
  elem_classes="helper-text"
606
  )
607
 
608
+ # 화면 아래에 Discord 배지를 삽입
609
+ gr.Markdown(
610
+ """
611
+ <p style="text-align: center;">
612
+ <a href="https://discord.gg/openfreeai" target="_blank">
613
+ <img src="https://img.shields.io/static/v1?label=Discord&message=Join%20our%20community&color=%230000ff&labelColor=%23800080&logo=discord&logoColor=white&style=for-the-badge" alt="badge">
614
+ </a>
615
+ </p>
616
+ """
617
+ )
618
+
619
  def update_style_selection(modern_value, classic_value, global_value):
620
  if modern_value is not None:
621
  return modern_value
 
626
  else:
627
  return "None"
628
 
 
629
  style_selection = gr.State("None")
630
 
631
  def clear_other_tabs(active_tab, value):
 
637
  return None, None, value
638
  return None, None, None
639
 
 
640
  style_selection_modern.change(
641
  fn=lambda x: clear_other_tabs("modern", x),
642
  inputs=[style_selection_modern],
 
655
  outputs=[style_selection_modern, style_selection_classic, style_selection_global]
656
  )
657
 
 
658
  def get_active_style(modern, classic, global_style):
659
  if modern is not None and modern != "":
660
  return modern
 
664
  return global_style
665
  return "None"
666
 
 
667
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
668
  if randomize_seed:
669
  seed = random.randint(0, MAX_SEED)
670
  return seed
671
 
 
672
  def get_config_inputs():
673
  return [
674
  image,
 
687
  randomize_seed,
688
  ]
689
 
 
690
  @gr.on(
691
  triggers=[image.upload, prompt.submit, run_button.click],
692
  inputs=get_config_inputs(),
 
698
  num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale,
699
  seed, randomize_seed, progress=gr.Progress(track_tqdm=True)
700
  ):
 
701
  active_style = get_active_style(style_modern, style_classic, style_global)
 
 
702
  processed_seed = randomize_seed_fn(seed, randomize_seed)
703
  print(f"Using processed seed: {processed_seed}")
 
 
 
704
  return process_image(
705
+ image,
706
+ active_style,
707
+ prompt,
708
+ a_prompt,
709
+ n_prompt,
710
+ num_images,
711
+ image_resolution,
712
+ preprocess_resolution,
713
+ num_steps,
714
+ guidance_scale,
715
  processed_seed
716
  )
717
 
 
726
  a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution,
727
  num_steps, guidance_scale, seed, randomize_seed, progress=gr.Progress(track_tqdm=True)
728
  ):
 
729
  yield previous_result, gr.update()
 
 
730
  active_style = get_active_style(style_modern, style_classic, style_global)
 
 
731
  processed_seed = randomize_seed_fn(seed, randomize_seed)
 
 
 
732
  new_result = process_image(
733
+ previous_result,
734
+ active_style,
735
+ prompt,
736
+ a_prompt,
737
+ n_prompt,
738
+ num_images,
739
+ image_resolution,
740
+ preprocess_resolution,
741
+ num_steps,
742
+ guidance_scale,
743
  processed_seed
744
  )
 
745
  yield previous_result, new_result
746
 
 
747
  @gr.on(
748
  triggers=[image.upload, use_ai_button.click, run_button.click],
749
  inputs=None,
 
753
  def turn_buttons_off():
754
  return gr.update(interactive=False, value="Processing..."), gr.update(interactive=False)
755
 
 
756
  @gr.on(
757
  triggers=[result.change],
758
  inputs=None,
 
762
  def turn_buttons_on():
763
  return gr.update(interactive=True, value="♻️ Use Result as New Input"), gr.update(interactive=True, value="🚀 Generate Design")
764
 
 
 
 
 
765
  def process_image(
766
+ image,
767
+ style_selection,
768
+ prompt,
769
+ a_prompt,
770
+ n_prompt,
771
+ num_images,
772
+ image_resolution,
773
+ preprocess_resolution,
774
+ num_steps,
775
+ guidance_scale,
776
+ seed,
777
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
778
  current_seed = seed
779
  generator = torch.cuda.manual_seed(current_seed) if torch.cuda.is_available() else torch.manual_seed(current_seed)
780
 
781
  if preprocessor.name != "NormalBae":
782
+ preprocessor.load("NormalBae")
 
783
  preprocessor.model.to("cuda" if torch.cuda.is_available() else "cpu")
784
 
785
  control_image = preprocessor(
 
788
  detect_resolution=preprocess_resolution,
789
  )
790
 
 
791
  if style_selection and style_selection != "None":
792
  style_prompt = apply_style(style_selection)
793
+ prompt_parts = [
794
+ f"Photo from Pinterest of {prompt}" if prompt else None,
795
+ style_prompt if style_prompt else None,
796
+ a_prompt if a_prompt else None
797
+ ]
798
  full_prompt = ", ".join(filter(None, prompt_parts))
799
  else:
800
  full_prompt = get_prompt(prompt, a_prompt)
 
817
  image=control_image,
818
  ).images[0]
819
 
 
820
  try:
821
  timestamp = int(time.time())
822
  results_path = f"{timestamp}_output.jpg"
 
844
  return initial_result
845
 
846
 
847
+ prod = False
848
+ port = 8080
849
  if prod:
850
  demo.queue(max_size=20).launch(server_name="localhost", server_port=port)
851
  else:
852
+ demo.queue().launch(share=True, show_api=False)