File size: 6,944 Bytes
a2df0cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c855cdb
 
 
 
 
 
 
 
 
a2df0cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c855cdb
a2df0cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c855cdb
 
a2df0cc
c855cdb
 
 
 
 
a2df0cc
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""Prisma model for HuggingFace integration.



Usage:

    from transformers import AutoModelForCausalLM, AutoTokenizer



    model = AutoModelForCausalLM.from_pretrained("y3i12/Prisma", trust_remote_code=True)

    tokenizer = AutoTokenizer.from_pretrained("y3i12/Prisma")

"""

import torch
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast

from .configuration_prisma import PrismaConfig
from .mirrored import MirroredTransformer, MirroredConfig
from .layers import build_word_start_table, compute_word_positions


class PrismaForCausalLM(PreTrainedModel):
    """Prisma mirrored transformer for causal language modeling."""

    config_class = PrismaConfig
    _tied_weights_keys = ["transformer.lm_head.weight"]
    _no_split_modules = ["MirroredBlock", "MiddleBlock"]
    _keys_to_ignore_on_load_missing = [
        r"transformer\..*\.rotary\.inv_freq",
        r"transformer\..*\.word_rope\.word_inv_freq",
    ]

    def __init__(self, config: PrismaConfig):
        super().__init__(config)

        mirrored_config = MirroredConfig(
            vocab_size=config.vocab_size,
            hidden_size=config.hidden_size,
            num_heads=config.num_heads,
            num_kv_heads=config.num_kv_heads,
            num_layers=config.num_layers,
            n_middle=config.n_middle,
            max_seq_len=config.max_seq_len,
            dropout=config.dropout,
            aux_skip_k=config.aux_skip_k,
            aux_skip_weight=config.aux_skip_weight,
            use_g2lu=config.use_g2lu,
            word_rope_dims=config.word_rope_dims,
            word_rope_base=config.word_rope_base,
            embed_dim=config.embed_dim,
            head_dim=config.head_dim,
        )
        self.transformer = MirroredTransformer(mirrored_config)

        # Word-position table for WoRPE (populated by from_pretrained or set_tokenizer)
        if config.word_rope_dims > 0:
            self.register_buffer(
                "word_start_table",
                torch.zeros(config.vocab_size, dtype=torch.bool),
                persistent=True,
            )
        else:
            self.word_start_table = None

        # Track word position during autoregressive generation
        self._word_pos_counter = 0

        self.post_init()

    def set_tokenizer(self, tokenizer):
        """Build word_start_table from tokenizer. Call this if not loading from pretrained."""
        if self.config.word_rope_dims > 0:
            table = build_word_start_table(tokenizer, self.config.vocab_size)
            self.word_start_table = table.to(self.device)

    def get_input_embeddings(self):
        return self.transformer.embed

    def set_input_embeddings(self, value):
        self.transformer.embed = value

    def get_output_embeddings(self):
        return self.transformer.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.transformer.lm_head = new_embeddings

    def tie_weights(self):
        if self.config.tie_word_embeddings:
            embed_dim = self.config.embed_dim or self.config.hidden_size
            head_dim = self.config.head_dim or self.config.hidden_size
            if embed_dim == head_dim:
                self.transformer.lm_head.weight = self.transformer.embed.weight

    def forward(

        self,

        input_ids=None,

        attention_mask=None,

        past_key_values=None,

        labels=None,

        use_cache=False,

        return_dict=True,

        **kwargs,

    ):
        # Convert HF DynamicCache to our list-of-tuples format
        past_kv_list = None
        if past_key_values is not None:
            # Check if cache has actual content (not just pre-allocated empty layers)
            has_content = False
            if isinstance(past_key_values, (list, tuple)):
                has_content = len(past_key_values) > 0
                past_kv_list = past_key_values if has_content else None
            elif hasattr(past_key_values, 'get_seq_length'):
                has_content = past_key_values.get_seq_length() > 0
                if has_content:
                    past_kv_list = [past_key_values[i] for i in range(len(past_key_values))]

        # Compute word positions if WoRPE is enabled
        word_positions = None
        if self.word_start_table is not None and self.config.word_rope_dims > 0:
            if past_kv_list is not None and input_ids.size(1) == 1:
                # Cached generation: track word position step by step
                last_token = input_ids[0, -1].item()
                if self.word_start_table[last_token]:
                    self._word_pos_counter = 0
                else:
                    self._word_pos_counter += 1
                word_positions = torch.tensor(
                    [[float(self._word_pos_counter)]],
                    device=input_ids.device,
                )
            else:
                # Full sequence: compute all word positions
                word_positions = compute_word_positions(input_ids, self.word_start_table)
                # Save last position for subsequent generation steps
                self._word_pos_counter = int(word_positions[0, -1].item())

        output = self.transformer(
            input_ids,
            labels=labels,
            use_cache=use_cache,
            past_kv=past_kv_list,
            word_positions=word_positions,
        )

        # Convert our list-of-tuples back to DynamicCache
        new_cache = None
        if use_cache and output.get("past_kv") is not None:
            from transformers.cache_utils import DynamicCache
            new_cache = DynamicCache()
            for layer_idx, (k, v) in enumerate(output["past_kv"]):
                new_cache.update(k, v, layer_idx)

        if not return_dict:
            result = (output["logits"],)
            if use_cache:
                result += (new_cache,)
            return result

        return CausalLMOutputWithPast(
            loss=output.get("loss"),
            logits=output["logits"],
            past_key_values=new_cache,
        )

    def prepare_inputs_for_generation(

        self, input_ids, past_key_values=None, **kwargs

    ):
        # Only trim to last token if cache has actual KV content
        has_cache = False
        if past_key_values is not None:
            if hasattr(past_key_values, 'get_seq_length'):
                has_cache = past_key_values.get_seq_length() > 0
            elif isinstance(past_key_values, (list, tuple)):
                has_cache = len(past_key_values) > 0
        if has_cache:
            input_ids = input_ids[:, -1:]

        return {
            "input_ids": input_ids,
            "past_key_values": past_key_values,
            "use_cache": True,
        }