import torch import torch.nn as nn from transformers import PreTrainedModel, AutoModel from .configuration_guardito import GuarditoConfig class GuarditoForSequenceClassification(PreTrainedModel): config_class = GuarditoConfig def __init__(self, config): super().__init__(config) self.config = config self.backbone = AutoModel.from_pretrained(config.base_model_name) self.dropout = nn.Dropout(config.dropout) self.classifier = nn.Linear(self.backbone.config.hidden_size, 1) self.post_init() def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): kwargs.pop("token_type_ids", None) outputs = self.backbone( input_ids=input_ids, attention_mask=attention_mask, return_dict=True, **kwargs ) last_hidden_state = outputs.last_hidden_state # Mean Pooling input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1) sum_mask = input_mask_expanded.sum(1) sum_mask = torch.clamp(sum_mask, min=1e-9) pooled_output = sum_embeddings / sum_mask pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) loss = None if labels is not None: loss_fct = nn.BCEWithLogitsLoss() loss = loss_fct(logits.view(-1), labels.float().view(-1)) return { "loss": loss, "logits": logits, } # Вспомогательный метод для получения вероятности и флага PII def predict_pii(self, input_ids, attention_mask=None, custom_threshold=None): outputs = self.forward(input_ids, attention_mask) probs = torch.sigmoid(outputs["logits"]) # Если порог не передан в метод, берем из конфига threshold = custom_threshold if custom_threshold is not None else self.config.threshold return { "probs": probs, "is_pii": probs >= threshold }