stable_image / app.py
K00B404's picture
Create app.py
504d46f verified
import gradio as gr
from diffusers import StableDiffusionPipeline
import torch
from PIL import Image
import re
import uuid
import gc
# Load word list for safety checking (using a simple list instead of loading dataset)
BLOCKED_WORDS = ["nsfw", "nude", "explicit"] # Add more as needed
# Initialize the pipeline with CPU optimization
def initialize_pipeline():
pipe = StableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1-base",
torch_dtype=torch.float32 # Use float32 for CPU
)
pipe = pipe.to("cpu")
# Enable memory efficient attention
pipe.enable_attention_slicing()
return pipe
pipe = initialize_pipeline()
def cleanup():
"""Force garbage collection to free memory"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def infer(prompt, negative, scale):
# Safety check
prompt = prompt.lower()
for word in BLOCKED_WORDS:
if word in prompt:
raise gr.Error("Unsafe content found. Please try again with different prompts.")
try:
# Generate only one image at a time to conserve memory
images = pipe(
prompt=prompt,
negative_prompt=negative,
guidance_scale=scale,
num_inference_steps=30, # Reduced steps for faster generation
num_images_per_prompt=1
).images
# Save image
output_path = f"{uuid.uuid4()}.jpg"
images[0].save(output_path)
# Cleanup to free memory
cleanup()
return [output_path]
except Exception as e:
cleanup()
raise gr.Error(f"Generation failed: {str(e)}")
css = """
.gradio-container {
max-width: 768px !important;
font-family: 'IBM Plex Sans', sans-serif;
}
.gr-button {
color: white;
border-color: black;
background: black;
}
input[type='range'] {
accent-color: black;
}
.dark input[type='range'] {
accent-color: #dfdfdf;
}
#gallery {
min-height: 22rem;
margin-bottom: 15px;
}
#gallery>div>.h-full {
min-height: 20rem;
}
"""
examples = [
[
'A small cabin on top of a snowy mountain, artstation style',
'low quality, ugly',
9
],
[
'A red apple on a wooden table, still life',
'low quality',
9
],
]
with gr.Blocks(css=css) as block:
gr.HTML(
"""
<div style="text-align: center; margin: 0 auto;">
<h1 style="font-weight: 900; margin-bottom: 7px;">
Stable Diffusion 2.1 (CPU Version)
</h1>
<p style="margin-bottom: 10px; font-size: 94%;">
Optimized for CPU usage with 16GB RAM
</p>
</div>
"""
)
with gr.Group():
with gr.Row():
with gr.Column(scale=3):
text = gr.Textbox(
label="Enter your prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt"
)
negative = gr.Textbox(
label="Enter your negative prompt",
show_label=False,
max_lines=1,
placeholder="Enter a negative prompt"
)
with gr.Column(scale=1, min_width=150):
btn = gr.Button("Generate image")
gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery"
)
with gr.Accordion("Advanced settings", open=False):
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1,
maximum=20,
value=9,
step=0.1
)
gr.Examples(
examples=examples,
fn=infer,
inputs=[text, negative, guidance_scale],
outputs=[gallery],
cache_examples=True
)
text.submit(infer, inputs=[text, negative, guidance_scale], outputs=[gallery])
negative.submit(infer, inputs=[text, negative, guidance_scale], outputs=[gallery])
btn.click(infer, inputs=[text, negative, guidance_scale], outputs=[gallery])
gr.HTML(
"""
<div style="text-align: center; margin-top: 20px;">
<p>Running on CPU - Please allow longer generation times</p>
</div>
"""
)
block.queue().launch(show_error=True)