import torch import torch.nn as nn from transformers import ( PreTrainedModel, T5ForConditionalGeneration, T5Config, AutoConfig, AutoModel, ) from transformers.configuration_utils import PretrainedConfig # ============================================================ # Configuration # ============================================================ class CaputemendatorisConfig(PretrainedConfig): model_type = "caputemendatoris" def __init__( self, byt5_config=None, max_position_embeddings=256, detector_hidden_dim=512, **kwargs, ): super().__init__(**kwargs) # Must allow None during save_pretrained() diff construction self.byt5_config = byt5_config self.max_position_embeddings = max_position_embeddings self.detector_hidden_dim = detector_hidden_dim def validate(self): if self.byt5_config is None: raise ValueError( "Invalid Caputemendatoris config: byt5_config missing." ) # ============================================================ # Model # ============================================================ class Caputemendatoris(PreTrainedModel): config_class = CaputemendatorisConfig base_model_prefix = "caputemendatoris" def __init__(self, config: CaputemendatorisConfig): super().__init__(config) # enforce real configuration during actual loading if config.byt5_config is None: raise ValueError( "Caputemendatoris loaded without embedded ByT5 configuration." ) # reconstruct finetuned ByT5 t5_config = T5Config(**config.byt5_config) self.t5 = T5ForConditionalGeneration(t5_config) self.encoder = self.t5.encoder d_model = self.t5.config.d_model # positional embedding (matches your training) self.pos_emb = nn.Embedding( config.max_position_embeddings, d_model, ) # detection head (identical to training architecture) self.head = nn.Sequential( nn.Linear(2 * d_model, config.detector_hidden_dim), nn.LayerNorm(config.detector_hidden_dim), nn.GELU(), nn.Linear(config.detector_hidden_dim, 1), ) self.post_init() # ---------------- detection ---------------- def detect(self, input_ids, attention_mask=None): enc = self.encoder( input_ids=input_ids, attention_mask=attention_mask, ) hidden = enc.last_hidden_state B, T, _ = hidden.shape pos_ids = torch.arange( T, device=input_ids.device ).unsqueeze(0).expand(B, T) pos = self.pos_emb(pos_ids) h = torch.cat([hidden, pos], dim=-1) return torch.sigmoid(self.head(h).squeeze(-1)) # forward = detector def forward(self, input_ids=None, attention_mask=None, **kwargs): return self.detect(input_ids, attention_mask) # correction def generate(self, **kwargs): return self.t5.generate(**kwargs) # ============================================================ # Registration (required for AutoModel) # ============================================================ AutoConfig.register("caputemendatoris", CaputemendatorisConfig) AutoModel.register(CaputemendatorisConfig, Caputemendatoris)