Spaces:
Build error
Build error
| import os | |
| import subprocess | |
| import sys | |
| import streamlit as st | |
| import numpy as np | |
| from PIL import Image, ImageOps | |
| from io import BytesIO | |
| # Function to install packages | |
| def install(package): | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", package]) | |
| # List of required packages | |
| required_packages = [ | |
| "streamlit", | |
| "diffusers", | |
| "imageio", | |
| "pillow", | |
| "transformers", | |
| "torch", | |
| "streamlit-drawable-canvas" | |
| ] | |
| # Install required packages | |
| for package in required_packages: | |
| try: | |
| __import__(package) | |
| except ImportError: | |
| install(package) | |
| # Import after installation to ensure packages are available | |
| from diffusers import DiffusionPipeline | |
| from streamlit_drawable_canvas import st_canvas | |
| MY_SECRET_TOKEN = os.environ.get('HF_TOKEN_SD') | |
| YOUR_TOKEN = MY_SECRET_TOKEN | |
| device = "cpu" | |
| pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", use_auth_token=YOUR_TOKEN) | |
| pipe.to(device) | |
| def resize(height, img): | |
| baseheight = height | |
| hpercent = (baseheight / float(img.size[1])) | |
| wsize = int((float(img.size[0]) * float(hpercent))) | |
| img = img.resize((wsize, baseheight), Image.Resampling.LANCZOS) | |
| return img | |
| def predict(source_img, mask_img, prompt, brush_size): | |
| source_img = ImageOps.exif_transpose(Image.open(BytesIO(source_img))) | |
| mask_img = Image.open(BytesIO(mask_img)) | |
| src = resize(512, source_img) | |
| src.save("src.png") | |
| mask = resize(512, mask_img) | |
| mask.save("mask.png") | |
| images_list = pipe([prompt] * 1, image=src, mask_image=mask, strength=0.75, brush_size=brush_size) | |
| images = [] | |
| safe_image = Image.open(r"unsafe.png") | |
| for i, image in enumerate(images_list["images"]): | |
| if images_list["nsfw_content_detected"][i]: | |
| images.append(safe_image) | |
| else: | |
| images.append(image) | |
| return images | |
| st.title("InPainting Stable Diffusion CPU") | |
| st.write("Inpainting Stable Diffusion example using CPU and HF token. Warning: Slow process... ~5/10 min inference time. NSFW filter enabled. Please use 512*512 square image as input to avoid memory error!") | |
| uploaded_file = st.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"]) | |
| brush_size = st.slider("Brush Size", min_value=1, max_value=20, value=5) | |
| prompt = st.text_input("Prompt") | |
| if uploaded_file: | |
| # Read uploaded file as bytes | |
| uploaded_file_bytes = uploaded_file.read() | |
| # Display the uploaded image | |
| st.image(uploaded_file_bytes, caption="Uploaded Image", use_column_width=True) | |
| st.write("Draw your mask on the canvas below:") | |
| canvas_result = st_canvas( | |
| fill_color="rgba(255, 255, 255, 0)", # Fixed fill color with alpha | |
| stroke_width=brush_size, | |
| stroke_color="black", | |
| background_color="rgba(0, 0, 0, 0)", | |
| background_image=Image.open(BytesIO(uploaded_file_bytes)), | |
| update_streamlit=True, | |
| height=512, | |
| width=512, | |
| drawing_mode="freedraw", | |
| key="canvas", | |
| ) | |
| if st.button("Generate"): | |
| if canvas_result.image_data is not None: | |
| mask_data = canvas_result.image_data | |
| mask_img = Image.fromarray(np.uint8(mask_data)).convert("L") | |
| mask_bytes = BytesIO() | |
| mask_img.save(mask_bytes, format="PNG") | |
| mask_bytes = mask_bytes.getvalue() | |
| images = predict(uploaded_file_bytes, mask_bytes, prompt, brush_size) | |
| st.image(images, caption='Generated images', use_column_width=True) | |
| else: | |
| st.write("Please draw a mask on the canvas.") | |