Spaces:
Paused
Paused
| # Configuration | |
| prod = False | |
| port = 8080 | |
| show_options = True # Changed to True for better visibility | |
| import os | |
| import random | |
| import time | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| import imageio | |
| from huggingface_hub import HfApi | |
| import gc | |
| import torch | |
| import cv2 | |
| from PIL import Image | |
| from diffusers import ( | |
| ControlNetModel, | |
| DPMSolverMultistepScheduler, | |
| StableDiffusionControlNetPipeline, | |
| ) | |
| from controlnet_aux_local import NormalBaeDetector | |
| MAX_SEED = np.iinfo(np.int32).max | |
| API_KEY = os.environ.get("API_KEY", None) | |
| print("CUDA version:", torch.version.cuda) | |
| print("loading everything") | |
| compiled = False | |
| api = HfApi() | |
| class Preprocessor: | |
| MODEL_ID = "lllyasviel/Annotators" | |
| def __init__(self): | |
| self.model = None | |
| self.name = "" | |
| def load(self, name: str) -> None: | |
| if name == self.name: | |
| return | |
| elif name == "NormalBae": | |
| print("Loading NormalBae") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID).to(device) | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| self.name = name | |
| else: | |
| raise ValueError | |
| return | |
| def __call__(self, image: Image.Image, **kwargs) -> Image.Image: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if hasattr(self.model, 'device'): | |
| if self.model.device.type != device: | |
| print(f"Moving preprocessor model to {device}") | |
| try: | |
| self.model.to(device) | |
| except Exception as e: | |
| print(f"Error moving preprocessor model to {device}: {e}") | |
| pass | |
| else: | |
| print("Warning: Preprocessor model has no .device attribute. Attempting to move to correct device.") | |
| try: | |
| self.model.to(device) | |
| except Exception as e: | |
| print(f"Error attempting to move preprocessor model without .device attribute: {e}") | |
| pass | |
| return self.model(image, **kwargs) | |
| # Load models and preprocessor when the script starts | |
| model_id = "lllyasviel/control_v11p_sd15_normalbae" | |
| print("initializing controlnet") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| controlnet = ControlNetModel.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| attn_implementation="flash_attention_2" if torch.cuda.is_available() else None, | |
| ).to(device) | |
| # Scheduler | |
| scheduler = DPMSolverMultistepScheduler.from_pretrained( | |
| "ashllay/stable-diffusion-v1-5-archive", | |
| solver_order=2, | |
| subfolder="scheduler", | |
| use_karras_sigmas=True, | |
| final_sigmas_type="sigma_min", | |
| algorithm_type="sde-dpmsolver++", | |
| prediction_type="epsilon", | |
| thresholding=False, | |
| denoise_final=True, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| ) | |
| # Stable Diffusion Pipeline URL | |
| base_model_url = "https://huggingface.co/Lykon/AbsoluteReality/blob/main/AbsoluteReality_1.8.1_pruned.safetensors" | |
| print('loading pipe') | |
| pipe = StableDiffusionControlNetPipeline.from_single_file( | |
| base_model_url, | |
| safety_checker=None, | |
| controlnet=controlnet, | |
| scheduler=scheduler, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| ).to(device) | |
| print("loading preprocessor") | |
| preprocessor = Preprocessor() | |
| preprocessor.load("NormalBae") | |
| # Load textual inversions | |
| try: | |
| textual_inversions = { | |
| "EasyNegativeV2": "EasyNegativeV2.safetensors", | |
| "badhandv4": "badhandv4.pt", | |
| "fcNeg-neg": "fcNeg-neg.pt", | |
| "HDA_Ahegao": "HDA_Ahegao.pt", | |
| "HDA_Bondage": "HDA_Bondage.pt", | |
| "HDA_pet_play": "HDA_pet_play.pt", | |
| "HDA_unconventional_maid": "HDA_unconventional maid.pt", | |
| "HDA_NakedHoodie": "HDA_NakedHoodie.pt", | |
| "HDA_NunDress": "HDA_NunDress.pt", | |
| "HDA_Shibari": "HDA_Shibari.pt", | |
| } | |
| for token, weight_name in textual_inversions.items(): | |
| try: | |
| pipe.load_textual_inversion( | |
| "broyang/hentaidigitalart_v20", weight_name=weight_name, token=token, | |
| ) | |
| print(f"Loaded textual inversion: {token}") | |
| except Exception as e: | |
| print(f"Warning: Could not load textual inversion {weight_name}: {e}") | |
| except Exception as e: | |
| print(f"Error during textual inversions loading process: {e}") | |
| print("---------------Loaded controlnet pipeline---------------") | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| print(f"CUDA memory allocated: {torch.cuda.max_memory_allocated(device='cuda') / 1e9:.2f} GB") | |
| def get_additional_prompt(): | |
| prompt = "hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed" | |
| top = ["tank top", "blouse", "button up shirt", "sweater", "corset top"] | |
| bottom = ["short skirt", "athletic shorts", "jean shorts", "pleated skirt", "short skirt", | |
| "leggings", "high-waisted shorts"] | |
| accessory = ["knee-high boots", "gloves", "Thigh-high stockings", "Garter belt", "choker", | |
| "necklace", "headband", "headphones"] | |
| return f"{prompt}, {random.choice(top)}, {random.choice(bottom)}, {random.choice(accessory)}, score_9" | |
| def get_prompt(prompt, additional_prompt): | |
| interior = "design-style interior designed (interior space),tungsten white balance,captured with a DSLR camera using f/10 aperture, 1/60 sec shutter speed, ISO 400, 20mm focal length" | |
| prompt_parts = [] | |
| if prompt: | |
| prompt_parts.append(f"Photo from Pinterest of {prompt}") | |
| else: | |
| prompt_parts.append("Photo from Pinterest of interior space") | |
| prompt_parts.append(interior) | |
| if additional_prompt: | |
| prompt_parts.append(additional_prompt) | |
| return ", ".join(filter(None, prompt_parts)) | |
| # Enhanced style list with more diverse options | |
| style_list = [ | |
| {"name": "None", | |
| "prompt": "" | |
| }, | |
| {"name": "Minimalistic", | |
| "prompt": "Minimalist interior design,clean lines,neutral colors,uncluttered space,functional furniture,lots of natural light" | |
| }, | |
| {"name": "Boho", | |
| "prompt": "Bohemian chic interior,eclectic mix of patterns and textures,vintage furniture,plants,woven textiles,warm earthy colors" | |
| }, | |
| {"name": "Farmhouse", | |
| "prompt": "Modern farmhouse interior,rustic wood elements,shiplap walls,neutral color palette,industrial accents,cozy textiles" | |
| }, | |
| {"name": "Saudi Prince", | |
| "prompt": "Opulent gold interior,luxurious ornate furniture,crystal chandeliers,rich fabrics,marble floors,intricate Arabic patterns" | |
| }, | |
| {"name": "Neoclassical", | |
| "prompt": "Neoclassical interior design,elegant columns,ornate moldings,symmetrical layout,refined furniture,muted color palette" | |
| }, | |
| {"name": "Eclectic", | |
| "prompt": "Eclectic interior design,mix of styles and eras,bold color combinations,diverse furniture pieces,unique art objects" | |
| }, | |
| {"name": "Parisian", | |
| "prompt": "Parisian apartment interior,all-white color scheme,ornate moldings,herringbone wood floors,elegant furniture,large windows" | |
| }, | |
| {"name": "Hollywood", | |
| "prompt": "Hollywood Regency interior,glamorous and luxurious,bold colors,mirrored surfaces,velvet upholstery,gold accents" | |
| }, | |
| {"name": "Scandinavian", | |
| "prompt": "Scandinavian interior design,light wood tones,white walls,minimalist furniture,cozy textiles,hygge atmosphere" | |
| }, | |
| {"name": "Beach", | |
| "prompt": "Coastal beach house interior,light blue and white color scheme,weathered wood,nautical accents,sheer curtains,ocean view" | |
| }, | |
| {"name": "Japanese", | |
| "prompt": "Traditional Japanese interior,tatami mats,shoji screens,low furniture,zen garden view,minimalist decor,natural materials" | |
| }, | |
| {"name": "Midcentury Modern", | |
| "prompt": "Mid-century modern interior,1950s-60s style furniture,organic shapes,warm wood tones,bold accent colors,large windows" | |
| }, | |
| {"name": "Retro Futurism", | |
| "prompt": "Neon (atompunk world) retro cyberpunk background", | |
| }, | |
| {"name": "Texan", | |
| "prompt": "Western cowboy interior,rustic wood beams,leather furniture,cowhide rugs,antler chandeliers,southwestern patterns" | |
| }, | |
| {"name": "Matrix", | |
| "prompt": "Futuristic cyberpunk interior,neon accent lighting,holographic plants,sleek black surfaces,advanced gaming setup,transparent screens,Blade Runner inspired decor,high-tech minimalist furniture" | |
| }, | |
| # New added styles | |
| {"name": "Industrial Loft", | |
| "prompt": "Industrial loft interior,exposed brick walls,metal finishes,high ceilings with exposed pipes,concrete floors,vintage factory lights,open floor plan" | |
| }, | |
| {"name": "Art Deco", | |
| "prompt": "Art Deco interior design,geometric patterns,bold colors,luxurious materials,symmetrical designs,metallic accents,sophisticated lighting" | |
| }, | |
| {"name": "Contemporary", | |
| "prompt": "Contemporary interior design,sleek finishes,neutral palette with bold accents,clean lines,minimal ornamentation,statement lighting,open concept" | |
| }, | |
| {"name": "Tropical Villa", | |
| "prompt": "Tropical villa interior,palm leaf patterns,natural materials,indoor plants,rattan furniture,light and airy spaces,ocean view,swimming pool" | |
| }, | |
| {"name": "Mediterranean", | |
| "prompt": "Mediterranean interior design,terracotta tiles,arched doorways,wrought iron details,warm color palette,hand-painted ceramics,indoor-outdoor living" | |
| }, | |
| {"name": "Gothic Victorian", | |
| "prompt": "Gothic Victorian interior,dark wood paneling,ornate furniture,velvet drapery,crystal chandeliers,rich jewel tones,antique decorative elements" | |
| }, | |
| {"name": "Rustic Cabin", | |
| "prompt": "Rustic mountain cabin interior,log walls,stone fireplace,wooden beams,cozy textiles,leather furniture,forest views,warm lighting" | |
| }, | |
| {"name": "Penthouse", | |
| "prompt": "Luxury penthouse interior,floor-to-ceiling windows,city skyline views,modern furniture,high-end appliances,marble countertops,designer lighting fixtures" | |
| }] | |
| styles = {k["name"]: (k["prompt"]) for k in style_list} | |
| STYLE_NAMES = list(styles.keys()) | |
| def apply_style(style_name): | |
| return styles.get(style_name, "") | |
| # Enhanced CSS for Gradio UI | |
| css = """ | |
| /* Global Styles */ | |
| :root { | |
| --primary-color: #3498db; | |
| --secondary-color: #2ecc71; | |
| --accent-color: #e74c3c; | |
| --text-color: #333; | |
| --light-bg: #f8f9fa; | |
| --border-radius: 10px; | |
| --box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
| } | |
| body { | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| color: var(--text-color); | |
| background-color: var(--light-bg); | |
| } | |
| /* Typography */ | |
| h1, h2, h3 { | |
| text-align: center; | |
| display: block; | |
| color: var(--primary-color); | |
| margin-bottom: 1rem; | |
| } | |
| h1 { | |
| font-size: 2.5rem; | |
| margin-top: 1rem; | |
| font-weight: 700; | |
| } | |
| h2 { | |
| font-size: 1.8rem; | |
| color: var(--accent-color); | |
| } | |
| /* Layout */ | |
| footer { | |
| visibility: hidden; | |
| } | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| margin: 0 auto; | |
| padding: 20px; | |
| } | |
| /* Image Containers */ | |
| .gr-image { | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| width: 100%; | |
| height: 512px; | |
| overflow: hidden; | |
| border-radius: var(--border-radius); | |
| box-shadow: var(--box-shadow); | |
| transition: all 0.3s ease; | |
| } | |
| .gr-image:hover { | |
| box-shadow: 0 8px 15px rgba(0, 0, 0, 0.2); | |
| } | |
| .gr-image img { | |
| width: 100%; | |
| height: 100%; | |
| object-fit: cover; | |
| object-position: center; | |
| border-radius: var(--border-radius); | |
| } | |
| /* Radio buttons styling */ | |
| .gr-radio-group { | |
| display: grid; | |
| grid-template-columns: repeat(auto-fill, minmax(150px, 1fr)); | |
| gap: 10px; | |
| padding: 15px; | |
| background-color: white; | |
| border-radius: var(--border-radius); | |
| box-shadow: var(--box-shadow); | |
| } | |
| /* Buttons */ | |
| button.gr-button { | |
| background-color: var(--primary-color) !important; | |
| color: white !important; | |
| border: none !important; | |
| padding: 10px 20px !important; | |
| border-radius: var(--border-radius) !important; | |
| font-weight: bold !important; | |
| transition: all 0.3s ease !important; | |
| box-shadow: var(--box-shadow) !important; | |
| } | |
| button.gr-button:hover { | |
| background-color: var(--secondary-color) !important; | |
| transform: translateY(-2px) !important; | |
| box-shadow: 0 6px 12px rgba(0, 0, 0, 0.15) !important; | |
| } | |
| /* Slider customization */ | |
| .gr-slider { | |
| margin-top: 10px !important; | |
| } | |
| /* Accordion styling */ | |
| .gr-accordion { | |
| margin-top: 20px; | |
| border-radius: var(--border-radius); | |
| overflow: hidden; | |
| box-shadow: var(--box-shadow); | |
| } | |
| /* Helper text */ | |
| .helper-text { | |
| background-color: #f0f8ff; | |
| padding: 10px; | |
| border-left: 4px solid var(--primary-color); | |
| margin: 15px 0; | |
| border-radius: 0 var(--border-radius) var(--border-radius) 0; | |
| } | |
| /* Style categories */ | |
| .style-category { | |
| font-weight: bold; | |
| margin-top: 10px; | |
| color: var(--accent-color); | |
| } | |
| /* Progress bar */ | |
| .gr-progress { | |
| height: 10px !important; | |
| border-radius: 5px !important; | |
| background-color: #e0e0e0 !important; | |
| } | |
| .gr-progress-bar { | |
| background-color: var(--secondary-color) !important; | |
| border-radius: 5px !important; | |
| } | |
| /* Main image section highlight */ | |
| .main-images { | |
| border: 2px solid var(--primary-color); | |
| border-radius: var(--border-radius); | |
| padding: 15px; | |
| background-color: white; | |
| margin-bottom: 20px; | |
| } | |
| /* Example images section */ | |
| .example-images { | |
| display: flex; | |
| justify-content: center; | |
| gap: 10px; | |
| margin-top: 10px; | |
| } | |
| .example-image { | |
| width: 120px; | |
| height: 90px; | |
| object-fit: cover; | |
| border-radius: 5px; | |
| cursor: pointer; | |
| transition: all 0.2s ease; | |
| border: 2px solid transparent; | |
| } | |
| .example-image:hover { | |
| transform: scale(1.05); | |
| border-color: var(--primary-color); | |
| } | |
| .example-thumb { | |
| border: 2px solid #ddd; | |
| border-radius: 8px; | |
| cursor: pointer; | |
| transition: all 0.3s ease; | |
| } | |
| .example-thumb:hover { | |
| border-color: var(--primary-color); | |
| transform: translateY(-2px); | |
| } | |
| """ | |
| # Load example images | |
| def load_examples(): | |
| examples = [] | |
| for i in range(1, 5): | |
| try: | |
| img_path = f"in{i}.jpg" | |
| if os.path.exists(img_path): | |
| examples.append(Image.open(img_path)) | |
| else: | |
| print(f"Warning: Example image {img_path} not found") | |
| except Exception as e: | |
| print(f"Error loading example image in{i}.jpg: {e}") | |
| return examples | |
| example_images = load_examples() | |
| # Function to select example image | |
| def select_example(index): | |
| if 0 <= index < len(example_images): | |
| return example_images[index] | |
| return None | |
| # Gradio Interface Definition | |
| with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo: | |
| gr.Markdown("<h1>✨ Dream of IKEA ✨</h1>") | |
| gr.Markdown("<h3>Transform your space with AI-powered interior design</h3>") | |
| # MAIN SECTION - Moved to top | |
| with gr.Row(equal_height=True, elem_classes="main-images"): | |
| with gr.Column(scale=1, min_width=300): | |
| image = gr.Image( | |
| label="📸 Upload Room Photo", | |
| sources=["upload"], | |
| show_label=True, | |
| mirror_webcam=True, | |
| type="pil", | |
| elem_id="input-image", | |
| value=example_images[0] if example_images else None # Set default image to in1.jpg | |
| ) | |
| # Example images section with buttons instead of images with style | |
| with gr.Row(elem_classes="example-images"): | |
| # Create example buttons | |
| example_buttons = [] | |
| for i in range(len(example_images)): | |
| if example_images[i]: | |
| btn = gr.Button(f"Example {i+1}", elem_classes="example-thumb") | |
| example_buttons.append(btn) | |
| # Add click event for each example button | |
| btn.click( | |
| fn=lambda idx=i: select_example(idx), | |
| outputs=image | |
| ) | |
| with gr.Column(scale=1, min_width=300): | |
| result = gr.Image( | |
| label="🎨 AI Redesigned Room", | |
| interactive=False, | |
| type="pil", | |
| show_share_button=True, | |
| elem_id="output-image" | |
| ) | |
| # Design input section | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| prompt = gr.Textbox( | |
| label="💭 Describe Your Dream Space", | |
| placeholder="E.g., 'A cozy bedroom with mountain view' or 'Modern kitchen with island'", | |
| elem_id="prompt-input" | |
| ) | |
| with gr.Column(scale=1): | |
| run_button = gr.Button(value="🚀 Generate Design", size="lg") | |
| use_ai_button = gr.Button(value="♻️ Use Result as New Input", size="lg") | |
| # Grouped style selection with categories | |
| gr.Markdown("<h2>Design Style Selection</h2>") | |
| with gr.Tabs(): | |
| with gr.TabItem("Modern Styles"): | |
| modern_styles = ["None", "Minimalistic", "Contemporary", "Scandinavian", "Industrial Loft", "Midcentury Modern", "Matrix", "Retro Futurism", "Penthouse"] | |
| style_selection_modern = gr.Radio( | |
| show_label=False, | |
| container=True, | |
| interactive=True, | |
| choices=modern_styles, | |
| value="None", | |
| elem_id="modern-styles" | |
| ) | |
| with gr.TabItem("Classic & Traditional"): | |
| classic_styles = ["Farmhouse", "Neoclassical", "Rustic Cabin", "Mediterranean", "Gothic Victorian", "Art Deco", "Parisian", "Texan"] | |
| style_selection_classic = gr.Radio( | |
| show_label=False, | |
| container=True, | |
| interactive=True, | |
| choices=classic_styles, | |
| value=None, | |
| elem_id="classic-styles" | |
| ) | |
| with gr.TabItem("Global & Eclectic"): | |
| global_styles = ["Boho", "Eclectic", "Japanese", "Tropical Villa", "Beach", "Hollywood", "Saudi Prince"] | |
| style_selection_global = gr.Radio( | |
| show_label=False, | |
| container=True, | |
| interactive=True, | |
| choices=global_styles, | |
| value=None, | |
| elem_id="global-styles" | |
| ) | |
| # Advanced options - now with a clearer separator and improved layout | |
| with gr.Accordion("⚙️ Advanced Options", open=False): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| num_images = gr.Slider( | |
| label="Number of Images", | |
| minimum=1, | |
| maximum=4, | |
| value=1, | |
| step=1 | |
| ) | |
| image_resolution = gr.Slider( | |
| label="Image Resolution", | |
| minimum=256, | |
| maximum=1024, | |
| value=512, | |
| step=256, | |
| ) | |
| preprocess_resolution = gr.Slider( | |
| label="Preprocess Resolution", | |
| minimum=128, | |
| maximum=1024, | |
| value=512, | |
| step=1, | |
| ) | |
| with gr.Column(scale=1): | |
| num_steps = gr.Slider( | |
| label="Number of Steps", | |
| minimum=1, | |
| maximum=100, | |
| value=15, | |
| step=1 | |
| ) | |
| guidance_scale = gr.Slider( | |
| label="Guidance Scale", | |
| minimum=0.1, | |
| maximum=30.0, | |
| value=5.5, | |
| step=0.1 | |
| ) | |
| with gr.Row(): | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=0 | |
| ) | |
| randomize_seed = gr.Checkbox( | |
| label="Randomize Seed", | |
| value=True | |
| ) | |
| with gr.Row(): | |
| a_prompt = gr.Textbox( | |
| label="Additional Prompt", | |
| value="design-style interior designed (interior space), tungsten white balance, captured with a DSLR camera using f/10 aperture, 1/60 sec shutter speed, ISO 400, 20mm focal length" | |
| ) | |
| n_prompt = gr.Textbox( | |
| label="Negative Prompt", | |
| value="EasyNegativeV2, fcNeg, (badhandv4:1.4), (worst quality, low quality, bad quality, normal quality:2.0), (bad hands, missing fingers, extra fingers:2.0)", | |
| ) | |
| with gr.Row(): | |
| helper_text = gr.Markdown( | |
| "### 💡 Community: https://discord.gg/openfreeai ", | |
| elem_classes="helper-text" | |
| ) | |
| # Function to handle style selection changes across tabs | |
| def update_style_selection(modern_value, classic_value, global_value): | |
| if modern_value is not None: | |
| return modern_value | |
| elif classic_value is not None: | |
| return classic_value | |
| elif global_value is not None: | |
| return global_value | |
| else: | |
| return "None" | |
| # Style synchronization | |
| style_selection = gr.State("None") | |
| def clear_other_tabs(active_tab, value): | |
| if active_tab == "modern" and value is not None: | |
| return value, None, None | |
| elif active_tab == "classic" and value is not None: | |
| return None, value, None | |
| elif active_tab == "global" and value is not None: | |
| return None, None, value | |
| return None, None, None | |
| # Connect the tab radios to update each other | |
| style_selection_modern.change( | |
| fn=lambda x: clear_other_tabs("modern", x), | |
| inputs=[style_selection_modern], | |
| outputs=[style_selection_modern, style_selection_classic, style_selection_global] | |
| ) | |
| style_selection_classic.change( | |
| fn=lambda x: clear_other_tabs("classic", x), | |
| inputs=[style_selection_classic], | |
| outputs=[style_selection_modern, style_selection_classic, style_selection_global] | |
| ) | |
| style_selection_global.change( | |
| fn=lambda x: clear_other_tabs("global", x), | |
| inputs=[style_selection_global], | |
| outputs=[style_selection_modern, style_selection_classic, style_selection_global] | |
| ) | |
| # Combine all style selections into one for processing | |
| def get_active_style(modern, classic, global_style): | |
| if modern is not None and modern != "": | |
| return modern | |
| elif classic is not None and classic != "": | |
| return classic | |
| elif global_style is not None and global_style != "": | |
| return global_style | |
| return "None" | |
| # Randomize seed function | |
| def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| return seed | |
| # Configuration list for inputs - using function to get active style | |
| def get_config_inputs(): | |
| return [ | |
| image, | |
| style_selection_modern, | |
| style_selection_classic, | |
| style_selection_global, | |
| prompt, | |
| a_prompt, | |
| n_prompt, | |
| num_images, | |
| image_resolution, | |
| preprocess_resolution, | |
| num_steps, | |
| guidance_scale, | |
| seed, | |
| randomize_seed, | |
| ] | |
| # Gradio Event Handling Functions | |
| def auto_process_image( | |
| image, style_modern, style_classic, style_global, prompt, a_prompt, n_prompt, | |
| num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, | |
| seed, randomize_seed, progress=gr.Progress(track_tqdm=True) | |
| ): | |
| # Get the active style | |
| active_style = get_active_style(style_modern, style_classic, style_global) | |
| # Apply seed randomization | |
| processed_seed = randomize_seed_fn(seed, randomize_seed) | |
| print(f"Using processed seed: {processed_seed}") | |
| print(f"Active style: {active_style}") | |
| # Call the core processing function | |
| return process_image( | |
| image, active_style, prompt, a_prompt, n_prompt, num_images, | |
| image_resolution, preprocess_resolution, num_steps, guidance_scale, | |
| processed_seed | |
| ) | |
| def submit( | |
| previous_result, image, style_modern, style_classic, style_global, prompt, | |
| a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, | |
| num_steps, guidance_scale, seed, randomize_seed, progress=gr.Progress(track_tqdm=True) | |
| ): | |
| # First, yield the previous result to update the input image immediately | |
| yield previous_result, gr.update() | |
| # Get active style | |
| active_style = get_active_style(style_modern, style_classic, style_global) | |
| # Apply seed randomization | |
| processed_seed = randomize_seed_fn(seed, randomize_seed) | |
| print(f"Using processed seed: {processed_seed}") | |
| # Then, process the new input image | |
| new_result = process_image( | |
| previous_result, active_style, prompt, a_prompt, | |
| n_prompt, num_images, image_resolution, | |
| preprocess_resolution, num_steps, guidance_scale, | |
| processed_seed | |
| ) | |
| # Finally, yield the new result | |
| yield previous_result, new_result | |
| # Turn off buttons when processing | |
| def turn_buttons_off(): | |
| return gr.update(interactive=False, value="Processing..."), gr.update(interactive=False) | |
| # Turn on buttons when processing is complete | |
| def turn_buttons_on(): | |
| return gr.update(interactive=True, value="♻️ Use Result as New Input"), gr.update(interactive=True, value="🚀 Generate Design") | |
| # Core Image Processing Function | |
| def process_image( | |
| image, | |
| style_selection, | |
| prompt, | |
| a_prompt, | |
| n_prompt, | |
| num_images, | |
| image_resolution, | |
| preprocess_resolution, | |
| num_steps, | |
| guidance_scale, | |
| seed, | |
| ): | |
| """ | |
| Processes an input image to generate a new image based on style and prompts. | |
| Args: | |
| image: Input PIL Image. | |
| style_selection: Name of the design style to apply. | |
| prompt: Custom design prompt. | |
| a_prompt: Additional positive prompt. | |
| n_prompt: Negative prompt. | |
| num_images: Number of images to generate (currently only 1 supported by pipeline). | |
| image_resolution: Resolution for the output image. | |
| preprocess_resolution: Resolution for the preprocessor. | |
| num_steps: Number of inference steps. | |
| guidance_scale: Guidance scale for the diffusion process. | |
| seed: Random seed for reproducibility. | |
| Returns: | |
| A PIL Image of the generated result. | |
| """ | |
| # Use the seed passed from the event handler | |
| current_seed = seed | |
| generator = torch.cuda.manual_seed(current_seed) if torch.cuda.is_available() else torch.manual_seed(current_seed) | |
| if preprocessor.name != "NormalBae": | |
| preprocessor.load("NormalBae") | |
| preprocessor.model.to("cuda" if torch.cuda.is_available() else "cpu") | |
| control_image = preprocessor( | |
| image=image, | |
| image_resolution=image_resolution, | |
| detect_resolution=preprocess_resolution, | |
| ) | |
| # Construct the full prompt | |
| if style_selection and style_selection != "None": | |
| style_prompt = apply_style(style_selection) | |
| prompt_parts = [f"Photo from Pinterest of {prompt}" if prompt else None, style_prompt if style_prompt else None, a_prompt if a_prompt else None] | |
| full_prompt = ", ".join(filter(None, prompt_parts)) | |
| else: | |
| full_prompt = get_prompt(prompt, a_prompt) | |
| negative_prompt = str(n_prompt) | |
| print(f"Using prompt: {full_prompt}") | |
| print(f"Using negative prompt: {negative_prompt}") | |
| print(f"Using seed: {current_seed}") | |
| pipe.to("cuda" if torch.cuda.is_available() else "cpu") | |
| with torch.no_grad(): | |
| initial_result = pipe( | |
| prompt=full_prompt, | |
| negative_prompt=negative_prompt, | |
| guidance_scale=guidance_scale, | |
| num_images_per_prompt=1, | |
| num_inference_steps=num_steps, | |
| generator=generator, | |
| image=control_image, | |
| ).images[0] | |
| # Save and upload results (optional) | |
| try: | |
| timestamp = int(time.time()) | |
| results_path = f"{timestamp}_output.jpg" | |
| imageio.imsave(results_path, initial_result) | |
| if API_KEY: | |
| print(f"Uploading result image to broyang/interior-ai-outputs/{results_path}") | |
| try: | |
| api.upload_file( | |
| path_or_fileobj=results_path, | |
| path_in_repo=results_path, | |
| repo_id="broyang/interior-ai-outputs", | |
| repo_type="dataset", | |
| token=API_KEY, | |
| run_as_future=True, | |
| ) | |
| except Exception as e: | |
| print(f"Error uploading file to Hugging Face Hub: {e}") | |
| else: | |
| print("Hugging Face API Key not found, skipping file upload.") | |
| except Exception as e: | |
| print(f"Error saving or uploading image: {e}") | |
| return initial_result | |
| # Launch the Gradio app | |
| if prod: | |
| demo.queue(max_size=20).launch(server_name="localhost", server_port=port) | |
| else: | |
| demo.queue().launch(share=True, show_api=False) |