dineth554's picture
Update app.py
34616b4 verified
import os
import subprocess
import sys
import streamlit as st
import numpy as np
from PIL import Image, ImageOps
from io import BytesIO
# Function to install packages
def install(package):
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
# List of required packages
required_packages = [
"streamlit",
"diffusers",
"imageio",
"pillow",
"transformers",
"torch",
"streamlit-drawable-canvas"
]
# Install required packages
for package in required_packages:
try:
__import__(package)
except ImportError:
install(package)
# Import after installation to ensure packages are available
from diffusers import DiffusionPipeline
from streamlit_drawable_canvas import st_canvas
MY_SECRET_TOKEN = os.environ.get('HF_TOKEN_SD')
YOUR_TOKEN = MY_SECRET_TOKEN
device = "cpu"
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", use_auth_token=YOUR_TOKEN)
pipe.to(device)
def resize(height, img):
baseheight = height
hpercent = (baseheight / float(img.size[1]))
wsize = int((float(img.size[0]) * float(hpercent)))
img = img.resize((wsize, baseheight), Image.Resampling.LANCZOS)
return img
def predict(source_img, mask_img, prompt, brush_size):
source_img = ImageOps.exif_transpose(Image.open(BytesIO(source_img)))
mask_img = Image.open(BytesIO(mask_img))
src = resize(512, source_img)
src.save("src.png")
mask = resize(512, mask_img)
mask.save("mask.png")
images_list = pipe([prompt] * 1, image=src, mask_image=mask, strength=0.75, brush_size=brush_size)
images = []
safe_image = Image.open(r"unsafe.png")
for i, image in enumerate(images_list["images"]):
if images_list["nsfw_content_detected"][i]:
images.append(safe_image)
else:
images.append(image)
return images
st.title("InPainting Stable Diffusion CPU")
st.write("Inpainting Stable Diffusion example using CPU and HF token. Warning: Slow process... ~5/10 min inference time. NSFW filter enabled. Please use 512*512 square image as input to avoid memory error!")
uploaded_file = st.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
brush_size = st.slider("Brush Size", min_value=1, max_value=20, value=5)
prompt = st.text_input("Prompt")
if uploaded_file:
# Read uploaded file as bytes
uploaded_file_bytes = uploaded_file.read()
# Display the uploaded image
st.image(uploaded_file_bytes, caption="Uploaded Image", use_column_width=True)
st.write("Draw your mask on the canvas below:")
canvas_result = st_canvas(
fill_color="rgba(255, 255, 255, 0)", # Fixed fill color with alpha
stroke_width=brush_size,
stroke_color="black",
background_color="rgba(0, 0, 0, 0)",
background_image=Image.open(BytesIO(uploaded_file_bytes)),
update_streamlit=True,
height=512,
width=512,
drawing_mode="freedraw",
key="canvas",
)
if st.button("Generate"):
if canvas_result.image_data is not None:
mask_data = canvas_result.image_data
mask_img = Image.fromarray(np.uint8(mask_data)).convert("L")
mask_bytes = BytesIO()
mask_img.save(mask_bytes, format="PNG")
mask_bytes = mask_bytes.getvalue()
images = predict(uploaded_file_bytes, mask_bytes, prompt, brush_size)
st.image(images, caption='Generated images', use_column_width=True)
else:
st.write("Please draw a mask on the canvas.")