Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from diffusers import StableDiffusionPipeline | |
| import torch | |
| from PIL import Image | |
| import re | |
| import uuid | |
| import gc | |
| # Load word list for safety checking (using a simple list instead of loading dataset) | |
| BLOCKED_WORDS = ["nsfw", "nude", "explicit"] # Add more as needed | |
| # Initialize the pipeline with CPU optimization | |
| def initialize_pipeline(): | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-2-1-base", | |
| torch_dtype=torch.float32 # Use float32 for CPU | |
| ) | |
| pipe = pipe.to("cpu") | |
| # Enable memory efficient attention | |
| pipe.enable_attention_slicing() | |
| return pipe | |
| pipe = initialize_pipeline() | |
| def cleanup(): | |
| """Force garbage collection to free memory""" | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def infer(prompt, negative, scale): | |
| # Safety check | |
| prompt = prompt.lower() | |
| for word in BLOCKED_WORDS: | |
| if word in prompt: | |
| raise gr.Error("Unsafe content found. Please try again with different prompts.") | |
| try: | |
| # Generate only one image at a time to conserve memory | |
| images = pipe( | |
| prompt=prompt, | |
| negative_prompt=negative, | |
| guidance_scale=scale, | |
| num_inference_steps=30, # Reduced steps for faster generation | |
| num_images_per_prompt=1 | |
| ).images | |
| # Save image | |
| output_path = f"{uuid.uuid4()}.jpg" | |
| images[0].save(output_path) | |
| # Cleanup to free memory | |
| cleanup() | |
| return [output_path] | |
| except Exception as e: | |
| cleanup() | |
| raise gr.Error(f"Generation failed: {str(e)}") | |
| css = """ | |
| .gradio-container { | |
| max-width: 768px !important; | |
| font-family: 'IBM Plex Sans', sans-serif; | |
| } | |
| .gr-button { | |
| color: white; | |
| border-color: black; | |
| background: black; | |
| } | |
| input[type='range'] { | |
| accent-color: black; | |
| } | |
| .dark input[type='range'] { | |
| accent-color: #dfdfdf; | |
| } | |
| #gallery { | |
| min-height: 22rem; | |
| margin-bottom: 15px; | |
| } | |
| #gallery>div>.h-full { | |
| min-height: 20rem; | |
| } | |
| """ | |
| examples = [ | |
| [ | |
| 'A small cabin on top of a snowy mountain, artstation style', | |
| 'low quality, ugly', | |
| 9 | |
| ], | |
| [ | |
| 'A red apple on a wooden table, still life', | |
| 'low quality', | |
| 9 | |
| ], | |
| ] | |
| with gr.Blocks(css=css) as block: | |
| gr.HTML( | |
| """ | |
| <div style="text-align: center; margin: 0 auto;"> | |
| <h1 style="font-weight: 900; margin-bottom: 7px;"> | |
| Stable Diffusion 2.1 (CPU Version) | |
| </h1> | |
| <p style="margin-bottom: 10px; font-size: 94%;"> | |
| Optimized for CPU usage with 16GB RAM | |
| </p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Group(): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| text = gr.Textbox( | |
| label="Enter your prompt", | |
| show_label=False, | |
| max_lines=1, | |
| placeholder="Enter your prompt" | |
| ) | |
| negative = gr.Textbox( | |
| label="Enter your negative prompt", | |
| show_label=False, | |
| max_lines=1, | |
| placeholder="Enter a negative prompt" | |
| ) | |
| with gr.Column(scale=1, min_width=150): | |
| btn = gr.Button("Generate image") | |
| gallery = gr.Gallery( | |
| label="Generated images", | |
| show_label=False, | |
| elem_id="gallery" | |
| ) | |
| with gr.Accordion("Advanced settings", open=False): | |
| guidance_scale = gr.Slider( | |
| label="Guidance Scale", | |
| minimum=1, | |
| maximum=20, | |
| value=9, | |
| step=0.1 | |
| ) | |
| gr.Examples( | |
| examples=examples, | |
| fn=infer, | |
| inputs=[text, negative, guidance_scale], | |
| outputs=[gallery], | |
| cache_examples=True | |
| ) | |
| text.submit(infer, inputs=[text, negative, guidance_scale], outputs=[gallery]) | |
| negative.submit(infer, inputs=[text, negative, guidance_scale], outputs=[gallery]) | |
| btn.click(infer, inputs=[text, negative, guidance_scale], outputs=[gallery]) | |
| gr.HTML( | |
| """ | |
| <div style="text-align: center; margin-top: 20px;"> | |
| <p>Running on CPU - Please allow longer generation times</p> | |
| </div> | |
| """ | |
| ) | |
| block.queue().launch(show_error=True) |