Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from diffusers import StableDiffusionPipeline | |
| from huggingface_hub import hf_hub_download | |
| import numpy as np | |
| from PIL import Image | |
| import os | |
| import gc | |
| # Suppress symlink warnings | |
| os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = "1" | |
| # Define styles | |
| styles = { | |
| "glitch": { | |
| "concept_url": "sd-concepts-library/001glitch-core", | |
| "seed": 42, | |
| "token": "<glitch-core>" | |
| }, | |
| "roth": { | |
| "concept_url": "sd-concepts-library/2814-roth", | |
| "seed": 123, | |
| "token": "<2814-roth>" | |
| }, | |
| "night": { | |
| "concept_url": "sd-concepts-library/4tnght", | |
| "seed": 456, | |
| "token": "<4tnght>" | |
| }, | |
| "anime80s": { | |
| "concept_url": "sd-concepts-library/80s-anime-ai", | |
| "seed": 789, | |
| "token": "<80s-anime>" | |
| }, | |
| "animeai": { | |
| "concept_url": "sd-concepts-library/80s-anime-ai-being", | |
| "seed": 1024, | |
| "token": "<80s-anime-being>" | |
| } | |
| } | |
| # Pre-generate example images | |
| example_images = { | |
| "glitch": "examples/glitch_example.jpg", | |
| "anime80s": "examples/anime80s_example.jpg", | |
| "night": "examples/night_example.jpg" | |
| } | |
| def load_pipeline(): | |
| """Load and prepare the pipeline with all style embeddings""" | |
| # Check if CUDA is available | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| # Use smaller model for CPU | |
| model_id = "runwayml/stable-diffusion-v1-5" if device == "cuda" else "CompVis/stable-diffusion-v1-4" | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=True | |
| ).to(device) | |
| # Load all embeddings | |
| for style_info in styles.values(): | |
| embedding_path = hf_hub_download( | |
| repo_id=style_info["concept_url"], | |
| filename="learned_embeds.bin", | |
| repo_type="model" | |
| ) | |
| pipe.load_textual_inversion(embedding_path) | |
| return pipe | |
| def apply_purple_guidance(image, strength=0.5): | |
| """Apply purple guidance to an image""" | |
| img_array = np.array(image).astype(float) | |
| purple_mask = (img_array[:,:,0] > 100) & (img_array[:,:,2] > 100) | |
| img_array[purple_mask] = img_array[purple_mask] * (1 - strength) + np.array([128, 0, 128]) * strength | |
| return Image.fromarray(np.uint8(img_array.clip(0, 255))) | |
| def generate_image(prompt, style, seed, apply_guidance, guidance_strength=0.5): | |
| """Generate an image with selected style and optional purple guidance""" | |
| # Check if this is one of our examples with pre-generated images | |
| if prompt == "A serene mountain landscape with a lake at sunset" and style == "glitch" and seed == 42: | |
| if os.path.exists(example_images["glitch"]): | |
| image = Image.open(example_images["glitch"]) | |
| if apply_guidance: | |
| image = apply_purple_guidance(image, guidance_strength) | |
| return image | |
| if prompt == "A magical forest at twilight" and style == "anime80s" and seed == 789: | |
| if os.path.exists(example_images["anime80s"]): | |
| image = Image.open(example_images["anime80s"]) | |
| if apply_guidance: | |
| image = apply_purple_guidance(image, guidance_strength) | |
| return image | |
| if prompt == "A cyberpunk city at night" and style == "night" and seed == 456: | |
| if os.path.exists(example_images["night"]): | |
| image = Image.open(example_images["night"]) | |
| if apply_guidance: | |
| image = apply_purple_guidance(image, guidance_strength) | |
| return image | |
| if style not in styles: | |
| return None | |
| # Get style info | |
| style_info = styles[style] | |
| # Prepare generator with appropriate device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| generator = torch.Generator(device).manual_seed(int(seed)) | |
| # Create styled prompt | |
| styled_prompt = f"{prompt} {style_info['token']}" | |
| # Generate image with reduced settings for CPU | |
| if device == "cpu": | |
| # Use much smaller image size and fewer steps on CPU | |
| image = pipe( | |
| styled_prompt, | |
| generator=generator, | |
| guidance_scale=7.5, | |
| num_inference_steps=10, # Reduced steps | |
| height=256, # Smaller height | |
| width=256 # Smaller width | |
| ).images[0] | |
| else: | |
| image = pipe( | |
| styled_prompt, | |
| generator=generator, | |
| guidance_scale=7.5, | |
| num_inference_steps=20 | |
| ).images[0] | |
| # Apply purple guidance if requested | |
| if apply_guidance: | |
| image = apply_purple_guidance(image, guidance_strength) | |
| # Clean up memory | |
| gc.collect() | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| return image | |
| # Create examples directory | |
| os.makedirs("examples", exist_ok=True) | |
| # Initialize the pipeline globally | |
| print("Loading pipeline and embeddings...") | |
| pipe = load_pipeline() | |
| # Create the Gradio interface | |
| demo = gr.Interface( | |
| fn=generate_image, | |
| inputs=[ | |
| gr.Textbox(label="Prompt", value="A serene mountain landscape with a lake at sunset"), | |
| gr.Radio(choices=list(styles.keys()), label="Style", value="glitch"), | |
| gr.Number(label="Seed", value=42), | |
| gr.Checkbox(label="Apply Purple Guidance", value=False), | |
| gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Purple Guidance Strength") | |
| ], | |
| outputs=gr.Image(label="Generated Image"), | |
| title="Style-Guided Image Generation with Purple Enhancement", | |
| description="""Generate images in different styles with optional purple color guidance. | |
| Choose a style, enter a prompt, and optionally apply purple color enhancement. | |
| Note: Generation may take a few minutes on CPU.""", | |
| examples=[ | |
| ["A serene mountain landscape with a lake at sunset", "glitch", 42, True, 0.5], | |
| ["A magical forest at twilight", "anime80s", 789, True, 0.7], | |
| ["A cyberpunk city at night", "night", 456, False, 0.5], | |
| ], | |
| cache_examples=True, | |
| allow_flagging="never" # Disable flagging to reduce overhead | |
| ) | |
| if __name__ == "__main__": | |
| # Generate and save example images if they don't exist | |
| if not all(os.path.exists(path) for path in example_images.values()): | |
| print("Pre-generating example images...") | |
| # Example 1 | |
| if not os.path.exists(example_images["glitch"]): | |
| img = generate_image("A serene mountain landscape with a lake at sunset", "glitch", 42, False, 0.5) | |
| img.save(example_images["glitch"]) | |
| # Example 2 | |
| if not os.path.exists(example_images["anime80s"]): | |
| img = generate_image("A magical forest at twilight", "anime80s", 789, False, 0.7) | |
| img.save(example_images["anime80s"]) | |
| # Example 3 | |
| if not os.path.exists(example_images["night"]): | |
| img = generate_image("A cyberpunk city at night", "night", 456, False, 0.5) | |
| img.save(example_images["night"]) | |
| # Launch the app | |
| demo.launch(share=False, show_error=True) |