| import torch |
| import torch.nn as nn |
|
|
| from transformers import ( |
| PreTrainedModel, |
| T5ForConditionalGeneration, |
| T5Config, |
| AutoConfig, |
| AutoModel, |
| ) |
| from transformers.configuration_utils import PretrainedConfig |
|
|
|
|
| |
| |
| |
|
|
| class CaputemendatorisConfig(PretrainedConfig): |
| model_type = "caputemendatoris" |
|
|
| def __init__( |
| self, |
| byt5_config=None, |
| max_position_embeddings=256, |
| detector_hidden_dim=512, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
|
|
| |
| self.byt5_config = byt5_config |
| self.max_position_embeddings = max_position_embeddings |
| self.detector_hidden_dim = detector_hidden_dim |
|
|
| def validate(self): |
| if self.byt5_config is None: |
| raise ValueError( |
| "Invalid Caputemendatoris config: byt5_config missing." |
| ) |
|
|
|
|
| |
| |
| |
|
|
| class Caputemendatoris(PreTrainedModel): |
| config_class = CaputemendatorisConfig |
| base_model_prefix = "caputemendatoris" |
|
|
| def __init__(self, config: CaputemendatorisConfig): |
| super().__init__(config) |
|
|
| |
| if config.byt5_config is None: |
| raise ValueError( |
| "Caputemendatoris loaded without embedded ByT5 configuration." |
| ) |
|
|
| |
| t5_config = T5Config(**config.byt5_config) |
| self.t5 = T5ForConditionalGeneration(t5_config) |
| self.encoder = self.t5.encoder |
|
|
| d_model = self.t5.config.d_model |
|
|
| |
| self.pos_emb = nn.Embedding( |
| config.max_position_embeddings, |
| d_model, |
| ) |
|
|
| |
| self.head = nn.Sequential( |
| nn.Linear(2 * d_model, config.detector_hidden_dim), |
| nn.LayerNorm(config.detector_hidden_dim), |
| nn.GELU(), |
| nn.Linear(config.detector_hidden_dim, 1), |
| ) |
|
|
| self.post_init() |
|
|
| |
|
|
| def detect(self, input_ids, attention_mask=None): |
| enc = self.encoder( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| ) |
|
|
| hidden = enc.last_hidden_state |
| B, T, _ = hidden.shape |
|
|
| pos_ids = torch.arange( |
| T, device=input_ids.device |
| ).unsqueeze(0).expand(B, T) |
|
|
| pos = self.pos_emb(pos_ids) |
|
|
| h = torch.cat([hidden, pos], dim=-1) |
|
|
| return torch.sigmoid(self.head(h).squeeze(-1)) |
|
|
| |
| def forward(self, input_ids=None, attention_mask=None, **kwargs): |
| return self.detect(input_ids, attention_mask) |
|
|
| |
| def generate(self, **kwargs): |
| return self.t5.generate(**kwargs) |
|
|
|
|
| |
| |
| |
|
|
| AutoConfig.register("caputemendatoris", CaputemendatorisConfig) |
| AutoModel.register(CaputemendatorisConfig, Caputemendatoris) |