CaputEmendatoris / modeling_caputemendatoris.py
aimgo's picture
Upload 10 files
98a901e verified
import torch
import torch.nn as nn
from transformers import (
PreTrainedModel,
T5ForConditionalGeneration,
T5Config,
AutoConfig,
AutoModel,
)
from transformers.configuration_utils import PretrainedConfig
# ============================================================
# Configuration
# ============================================================
class CaputemendatorisConfig(PretrainedConfig):
model_type = "caputemendatoris"
def __init__(
self,
byt5_config=None,
max_position_embeddings=256,
detector_hidden_dim=512,
**kwargs,
):
super().__init__(**kwargs)
# Must allow None during save_pretrained() diff construction
self.byt5_config = byt5_config
self.max_position_embeddings = max_position_embeddings
self.detector_hidden_dim = detector_hidden_dim
def validate(self):
if self.byt5_config is None:
raise ValueError(
"Invalid Caputemendatoris config: byt5_config missing."
)
# ============================================================
# Model
# ============================================================
class Caputemendatoris(PreTrainedModel):
config_class = CaputemendatorisConfig
base_model_prefix = "caputemendatoris"
def __init__(self, config: CaputemendatorisConfig):
super().__init__(config)
# enforce real configuration during actual loading
if config.byt5_config is None:
raise ValueError(
"Caputemendatoris loaded without embedded ByT5 configuration."
)
# reconstruct finetuned ByT5
t5_config = T5Config(**config.byt5_config)
self.t5 = T5ForConditionalGeneration(t5_config)
self.encoder = self.t5.encoder
d_model = self.t5.config.d_model
# positional embedding (matches your training)
self.pos_emb = nn.Embedding(
config.max_position_embeddings,
d_model,
)
# detection head (identical to training architecture)
self.head = nn.Sequential(
nn.Linear(2 * d_model, config.detector_hidden_dim),
nn.LayerNorm(config.detector_hidden_dim),
nn.GELU(),
nn.Linear(config.detector_hidden_dim, 1),
)
self.post_init()
# ---------------- detection ----------------
def detect(self, input_ids, attention_mask=None):
enc = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
)
hidden = enc.last_hidden_state
B, T, _ = hidden.shape
pos_ids = torch.arange(
T, device=input_ids.device
).unsqueeze(0).expand(B, T)
pos = self.pos_emb(pos_ids)
h = torch.cat([hidden, pos], dim=-1)
return torch.sigmoid(self.head(h).squeeze(-1))
# forward = detector
def forward(self, input_ids=None, attention_mask=None, **kwargs):
return self.detect(input_ids, attention_mask)
# correction
def generate(self, **kwargs):
return self.t5.generate(**kwargs)
# ============================================================
# Registration (required for AutoModel)
# ============================================================
AutoConfig.register("caputemendatoris", CaputemendatorisConfig)
AutoModel.register(CaputemendatorisConfig, Caputemendatoris)