HARRY07979 commited on
Commit
be07eac
·
verified ·
1 Parent(s): 6b3fea6

Update safety_checker/StableDiffusionSafetyChecker.py

Browse files
safety_checker/StableDiffusionSafetyChecker.py CHANGED
@@ -7,30 +7,31 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
7
 
8
  def __init__(self, config: CLIPConfig):
9
  super().__init__(config)
10
- # Truy cập hidden_size từ vision_config thay vì từ config gốc
11
  self.vision_model = CLIPVisionModel(config.vision_config)
12
  self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False)
13
 
14
- self.register_buffer("concept_embeds", torch.ones(1, 17, config.projection_dim))
15
- self.register_buffer("special_care_embeds", torch.ones(1, 3, config.projection_dim))
16
- self.register_buffer("concept_embeds_weights", torch.ones(1, 17))
17
- self.register_buffer("special_care_embeds_weights", torch.ones(1, 3))
 
 
18
 
19
  @torch.no_grad()
20
  def forward(self, clip_input, images):
21
- pooled_output = self.vision_model(clip_input)[1] # Get pooled_output
22
  image_embeds = self.visual_projection(pooled_output)
23
  image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
24
 
25
- # Logic lọc nội dung
26
- special_cos_dist = torch.mm(image_embeds, self.special_care_embeds[0].t())
27
- cos_dist = torch.mm(image_embeds, self.concept_embeds[0].t())
28
 
29
  has_nsfw_concepts = []
30
  for i in range(image_embeds.shape[0]):
31
- concept_idx = (cos_dist[i] > self.concept_embeds_weights[0]).any().item()
32
  has_nsfw_concepts.append(concept_idx)
33
  if concept_idx:
34
- images[i] = torch.zeros_like(images[i]) # Trả về ảnh đen nếu dính NSFW
35
 
36
  return images, has_nsfw_concepts
 
7
 
8
  def __init__(self, config: CLIPConfig):
9
  super().__init__(config)
 
10
  self.vision_model = CLIPVisionModel(config.vision_config)
11
  self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False)
12
 
13
+ # SỬA TẠI ĐÂY: Bỏ số 1 ở đầu để khớp với shape [17, 768] và [3, 768]
14
+ self.register_buffer("concept_embeds", torch.ones(17, config.projection_dim))
15
+ self.register_buffer("special_care_embeds", torch.ones(3, config.projection_dim))
16
+
17
+ self.register_buffer("concept_embeds_weights", torch.ones(17))
18
+ self.register_buffer("special_care_embeds_weights", torch.ones(3))
19
 
20
  @torch.no_grad()
21
  def forward(self, clip_input, images):
22
+ pooled_output = self.vision_model(clip_input)[1]
23
  image_embeds = self.visual_projection(pooled_output)
24
  image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
25
 
26
+ # Sửa logic nhân ma trận để khớp với shape mới
27
+ special_cos_dist = torch.mm(image_embeds, self.special_care_embeds.t())
28
+ cos_dist = torch.mm(image_embeds, self.concept_embeds.t())
29
 
30
  has_nsfw_concepts = []
31
  for i in range(image_embeds.shape[0]):
32
+ concept_idx = (cos_dist[i] > self.concept_embeds_weights).any().item()
33
  has_nsfw_concepts.append(concept_idx)
34
  if concept_idx:
35
+ images[i] = torch.zeros_like(images[i])
36
 
37
  return images, has_nsfw_concepts