File size: 1,713 Bytes
1b13b35
 
 
 
 
 
 
 
 
6b3fea6
 
 
be07eac
 
 
 
 
 
1b13b35
 
 
be07eac
1b13b35
 
 
be07eac
 
 
1b13b35
 
 
be07eac
1b13b35
 
be07eac
1b13b35
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import torch.nn as nn
from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel

class StableDiffusionSafetyChecker(PreTrainedModel):
    config_class = CLIPConfig

    def __init__(self, config: CLIPConfig):
        super().__init__(config)
        self.vision_model = CLIPVisionModel(config.vision_config)
        self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False)
        
        # SỬA TẠI ĐÂY: Bỏ số 1 ở đầu để khớp với shape [17, 768] và [3, 768]
        self.register_buffer("concept_embeds", torch.ones(17, config.projection_dim))
        self.register_buffer("special_care_embeds", torch.ones(3, config.projection_dim))
        
        self.register_buffer("concept_embeds_weights", torch.ones(17))
        self.register_buffer("special_care_embeds_weights", torch.ones(3))

    @torch.no_grad()
    def forward(self, clip_input, images):
        pooled_output = self.vision_model(clip_input)[1]
        image_embeds = self.visual_projection(pooled_output)
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

        # Sửa logic nhân ma trận để khớp với shape mới
        special_cos_dist = torch.mm(image_embeds, self.special_care_embeds.t())
        cos_dist = torch.mm(image_embeds, self.concept_embeds.t())

        has_nsfw_concepts = []
        for i in range(image_embeds.shape[0]):
            concept_idx = (cos_dist[i] > self.concept_embeds_weights).any().item()
            has_nsfw_concepts.append(concept_idx)
            if concept_idx:
                images[i] = torch.zeros_like(images[i])
        
        return images, has_nsfw_concepts