basimazam commited on
Commit
b8877ca
·
verified ·
1 Parent(s): f96b6c3

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -1,35 +1,3 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
 
2
  *.bin filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  *.pth filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.pyc
4
+ .venv/
5
+ .env
6
+ .envrc
7
+
8
+ # OS
9
+ .DS_Store
10
+ Thumbs.db
11
+
12
+ # HF cache
13
+ ~/.cache/huggingface/
14
+ *.lock
15
+
16
+ # Models & artifacts (we use LFS for these)
17
+ *.pt
18
+ *.pth
19
+ *.bin
20
+ *.safetensors
21
+ *.onnx
22
+
23
+ # Misc
24
+ *.log
25
+ outputs/
README.md ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: diffusers
3
+ pipeline_tag: text-to-image
4
+ tags:
5
+ - safety
6
+ - classifier-guidance
7
+ - stable-diffusion
8
+ - plug-and-play
9
+ license: apache-2.0
10
+ ---
11
+
12
+ # Safe Diffusion Guidance (SDG) — plug-and-play safety layer for Stable Diffusion
13
+
14
+ **Safe Diffusion Guidance (SDG)** is a *classifier-guided denoising* layer that steers the sampling trajectory away from unsafe content **without retraining** the base model.
15
+ It works **standalone** with SD 1.4 / 1.5 / 2.1 and **composes** cleanly with ESD/UCE/SLD.
16
+
17
+ - **Safety signal:** a 5-class mid-UNet feature classifier (classes: `gore, hate, medical, safe, sexual`) trained on (1280×8×8) features.
18
+ - **Controls:** `safety_scale` (strength), `mid_fraction` (fraction of steps guided).
19
+ - **Plug-in:** drop into any SD pipeline, or stack on top of ESD/UCE/SLD.
20
+ - **No retraining:** small gradient nudges to latents during denoising.
21
+
22
+ > **Note on metrics** (matching our paper): FID/KID are computed vs. _baseline model outputs_ rather than real images; baseline FID/KID are ≈0 by construction.
23
+
24
+ ## Quickstart (SD 1.5)
25
+
26
+ ```python
27
+ import torch
28
+ from diffusers import StableDiffusionPipeline
29
+
30
+ # 1) Load base SD pipeline (disable default safety checker)
31
+ base = StableDiffusionPipeline.from_pretrained(
32
+ "runwayml/stable-diffusion-v1-5",
33
+ torch_dtype=torch.float16,
34
+ safety_checker=None
35
+ ).to("cuda")
36
+
37
+ # 2) Load SDG custom pipeline from Hub (this repo)
38
+ sdg = StableDiffusionPipeline.from_pretrained(
39
+ "your-org/safe-diffusion-guidance",
40
+ custom_pipeline="safe_diffusion_guidance",
41
+ torch_dtype=torch.float16
42
+ ).to("cuda")
43
+
44
+ img = sdg(
45
+ base_pipe=base,
46
+ prompt="portrait photograph, studio light, 85mm, realistic",
47
+ num_inference_steps=50,
48
+ guidance_scale=7.5,
49
+ safety_scale=5.0, # strength: ~2–8 (Light→Strong)
50
+ mid_fraction=1.0, # guide fraction of steps: 0.5, 0.8, 1.0
51
+ safe_class_index=3 # index of 'safe' in [gore,hate,medical,safe,sexual]
52
+ ).images[0]
53
+
54
+ img.save("sdg_safe_output.png")
__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Makes the repo importable if someone clones it as a package.
2
+ __all__ = ["safe_diffusion_guidance"]
model_index.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "library_name": "diffusers",
3
+ "pipeline": "StableDiffusionPipeline",
4
+ "tags": [
5
+ "safety",
6
+ "classifier-guidance",
7
+ "stable-diffusion",
8
+ "plug-and-play"
9
+ ],
10
+ "inference": "Use `custom_pipeline='safe_diffusion_guidance'` and pass your base SD pipeline via `base_pipe=...`."
11
+ }
safe_diffusion_guidance.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # safe_diffusion_guidance.py
2
+ import torch
3
+ from typing import Optional, List
4
+ from diffusers import StableDiffusionPipeline, DiffusionPipeline
5
+ from diffusers.utils import BaseOutput
6
+
7
+ from utils.adaptive_classifiers import (
8
+ load_classifier_1280, pick_weights_for_pipe
9
+ )
10
+
11
+ class SDGOutput(BaseOutput):
12
+ images: List
13
+
14
+ class SafeDiffusionGuidancePipeline(StableDiffusionPipeline):
15
+ """
16
+ Plug-and-play safety guidance for Stable Diffusion.
17
+ - If `base_pipe` is None, we auto-load the base SD checkpoint specified by `base_model_id`.
18
+ - Classifier weights are shipped in this repo; no extra installs required.
19
+
20
+ Typical use:
21
+ sdg = DiffusionPipeline.from_pretrained("your-org/safe-diffusion-guidance",
22
+ custom_pipeline="safe_diffusion_guidance",
23
+ torch_dtype=torch.float16).to("cuda")
24
+ out = sdg(prompt="...", base_model_id="runwayml/stable-diffusion-v1-5",
25
+ safety_scale=5.0, mid_fraction=1.0, safe_class_index=3)
26
+ """
27
+
28
+ def __init__(self, *args, **kwargs):
29
+ super().__init__(*args, **kwargs)
30
+ self.base_pipe_ = None # lazy-loaded base pipeline cache
31
+
32
+ def _ensure_base(self, base_pipe, base_model_id, torch_dtype):
33
+ if base_pipe is not None:
34
+ self.base_pipe_ = base_pipe
35
+ return self.base_pipe_
36
+ if self.base_pipe_ is None:
37
+ # lazy load chosen SD checkpoint
38
+ self.base_pipe_ = StableDiffusionPipeline.from_pretrained(
39
+ base_model_id,
40
+ torch_dtype=torch_dtype,
41
+ safety_checker=None
42
+ ).to(self.device)
43
+ return self.base_pipe_
44
+
45
+ def __call__(
46
+ self,
47
+ prompt: str,
48
+ negative_prompt: Optional[str] = None,
49
+ num_inference_steps: int = 50,
50
+ guidance_scale: float = 7.5,
51
+ safety_scale: float = 5.0,
52
+ mid_fraction: float = 1.0,
53
+ safe_class_index: int = 3,
54
+ classifier_weights: Optional[str] = None,
55
+ # new convenience arg (optional):
56
+ base_pipe: Optional[StableDiffusionPipeline] = None,
57
+ base_model_id: str = "runwayml/stable-diffusion-v1-5",
58
+ generator: Optional[torch.Generator] = None,
59
+ **kwargs
60
+ ) -> SDGOutput:
61
+
62
+ # 0) choose / load base SD
63
+ base = self._ensure_base(base_pipe, base_model_id, torch_dtype=self.unet.dtype)
64
+ device = base._execution_device if hasattr(base, "_execution_device") else base.device
65
+ dtype = base.unet.dtype
66
+
67
+ # 1) Text embeddings (CFG)
68
+ tok = base.tokenizer
69
+ max_len = tok.model_max_length
70
+ text_inputs = tok([prompt], padding="max_length", max_length=max_len, return_tensors="pt")
71
+ text_embeds = base.text_encoder(text_inputs.input_ids.to(device)).last_hidden_state
72
+
73
+ if negative_prompt is not None:
74
+ uncond_inputs = tok([negative_prompt], padding="max_length", max_length=max_len, return_tensors="pt")
75
+ else:
76
+ uncond_inputs = tok([""], padding="max_length", max_length=max_len, return_tensors="pt")
77
+ uncond_embeds = base.text_encoder(uncond_inputs.input_ids.to(device)).last_hidden_state
78
+ cond_embeds = torch.cat([uncond_embeds, text_embeds], dim=0)
79
+
80
+ # 2) Latent init
81
+ height = kwargs.pop("height", 512); width = kwargs.pop("width", 512)
82
+ latents = torch.randn(
83
+ (1, base.unet.in_channels, height // 8, width // 8),
84
+ device=device, generator=generator, dtype=dtype
85
+ )
86
+
87
+ base.scheduler.set_timesteps(num_inference_steps, device=device)
88
+ timesteps = base.scheduler.timesteps
89
+
90
+ # 3) Load classifier (fp32) from this repo
91
+ weights_path = classifier_weights or pick_weights_for_pipe(base)
92
+ classifier = load_classifier_1280(weights_path, device=device, dtype=torch.float32).eval()
93
+
94
+ # 4) Hook mid-block features
95
+ mid_cache = {}
96
+ def mid_hook(module, inputs, output):
97
+ mid_cache["feat"] = output[0] if isinstance(output, tuple) else output
98
+ h = base.unet.mid_block.register_forward_hook(mid_hook)
99
+
100
+ base_alpha = 1e-3 # small step
101
+
102
+ with torch.no_grad():
103
+ for i, t in enumerate(timesteps):
104
+ latent_in = base.scheduler.scale_model_input(latents, t)
105
+ latent_pair = torch.cat([latent_in, latent_in], dim=0)
106
+
107
+ do_guide = (i / len(timesteps)) <= mid_fraction and safety_scale > 0
108
+ if do_guide:
109
+ with torch.enable_grad():
110
+ latents_g = latents.detach().clone().requires_grad_(True)
111
+ latent_in_g = base.scheduler.scale_model_input(latents_g, t)
112
+ latent_pair_g = torch.cat([latent_in_g, latent_in_g], dim=0)
113
+
114
+ _ = base.unet(latent_pair_g, t, encoder_hidden_states=cond_embeds).sample
115
+ feat = mid_cache["feat"].detach().float()
116
+ logits = classifier(feat)
117
+ probs = torch.softmax(logits, dim=-1)
118
+ unsafe_prob = 1.0 - probs[:, safe_class_index].mean()
119
+
120
+ loss = safety_scale * unsafe_prob
121
+ loss.backward()
122
+
123
+ alpha = base_alpha
124
+ if hasattr(base.scheduler, "sigmas"):
125
+ step_idx = min(i, len(base.scheduler.sigmas) - 1)
126
+ alpha = base_alpha * float(base.scheduler.sigmas[step_idx])
127
+ latents = (latents_g - alpha * latents_g.grad).detach()
128
+
129
+ latent_in = base.scheduler.scale_model_input(latents, t)
130
+ latent_pair = torch.cat([latent_in, latent_in], dim=0)
131
+ noise_pred = base.unet(latent_pair, t, encoder_hidden_states=cond_embeds).sample
132
+ else:
133
+ noise_pred = base.unet(latent_pair, t, encoder_hidden_states=cond_embeds).sample
134
+
135
+ noise_uncond, noise_text = noise_pred.chunk(2)
136
+ noise = noise_uncond + guidance_scale * (noise_text - noise_uncond)
137
+ latents = base.scheduler.step(noise, t, latents).prev_sample
138
+
139
+ h.remove()
140
+ with torch.no_grad():
141
+ image = base.decode_latents(latents)
142
+ image = base.image_processor.postprocess(image, output_type="pil")[0]
143
+ return SDGOutput(images=[image])
scripts/push_to_hub.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import create_repo, upload_folder
2
+ import os
3
+
4
+ ORG_REPO = "basimazam/safe-diffusion-guidance"
5
+
6
+ def main():
7
+ create_repo(ORG_REPO, repo_type="model", exist_ok=True)
8
+ upload_folder(
9
+ repo_id=ORG_REPO,
10
+ folder_path=".", # current folder
11
+ repo_type="model",
12
+ ignore_patterns=[".git/*", "outputs/*", "tests/*__pycache__*"]
13
+ )
14
+ print(f"Uploaded to https://huggingface.co/{ORG_REPO}")
15
+
16
+ if __name__ == "__main__":
17
+ assert os.path.exists("safe_diffusion_guidance.py")
18
+ main()
tests/test_classifier.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from utils.adaptive_classifiers import SafetyClassifier1280
3
+
4
+ def test_forward_shape():
5
+ model = SafetyClassifier1280().eval()
6
+ x = torch.randn(2, 1280, 8, 8) # fake mid features
7
+ with torch.no_grad():
8
+ y = model(x)
9
+ assert y.shape == (2, 5), f"Expected (2,5), got {tuple(y.shape)}"
10
+
11
+ if __name__ == "__main__":
12
+ test_forward_shape()
13
+ print("OK: classifier forward shape is (B,5)")
tests/test_pipeline_integration.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import StableDiffusionPipeline, DiffusionPipeline
3
+
4
+ def test_sdg_minimal():
5
+ sdg = DiffusionPipeline.from_pretrained(
6
+ "your-org/safe-diffusion-guidance",
7
+ custom_pipeline="safe_diffusion_guidance",
8
+ torch_dtype=torch.float16
9
+ )
10
+ sdg = sdg.to("cuda" if torch.cuda.is_available() else "cpu")
11
+ out = sdg(
12
+ prompt="test scene",
13
+ base_model_id="runwayml/stable-diffusion-v1-5",
14
+ num_inference_steps=2, # small for CI
15
+ guidance_scale=5.0,
16
+ safety_scale=2.0,
17
+ mid_fraction=0.5,
18
+ safe_class_index=3
19
+ )
20
+ assert len(out.images) == 1
21
+ print("OK: pipeline end-to-end")
22
+
23
+ if __name__ == "__main__":
24
+ test_sdg_minimal()
utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Namespace init for utils
2
+ from .adaptive_classifiers import (
3
+ SafetyClassifier1280, load_classifier_1280, CLASS_NAMES
4
+ )
utils/adaptiv_classifiers.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/adaptive_classifiers.py
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import Optional
5
+
6
+ CLASS_NAMES = ['gore', 'hate', 'medical', 'safe', 'sexual']
7
+
8
+ class SafetyClassifier1280(nn.Module):
9
+ """
10
+ Unified safety classifier for mid-UNet features of shape (B, 1280, H, W).
11
+ Robust to variable HxW via AdaptiveAvgPool2d((8,8)) before the head.
12
+ """
13
+ def __init__(self, num_classes: int = 5):
14
+ super().__init__()
15
+ self.pre = nn.AdaptiveAvgPool2d((8, 8))
16
+ self.net = nn.Sequential(
17
+ nn.Conv2d(1280, 512, 3, padding=1),
18
+ nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.MaxPool2d(2), # 512 x 4 x 4
19
+ nn.Conv2d(512, 256, 3, padding=1),
20
+ nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.MaxPool2d(2), # 256 x 2 x 2
21
+ nn.AdaptiveAvgPool2d(1), nn.Flatten(), # 256
22
+ nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Dropout(0.3),
23
+ nn.Linear(128, num_classes)
24
+ )
25
+ self.apply(self._init_weights)
26
+
27
+ @staticmethod
28
+ def _init_weights(m):
29
+ if isinstance(m, nn.Linear):
30
+ nn.init.xavier_uniform_(m.weight); nn.init.zeros_(m.bias)
31
+ elif isinstance(m, nn.Conv2d):
32
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
33
+ if m.bias is not None: nn.init.zeros_(m.bias)
34
+ elif isinstance(m, nn.BatchNorm2d):
35
+ nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
36
+
37
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
38
+ x = self.pre(x) # (B, 1280, 8, 8)
39
+ return self.net(x)
40
+
41
+ def load_classifier_1280(
42
+ weights_path: str,
43
+ device: Optional[torch.device] = None,
44
+ dtype: torch.dtype = torch.float32
45
+ ) -> SafetyClassifier1280:
46
+ model = SafetyClassifier1280().to(device or "cpu", dtype=dtype)
47
+ state = torch.load(weights_path, map_location="cpu")
48
+ if isinstance(state, dict) and "model_state_dict" in state:
49
+ state = state["model_state_dict"]
50
+ model.load_state_dict(state, strict=True)
51
+ model.eval()
52
+ return model
53
+
54
+ def pick_weights_for_pipe(pipe) -> str:
55
+ """
56
+ Optional helper: return a default weights file based on the base SD pipeline id.
57
+ You can also use a single shared file 'classifiers/safety_classifier_1280.pth'.
58
+ """
59
+ name = str(getattr(pipe, "_internal_dict", {}).get("_name_or_path", "")).lower()
60
+ # Adjust logic as you like — default to a single shared file:
61
+ return "classifiers/safety_classifier_1280.pth"
utils/compose.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/compose.py
2
+ import torch
3
+ from diffusers import StableDiffusionPipeline
4
+ from safetensors.torch import load_file
5
+
6
+ def load_and_patch_sd_pipeline(repo_id, unet_weights_path, dtype=torch.float16, device="cuda"):
7
+ """
8
+ Load a base SD pipeline and patch its UNet with ESD/UCE weights.
9
+ """
10
+ pipe = StableDiffusionPipeline.from_pretrained(
11
+ repo_id, torch_dtype=dtype, safety_checker=None
12
+ ).to(device)
13
+
14
+ # Load patch state dict
15
+ if unet_weights_path.endswith(".safetensors"):
16
+ patch = load_file(unet_weights_path)
17
+ else:
18
+ patch = torch.load(unet_weights_path, map_location="cpu")
19
+
20
+ sd = pipe.unet.state_dict()
21
+ sd.update(patch)
22
+ pipe.unet.load_state_dict(sd, strict=True)
23
+ return pipe