File size: 8,929 Bytes
b8877ca febc264 b8877ca fbfadc4 855692d b8877ca a2cf62b bad5b88 a2cf62b bad5b88 82a327a bad5b88 82a327a bad5b88 82a327a a643557 82a327a bad5b88 a2cf62b febc264 82a327a bad5b88 fbfadc4 b8877ca fbfadc4 b8877ca 855692d fbfadc4 b8877ca c44a3aa 855692d fbfadc4 b8877ca fbfadc4 b8877ca fbfadc4 b8877ca fbfadc4 b8877ca fbfadc4 b8877ca fbfadc4 855692d fbfadc4 b8877ca fbfadc4 b8877ca 855692d b8877ca 855692d b8877ca 855692d b8877ca fbfadc4 855692d fbfadc4 b8877ca fbfadc4 a2cf62b b8877ca fbfadc4 855692d b8877ca fbfadc4 a2cf62b fbfadc4 b8877ca fbfadc4 855692d b8877ca fbfadc4 b8877ca 855692d fbfadc4 febc264 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 | # safe_diffusion_guidance.py
import os
from typing import Optional, List
import torch
import torch.nn as nn
from diffusers import DiffusionPipeline, StableDiffusionPipeline
from diffusers.utils import BaseOutput
import torch, torch.nn as nn, os
from typing import Optional
CLASS_NAMES = ['gore', 'hate', 'medical', 'safe', 'sexual']
class SafetyClassifier1280(nn.Module):
def __init__(self, num_classes: int = 5):
super().__init__()
self.pre = nn.AdaptiveAvgPool2d((8, 8))
self.model = nn.Sequential( # <--- use "model" to match checkpoint
nn.Conv2d(1280, 512, 3, padding=1),
nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.MaxPool2d(2),
nn.Conv2d(512, 256, 3, padding=1),
nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.MaxPool2d(2),
nn.AdaptiveAvgPool2d(1), nn.Flatten(),
nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Dropout(0.3),
nn.Linear(128, num_classes)
)
self.apply(self._init_weights)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pre(x)
return self.model(x) # <--- forward through "model"
def _find_weights_path() -> str:
# 1) explicit env; 2) repo root file; 3) classifiers/ subdir
env_p = os.getenv("SDG_CLASSIFIER_WEIGHTS")
if env_p and os.path.exists(env_p): return env_p
for p in ["safety_classifier_1280.pth", os.path.join("classifiers","safety_classifier_1280.pth")]:
if os.path.exists(p): return p
# If running from HF cache, these paths are relative to the cached repo folder.
raise FileNotFoundError(
"Safety-classifier weights not found. Provide via env SDG_CLASSIFIER_WEIGHTS, "
"place 'safety_classifier_1280.pth' at repo root or 'classifiers/', "
"or pass `classifier_weights=...` to the pipeline call."
)
def load_classifier_1280(weights_path: str, device=None, dtype=torch.float32):
model = SafetyClassifier1280().to(device or "cpu", dtype=dtype)
state = torch.load(weights_path, map_location="cpu", weights_only=False)
if isinstance(state, dict) and "model_state_dict" in state:
state = state["model_state_dict"]
model.load_state_dict(state, strict=True)
model.eval()
return model
def _here(*paths: str) -> str:
return os.path.join(os.path.dirname(__file__), *paths)
def pick_weights_path() -> str:
"""
Try common locations; allow env override. Raise if not found.
"""
candidates = [
os.getenv("SDG_CLASSIFIER_WEIGHTS", ""),
_here("classifiers", "safety_classifier_1280.pth"),
_here("safety_classifier_1280.pth"),
"classifiers/safety_classifier_1280.pth",
"safety_classifier_1280.pth",
]
for p in candidates:
if p and os.path.exists(p):
return p
raise FileNotFoundError(
"Safety-classifier weights not found. Place 'safety_classifier_1280.pth' "
"in repo root or 'classifiers/' (or set SDG_CLASSIFIER_WEIGHTS, or pass "
"`classifier_weights=...` to the call())."
)
# ----------------------------- Pipeline --------------------------------------
class SDGOutput(BaseOutput):
images: List # list of PIL Images
class SafeDiffusionGuidance(DiffusionPipeline):
"""
Minimal custom pipeline that loads a base Stable Diffusion pipeline on demand
and applies mid-UNet classifier-guided denoising for safety.
"""
def __init__(self,**kwargs): # IMPORTANT: no **kwargs (diffusers inspects this)
super().__init__()
self.base_pipe_ = None # lazy cache
def _ensure_base(
self,
base_pipe: Optional[StableDiffusionPipeline],
base_model_id: str,
torch_dtype: torch.dtype,
) -> StableDiffusionPipeline:
if base_pipe is not None:
self.base_pipe_ = base_pipe
return self.base_pipe_
if self.base_pipe_ is None:
self.base_pipe_ = StableDiffusionPipeline.from_pretrained(
base_model_id,
torch_dtype=torch_dtype,
safety_checker=None,
requires_safety_checker=False,
).to(self.device)
return self.base_pipe_
@torch.no_grad()
def __call__(
self,
prompt: str,
negative_prompt: Optional[str] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
safety_scale: float = 5.0,
mid_fraction: float = 1.0, # 0..1 fraction of steps to guide
safe_class_index: int = 3, # "safe" in CLASS_NAMES
classifier_weights: Optional[str] = None,
base_pipe: Optional[StableDiffusionPipeline] = None,
base_model_id: str = "runwayml/stable-diffusion-v1-5",
generator: Optional[torch.Generator] = None,
**kwargs,
) -> SDGOutput:
# 1) prepare base SD
base = self._ensure_base(base_pipe, base_model_id, torch_dtype=torch.float16)
device = getattr(base, "_execution_device", base.device)
dtype = base.unet.dtype
# 2) text embeddings (classifier-free guidance)
tok = base.tokenizer
max_len = tok.model_max_length
txt = tok([prompt], padding="max_length", max_length=max_len, return_tensors="pt")
cond = base.text_encoder(txt.input_ids.to(device)).last_hidden_state
if negative_prompt is not None:
uncond_txt = tok([negative_prompt], padding="max_length", max_length=max_len, return_tensors="pt")
else:
uncond_txt = tok([""], padding="max_length", max_length=max_len, return_tensors="pt")
uncond = base.text_encoder(uncond_txt.input_ids.to(device)).last_hidden_state
cond_embeds = torch.cat([uncond, cond], dim=0)
# 3) latents
h = kwargs.pop("height", 512); w = kwargs.pop("width", 512)
latents = torch.randn(
(1, base.unet.in_channels, h // 8, w // 8),
device=device, generator=generator, dtype=dtype
)
base.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = base.scheduler.timesteps
# 4) classifier (run in fp32)
weights = classifier_weights or pick_weights_for_pipe(base)
clf = load_classifier_1280(weights, device=device, dtype=torch.float32).eval()
# 5) mid-block hook
mid = {}
def hook(_, __, out): mid["feat"] = out[0] if isinstance(out, tuple) else out
handle = base.unet.mid_block.register_forward_hook(hook)
base_alpha = 1e-3 # step size factor for safety update
# 6) denoising loop
for i, t in enumerate(timesteps):
# standard SD forward
lat_in = base.scheduler.scale_model_input(latents, t)
lat_cat = torch.cat([lat_in, lat_in], dim=0) # for CFG
do_guide = (i / len(timesteps)) <= mid_fraction and safety_scale > 0
if do_guide:
# safety gradient w.r.t latents
with torch.enable_grad():
lg = latents.detach().clone().requires_grad_(True)
lin = base.scheduler.scale_model_input(lg, t)
lcat = torch.cat([lin, lin], dim=0)
_ = base.unet(lcat, t, encoder_hidden_states=cond_embeds).sample
feat = mid["feat"].detach().to(torch.float32)
logits = clf(feat)
probs = torch.softmax(logits, dim=-1)
unsafe = 1.0 - probs[:, safe_class_index].mean() # encourage "safe"
loss = safety_scale * unsafe
loss.backward()
alpha = base_alpha
if hasattr(base.scheduler, "sigmas"): # DDIM/PNDM/… support
idx = min(i, len(base.scheduler.sigmas) - 1)
alpha = base_alpha * float(base.scheduler.sigmas[idx])
latents = (lg - alpha * lg.grad).detach()
# resume SD denoising with updated latents
lat_in = base.scheduler.scale_model_input(latents, t)
lat_cat = torch.cat([lat_in, lat_in], dim=0)
noise_pred = base.unet(lat_cat, t, encoder_hidden_states=cond_embeds).sample
n_uncond, n_text = noise_pred.chunk(2)
noise = n_uncond + guidance_scale * (n_text - n_uncond)
latents = base.scheduler.step(noise, t, latents).prev_sample
handle.remove()
# 7) decode
img = base.decode_latents(latents)
pil = base.image_processor.postprocess(img, output_type="pil")[0]
return SDGOutput(images=[pil])
__all__ = ["SafeDiffusionGuidance"]
|