basimazam commited on
Commit
855692d
·
verified ·
1 Parent(s): b8877ca

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. model_index.json +3 -9
  2. safe_diffusion_guidance.py +54 -70
model_index.json CHANGED
@@ -1,11 +1,5 @@
1
  {
2
- "library_name": "diffusers",
3
- "pipeline": "StableDiffusionPipeline",
4
- "tags": [
5
- "safety",
6
- "classifier-guidance",
7
- "stable-diffusion",
8
- "plug-and-play"
9
- ],
10
- "inference": "Use `custom_pipeline='safe_diffusion_guidance'` and pass your base SD pipeline via `base_pipe=...`."
11
  }
 
1
  {
2
+ "_class_name": "SafeDiffusionGuidance",
3
+ "_diffusers_version": "0.29.0",
4
+ "custom_pipeline": "safe_diffusion_guidance"
 
 
 
 
 
 
5
  }
safe_diffusion_guidance.py CHANGED
@@ -1,40 +1,30 @@
1
  # safe_diffusion_guidance.py
2
  import torch
3
  from typing import Optional, List
4
- from diffusers import StableDiffusionPipeline, DiffusionPipeline
5
  from diffusers.utils import BaseOutput
6
 
7
- from utils.adaptive_classifiers import (
8
- load_classifier_1280, pick_weights_for_pipe
9
- )
10
 
11
  class SDGOutput(BaseOutput):
12
  images: List
13
 
14
- class SafeDiffusionGuidancePipeline(StableDiffusionPipeline):
15
  """
16
- Plug-and-play safety guidance for Stable Diffusion.
17
- - If `base_pipe` is None, we auto-load the base SD checkpoint specified by `base_model_id`.
18
- - Classifier weights are shipped in this repo; no extra installs required.
19
-
20
- Typical use:
21
- sdg = DiffusionPipeline.from_pretrained("your-org/safe-diffusion-guidance",
22
- custom_pipeline="safe_diffusion_guidance",
23
- torch_dtype=torch.float16).to("cuda")
24
- out = sdg(prompt="...", base_model_id="runwayml/stable-diffusion-v1-5",
25
- safety_scale=5.0, mid_fraction=1.0, safe_class_index=3)
26
  """
27
 
28
- def __init__(self, *args, **kwargs):
29
- super().__init__(*args, **kwargs)
30
- self.base_pipe_ = None # lazy-loaded base pipeline cache
 
31
 
32
  def _ensure_base(self, base_pipe, base_model_id, torch_dtype):
33
  if base_pipe is not None:
34
  self.base_pipe_ = base_pipe
35
  return self.base_pipe_
36
  if self.base_pipe_ is None:
37
- # lazy load chosen SD checkpoint
38
  self.base_pipe_ = StableDiffusionPipeline.from_pretrained(
39
  base_model_id,
40
  torch_dtype=torch_dtype,
@@ -52,92 +42,86 @@ class SafeDiffusionGuidancePipeline(StableDiffusionPipeline):
52
  mid_fraction: float = 1.0,
53
  safe_class_index: int = 3,
54
  classifier_weights: Optional[str] = None,
55
- # new convenience arg (optional):
56
  base_pipe: Optional[StableDiffusionPipeline] = None,
57
  base_model_id: str = "runwayml/stable-diffusion-v1-5",
58
  generator: Optional[torch.Generator] = None,
59
  **kwargs
60
  ) -> SDGOutput:
61
 
62
- # 0) choose / load base SD
63
- base = self._ensure_base(base_pipe, base_model_id, torch_dtype=self.unet.dtype)
64
- device = base._execution_device if hasattr(base, "_execution_device") else base.device
65
- dtype = base.unet.dtype
66
 
67
- # 1) Text embeddings (CFG)
68
  tok = base.tokenizer
69
  max_len = tok.model_max_length
70
- text_inputs = tok([prompt], padding="max_length", max_length=max_len, return_tensors="pt")
71
- text_embeds = base.text_encoder(text_inputs.input_ids.to(device)).last_hidden_state
72
-
73
  if negative_prompt is not None:
74
- uncond_inputs = tok([negative_prompt], padding="max_length", max_length=max_len, return_tensors="pt")
75
  else:
76
- uncond_inputs = tok([""], padding="max_length", max_length=max_len, return_tensors="pt")
77
- uncond_embeds = base.text_encoder(uncond_inputs.input_ids.to(device)).last_hidden_state
78
- cond_embeds = torch.cat([uncond_embeds, text_embeds], dim=0)
79
 
80
- # 2) Latent init
81
- height = kwargs.pop("height", 512); width = kwargs.pop("width", 512)
82
- latents = torch.randn(
83
- (1, base.unet.in_channels, height // 8, width // 8),
84
- device=device, generator=generator, dtype=dtype
85
- )
86
 
87
  base.scheduler.set_timesteps(num_inference_steps, device=device)
88
  timesteps = base.scheduler.timesteps
89
 
90
- # 3) Load classifier (fp32) from this repo
91
- weights_path = classifier_weights or pick_weights_for_pipe(base)
92
- classifier = load_classifier_1280(weights_path, device=device, dtype=torch.float32).eval()
93
 
94
- # 4) Hook mid-block features
95
- mid_cache = {}
96
- def mid_hook(module, inputs, output):
97
- mid_cache["feat"] = output[0] if isinstance(output, tuple) else output
98
- h = base.unet.mid_block.register_forward_hook(mid_hook)
99
 
100
- base_alpha = 1e-3 # small step
101
 
102
  with torch.no_grad():
103
  for i, t in enumerate(timesteps):
104
- latent_in = base.scheduler.scale_model_input(latents, t)
105
- latent_pair = torch.cat([latent_in, latent_in], dim=0)
106
 
107
  do_guide = (i / len(timesteps)) <= mid_fraction and safety_scale > 0
108
  if do_guide:
109
  with torch.enable_grad():
110
- latents_g = latents.detach().clone().requires_grad_(True)
111
- latent_in_g = base.scheduler.scale_model_input(latents_g, t)
112
- latent_pair_g = torch.cat([latent_in_g, latent_in_g], dim=0)
113
 
114
- _ = base.unet(latent_pair_g, t, encoder_hidden_states=cond_embeds).sample
115
- feat = mid_cache["feat"].detach().float()
116
- logits = classifier(feat)
117
  probs = torch.softmax(logits, dim=-1)
118
- unsafe_prob = 1.0 - probs[:, safe_class_index].mean()
119
 
120
- loss = safety_scale * unsafe_prob
121
  loss.backward()
122
 
123
  alpha = base_alpha
124
  if hasattr(base.scheduler, "sigmas"):
125
- step_idx = min(i, len(base.scheduler.sigmas) - 1)
126
- alpha = base_alpha * float(base.scheduler.sigmas[step_idx])
127
- latents = (latents_g - alpha * latents_g.grad).detach()
128
 
129
- latent_in = base.scheduler.scale_model_input(latents, t)
130
- latent_pair = torch.cat([latent_in, latent_in], dim=0)
131
- noise_pred = base.unet(latent_pair, t, encoder_hidden_states=cond_embeds).sample
132
  else:
133
- noise_pred = base.unet(latent_pair, t, encoder_hidden_states=cond_embeds).sample
134
 
135
- noise_uncond, noise_text = noise_pred.chunk(2)
136
- noise = noise_uncond + guidance_scale * (noise_text - noise_uncond)
137
  latents = base.scheduler.step(noise, t, latents).prev_sample
138
 
139
- h.remove()
140
  with torch.no_grad():
141
- image = base.decode_latents(latents)
142
- image = base.image_processor.postprocess(image, output_type="pil")[0]
143
- return SDGOutput(images=[image])
 
1
  # safe_diffusion_guidance.py
2
  import torch
3
  from typing import Optional, List
4
+ from diffusers import DiffusionPipeline, StableDiffusionPipeline
5
  from diffusers.utils import BaseOutput
6
 
7
+ from utils.adaptive_classifiers import load_classifier_1280, pick_weights_for_pipe
 
 
8
 
9
  class SDGOutput(BaseOutput):
10
  images: List
11
 
12
+ class SafeDiffusionGuidance(DiffusionPipeline):
13
  """
14
+ Pure custom pipeline. No pre-saved components in the repo.
15
+ It auto-loads a base SD pipeline if `base_pipe` is None.
 
 
 
 
 
 
 
 
16
  """
17
 
18
+ def __init__(self, **kwargs):
19
+ # Accept any extra kwargs Diffusers might pass; we ignore them.
20
+ super().__init__()
21
+ self.base_pipe_ = None # lazy cache
22
 
23
  def _ensure_base(self, base_pipe, base_model_id, torch_dtype):
24
  if base_pipe is not None:
25
  self.base_pipe_ = base_pipe
26
  return self.base_pipe_
27
  if self.base_pipe_ is None:
 
28
  self.base_pipe_ = StableDiffusionPipeline.from_pretrained(
29
  base_model_id,
30
  torch_dtype=torch_dtype,
 
42
  mid_fraction: float = 1.0,
43
  safe_class_index: int = 3,
44
  classifier_weights: Optional[str] = None,
 
45
  base_pipe: Optional[StableDiffusionPipeline] = None,
46
  base_model_id: str = "runwayml/stable-diffusion-v1-5",
47
  generator: Optional[torch.Generator] = None,
48
  **kwargs
49
  ) -> SDGOutput:
50
 
51
+ base = self._ensure_base(base_pipe, base_model_id, torch_dtype=torch.float16)
52
+ device = getattr(base, "_execution_device", base.device)
53
+ dtype = base.unet.dtype
 
54
 
55
+ # text embeds (CFG)
56
  tok = base.tokenizer
57
  max_len = tok.model_max_length
58
+ txt = tok([prompt], padding="max_length", max_length=max_len, return_tensors="pt")
59
+ cond = base.text_encoder(txt.input_ids.to(device)).last_hidden_state
 
60
  if negative_prompt is not None:
61
+ uncond_txt = tok([negative_prompt], padding="max_length", max_length=max_len, return_tensors="pt")
62
  else:
63
+ uncond_txt = tok([""], padding="max_length", max_length=max_len, return_tensors="pt")
64
+ uncond = base.text_encoder(uncond_txt.input_ids.to(device)).last_hidden_state
65
+ cond_embeds = torch.cat([uncond, cond], dim=0)
66
 
67
+ # latents
68
+ h = kwargs.pop("height", 512); w = kwargs.pop("width", 512)
69
+ latents = torch.randn((1, base.unet.in_channels, h // 8, w // 8),
70
+ device=device, generator=generator, dtype=dtype)
 
 
71
 
72
  base.scheduler.set_timesteps(num_inference_steps, device=device)
73
  timesteps = base.scheduler.timesteps
74
 
75
+ # classifier (fp32)
76
+ weights = classifier_weights or pick_weights_for_pipe(base)
77
+ clf = load_classifier_1280(weights, device=device, dtype=torch.float32).eval()
78
 
79
+ # mid-block hook
80
+ mid = {}
81
+ def hook(_, __, out): mid["feat"] = out[0] if isinstance(out, tuple) else out
82
+ handle = base.unet.mid_block.register_forward_hook(hook)
 
83
 
84
+ base_alpha = 1e-3
85
 
86
  with torch.no_grad():
87
  for i, t in enumerate(timesteps):
88
+ lat_in = base.scheduler.scale_model_input(latents, t)
89
+ lat_cat = torch.cat([lat_in, lat_in], dim=0)
90
 
91
  do_guide = (i / len(timesteps)) <= mid_fraction and safety_scale > 0
92
  if do_guide:
93
  with torch.enable_grad():
94
+ lg = latents.detach().clone().requires_grad_(True)
95
+ lin = base.scheduler.scale_model_input(lg, t)
96
+ lcat = torch.cat([lin, lin], dim=0)
97
 
98
+ _ = base.unet(lcat, t, encoder_hidden_states=cond_embeds).sample
99
+ feat = mid["feat"].detach().float()
100
+ logits = clf(feat)
101
  probs = torch.softmax(logits, dim=-1)
102
+ unsafe = 1.0 - probs[:, safe_class_index].mean()
103
 
104
+ loss = safety_scale * unsafe
105
  loss.backward()
106
 
107
  alpha = base_alpha
108
  if hasattr(base.scheduler, "sigmas"):
109
+ idx = min(i, len(base.scheduler.sigmas) - 1)
110
+ alpha = base_alpha * float(base.scheduler.sigmas[idx])
111
+ latents = (lg - alpha * lg.grad).detach()
112
 
113
+ lat_in = base.scheduler.scale_model_input(latents, t)
114
+ lat_cat = torch.cat([lat_in, lat_in], dim=0)
115
+ noise_pred = base.unet(lat_cat, t, encoder_hidden_states=cond_embeds).sample
116
  else:
117
+ noise_pred = base.unet(lat_cat, t, encoder_hidden_states=cond_embeds).sample
118
 
119
+ n_uncond, n_text = noise_pred.chunk(2)
120
+ noise = n_uncond + guidance_scale * (n_text - n_uncond)
121
  latents = base.scheduler.step(noise, t, latents).prev_sample
122
 
123
+ handle.remove()
124
  with torch.no_grad():
125
+ img = base.decode_latents(latents)
126
+ img = base.image_processor.postprocess(img, output_type="pil")[0]
127
+ return SDGOutput(images=[img])