Spaces:
Sleeping
Sleeping
File size: 7,667 Bytes
68e4b96 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 | import torch
import numpy as np
from diffusers import StableDiffusionInpaintPipeline, DDIMScheduler
from PIL import Image
class LEDITSModel:
"""
Implementation of LEDITS++ model for localized image editing using Stable Diffusion.
"""
def __init__(self, model_id="runwayml/stable-diffusion-inpainting", device=None):
"""
Initialize the LEDITS++ model.
Args:
model_id (str): Hugging Face model ID for the Stable Diffusion inpainting model
device (str, optional): Device to run the model on ('cuda' or 'cpu')
"""
self.model_id = model_id
# Determine device
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
# Model will be loaded on first use to save memory
self.pipe = None
def load_model(self):
"""
Load the Stable Diffusion inpainting model.
"""
if self.pipe is None:
# Load the pipeline with DDIM scheduler for better quality
scheduler = DDIMScheduler.from_pretrained(
self.model_id,
subfolder="scheduler"
)
self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
self.model_id,
scheduler=scheduler,
safety_checker=None # Disable safety checker for NSFW content as per user request
)
# Move to device
self.pipe = self.pipe.to(self.device)
# Enable memory optimization if on CUDA
if self.device == "cuda":
self.pipe.enable_attention_slicing()
def edit_image(self, image, mask, prompt, intensity=0.5, guidance_scale=7.5, num_inference_steps=30):
"""
Edit an image using the LEDITS++ approach.
Args:
image (numpy.ndarray): Input image (normalized to [0, 1])
mask (numpy.ndarray): Mask indicating the region to edit (values in [0, 1])
prompt (str): Text prompt describing the desired edit
intensity (float): Strength of the edit (0.0 to 1.0)
guidance_scale (float): Guidance scale for diffusion model
num_inference_steps (int): Number of denoising steps
Returns:
numpy.ndarray: Edited image
"""
# Load model if not already loaded
self.load_model()
# Convert numpy arrays to PIL Images
if isinstance(image, np.ndarray):
# Convert to uint8 if the image is float
if image.dtype == np.float32 or image.dtype == np.float64:
image_pil = Image.fromarray((image * 255).astype(np.uint8))
else:
image_pil = Image.fromarray(image)
else:
image_pil = image
if isinstance(mask, np.ndarray):
# Convert to uint8 if the mask is float
if mask.dtype == np.float32 or mask.dtype == np.float64:
mask_pil = Image.fromarray((mask * 255).astype(np.uint8))
else:
mask_pil = Image.fromarray(mask)
# Ensure mask is grayscale
if mask_pil.mode != 'L':
mask_pil = mask_pil.convert('L')
else:
mask_pil = mask
# Resize images to multiples of 8 (required by Stable Diffusion)
width, height = image_pil.size
new_width = width - (width % 8)
new_height = height - (height % 8)
if (new_width, new_height) != image_pil.size:
image_pil = image_pil.resize((new_width, new_height), Image.LANCZOS)
mask_pil = mask_pil.resize((new_width, new_height), Image.LANCZOS)
# Run the inpainting pipeline
with torch.no_grad():
output = self.pipe(
prompt=prompt,
image=image_pil,
mask_image=mask_pil,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
strength=intensity,
).images[0]
# Convert back to numpy array
output_np = np.array(output) / 255.0
return output_np
def __del__(self):
"""
Clean up resources when the object is deleted.
"""
if self.pipe is not None and self.device == "cuda":
try:
# Clear CUDA cache
torch.cuda.empty_cache()
except:
pass
class StableDiffusionModel:
"""
Implementation of Stable Diffusion model for image generation and editing.
"""
def __init__(self, model_id="runwayml/stable-diffusion-v1-5", device=None):
"""
Initialize the Stable Diffusion model.
Args:
model_id (str): Hugging Face model ID for the Stable Diffusion model
device (str, optional): Device to run the model on ('cuda' or 'cpu')
"""
self.model_id = model_id
# Determine device
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
# Model will be loaded on first use to save memory
self.pipe = None
def load_model(self):
"""
Load the Stable Diffusion model.
"""
if self.pipe is None:
from diffusers import StableDiffusionPipeline
self.pipe = StableDiffusionPipeline.from_pretrained(
self.model_id,
safety_checker=None # Disable safety checker for NSFW content as per user request
)
# Move to device
self.pipe = self.pipe.to(self.device)
# Enable memory optimization if on CUDA
if self.device == "cuda":
self.pipe.enable_attention_slicing()
def generate_image(self, prompt, negative_prompt="", width=512, height=512, guidance_scale=7.5, num_inference_steps=30):
"""
Generate an image using Stable Diffusion.
Args:
prompt (str): Text prompt describing the desired image
negative_prompt (str): Text prompt describing what to avoid
width (int): Width of the generated image
height (int): Height of the generated image
guidance_scale (float): Guidance scale for diffusion model
num_inference_steps (int): Number of denoising steps
Returns:
numpy.ndarray: Generated image
"""
# Load model if not already loaded
self.load_model()
# Run the pipeline
with torch.no_grad():
output = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
).images[0]
# Convert to numpy array
output_np = np.array(output) / 255.0
return output_np
def __del__(self):
"""
Clean up resources when the object is deleted.
"""
if self.pipe is not None and self.device == "cuda":
try:
# Clear CUDA cache
torch.cuda.empty_cache()
except:
pass
|