basimazam commited on
Commit
febc264
·
verified ·
1 Parent(s): 01836d7

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. README.md +54 -48
  2. safe_diffusion_guidance.py +45 -33
README.md CHANGED
@@ -1,48 +1,54 @@
1
- ---
2
- library_name: diffusers
3
- pipeline_tag: text-to-image
4
- tags:
5
- - safety
6
- - classifier-guidance
7
- - stable-diffusion
8
- - plug-and-play
9
- license: apache-2.0
10
- ---
11
-
12
- # Safe Diffusion Guidance (SDG) —
13
-
14
- **Safe Diffusion Guidance (SDG)** is a *classifier-guided denoising* layer that steers the sampling trajectory away from unsafe content **without retraining** the base model.
15
- It works **standalone** with SD 1.4 / 1.5 / 2.1.
16
-
17
-
18
- ## Quickstart (SD 1.5)
19
-
20
- ```python
21
- import torch
22
- from diffusers import StableDiffusionPipeline
23
-
24
- # 1) Load base SD pipeline (disable default safety checker)
25
- base = StableDiffusionPipeline.from_pretrained(
26
- "runwayml/stable-diffusion-v1-5",
27
- torch_dtype=torch.float16,
28
- safety_checker=None
29
- ).to("cuda")
30
-
31
- # 2) Load SDG custom pipeline from Hub (this repo)
32
- sdg = StableDiffusionPipeline.from_pretrained(
33
- "basimazam/safe-diffusion-guidance",
34
- custom_pipeline="safe_diffusion_guidance",
35
- torch_dtype=torch.float16
36
- ).to("cuda")
37
-
38
- img = sdg(
39
- base_pipe=base,
40
- prompt="portrait photograph, studio light, 85mm, realistic",
41
- num_inference_steps=50,
42
- guidance_scale=7.5,
43
- safety_scale=5.0,
44
- mid_fraction=1.0,
45
- safe_class_index=3
46
- ).images[0]
47
-
48
- img.save("sdg_safe_output.png")
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: diffusers
3
+ pipeline_tag: text-to-image
4
+ tags:
5
+ - safety
6
+ - classifier-guidance
7
+ - stable-diffusion
8
+ - plug-and-play
9
+ license: apache-2.0
10
+ ---
11
+
12
+ # Safe Diffusion Guidance (SDG) — plug-and-play safety layer for Stable Diffusion
13
+
14
+ **Safe Diffusion Guidance (SDG)** is a *classifier-guided denoising* layer that steers the sampling trajectory away from unsafe content **without retraining** the base model.
15
+ It works **standalone** with SD 1.4 / 1.5 / 2.1 and **composes** cleanly with ESD/UCE/SLD.
16
+
17
+ - **Safety signal:** a 5-class mid-UNet feature classifier (classes: `gore, hate, medical, safe, sexual`) trained on (1280×8×8) features.
18
+ - **Controls:** `safety_scale` (strength), `mid_fraction` (fraction of steps guided).
19
+ - **Plug-in:** drop into any SD pipeline, or stack on top of ESD/UCE/SLD.
20
+ - **No retraining:** small gradient nudges to latents during denoising.
21
+
22
+ > **Note on metrics** (matching our paper): FID/KID are computed vs. _baseline model outputs_ rather than real images; baseline FID/KID are ≈0 by construction.
23
+
24
+ ## Quickstart (SD 1.5)
25
+
26
+ ```python
27
+ import torch
28
+ from diffusers import StableDiffusionPipeline
29
+
30
+ # 1) Load base SD pipeline (disable default safety checker)
31
+ base = StableDiffusionPipeline.from_pretrained(
32
+ "runwayml/stable-diffusion-v1-5",
33
+ torch_dtype=torch.float16,
34
+ safety_checker=None
35
+ ).to("cuda")
36
+
37
+ # 2) Load SDG custom pipeline from Hub (this repo)
38
+ sdg = StableDiffusionPipeline.from_pretrained(
39
+ "your-org/safe-diffusion-guidance",
40
+ custom_pipeline="safe_diffusion_guidance",
41
+ torch_dtype=torch.float16
42
+ ).to("cuda")
43
+
44
+ img = sdg(
45
+ base_pipe=base,
46
+ prompt="portrait photograph, studio light, 85mm, realistic",
47
+ num_inference_steps=50,
48
+ guidance_scale=7.5,
49
+ safety_scale=5.0, # strength: ~2–8 (Light→Strong)
50
+ mid_fraction=1.0, # guide fraction of steps: 0.5, 0.8, 1.0
51
+ safe_class_index=3 # index of 'safe' in [gore,hate,medical,safe,sexual]
52
+ ).images[0]
53
+
54
+ img.save("sdg_safe_output.png")
safe_diffusion_guidance.py CHANGED
@@ -1,30 +1,25 @@
1
  # safe_diffusion_guidance.py
 
2
  import torch
3
  from typing import Optional, List
4
  from diffusers import DiffusionPipeline, StableDiffusionPipeline
5
  from diffusers.utils import BaseOutput
6
 
7
- # utils/adaptive_classifiers.py
8
- import torch
9
  import torch.nn as nn
10
- from typing import Optional
11
 
12
  CLASS_NAMES = ['gore', 'hate', 'medical', 'safe', 'sexual']
13
 
14
  class SafetyClassifier1280(nn.Module):
15
- """
16
- Unified safety classifier for mid-UNet features of shape (B, 1280, H, W).
17
- Robust to variable HxW via AdaptiveAvgPool2d((8,8)) before the head.
18
- """
19
  def __init__(self, num_classes: int = 5):
20
  super().__init__()
21
  self.pre = nn.AdaptiveAvgPool2d((8, 8))
22
  self.net = nn.Sequential(
23
  nn.Conv2d(1280, 512, 3, padding=1),
24
- nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.MaxPool2d(2), # 512 x 4 x 4
25
  nn.Conv2d(512, 256, 3, padding=1),
26
- nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.MaxPool2d(2), # 256 x 2 x 2
27
- nn.AdaptiveAvgPool2d(1), nn.Flatten(), # 256
28
  nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Dropout(0.3),
29
  nn.Linear(128, num_classes)
30
  )
@@ -41,9 +36,32 @@ class SafetyClassifier1280(nn.Module):
41
  nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
42
 
43
  def forward(self, x: torch.Tensor) -> torch.Tensor:
44
- x = self.pre(x) # (B, 1280, 8, 8)
45
  return self.net(x)
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def load_classifier_1280(
48
  weights_path: str,
49
  device: Optional[torch.device] = None,
@@ -58,28 +76,23 @@ def load_classifier_1280(
58
  return model
59
 
60
  def pick_weights_for_pipe(pipe) -> str:
61
- """
62
- Optional helper: return a default weights file based on the base SD pipeline id.
63
- You can also use a single shared file 'classifiers/safety_classifier_1280.pth'.
64
- """
65
- name = str(getattr(pipe, "_internal_dict", {}).get("_name_or_path", "")).lower()
66
- # Adjust logic as you like — default to a single shared file:
67
- return "classifiers/safety_classifier_1280.pth"
68
-
69
 
70
  class SDGOutput(BaseOutput):
71
  images: List
72
 
73
  class SafeDiffusionGuidance(DiffusionPipeline):
74
- """
75
- Pure custom pipeline. No pre-saved components in the repo.
76
- It auto-loads a base SD pipeline if `base_pipe` is None.
77
- """
78
 
79
- def __init__(self, **kwargs):
80
- # Accept any extra kwargs Diffusers might pass; we ignore them.
81
  super().__init__()
82
- self.base_pipe_ = None # lazy cache
 
 
 
 
 
83
 
84
  def _ensure_base(self, base_pipe, base_model_id, torch_dtype):
85
  if base_pipe is not None:
@@ -87,9 +100,7 @@ class SafeDiffusionGuidance(DiffusionPipeline):
87
  return self.base_pipe_
88
  if self.base_pipe_ is None:
89
  self.base_pipe_ = StableDiffusionPipeline.from_pretrained(
90
- base_model_id,
91
- torch_dtype=torch_dtype,
92
- safety_checker=None
93
  ).to(self.device)
94
  return self.base_pipe_
95
 
@@ -108,7 +119,6 @@ class SafeDiffusionGuidance(DiffusionPipeline):
108
  generator: Optional[torch.Generator] = None,
109
  **kwargs
110
  ) -> SDGOutput:
111
-
112
  base = self._ensure_base(base_pipe, base_model_id, torch_dtype=torch.float16)
113
  device = getattr(base, "_execution_device", base.device)
114
  dtype = base.unet.dtype
@@ -133,9 +143,9 @@ class SafeDiffusionGuidance(DiffusionPipeline):
133
  base.scheduler.set_timesteps(num_inference_steps, device=device)
134
  timesteps = base.scheduler.timesteps
135
 
136
- # classifier (fp32)
137
- weights = classifier_weights or pick_weights_for_pipe(base)
138
- clf = load_classifier_1280(weights, device=device, dtype=torch.float32).eval()
139
 
140
  # mid-block hook
141
  mid = {}
@@ -186,3 +196,5 @@ class SafeDiffusionGuidance(DiffusionPipeline):
186
  img = base.decode_latents(latents)
187
  img = base.image_processor.postprocess(img, output_type="pil")[0]
188
  return SDGOutput(images=[img])
 
 
 
1
  # safe_diffusion_guidance.py
2
+ import os
3
  import torch
4
  from typing import Optional, List
5
  from diffusers import DiffusionPipeline, StableDiffusionPipeline
6
  from diffusers.utils import BaseOutput
7
 
8
+ # ---- Classifier (unchanged) ----
 
9
  import torch.nn as nn
 
10
 
11
  CLASS_NAMES = ['gore', 'hate', 'medical', 'safe', 'sexual']
12
 
13
  class SafetyClassifier1280(nn.Module):
 
 
 
 
14
  def __init__(self, num_classes: int = 5):
15
  super().__init__()
16
  self.pre = nn.AdaptiveAvgPool2d((8, 8))
17
  self.net = nn.Sequential(
18
  nn.Conv2d(1280, 512, 3, padding=1),
19
+ nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.MaxPool2d(2),
20
  nn.Conv2d(512, 256, 3, padding=1),
21
+ nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.MaxPool2d(2),
22
+ nn.AdaptiveAvgPool2d(1), nn.Flatten(),
23
  nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Dropout(0.3),
24
  nn.Linear(128, num_classes)
25
  )
 
36
  nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
37
 
38
  def forward(self, x: torch.Tensor) -> torch.Tensor:
39
+ x = self.pre(x)
40
  return self.net(x)
41
 
42
+ # ---- NEW: robust path resolution for weights ----
43
+ def _resolve_repo_path(rel_path: str) -> str:
44
+ """Return an absolute path inside the cached repo; fallback to hf_hub_download."""
45
+ here = os.path.dirname(__file__)
46
+ local_path = os.path.join(here, rel_path)
47
+ if os.path.exists(local_path):
48
+ return local_path
49
+ # Fallback: try hub download (works even if code is executed outside repo root)
50
+ try:
51
+ from huggingface_hub import hf_hub_download
52
+ # Best effort to get repo id; default to your public repo if unknown
53
+ repo_id = getattr(_resolve_repo_path, "_repo_id", None)
54
+ if repo_id is None:
55
+ # Diffusers stores name or path in internal dict sometimes:
56
+ repo_id = getattr(SafeDiffusionGuidance, "__repo_id__", "basimazam/safe-diffusion-guidance")
57
+ return hf_hub_download(repo_id=repo_id, filename=rel_path)
58
+ except Exception as e:
59
+ raise FileNotFoundError(
60
+ f"Could not find classifier weights at '{rel_path}'. "
61
+ f"Make sure the file exists in the repo, or pass `classifier_weights=...`. "
62
+ f"Original error: {e}"
63
+ )
64
+
65
  def load_classifier_1280(
66
  weights_path: str,
67
  device: Optional[torch.device] = None,
 
76
  return model
77
 
78
  def pick_weights_for_pipe(pipe) -> str:
79
+ # Always use the standard path inside the repo
80
+ return _resolve_repo_path("classifiers/safety_classifier_1280.pth")
 
 
 
 
 
 
81
 
82
  class SDGOutput(BaseOutput):
83
  images: List
84
 
85
  class SafeDiffusionGuidance(DiffusionPipeline):
86
+ """Pure custom pipeline; loads base SD lazily at runtime."""
 
 
 
87
 
88
+ def __init__(self): # <-- IMPORTANT: no **kwargs
 
89
  super().__init__()
90
+ self.base_pipe_ = None
91
+ # Hint for the fallback downloader (optional)
92
+ try:
93
+ SafeDiffusionGuidance.__repo_id__ = self.config._name_or_path # diffusers sets this sometimes
94
+ except Exception:
95
+ pass
96
 
97
  def _ensure_base(self, base_pipe, base_model_id, torch_dtype):
98
  if base_pipe is not None:
 
100
  return self.base_pipe_
101
  if self.base_pipe_ is None:
102
  self.base_pipe_ = StableDiffusionPipeline.from_pretrained(
103
+ base_model_id, torch_dtype=torch_dtype, safety_checker=None
 
 
104
  ).to(self.device)
105
  return self.base_pipe_
106
 
 
119
  generator: Optional[torch.Generator] = None,
120
  **kwargs
121
  ) -> SDGOutput:
 
122
  base = self._ensure_base(base_pipe, base_model_id, torch_dtype=torch.float16)
123
  device = getattr(base, "_execution_device", base.device)
124
  dtype = base.unet.dtype
 
143
  base.scheduler.set_timesteps(num_inference_steps, device=device)
144
  timesteps = base.scheduler.timesteps
145
 
146
+ # classifier (fp32) — use provided path or default resolved path
147
+ weights_file = classifier_weights or pick_weights_for_pipe(base)
148
+ clf = load_classifier_1280(weights_file, device=device, dtype=torch.float32).eval()
149
 
150
  # mid-block hook
151
  mid = {}
 
196
  img = base.decode_latents(latents)
197
  img = base.image_processor.postprocess(img, output_type="pil")[0]
198
  return SDGOutput(images=[img])
199
+
200
+ __all__ = ["SafeDiffusionGuidance"]