| | """Dual-head safety classifier: binary (safe/unsafe) + multi-label categories.""" |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from transformers import AutoModel |
| |
|
| |
|
| | class SafetyClassifier(nn.Module): |
| | def __init__(self, base_model_name: str, num_categories: int = 7): |
| | super().__init__() |
| | self.backbone = AutoModel.from_pretrained(base_model_name) |
| | hidden = self.backbone.config.hidden_size |
| |
|
| | self.binary_head = nn.Sequential( |
| | nn.Dropout(0.1), |
| | nn.Linear(hidden, 1), |
| | ) |
| | self.category_head = nn.Sequential( |
| | nn.Dropout(0.1), |
| | nn.Linear(hidden, num_categories), |
| | ) |
| |
|
| | def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor): |
| | outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask) |
| | cls_embed = outputs.last_hidden_state[:, 0] |
| |
|
| | binary_logit = self.binary_head(cls_embed) |
| | category_logits = self.category_head(cls_embed) |
| |
|
| | return binary_logit, category_logits |
| |
|
| | def predict(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> dict: |
| | self.eval() |
| | with torch.no_grad(): |
| | binary_logit, category_logits = self.forward(input_ids, attention_mask) |
| |
|
| | unsafe_score = torch.sigmoid(binary_logit).squeeze(-1) |
| | category_scores = torch.sigmoid(category_logits) |
| |
|
| | return { |
| | "unsafe_score": unsafe_score, |
| | "category_scores": category_scores, |
| | } |
| |
|