|
|
| import os
|
| from typing import Optional, List
|
|
|
| import torch
|
| import torch.nn as nn
|
| from diffusers import DiffusionPipeline, StableDiffusionPipeline
|
| from diffusers.utils import BaseOutput
|
|
|
| import torch, torch.nn as nn, os
|
| from typing import Optional
|
|
|
| CLASS_NAMES = ['gore', 'hate', 'medical', 'safe', 'sexual']
|
|
|
| class AdaptiveClassifier1280(nn.Module):
|
| """
|
| Same CNN topology you trained (keys start with 'model.*').
|
| Input (B,1280,H,W) -> AdaptiveAvgPool2d(8,8) -> conv stack -> head
|
| """
|
| def __init__(self, num_classes: int = 5):
|
| super().__init__()
|
| self.pre = nn.AdaptiveAvgPool2d((8, 8))
|
|
|
| self.model = nn.Sequential(
|
| nn.Conv2d(1280, 512, kernel_size=3, padding=1),
|
| nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.MaxPool2d(2),
|
| nn.Dropout2d(0.1),
|
|
|
| nn.Conv2d(512, 256, kernel_size=3, padding=1),
|
| nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.MaxPool2d(2),
|
| nn.Dropout2d(0.1),
|
|
|
| nn.AdaptiveAvgPool2d(1),
|
| nn.Flatten(),
|
| nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Dropout(0.5),
|
| nn.Linear(128, num_classes)
|
| )
|
| self.apply(self._init)
|
|
|
| @staticmethod
|
| def _init(m):
|
| if isinstance(m, nn.Linear):
|
| nn.init.xavier_uniform_(m.weight); nn.init.zeros_(m.bias)
|
| elif isinstance(m, nn.Conv2d):
|
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| if m.bias is not None: nn.init.zeros_(m.bias)
|
| elif isinstance(m, nn.BatchNorm2d):
|
| nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
|
|
|
| def forward(self, x):
|
| x = self.pre(x)
|
| return self.model(x)
|
|
|
| def _find_weights_path() -> str:
|
|
|
| env_p = os.getenv("SDG_CLASSIFIER_WEIGHTS")
|
| if env_p and os.path.exists(env_p): return env_p
|
| for p in ["safety_classifier_1280.pth", os.path.join("classifiers","safety_classifier_1280.pth")]:
|
| if os.path.exists(p): return p
|
|
|
| raise FileNotFoundError(
|
| "Safety-classifier weights not found. Provide via env SDG_CLASSIFIER_WEIGHTS, "
|
| "place 'safety_classifier_1280.pth' at repo root or 'classifiers/', "
|
| "or pass `classifier_weights=...` to the pipeline call."
|
| )
|
|
|
| def load_classifier_1280(
|
| weights_path: Optional[str],
|
| device: torch.device,
|
| dtype: torch.dtype = torch.float32
|
| ) -> AdaptiveClassifier1280:
|
| path = weights_path or _find_weights_path()
|
| ckpt = torch.load(path, map_location="cpu", weights_only=False)
|
|
|
|
|
| if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
| state = ckpt["model_state_dict"]
|
| elif isinstance(ckpt, dict) and any(k.startswith("model.") for k in ckpt.keys()):
|
| state = ckpt
|
| else:
|
|
|
| state = ckpt
|
|
|
| model = AdaptiveClassifier1280().to(device=device, dtype=torch.float32)
|
| missing, unexpected = model.load_state_dict(state, strict=False)
|
| if missing or unexpected:
|
| print(f"[SDG] load_state_dict: missing={missing[:4]}... ({len(missing)}), unexpected={unexpected[:4]}... ({len(unexpected)})")
|
| model.eval()
|
| return model
|
|
|
| def _here(*paths: str) -> str:
|
| return os.path.join(os.path.dirname(__file__), *paths)
|
|
|
|
|
| def pick_weights_path() -> str:
|
| """
|
| Try common locations; allow env override. Raise if not found.
|
| """
|
| candidates = [
|
| os.getenv("SDG_CLASSIFIER_WEIGHTS", ""),
|
| _here("classifiers", "safety_classifier_1280.pth"),
|
| _here("safety_classifier_1280.pth"),
|
| "classifiers/safety_classifier_1280.pth",
|
| "safety_classifier_1280.pth",
|
| ]
|
| for p in candidates:
|
| if p and os.path.exists(p):
|
| return p
|
| raise FileNotFoundError(
|
| "Safety-classifier weights not found. Place 'safety_classifier_1280.pth' "
|
| "in repo root or 'classifiers/' (or set SDG_CLASSIFIER_WEIGHTS, or pass "
|
| "`classifier_weights=...` to the call())."
|
| )
|
|
|
|
|
|
|
| class SDGOutput(BaseOutput):
|
| images: List
|
|
|
|
|
| class SafeDiffusionGuidance(DiffusionPipeline):
|
| """
|
| Minimal custom pipeline that loads a base Stable Diffusion pipeline on demand
|
| and applies mid-UNet classifier-guided denoising for safety.
|
| """
|
|
|
| def __init__(self):
|
| super().__init__()
|
| self.base_pipe_ = None
|
|
|
| def _ensure_base(
|
| self,
|
| base_pipe: Optional[StableDiffusionPipeline],
|
| base_model_id: str,
|
| torch_dtype: torch.dtype,
|
| ) -> StableDiffusionPipeline:
|
| if base_pipe is not None:
|
| self.base_pipe_ = base_pipe
|
| return self.base_pipe_
|
| if self.base_pipe_ is None:
|
| self.base_pipe_ = StableDiffusionPipeline.from_pretrained(
|
| base_model_id,
|
| torch_dtype=torch_dtype,
|
| safety_checker=None,
|
| requires_safety_checker=False,
|
| ).to(self.device)
|
| return self.base_pipe_
|
|
|
| @torch.no_grad()
|
| def __call__(
|
| self,
|
| prompt: str,
|
| negative_prompt: Optional[str] = None,
|
| num_inference_steps: int = 50,
|
| guidance_scale: float = 7.5,
|
| safety_scale: float = 5.0,
|
| mid_fraction: float = 1.0,
|
| safe_class_index: int = 3,
|
| classifier_weights: Optional[str] = None,
|
| base_pipe: Optional[StableDiffusionPipeline] = None,
|
| base_model_id: str = "runwayml/stable-diffusion-v1-5",
|
| generator: Optional[torch.Generator] = None,
|
| **kwargs,
|
| ) -> SDGOutput:
|
|
|
|
|
| base = self._ensure_base(base_pipe, base_model_id, torch_dtype=torch.float16)
|
| device = getattr(base, "_execution_device", base.device)
|
| dtype = base.unet.dtype
|
|
|
|
|
| tok = base.tokenizer
|
| max_len = tok.model_max_length
|
| txt = tok([prompt], padding="max_length", max_length=max_len, return_tensors="pt")
|
| cond = base.text_encoder(txt.input_ids.to(device)).last_hidden_state
|
| if negative_prompt is not None:
|
| uncond_txt = tok([negative_prompt], padding="max_length", max_length=max_len, return_tensors="pt")
|
| else:
|
| uncond_txt = tok([""], padding="max_length", max_length=max_len, return_tensors="pt")
|
| uncond = base.text_encoder(uncond_txt.input_ids.to(device)).last_hidden_state
|
| cond_embeds = torch.cat([uncond, cond], dim=0)
|
|
|
|
|
| h = kwargs.pop("height", 512); w = kwargs.pop("width", 512)
|
| latents = torch.randn(
|
| (1, base.unet.in_channels, h // 8, w // 8),
|
| device=device, generator=generator, dtype=dtype
|
| )
|
|
|
| base.scheduler.set_timesteps(num_inference_steps, device=device)
|
| timesteps = base.scheduler.timesteps
|
|
|
|
|
| weights = classifier_weights or pick_weights_for_pipe(base)
|
| clf = load_classifier_1280(weights, device=device, dtype=torch.float32).eval()
|
|
|
|
|
|
|
| mid = {}
|
| def hook(_, __, out): mid["feat"] = out[0] if isinstance(out, tuple) else out
|
| handle = base.unet.mid_block.register_forward_hook(hook)
|
|
|
| base_alpha = 1e-3
|
|
|
|
|
| for i, t in enumerate(timesteps):
|
|
|
| lat_in = base.scheduler.scale_model_input(latents, t)
|
| lat_cat = torch.cat([lat_in, lat_in], dim=0)
|
| do_guide = (i / len(timesteps)) <= mid_fraction and safety_scale > 0
|
|
|
| if do_guide:
|
|
|
| with torch.enable_grad():
|
| lg = latents.detach().clone().requires_grad_(True)
|
| lin = base.scheduler.scale_model_input(lg, t)
|
| lcat = torch.cat([lin, lin], dim=0)
|
|
|
| _ = base.unet(lcat, t, encoder_hidden_states=cond_embeds).sample
|
| feat = mid["feat"].detach().to(torch.float32)
|
| logits = clf(feat)
|
| probs = torch.softmax(logits, dim=-1)
|
| unsafe = 1.0 - probs[:, safe_class_index].mean()
|
|
|
| loss = safety_scale * unsafe
|
| loss.backward()
|
|
|
| alpha = base_alpha
|
| if hasattr(base.scheduler, "sigmas"):
|
| idx = min(i, len(base.scheduler.sigmas) - 1)
|
| alpha = base_alpha * float(base.scheduler.sigmas[idx])
|
|
|
| latents = (lg - alpha * lg.grad).detach()
|
|
|
|
|
| lat_in = base.scheduler.scale_model_input(latents, t)
|
| lat_cat = torch.cat([lat_in, lat_in], dim=0)
|
|
|
| noise_pred = base.unet(lat_cat, t, encoder_hidden_states=cond_embeds).sample
|
| n_uncond, n_text = noise_pred.chunk(2)
|
| noise = n_uncond + guidance_scale * (n_text - n_uncond)
|
| latents = base.scheduler.step(noise, t, latents).prev_sample
|
|
|
| handle.remove()
|
|
|
|
|
| img = base.decode_latents(latents)
|
| pil = base.image_processor.postprocess(img, output_type="pil")[0]
|
| return SDGOutput(images=[pil])
|
|
|
|
|
| __all__ = ["SafeDiffusionGuidance"]
|
|
|