basimazam commited on
Commit
bad5b88
·
verified ·
1 Parent(s): 31eb182

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. safe_diffusion_guidance.py +60 -8
safe_diffusion_guidance.py CHANGED
@@ -3,17 +3,69 @@ import torch
3
  from typing import Optional, List
4
  from diffusers import DiffusionPipeline, StableDiffusionPipeline
5
  from diffusers.utils import BaseOutput
6
- # at the top of safe_diffusion_guidance.py
7
 
8
- import os, sys
9
- FILE_DIR = os.path.dirname(__file__)
10
- UTILS_DIR = os.path.join(FILE_DIR, "utils")
11
- if UTILS_DIR not in sys.path:
12
- sys.path.insert(0, UTILS_DIR)
 
13
 
14
- # from adaptive_classifiers import load_classifier_1280, pick_weights_for_pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- from utils.adaptive_classifiers import load_classifier_1280, pick_weights_for_pipe
17
 
18
  class SDGOutput(BaseOutput):
19
  images: List
 
3
  from typing import Optional, List
4
  from diffusers import DiffusionPipeline, StableDiffusionPipeline
5
  from diffusers.utils import BaseOutput
 
6
 
7
+ # utils/adaptive_classifiers.py
8
+ import torch
9
+ import torch.nn as nn
10
+ from typing import Optional
11
+
12
+ CLASS_NAMES = ['gore', 'hate', 'medical', 'safe', 'sexual']
13
 
14
+ class SafetyClassifier1280(nn.Module):
15
+ """
16
+ Unified safety classifier for mid-UNet features of shape (B, 1280, H, W).
17
+ Robust to variable HxW via AdaptiveAvgPool2d((8,8)) before the head.
18
+ """
19
+ def __init__(self, num_classes: int = 5):
20
+ super().__init__()
21
+ self.pre = nn.AdaptiveAvgPool2d((8, 8))
22
+ self.net = nn.Sequential(
23
+ nn.Conv2d(1280, 512, 3, padding=1),
24
+ nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.MaxPool2d(2), # 512 x 4 x 4
25
+ nn.Conv2d(512, 256, 3, padding=1),
26
+ nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.MaxPool2d(2), # 256 x 2 x 2
27
+ nn.AdaptiveAvgPool2d(1), nn.Flatten(), # 256
28
+ nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Dropout(0.3),
29
+ nn.Linear(128, num_classes)
30
+ )
31
+ self.apply(self._init_weights)
32
+
33
+ @staticmethod
34
+ def _init_weights(m):
35
+ if isinstance(m, nn.Linear):
36
+ nn.init.xavier_uniform_(m.weight); nn.init.zeros_(m.bias)
37
+ elif isinstance(m, nn.Conv2d):
38
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
39
+ if m.bias is not None: nn.init.zeros_(m.bias)
40
+ elif isinstance(m, nn.BatchNorm2d):
41
+ nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
42
+
43
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
44
+ x = self.pre(x) # (B, 1280, 8, 8)
45
+ return self.net(x)
46
+
47
+ def load_classifier_1280(
48
+ weights_path: str,
49
+ device: Optional[torch.device] = None,
50
+ dtype: torch.dtype = torch.float32
51
+ ) -> SafetyClassifier1280:
52
+ model = SafetyClassifier1280().to(device or "cpu", dtype=dtype)
53
+ state = torch.load(weights_path, map_location="cpu")
54
+ if isinstance(state, dict) and "model_state_dict" in state:
55
+ state = state["model_state_dict"]
56
+ model.load_state_dict(state, strict=True)
57
+ model.eval()
58
+ return model
59
+
60
+ def pick_weights_for_pipe(pipe) -> str:
61
+ """
62
+ Optional helper: return a default weights file based on the base SD pipeline id.
63
+ You can also use a single shared file 'classifiers/safety_classifier_1280.pth'.
64
+ """
65
+ name = str(getattr(pipe, "_internal_dict", {}).get("_name_or_path", "")).lower()
66
+ # Adjust logic as you like — default to a single shared file:
67
+ return "classifiers/safety_classifier_1280.pth"
68
 
 
69
 
70
  class SDGOutput(BaseOutput):
71
  images: List