iptng / app.py
Arnold Manzano
remove noise
9aed7c0
import gradio as gr
import torch
from diffusers import StableDiffusionInpaintPipeline, LCMScheduler
from PIL import Image
# 1. Load the most compatible inpainting model
model_id = "runwayml/stable-diffusion-inpainting"
print("Loading Stable Diffusion Inpainting to CPU...")
pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id,
torch_dtype=torch.float32,
low_cpu_mem_usage=True
)
# 2. THE "OFF SWITCH" FOR NSFW
# This bypasses the built-in safety checker completely.
pipe.safety_checker = None
pipe.requires_safety_checker = False
# 3. CPU OPTIMIZATIONS
pipe.to("cpu")
pipe.enable_attention_slicing()
def predict(image_data, prompt):
if image_data is None or not prompt:
return None
# 1. Prep Images
raw_bg = image_data["background"].convert("RGB")
raw_layer = image_data["layers"][-1]
raw_mask = raw_layer.split()[-1]
raw_mask = raw_mask.point(lambda x: 255 if x > 0 else 0)
orig_w, orig_h = raw_bg.size
scale = 512 / max(orig_w, orig_h)
new_w, new_h = int((orig_w * scale) // 8) * 8, int((orig_h * scale) // 8) * 8
base_image = raw_bg.resize((new_w, new_h), Image.LANCZOS)
mask_image = raw_mask.resize((new_w, new_h), Image.NEAREST)
# return mask_image.resize((orig_w, orig_h), Image.NEAREST)
# 3. RUN THE MODEL
result = pipe(
prompt=prompt,
image=base_image, # Use the noised version!
mask_image=mask_image,
num_inference_steps=20, # Increased slightly for better detail on CPU
guidance_scale=8.0, # Increased to make the "Cat" more likely to appear
).images[0]
# 4. Final Alignment & Composite
result = result.resize(base_image.size, Image.LANCZOS)
final_segmented = Image.composite(result, base_image, mask_image)
return final_segmented.resize((orig_w, orig_h), Image.LANCZOS)
# 4. UI SETUP
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎨 Fast CPU Inpainter (Unrestricted)")
gr.Markdown("Draw on the image and describe what to add. **CPU Time: ~45 seconds.**")
with gr.Row():
input_img = gr.ImageMask(label="Upload & Mask", type="pil")
output_img = gr.Image(label="Result")
prompt_text = gr.Textbox(label="Prompt")
run_btn = gr.Button("Generate", variant="primary")
run_btn.click(predict, [input_img, prompt_text], output_img)
demo.launch()