import gradio as gr from diffusers import StableDiffusionPipeline import torch from PIL import Image import re import uuid import gc # Load word list for safety checking (using a simple list instead of loading dataset) BLOCKED_WORDS = ["nsfw", "nude", "explicit"] # Add more as needed # Initialize the pipeline with CPU optimization def initialize_pipeline(): pipe = StableDiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float32 # Use float32 for CPU ) pipe = pipe.to("cpu") # Enable memory efficient attention pipe.enable_attention_slicing() return pipe pipe = initialize_pipeline() def cleanup(): """Force garbage collection to free memory""" gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def infer(prompt, negative, scale): # Safety check prompt = prompt.lower() for word in BLOCKED_WORDS: if word in prompt: raise gr.Error("Unsafe content found. Please try again with different prompts.") try: # Generate only one image at a time to conserve memory images = pipe( prompt=prompt, negative_prompt=negative, guidance_scale=scale, num_inference_steps=30, # Reduced steps for faster generation num_images_per_prompt=1 ).images # Save image output_path = f"{uuid.uuid4()}.jpg" images[0].save(output_path) # Cleanup to free memory cleanup() return [output_path] except Exception as e: cleanup() raise gr.Error(f"Generation failed: {str(e)}") css = """ .gradio-container { max-width: 768px !important; font-family: 'IBM Plex Sans', sans-serif; } .gr-button { color: white; border-color: black; background: black; } input[type='range'] { accent-color: black; } .dark input[type='range'] { accent-color: #dfdfdf; } #gallery { min-height: 22rem; margin-bottom: 15px; } #gallery>div>.h-full { min-height: 20rem; } """ examples = [ [ 'A small cabin on top of a snowy mountain, artstation style', 'low quality, ugly', 9 ], [ 'A red apple on a wooden table, still life', 'low quality', 9 ], ] with gr.Blocks(css=css) as block: gr.HTML( """
Optimized for CPU usage with 16GB RAM
Running on CPU - Please allow longer generation times