# safe_diffusion_guidance.py import os from typing import Optional, List import torch import torch.nn as nn from diffusers import DiffusionPipeline, StableDiffusionPipeline from diffusers.utils import BaseOutput import torch, torch.nn as nn, os from typing import Optional CLASS_NAMES = ['gore', 'hate', 'medical', 'safe', 'sexual'] class SafetyClassifier1280(nn.Module): def __init__(self, num_classes: int = 5): super().__init__() self.pre = nn.AdaptiveAvgPool2d((8, 8)) self.model = nn.Sequential( # <--- use "model" to match checkpoint nn.Conv2d(1280, 512, 3, padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.MaxPool2d(2), nn.Conv2d(512, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.MaxPool2d(2), nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Dropout(0.3), nn.Linear(128, num_classes) ) self.apply(self._init_weights) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.pre(x) return self.model(x) # <--- forward through "model" def _find_weights_path() -> str: # 1) explicit env; 2) repo root file; 3) classifiers/ subdir env_p = os.getenv("SDG_CLASSIFIER_WEIGHTS") if env_p and os.path.exists(env_p): return env_p for p in ["safety_classifier_1280.pth", os.path.join("classifiers","safety_classifier_1280.pth")]: if os.path.exists(p): return p # If running from HF cache, these paths are relative to the cached repo folder. raise FileNotFoundError( "Safety-classifier weights not found. Provide via env SDG_CLASSIFIER_WEIGHTS, " "place 'safety_classifier_1280.pth' at repo root or 'classifiers/', " "or pass `classifier_weights=...` to the pipeline call." ) def load_classifier_1280(weights_path: str, device=None, dtype=torch.float32): model = SafetyClassifier1280().to(device or "cpu", dtype=dtype) state = torch.load(weights_path, map_location="cpu", weights_only=False) if isinstance(state, dict) and "model_state_dict" in state: state = state["model_state_dict"] model.load_state_dict(state, strict=True) model.eval() return model def _here(*paths: str) -> str: return os.path.join(os.path.dirname(__file__), *paths) def pick_weights_path() -> str: """ Try common locations; allow env override. Raise if not found. """ candidates = [ os.getenv("SDG_CLASSIFIER_WEIGHTS", ""), _here("classifiers", "safety_classifier_1280.pth"), _here("safety_classifier_1280.pth"), "classifiers/safety_classifier_1280.pth", "safety_classifier_1280.pth", ] for p in candidates: if p and os.path.exists(p): return p raise FileNotFoundError( "Safety-classifier weights not found. Place 'safety_classifier_1280.pth' " "in repo root or 'classifiers/' (or set SDG_CLASSIFIER_WEIGHTS, or pass " "`classifier_weights=...` to the call())." ) # ----------------------------- Pipeline -------------------------------------- class SDGOutput(BaseOutput): images: List # list of PIL Images class SafeDiffusionGuidance(DiffusionPipeline): """ Minimal custom pipeline that loads a base Stable Diffusion pipeline on demand and applies mid-UNet classifier-guided denoising for safety. """ def __init__(self,**kwargs): # IMPORTANT: no **kwargs (diffusers inspects this) super().__init__() self.base_pipe_ = None # lazy cache def _ensure_base( self, base_pipe: Optional[StableDiffusionPipeline], base_model_id: str, torch_dtype: torch.dtype, ) -> StableDiffusionPipeline: if base_pipe is not None: self.base_pipe_ = base_pipe return self.base_pipe_ if self.base_pipe_ is None: self.base_pipe_ = StableDiffusionPipeline.from_pretrained( base_model_id, torch_dtype=torch_dtype, safety_checker=None, requires_safety_checker=False, ).to(self.device) return self.base_pipe_ @torch.no_grad() def __call__( self, prompt: str, negative_prompt: Optional[str] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, safety_scale: float = 5.0, mid_fraction: float = 1.0, # 0..1 fraction of steps to guide safe_class_index: int = 3, # "safe" in CLASS_NAMES classifier_weights: Optional[str] = None, base_pipe: Optional[StableDiffusionPipeline] = None, base_model_id: str = "runwayml/stable-diffusion-v1-5", generator: Optional[torch.Generator] = None, **kwargs, ) -> SDGOutput: # 1) prepare base SD base = self._ensure_base(base_pipe, base_model_id, torch_dtype=torch.float16) device = getattr(base, "_execution_device", base.device) dtype = base.unet.dtype # 2) text embeddings (classifier-free guidance) tok = base.tokenizer max_len = tok.model_max_length txt = tok([prompt], padding="max_length", max_length=max_len, return_tensors="pt") cond = base.text_encoder(txt.input_ids.to(device)).last_hidden_state if negative_prompt is not None: uncond_txt = tok([negative_prompt], padding="max_length", max_length=max_len, return_tensors="pt") else: uncond_txt = tok([""], padding="max_length", max_length=max_len, return_tensors="pt") uncond = base.text_encoder(uncond_txt.input_ids.to(device)).last_hidden_state cond_embeds = torch.cat([uncond, cond], dim=0) # 3) latents h = kwargs.pop("height", 512); w = kwargs.pop("width", 512) latents = torch.randn( (1, base.unet.in_channels, h // 8, w // 8), device=device, generator=generator, dtype=dtype ) base.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = base.scheduler.timesteps # 4) classifier (run in fp32) weights = classifier_weights or pick_weights_for_pipe(base) clf = load_classifier_1280(weights, device=device, dtype=torch.float32).eval() # 5) mid-block hook mid = {} def hook(_, __, out): mid["feat"] = out[0] if isinstance(out, tuple) else out handle = base.unet.mid_block.register_forward_hook(hook) base_alpha = 1e-3 # step size factor for safety update # 6) denoising loop for i, t in enumerate(timesteps): # standard SD forward lat_in = base.scheduler.scale_model_input(latents, t) lat_cat = torch.cat([lat_in, lat_in], dim=0) # for CFG do_guide = (i / len(timesteps)) <= mid_fraction and safety_scale > 0 if do_guide: # safety gradient w.r.t latents with torch.enable_grad(): lg = latents.detach().clone().requires_grad_(True) lin = base.scheduler.scale_model_input(lg, t) lcat = torch.cat([lin, lin], dim=0) _ = base.unet(lcat, t, encoder_hidden_states=cond_embeds).sample feat = mid["feat"].detach().to(torch.float32) logits = clf(feat) probs = torch.softmax(logits, dim=-1) unsafe = 1.0 - probs[:, safe_class_index].mean() # encourage "safe" loss = safety_scale * unsafe loss.backward() alpha = base_alpha if hasattr(base.scheduler, "sigmas"): # DDIM/PNDM/… support idx = min(i, len(base.scheduler.sigmas) - 1) alpha = base_alpha * float(base.scheduler.sigmas[idx]) latents = (lg - alpha * lg.grad).detach() # resume SD denoising with updated latents lat_in = base.scheduler.scale_model_input(latents, t) lat_cat = torch.cat([lat_in, lat_in], dim=0) noise_pred = base.unet(lat_cat, t, encoder_hidden_states=cond_embeds).sample n_uncond, n_text = noise_pred.chunk(2) noise = n_uncond + guidance_scale * (n_text - n_uncond) latents = base.scheduler.step(noise, t, latents).prev_sample handle.remove() # 7) decode img = base.decode_latents(latents) pil = base.image_processor.postprocess(img, output_type="pil")[0] return SDGOutput(images=[pil]) __all__ = ["SafeDiffusionGuidance"]