Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import random | |
| import uuid | |
| from typing import Tuple, Dict | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import spaces | |
| import torch | |
| from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler | |
| DESCRIPTIONz= """## SDXL-LoRA-DLC β‘ | |
| Select a base model, choose a LoRA, and generate images! | |
| """ | |
| # --- Constants --- | |
| MAX_SEED = np.iinfo(np.int32).max | |
| DEFAULT_STYLE_NAME = "3840 x 2160" | |
| USE_TORCH_COMPILE = False # Set to True if you want to try torch compile (might be faster but requires compatible hardware/drivers) | |
| ENABLE_CPU_OFFLOAD = False # Set to True to offload parts of the model to CPU (saves VRAM but slower) | |
| # --- Model Definitions --- | |
| # Dictionary mapping user-friendly names to Hugging Face model IDs | |
| pipelines_info = { | |
| "RealVisXL V4.0 Lightning": "SG161222/RealVisXL_V4.0_Lightning", | |
| "RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning", | |
| # Add more SDXL base models here if desired | |
| # "Another SDXL Model": "stabilityai/stable-diffusion-xl-base-1.0", # Example | |
| } | |
| # Dictionary to cache loaded pipelines | |
| loaded_pipelines: Dict[str, StableDiffusionXLPipeline] = {} | |
| # --- LoRA Definitions --- | |
| LORA_OPTIONS = { | |
| # Name: (HuggingFace Repo ID, Weight Filename, Adapter Name) | |
| "Realism (face/character)π¦π»": ("prithivMLmods/Canopus-Realism-LoRA", "Canopus-Realism-LoRA.safetensors", "rlms"), | |
| "Pixar (art/toons)π": ("prithivMLmods/Canopus-Pixar-Art", "Canopus-Pixar-Art.safetensors", "pixar"), | |
| "Photoshoot (camera/film)πΈ": ("prithivMLmods/Canopus-Photo-Shoot-Mini-LoRA", "Canopus-Photo-Shoot-Mini-LoRA.safetensors", "photo"), | |
| "Clothing (hoodies/pant/shirts)π": ("prithivMLmods/Canopus-Clothing-Adp-LoRA", "Canopus-Dress-Clothing-LoRA.safetensors", "clth"), | |
| "Interior Architecture (house/hotel)π ": ("prithivMLmods/Canopus-Interior-Architecture-0.1", "Canopus-Interior-Architecture-0.1Ξ΄.safetensors", "arch"), | |
| "Fashion Product (wearing/usable)π": ("prithivMLmods/Canopus-Fashion-Product-Dilation", "Canopus-Fashion-Product-Dilation.safetensors", "fashion"), | |
| "Minimalistic Image (minimal/detailed)ποΈ": ("prithivMLmods/Pegasi-Minimalist-Image-Style", "Pegasi-Minimalist-Image-Style.safetensors", "minimalist"), | |
| "Modern Clothing (trend/new)π": ("prithivMLmods/Canopus-Modern-Clothing-Design", "Canopus-Modern-Clothing-Design.safetensors", "mdrnclth"), | |
| "Animaliea (farm/wild)π«": ("prithivMLmods/Canopus-Animaliea-Artism", "Canopus-Animaliea-Artism.safetensors", "Animaliea"), | |
| "Liquid Wallpaper (minimal/illustration)πΌοΈ": ("prithivMLmods/Canopus-Liquid-Wallpaper-Art", "Canopus-Liquid-Wallpaper-Minimalize-LoRA.safetensors", "liquid"), | |
| "Canes Cars (realistic/futurecars)π": ("prithivMLmods/Canes-Cars-Model-LoRA", "Canes-Cars-Model-LoRA.safetensors", "car"), | |
| "Pencil Art (characteristic/creative)βοΈ": ("prithivMLmods/Canopus-Pencil-Art-LoRA", "Canopus-Pencil-Art-LoRA.safetensors", "Pencil Art"), | |
| "Art Minimalistic (paint/semireal)π¨": ("prithivMLmods/Canopus-Art-Medium-LoRA", "Canopus-Art-Medium-LoRA.safetensors", "mdm"), | |
| } | |
| # --- Style Definitions --- | |
| style_list = [ | |
| { | |
| "name": "3840 x 2160", | |
| "prompt": "hyper-realistic 8K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic", | |
| "negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly, bad anatomy, worst quality, low quality", | |
| }, | |
| { | |
| "name": "2560 x 1440", | |
| "prompt": "hyper-realistic 4K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic", | |
| "negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly, bad anatomy, worst quality, low quality", | |
| }, | |
| { | |
| "name": "HD+", | |
| "prompt": "hyper-realistic 2K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic", | |
| "negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly, bad anatomy, worst quality, low quality", | |
| }, | |
| { | |
| "name": "Style Zero", | |
| "prompt": "{prompt}", | |
| "negative_prompt": "worst quality, low quality", # Added basic negative prompt | |
| }, | |
| ] | |
| styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list} | |
| STYLE_NAMES = list(styles.keys()) | |
| # --- Utility Functions --- | |
| def save_image(img): | |
| unique_name = str(uuid.uuid4()) + ".png" | |
| img.save(unique_name) | |
| return unique_name | |
| def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| return seed | |
| def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]: | |
| # Get the base style prompt and negative prompt | |
| base_p, base_n = styles.get(style_name, styles[DEFAULT_STYLE_NAME]) | |
| # Combine the base negative prompt with the user's negative prompt | |
| # Ensure user's negative prompt is appended correctly | |
| if negative and base_n: | |
| combined_n = f"{base_n}, {negative}" | |
| elif negative: | |
| combined_n = negative | |
| else: | |
| combined_n = base_n | |
| # Apply the positive prompt template | |
| final_p = base_p.replace("{prompt}", positive) | |
| return final_p, combined_n | |
| def load_predefined_images(): | |
| # Ensure the assets directory and images exist | |
| asset_dir = "assets" | |
| image_files = [ | |
| "1.png", "2.png", "3.png", | |
| "4.png", "5.png", "6.png", | |
| "7.png", "8.png", "9.png", | |
| ] | |
| predefined_images = [] | |
| if os.path.exists(asset_dir): | |
| for img_file in image_files: | |
| img_path = os.path.join(asset_dir, img_file) | |
| if os.path.exists(img_path): | |
| predefined_images.append(img_path) | |
| else: | |
| print(f"Warning: Predefined image not found: {img_path}") | |
| else: | |
| print(f"Warning: Asset directory not found: {asset_dir}") | |
| # If no images were found, return None or an empty list | |
| # to avoid errors in gr.Gallery | |
| return predefined_images if predefined_images else None | |
| # --- Core Generation Function --- | |
| def generate( | |
| selected_base_model_name: str, # New input for base model selection | |
| prompt: str, | |
| negative_prompt: str = "", | |
| use_negative_prompt: bool = False, | |
| seed: int = 0, | |
| width: int = 1024, | |
| height: int = 1024, | |
| guidance_scale: float = 3, | |
| num_inference_steps: int = 4, # Lightning models use fewer steps | |
| randomize_seed: bool = False, | |
| style_name: str = DEFAULT_STYLE_NAME, | |
| lora_choice: str = "Realism (face/character)π¦π»", | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| if not torch.cuda.is_available(): | |
| raise gr.Error("GPU not available. This Space requires a GPU to run.") | |
| seed = int(randomize_seed_fn(seed, randomize_seed)) | |
| torch.manual_seed(seed) # Ensure reproducibility if seed is fixed | |
| # --- Pipeline Loading and Caching --- | |
| pipe = None | |
| if selected_base_model_name in loaded_pipelines: | |
| print(f"Using cached pipeline: {selected_base_model_name}") | |
| pipe = loaded_pipelines[selected_base_model_name] | |
| else: | |
| print(f"Loading pipeline: {selected_base_model_name}") | |
| model_id = pipelines_info[selected_base_model_name] | |
| pipe = StableDiffusionXLPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16, | |
| use_safetensors=True, | |
| variant="fp16" if torch.cuda.is_available() else None # Use fp16 variant if available on GPU | |
| ) | |
| # Apply optimizations based on flags | |
| if ENABLE_CPU_OFFLOAD: | |
| print("Enabling CPU Offload") | |
| pipe.enable_model_cpu_offload() | |
| else: | |
| pipe.to("cuda") # Default: move entire pipeline to GPU | |
| # Configure scheduler (important for Lightning models) | |
| pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) | |
| # Load ALL LoRAs onto this newly loaded pipeline instance | |
| print(f"Loading LoRAs for {selected_base_model_name}...") | |
| for lora_name, (model_repo, weight_file, adapter_tag) in LORA_OPTIONS.items(): | |
| try: | |
| print(f" Loading LoRA: {lora_name} ({adapter_tag})") | |
| pipe.load_lora_weights(model_repo, weight_name=weight_file, adapter_name=adapter_tag) | |
| except Exception as e: | |
| print(f" Failed to load LoRA {lora_name}: {e}") | |
| # Optionally raise an error or continue without this LoRA | |
| # raise gr.Error(f"Failed to load LoRA {lora_name}. Check repo/file names.") | |
| if USE_TORCH_COMPILE: | |
| print("Attempting to compile UNet (may take time)...") | |
| try: | |
| pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) | |
| print("UNet compiled successfully.") | |
| except Exception as e: | |
| print(f"Torch compile failed: {e}. Running without compilation.") | |
| # Cache the fully loaded and configured pipeline | |
| loaded_pipelines[selected_base_model_name] = pipe | |
| print(f"Pipeline {selected_base_model_name} loaded and cached.") | |
| # --- Prompt Styling --- | |
| positive_prompt, effective_negative_prompt = apply_style(style_name, prompt, negative_prompt if use_negative_prompt else "") | |
| # --- LoRA Selection --- | |
| if lora_choice not in LORA_OPTIONS: | |
| raise gr.Error(f"Selected LoRA '{lora_choice}' not found in options.") | |
| _lora_repo, _lora_weight, lora_adapter_name = LORA_OPTIONS[lora_choice] | |
| print(f"Activating LoRA: {lora_choice} (Adapter: {lora_adapter_name})") | |
| pipe.set_adapters(lora_adapter_name) | |
| # Note: LoRA weight/scale is often handled within the pipeline or during loading. | |
| # If you need adjustable LoRA scale, you might need `add_weighted_adapter` or similar. | |
| # For simplicity here, we assume the default scale is used. | |
| # cross_attention_kwargs={"scale": 0.8} # Example if you need to set scale explicitly | |
| # --- Image Generation --- | |
| print("Starting image generation...") | |
| generator = torch.Generator("cuda").manual_seed(seed) | |
| images = pipe( | |
| prompt=positive_prompt, | |
| negative_prompt=effective_negative_prompt, | |
| width=width, | |
| height=height, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, # Use steps suitable for Lightning | |
| generator=generator, | |
| num_images_per_prompt=1, | |
| # cross_attention_kwargs=cross_attention_kwargs, # Add if scale needed | |
| output_type="pil", | |
| ).images | |
| image_paths = [save_image(img) for img in images] | |
| print("Image generation complete.") | |
| return image_paths, seed | |
| # --- Gradio UI --- | |
| css = ''' | |
| .gradio-container{max-width: 860px !important; margin: auto;} | |
| h1{text-align:center} | |
| .gr-prose { text-align: center; } | |
| #model-select-row { justify-content: center; } /* Center dropdowns */ | |
| /* Make gallery taller */ | |
| #result_gallery .h-\[400px\] { | |
| height: 600px !important; /* Adjust height as needed */ | |
| } | |
| #predefined_gallery .h-\[400px\] { | |
| height: 300px !important; /* Adjust height as needed */ | |
| } | |
| footer { visibility: hidden } | |
| ''' | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown(DESCRIPTIONz) | |
| with gr.Row(elem_id="model-select-row"): | |
| model_selector = gr.Dropdown( | |
| label="Select Base Model", | |
| choices=list(pipelines_info.keys()), | |
| value=list(pipelines_info.keys())[0], # Default to the first model | |
| scale=1 | |
| ) | |
| model_choice = gr.Dropdown( | |
| label="Select LoRA Style", | |
| choices=list(LORA_OPTIONS.keys()), | |
| value="Realism (face/character)π¦π»", # Default LoRA | |
| scale=1 | |
| ) | |
| with gr.Group(): | |
| with gr.Row(): | |
| prompt = gr.Text( | |
| label="Prompt", | |
| show_label=False, | |
| max_lines=2, # Allow slightly more room for prompt | |
| placeholder="Enter your prompt (e.g., 'Astronaut riding a horse')", | |
| container=False, | |
| scale=5, # Make prompt input wider | |
| ) | |
| run_button = gr.Button("Generate", scale=1, variant="primary") # Make button stand out | |
| # Use Tabs for Main Result and Examples/Gallery | |
| with gr.Tabs(): | |
| with gr.TabItem("Result", id="result_tab"): | |
| result = gr.Gallery( | |
| label="Generated Image", elem_id="result_gallery", | |
| columns=1, preview=True, show_label=False, height=600 # Make gallery taller | |
| ) | |
| # Display the seed used for the generated image | |
| used_seed = gr.Number(label="Seed Used", interactive=False) | |
| with gr.TabItem("Examples & Predefined Gallery", id="examples_tab"): | |
| gr.Markdown("### Prompt Examples") | |
| gr.Examples( | |
| examples=[ | |
| "cinematic photo, a man sitting on a chair in a dark room, realistic", # Realism example | |
| "pixar style 3d render of a cute cat astronaut exploring mars", # Pixar example | |
| "studio photography, high fashion model wearing a futuristic silver hoodie, dramatic lighting", # Photoshoot/Clothing example | |
| "minimalist vector art illustration of a mountain range at sunset, liquid style", # Minimalist/Liquid example | |
| "pencil sketch drawing of an old wise wizard with a long beard", # Pencil Art example | |
| ], | |
| inputs=[prompt], # Only update the prompt field from examples | |
| outputs=[result, used_seed], # Define outputs for example generation | |
| fn=lambda p: generate( # Need a lambda to pass default values for other args | |
| selected_base_model_name=list(pipelines_info.keys())[0], # Use default model for examples | |
| prompt=p, | |
| lora_choice="Realism (face/character)π¦π»", # Use default LoRA for examples | |
| # Add other default args from 'generate' signature if needed | |
| negative_prompt="(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation", | |
| use_negative_prompt=True, | |
| seed=0, # Or make examples use random seed? | |
| width=1024, | |
| height=1024, | |
| guidance_scale=3.0, | |
| num_inference_steps=4, | |
| randomize_seed=True, # Randomize seed for examples | |
| style_name=DEFAULT_STYLE_NAME, | |
| ), | |
| cache_examples=False, # Recalculate examples if needed | |
| label="Click an example to generate" | |
| ) | |
| gr.Markdown("### Predefined Image Gallery") | |
| predefined_gallery = gr.Gallery( | |
| label="Image Gallery", elem_id="predefined_gallery", | |
| columns=3, show_label=False, value=load_predefined_images(), height=300 | |
| ) | |
| with gr.Accordion("βοΈ Advanced Settings", open=False): | |
| style_selection = gr.Radio( | |
| show_label=True, | |
| container=True, | |
| interactive=True, | |
| choices=STYLE_NAMES, | |
| value=DEFAULT_STYLE_NAME, | |
| label="Image Quality Style", | |
| ) | |
| with gr.Row(): | |
| use_negative_prompt = gr.Checkbox(label="Use Negative Prompt", value=True) | |
| randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) | |
| negative_prompt = gr.Text( | |
| label="Negative Prompt", | |
| max_lines=2, | |
| value="(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, worst quality, low quality", | |
| placeholder="Enter concepts to avoid...", | |
| visible=True, # Initially visible, controlled by checkbox change | |
| ) | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=0, | |
| visible=True, # Initially visible, maybe hide if randomize is checked? | |
| interactive=True | |
| ) | |
| with gr.Row(): | |
| width = gr.Slider( | |
| label="Width", | |
| minimum=512, | |
| maximum=1536, # Adjusted max based on typical SDXL use | |
| step=64, | |
| value=1024, | |
| ) | |
| height = gr.Slider( | |
| label="Height", | |
| minimum=512, | |
| maximum=1536, # Adjusted max based on typical SDXL use | |
| step=64, | |
| value=1024, | |
| ) | |
| with gr.Row(): | |
| guidance_scale = gr.Slider( | |
| label="Guidance Scale (CFG)", | |
| minimum=0.0, | |
| maximum=10.0, # Lightning models often use low CFG | |
| step=0.1, | |
| value=1.5, # Default low CFG for Lightning | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Inference Steps", | |
| minimum=1, | |
| maximum=20, # Lightning models need very few steps | |
| step=1, | |
| value=4, # Default steps for Lightning | |
| ) | |
| # --- Event Listeners --- | |
| # Show/hide negative prompt input based on checkbox | |
| use_negative_prompt.change( | |
| fn=lambda x: gr.update(visible=x), | |
| inputs=use_negative_prompt, | |
| outputs=negative_prompt, | |
| api_name=False, | |
| ) | |
| # Show/hide seed slider based on randomize checkbox | |
| randomize_seed.change( | |
| fn=lambda x: gr.update(interactive=not x), # Make slider non-interactive if randomizing | |
| inputs=randomize_seed, | |
| outputs=seed, | |
| api_name=False, | |
| ) | |
| # Main generation trigger | |
| inputs_list = [ | |
| model_selector, # Add model selector | |
| prompt, | |
| negative_prompt, | |
| use_negative_prompt, | |
| seed, | |
| width, | |
| height, | |
| guidance_scale, | |
| num_inference_steps, # Add steps slider | |
| randomize_seed, | |
| style_selection, | |
| model_choice, # This is the LoRA choice dropdown | |
| ] | |
| outputs_list = [result, used_seed] # Output gallery and the seed number | |
| prompt.submit( | |
| fn=generate, | |
| inputs=inputs_list, | |
| outputs=outputs_list, | |
| api_name="run_prompt_submit" # Optional: Define API name | |
| ) | |
| run_button.click( | |
| fn=generate, | |
| inputs=inputs_list, | |
| outputs=outputs_list, | |
| api_name="run_button_click" # Optional: Define API name | |
| ) | |
| # --- Launch --- | |
| if __name__ == "__main__": | |
| if not torch.cuda.is_available(): | |
| print("Warning: No CUDA GPU detected. Running on CPU will be extremely slow or may fail.") | |
| DESCRIPTIONz += "\n<p>β οΈ<b>WARNING: No GPU detected. Running on CPU is very slow and may not work reliably.</b> Consider using a GPU instance.</p>" | |
| # Optionally disable parts of the UI or exit if CPU is unacceptable | |
| # exit() | |
| # Ensure asset directory exists for predefined images (optional but good practice) | |
| if not os.path.exists("assets"): | |
| print("Warning: 'assets' directory not found. Predefined images will not load.") | |
| demo.queue(max_size=20).launch(debug=False) # Set debug=True for more logs if needed |