File size: 3,421 Bytes
98a901e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import torch.nn as nn

from transformers import (
    PreTrainedModel,
    T5ForConditionalGeneration,
    T5Config,
    AutoConfig,
    AutoModel,
)
from transformers.configuration_utils import PretrainedConfig


# ============================================================
# Configuration
# ============================================================

class CaputemendatorisConfig(PretrainedConfig):
    model_type = "caputemendatoris"

    def __init__(
        self,
        byt5_config=None,
        max_position_embeddings=256,
        detector_hidden_dim=512,
        **kwargs,
    ):
        super().__init__(**kwargs)

        # Must allow None during save_pretrained() diff construction
        self.byt5_config = byt5_config
        self.max_position_embeddings = max_position_embeddings
        self.detector_hidden_dim = detector_hidden_dim

    def validate(self):
        if self.byt5_config is None:
            raise ValueError(
                "Invalid Caputemendatoris config: byt5_config missing."
            )


# ============================================================
# Model
# ============================================================

class Caputemendatoris(PreTrainedModel):
    config_class = CaputemendatorisConfig
    base_model_prefix = "caputemendatoris"

    def __init__(self, config: CaputemendatorisConfig):
        super().__init__(config)

        # enforce real configuration during actual loading
        if config.byt5_config is None:
            raise ValueError(
                "Caputemendatoris loaded without embedded ByT5 configuration."
            )

        # reconstruct finetuned ByT5
        t5_config = T5Config(**config.byt5_config)
        self.t5 = T5ForConditionalGeneration(t5_config)
        self.encoder = self.t5.encoder

        d_model = self.t5.config.d_model

        # positional embedding (matches your training)
        self.pos_emb = nn.Embedding(
            config.max_position_embeddings,
            d_model,
        )

        # detection head (identical to training architecture)
        self.head = nn.Sequential(
            nn.Linear(2 * d_model, config.detector_hidden_dim),
            nn.LayerNorm(config.detector_hidden_dim),
            nn.GELU(),
            nn.Linear(config.detector_hidden_dim, 1),
        )

        self.post_init()

    # ---------------- detection ----------------

    def detect(self, input_ids, attention_mask=None):
        enc = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )

        hidden = enc.last_hidden_state
        B, T, _ = hidden.shape

        pos_ids = torch.arange(
            T, device=input_ids.device
        ).unsqueeze(0).expand(B, T)

        pos = self.pos_emb(pos_ids)

        h = torch.cat([hidden, pos], dim=-1)

        return torch.sigmoid(self.head(h).squeeze(-1))

    # forward = detector
    def forward(self, input_ids=None, attention_mask=None, **kwargs):
        return self.detect(input_ids, attention_mask)

    # correction
    def generate(self, **kwargs):
        return self.t5.generate(**kwargs)


# ============================================================
# Registration (required for AutoModel)
# ============================================================

AutoConfig.register("caputemendatoris", CaputemendatorisConfig)
AutoModel.register(CaputemendatorisConfig, Caputemendatoris)