HARRY07979 commited on
Commit
6b3fea6
·
verified ·
1 Parent(s): 6d24a2c

Update safety_checker/StableDiffusionSafetyChecker.py

Browse files
safety_checker/StableDiffusionSafetyChecker.py CHANGED
@@ -7,8 +7,10 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
7
 
8
  def __init__(self, config: CLIPConfig):
9
  super().__init__(config)
10
- self.vision_model = CLIPVisionModel(config)
11
- self.visual_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
 
 
12
  self.register_buffer("concept_embeds", torch.ones(1, 17, config.projection_dim))
13
  self.register_buffer("special_care_embeds", torch.ones(1, 3, config.projection_dim))
14
  self.register_buffer("concept_embeds_weights", torch.ones(1, 17))
@@ -16,10 +18,11 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
16
 
17
  @torch.no_grad()
18
  def forward(self, clip_input, images):
19
- pooled_output = self.vision_model(clip_input)[1]
20
  image_embeds = self.visual_projection(pooled_output)
21
  image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
22
 
 
23
  special_cos_dist = torch.mm(image_embeds, self.special_care_embeds[0].t())
24
  cos_dist = torch.mm(image_embeds, self.concept_embeds[0].t())
25
 
@@ -28,6 +31,6 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
28
  concept_idx = (cos_dist[i] > self.concept_embeds_weights[0]).any().item()
29
  has_nsfw_concepts.append(concept_idx)
30
  if concept_idx:
31
- images[i] = torch.zeros_like(images[i])
32
 
33
  return images, has_nsfw_concepts
 
7
 
8
  def __init__(self, config: CLIPConfig):
9
  super().__init__(config)
10
+ # Truy cập hidden_size từ vision_config thay vì từ config gốc
11
+ self.vision_model = CLIPVisionModel(config.vision_config)
12
+ self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False)
13
+
14
  self.register_buffer("concept_embeds", torch.ones(1, 17, config.projection_dim))
15
  self.register_buffer("special_care_embeds", torch.ones(1, 3, config.projection_dim))
16
  self.register_buffer("concept_embeds_weights", torch.ones(1, 17))
 
18
 
19
  @torch.no_grad()
20
  def forward(self, clip_input, images):
21
+ pooled_output = self.vision_model(clip_input)[1] # Get pooled_output
22
  image_embeds = self.visual_projection(pooled_output)
23
  image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
24
 
25
+ # Logic lọc nội dung
26
  special_cos_dist = torch.mm(image_embeds, self.special_care_embeds[0].t())
27
  cos_dist = torch.mm(image_embeds, self.concept_embeds[0].t())
28
 
 
31
  concept_idx = (cos_dist[i] > self.concept_embeds_weights[0]).any().item()
32
  has_nsfw_concepts.append(concept_idx)
33
  if concept_idx:
34
+ images[i] = torch.zeros_like(images[i]) # Trả về ảnh đen nếu dính NSFW
35
 
36
  return images, has_nsfw_concepts