safe-diffusion-guidance / safe_diffusion_guidance.py
basimazam's picture
Upload SDG pipeline + classifier weights
82a327a verified
# safe_diffusion_guidance.py
import os
from typing import Optional, List
import torch
import torch.nn as nn
from diffusers import DiffusionPipeline, StableDiffusionPipeline
from diffusers.utils import BaseOutput
import torch, torch.nn as nn, os
from typing import Optional
CLASS_NAMES = ['gore', 'hate', 'medical', 'safe', 'sexual']
class SafetyClassifier1280(nn.Module):
def __init__(self, num_classes: int = 5):
super().__init__()
self.pre = nn.AdaptiveAvgPool2d((8, 8))
self.model = nn.Sequential( # <--- use "model" to match checkpoint
nn.Conv2d(1280, 512, 3, padding=1),
nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.MaxPool2d(2),
nn.Conv2d(512, 256, 3, padding=1),
nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.MaxPool2d(2),
nn.AdaptiveAvgPool2d(1), nn.Flatten(),
nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Dropout(0.3),
nn.Linear(128, num_classes)
)
self.apply(self._init_weights)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pre(x)
return self.model(x) # <--- forward through "model"
def _find_weights_path() -> str:
# 1) explicit env; 2) repo root file; 3) classifiers/ subdir
env_p = os.getenv("SDG_CLASSIFIER_WEIGHTS")
if env_p and os.path.exists(env_p): return env_p
for p in ["safety_classifier_1280.pth", os.path.join("classifiers","safety_classifier_1280.pth")]:
if os.path.exists(p): return p
# If running from HF cache, these paths are relative to the cached repo folder.
raise FileNotFoundError(
"Safety-classifier weights not found. Provide via env SDG_CLASSIFIER_WEIGHTS, "
"place 'safety_classifier_1280.pth' at repo root or 'classifiers/', "
"or pass `classifier_weights=...` to the pipeline call."
)
def load_classifier_1280(weights_path: str, device=None, dtype=torch.float32):
model = SafetyClassifier1280().to(device or "cpu", dtype=dtype)
state = torch.load(weights_path, map_location="cpu", weights_only=False)
if isinstance(state, dict) and "model_state_dict" in state:
state = state["model_state_dict"]
model.load_state_dict(state, strict=True)
model.eval()
return model
def _here(*paths: str) -> str:
return os.path.join(os.path.dirname(__file__), *paths)
def pick_weights_path() -> str:
"""
Try common locations; allow env override. Raise if not found.
"""
candidates = [
os.getenv("SDG_CLASSIFIER_WEIGHTS", ""),
_here("classifiers", "safety_classifier_1280.pth"),
_here("safety_classifier_1280.pth"),
"classifiers/safety_classifier_1280.pth",
"safety_classifier_1280.pth",
]
for p in candidates:
if p and os.path.exists(p):
return p
raise FileNotFoundError(
"Safety-classifier weights not found. Place 'safety_classifier_1280.pth' "
"in repo root or 'classifiers/' (or set SDG_CLASSIFIER_WEIGHTS, or pass "
"`classifier_weights=...` to the call())."
)
# ----------------------------- Pipeline --------------------------------------
class SDGOutput(BaseOutput):
images: List # list of PIL Images
class SafeDiffusionGuidance(DiffusionPipeline):
"""
Minimal custom pipeline that loads a base Stable Diffusion pipeline on demand
and applies mid-UNet classifier-guided denoising for safety.
"""
def __init__(self,**kwargs): # IMPORTANT: no **kwargs (diffusers inspects this)
super().__init__()
self.base_pipe_ = None # lazy cache
def _ensure_base(
self,
base_pipe: Optional[StableDiffusionPipeline],
base_model_id: str,
torch_dtype: torch.dtype,
) -> StableDiffusionPipeline:
if base_pipe is not None:
self.base_pipe_ = base_pipe
return self.base_pipe_
if self.base_pipe_ is None:
self.base_pipe_ = StableDiffusionPipeline.from_pretrained(
base_model_id,
torch_dtype=torch_dtype,
safety_checker=None,
requires_safety_checker=False,
).to(self.device)
return self.base_pipe_
@torch.no_grad()
def __call__(
self,
prompt: str,
negative_prompt: Optional[str] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
safety_scale: float = 5.0,
mid_fraction: float = 1.0, # 0..1 fraction of steps to guide
safe_class_index: int = 3, # "safe" in CLASS_NAMES
classifier_weights: Optional[str] = None,
base_pipe: Optional[StableDiffusionPipeline] = None,
base_model_id: str = "runwayml/stable-diffusion-v1-5",
generator: Optional[torch.Generator] = None,
**kwargs,
) -> SDGOutput:
# 1) prepare base SD
base = self._ensure_base(base_pipe, base_model_id, torch_dtype=torch.float16)
device = getattr(base, "_execution_device", base.device)
dtype = base.unet.dtype
# 2) text embeddings (classifier-free guidance)
tok = base.tokenizer
max_len = tok.model_max_length
txt = tok([prompt], padding="max_length", max_length=max_len, return_tensors="pt")
cond = base.text_encoder(txt.input_ids.to(device)).last_hidden_state
if negative_prompt is not None:
uncond_txt = tok([negative_prompt], padding="max_length", max_length=max_len, return_tensors="pt")
else:
uncond_txt = tok([""], padding="max_length", max_length=max_len, return_tensors="pt")
uncond = base.text_encoder(uncond_txt.input_ids.to(device)).last_hidden_state
cond_embeds = torch.cat([uncond, cond], dim=0)
# 3) latents
h = kwargs.pop("height", 512); w = kwargs.pop("width", 512)
latents = torch.randn(
(1, base.unet.in_channels, h // 8, w // 8),
device=device, generator=generator, dtype=dtype
)
base.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = base.scheduler.timesteps
# 4) classifier (run in fp32)
weights = classifier_weights or pick_weights_for_pipe(base)
clf = load_classifier_1280(weights, device=device, dtype=torch.float32).eval()
# 5) mid-block hook
mid = {}
def hook(_, __, out): mid["feat"] = out[0] if isinstance(out, tuple) else out
handle = base.unet.mid_block.register_forward_hook(hook)
base_alpha = 1e-3 # step size factor for safety update
# 6) denoising loop
for i, t in enumerate(timesteps):
# standard SD forward
lat_in = base.scheduler.scale_model_input(latents, t)
lat_cat = torch.cat([lat_in, lat_in], dim=0) # for CFG
do_guide = (i / len(timesteps)) <= mid_fraction and safety_scale > 0
if do_guide:
# safety gradient w.r.t latents
with torch.enable_grad():
lg = latents.detach().clone().requires_grad_(True)
lin = base.scheduler.scale_model_input(lg, t)
lcat = torch.cat([lin, lin], dim=0)
_ = base.unet(lcat, t, encoder_hidden_states=cond_embeds).sample
feat = mid["feat"].detach().to(torch.float32)
logits = clf(feat)
probs = torch.softmax(logits, dim=-1)
unsafe = 1.0 - probs[:, safe_class_index].mean() # encourage "safe"
loss = safety_scale * unsafe
loss.backward()
alpha = base_alpha
if hasattr(base.scheduler, "sigmas"): # DDIM/PNDM/… support
idx = min(i, len(base.scheduler.sigmas) - 1)
alpha = base_alpha * float(base.scheduler.sigmas[idx])
latents = (lg - alpha * lg.grad).detach()
# resume SD denoising with updated latents
lat_in = base.scheduler.scale_model_input(latents, t)
lat_cat = torch.cat([lat_in, lat_in], dim=0)
noise_pred = base.unet(lat_cat, t, encoder_hidden_states=cond_embeds).sample
n_uncond, n_text = noise_pred.chunk(2)
noise = n_uncond + guidance_scale * (n_text - n_uncond)
latents = base.scheduler.step(noise, t, latents).prev_sample
handle.remove()
# 7) decode
img = base.decode_latents(latents)
pil = base.image_processor.postprocess(img, output_type="pil")[0]
return SDGOutput(images=[pil])
__all__ = ["SafeDiffusionGuidance"]