Yujivus commited on
Commit
9cff9de
·
verified ·
1 Parent(s): 500bb46

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ source.spm filter=lfs diff=lfs merge=lfs -text
37
+ target.spm filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - translation
4
+ - prism
5
+ - shimmer
6
+ - pytorch
7
+ datasets:
8
+ - wmt14
9
+ metrics:
10
+ - bleu
11
+ model-index:
12
+ - name: Yujivus/PRISM-Shimmer
13
+ results:
14
+ - task:
15
+ type: translation
16
+ name: Translation (de-en)
17
+ dataset:
18
+ name: WMT14
19
+ type: wmt14
20
+ config: de-en
21
+ metrics:
22
+ - name: BLEU
23
+ type: bleu
24
+ value: TBD
25
+ ---
26
+ # PRISM-Shimmer V5 (Experimental)
27
+
28
+ Official checkpoint for the **Shimmer** architecture (PRISM with Complex Embeddings + Intrinsic Phase).
29
+ This model uses **Harmonic Embeddings** instead of standard vector lookup tables to enable spectral alignment.
30
+
31
+ ## Architecture
32
+ - **Encoder:** 6-Layer PRISM (Spectral Gated Harmonic Convolution)
33
+ - **Decoder:** 6-Layer Transformer (RoPE + Flash Attention)
34
+ - **Embedding:** Complex-Valued Shimmer Embeddings (Real+Imag parts learned separately)
35
+
36
+ ## Usage
37
+ ```python
38
+ # Requires modeling_prism.py (included in repo)
39
+ from modeling_prism import PRISMHybrid_RoPE
40
+ # Model definition is self-contained in the repo
41
+ ```
__pycache__/modeling_prism.cpython-312.pyc ADDED
Binary file (18.4 kB). View file
 
config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 58101,
3
+ "d_model": 512,
4
+ "num_heads": 8,
5
+ "dff": 2048,
6
+ "dropout": 0.1,
7
+ "max_length": 128,
8
+ "num_encoder_layers": 6,
9
+ "num_refining_layers": 0,
10
+ "num_decoder_layers": 6,
11
+ "architecture": "PRISM_Shimmer_v5"
12
+ }
modeling_prism.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.fft
6
+ import math
7
+ from x_transformers import Decoder
8
+ from transformers import AutoTokenizer
9
+ import os
10
+
11
+ # --- GLOBAL TOKENIZER SETUP ---
12
+ try:
13
+ if os.path.exists("tokenizer_config.json"):
14
+ tokenizer = AutoTokenizer.from_pretrained(".")
15
+ else:
16
+ tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-de-en")
17
+ except Exception as e:
18
+ print(f"Warning: Tokenizer load failed: {e}")
19
+
20
+ # ==================================================================
21
+ # SHIMMER ARCHITECTURE CLASSES
22
+ # ==================================================================
23
+
24
+ class ComplexDropout(nn.Module):
25
+ """
26
+ # Standard nn.Dropout doesn't work on ComplexFloat.
27
+ # This module generates a mask based on the shape and applies it to both
28
+ # Real and Imaginary parts identically to preserve Phase.
29
+ """
30
+ def __init__(self, p=0.5):
31
+ super().__init__()
32
+ self.p = p
33
+
34
+ def forward(self, z):
35
+ if not self.training or self.p == 0.0:
36
+ return z
37
+
38
+ # Generate mask using F.dropout on a ones tensor of the same shape (Real part)
39
+ # F.dropout handles the scaling (1 / 1-p) automatically
40
+ mask = torch.ones_like(z.real)
41
+ mask = F.dropout(mask, self.p, self.training, inplace=False)
42
+
43
+ # Apply mask to the complex tensor
44
+ return z * mask
45
+
46
+ class PhasePreservingLayerNorm(nn.Module):
47
+ def __init__(self, d_model, eps=1e-5):
48
+ super().__init__()
49
+ self.layernorm = nn.LayerNorm(d_model, eps=eps)
50
+ self.eps = eps
51
+
52
+ def forward(self, x):
53
+ mag = torch.abs(x)
54
+ mag_norm = self.layernorm(mag)
55
+ # Avoid division by zero
56
+ return mag_norm.to(x.dtype) * (x / (mag + self.eps))
57
+
58
+ class HarmonicEmbedding(nn.Module):
59
+ def __init__(self, num_embeddings, embedding_dim, max_period=10000.0):
60
+ super().__init__()
61
+ self.embedding_dim = embedding_dim
62
+
63
+ # 1. Learnable Real and Imaginary parts (Cartesian coordinates)
64
+ # This allows learning both Amplitude AND Intrinsic Phase implicitly
65
+ self.complex_embedding = nn.Embedding(num_embeddings, embedding_dim * 2)
66
+
67
+ # Frequencies (Fixed)
68
+ freqs = torch.exp(torch.arange(0, embedding_dim, dtype=torch.float32) * -(math.log(max_period) / embedding_dim))
69
+ self.register_buffer('freqs', freqs)
70
+
71
+ def forward(self, input_ids):
72
+ # A. Get Learnable Content (Mag + Intrinsic Phase)
73
+ # Shape: [Batch, Seq, Dim * 2]
74
+ raw_embeds = self.complex_embedding(input_ids)
75
+
76
+ # Split into Real/Imag
77
+ real = raw_embeds[..., :self.embedding_dim]
78
+ imag = raw_embeds[..., self.embedding_dim:]
79
+
80
+ # Convert to Complex Tensor
81
+ # This Z already has Amplitude AND Intrinsic Phase
82
+ content_z = torch.complex(real, imag)
83
+
84
+ # B. Apply Positional Rotation (The "Clock")
85
+ seq_len = input_ids.shape[1]
86
+ positions = torch.arange(seq_len, device=input_ids.device).float()
87
+ angles = torch.outer(positions, self.freqs)
88
+
89
+ # Create Rotation (Phase Shift)
90
+ # e^(i * theta)
91
+ pos_rotation = torch.polar(torch.ones_like(angles), angles).unsqueeze(0)
92
+
93
+ # C. Rotate the Content
94
+ # Z_final = Z_content * e^(i * pos)
95
+ return content_z * pos_rotation
96
+
97
+ class PRISMEncoder(nn.Module):
98
+ def __init__(self, num_layers, d_model, max_len, dropout=0.1):
99
+ super().__init__()
100
+ self.layers = nn.ModuleList([PRISMLayer(d_model, max_len, dropout) for _ in range(num_layers)])
101
+
102
+ self.final_norm = PhasePreservingLayerNorm(d_model)
103
+
104
+ def forward(self, x, src_mask=None):
105
+ for layer in self.layers:
106
+ x = layer(x, src_mask)
107
+
108
+ # Apply Final Norm
109
+ return self.final_norm(x)
110
+
111
+ class ModReLU(nn.Module):
112
+ def __init__(self, features):
113
+ super().__init__()
114
+ self.b = nn.Parameter(torch.zeros(features))
115
+ def forward(self, z):
116
+ mag = torch.abs(z)
117
+ new_mag = F.relu(mag + self.b)
118
+ phase = z / (mag + 1e-6)
119
+ return new_mag * phase
120
+
121
+ class PRISMLayer(nn.Module):
122
+ def __init__(self, d_model, max_len=5000, dropout=0.1):
123
+ super().__init__()
124
+ self.d_model = d_model
125
+ self.filter_len = max_len
126
+
127
+ # --- REMOVED GATING PARAMS ---
128
+ # self.pre_gate = nn.Linear(d_model * 2, d_model)
129
+
130
+ # Global Filter
131
+ self.global_filter = nn.Parameter(torch.randn(d_model, max_len, dtype=torch.cfloat) * 0.02)
132
+
133
+ # Mixing
134
+ self.mix_real = nn.Linear(d_model, d_model)
135
+ self.mix_imag = nn.Linear(d_model, d_model)
136
+ self.out_real = nn.Linear(d_model, d_model)
137
+ self.out_imag = nn.Linear(d_model, d_model)
138
+
139
+ self.activation = ModReLU(d_model)
140
+ self.norm = PhasePreservingLayerNorm(d_model)
141
+ self.dropout = ComplexDropout(dropout)
142
+
143
+ def complex_linear(self, x, l_real, l_imag):
144
+ r, i = x.real, x.imag
145
+ new_r = l_real(r) - l_imag(i)
146
+ new_i = l_real(i) + l_imag(r)
147
+ return torch.complex(new_r, new_i)
148
+
149
+ def forward(self, x, src_mask=None):
150
+ residual = x
151
+ x_norm = self.norm(x)
152
+
153
+ if src_mask is not None:
154
+ mask_expanded = src_mask.unsqueeze(-1)
155
+ x_norm = x_norm.masked_fill(mask_expanded, 0.0)
156
+
157
+ # --- REMOVED GATING LOGIC ---
158
+ # Pass x_norm directly to FFT
159
+ x_gated = x_norm
160
+
161
+ # B. FFT Resonance
162
+ B, L, D = x_gated.shape
163
+ x_freq = torch.fft.fft(x_gated, n=self.filter_len, dim=1)
164
+ filter_transposed = self.global_filter.transpose(-1, -2)
165
+ x_filtered = x_freq * filter_transposed
166
+ x_time = torch.fft.ifft(x_filtered, n=self.filter_len, dim=1)
167
+ x_time = x_time[:, :L, :]
168
+
169
+ # C. Mix & Activate
170
+ x_mixed = self.complex_linear(x_time, self.mix_real, self.mix_imag)
171
+ x_act = self.activation(x_mixed)
172
+ out = self.complex_linear(x_act, self.out_real, self.out_imag)
173
+
174
+ return self.dropout(out) + residual
175
+
176
+ class ComplexToRealBridge(nn.Module):
177
+ def __init__(self, d_model):
178
+ super().__init__()
179
+ self.proj = nn.Linear(d_model * 2, d_model)
180
+ def forward(self, x_complex):
181
+ cat = torch.cat([x_complex.real, x_complex.imag], dim=-1)
182
+ return self.proj(cat)
183
+
184
+ class PRISMHybrid_RoPE(nn.Module):
185
+ def __init__(self, num_encoder_layers, num_refining_layers, num_decoder_layers,
186
+ num_heads, d_model, dff, vocab_size, max_length, dropout):
187
+ super().__init__()
188
+ self.d_model = d_model
189
+
190
+ # 1. Embeddings
191
+ self.harmonic_embedding = HarmonicEmbedding(vocab_size, d_model)
192
+ self.tgt_embedding = nn.Embedding(vocab_size, d_model)
193
+ self.dropout = nn.Dropout(dropout)
194
+
195
+ # 2. Harmonic Body (PRISM Encoder)
196
+ if num_encoder_layers > 0:
197
+ self.prism_encoder = PRISMEncoder(num_encoder_layers, d_model, max_length, dropout)
198
+ else:
199
+ self.prism_encoder = None
200
+
201
+ # 3. The Bridge
202
+ self.bridge = ComplexToRealBridge(d_model)
203
+
204
+ # 4. Refining Encoder
205
+ if num_refining_layers > 0:
206
+ refining_layer = nn.TransformerEncoderLayer(
207
+ d_model, num_heads, dff, dropout,
208
+ batch_first=True, norm_first=True
209
+ )
210
+ self.reasoning_encoder = nn.TransformerEncoder(refining_layer, num_layers=num_refining_layers)
211
+ else:
212
+ self.reasoning_encoder = None
213
+
214
+ # 5. Decoder (x-transformers)
215
+ self.decoder = Decoder(
216
+ dim = d_model,
217
+ depth = num_decoder_layers,
218
+ heads = num_heads,
219
+ attn_dim_head = d_model // num_heads,
220
+ ff_mult = dff / d_model,
221
+ rotary_pos_emb = True,
222
+ cross_attend = True,
223
+ attn_flash = True,
224
+ attn_dropout = dropout,
225
+ ff_dropout = dropout,
226
+ use_rmsnorm = True
227
+ )
228
+
229
+ # 6. Output Head
230
+ self.final_linear = nn.Linear(d_model, vocab_size)
231
+ self.final_linear.weight = self.tgt_embedding.weight
232
+
233
+ def create_masks(self, src, tgt):
234
+ src_padding_mask = (src == tokenizer.pad_token_id)
235
+ tgt_padding_mask = (tgt == tokenizer.pad_token_id)
236
+ tgt_mask = nn.Transformer.generate_square_subsequent_mask(
237
+ sz=tgt.size(1), device=src.device, dtype=torch.bool
238
+ )
239
+ return src_padding_mask, tgt_padding_mask, src_padding_mask, tgt_mask
240
+
241
+ def forward(self, src, tgt, src_mask, tgt_pad, mem_pad, tgt_mask):
242
+ # A. Harmonic Phase
243
+ src_harmonic = self.harmonic_embedding(src)
244
+ if src_mask is not None:
245
+ src_harmonic = src_harmonic.masked_fill(src_mask.unsqueeze(-1), 0.0)
246
+
247
+ # PRISM Encoder Pass
248
+ if self.prism_encoder is not None:
249
+ if self.training:
250
+ src_harmonic.requires_grad_(True)
251
+ encoded_complex = torch.utils.checkpoint.checkpoint(
252
+ self.prism_encoder, src_harmonic, src_mask, use_reentrant=False
253
+ )
254
+ else:
255
+ encoded_complex = self.prism_encoder(src_harmonic, src_mask)
256
+ else:
257
+ encoded_complex = src_harmonic
258
+
259
+ # B. The Bridge
260
+ coarse_memory = self.bridge(encoded_complex)
261
+
262
+ # C. Refining Phase
263
+ if self.reasoning_encoder is not None:
264
+ refined_memory = self.reasoning_encoder(coarse_memory, src_key_padding_mask=mem_pad)
265
+ else:
266
+ refined_memory = coarse_memory
267
+
268
+ # D. Decoder Prep
269
+ tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
270
+ tgt_emb = self.dropout(tgt_emb)
271
+ context_mask = ~mem_pad if mem_pad is not None else None
272
+ decoder_mask = ~tgt_pad if tgt_pad is not None else None
273
+
274
+ # E. Decoder Pass (Checkpointing)
275
+ if self.training:
276
+ tgt_emb.requires_grad_(True)
277
+ output = torch.utils.checkpoint.checkpoint(
278
+ self.decoder,
279
+ tgt_emb,
280
+ context=refined_memory,
281
+ mask=decoder_mask,
282
+ context_mask=context_mask,
283
+ use_reentrant=False
284
+ )
285
+ else:
286
+ output = self.decoder(
287
+ tgt_emb,
288
+ context=refined_memory,
289
+ mask=decoder_mask,
290
+ context_mask=context_mask
291
+ )
292
+
293
+ return self.final_linear(output)
294
+
295
+ @torch.no_grad()
296
+ def generate(self, src, max_length, num_beams=5):
297
+ self.eval()
298
+ src_mask = (src == tokenizer.pad_token_id)
299
+ context_mask = ~src_mask
300
+ src_harmonic = self.harmonic_embedding(src)
301
+ if src_mask is not None:
302
+ src_harmonic = src_harmonic.masked_fill(src_mask.unsqueeze(-1), 0.0)
303
+
304
+ if self.prism_encoder is not None:
305
+ encoded_complex = self.prism_encoder(src_harmonic, src_mask)
306
+ else:
307
+ encoded_complex = src_harmonic
308
+
309
+ coarse_memory = self.bridge(encoded_complex)
310
+
311
+ if self.reasoning_encoder is not None:
312
+ memory = self.reasoning_encoder(coarse_memory, src_key_padding_mask=src_mask)
313
+ else:
314
+ memory = coarse_memory
315
+
316
+ batch_size = src.shape[0]
317
+ memory = memory.repeat_interleave(num_beams, dim=0)
318
+ context_mask = context_mask.repeat_interleave(num_beams, dim=0)
319
+
320
+ beams = torch.full((batch_size * num_beams, 1), tokenizer.pad_token_id, dtype=torch.long, device=src.device)
321
+ beam_scores = torch.zeros(batch_size * num_beams, device=src.device)
322
+ finished_beams = torch.zeros(batch_size * num_beams, dtype=torch.bool, device=src.device)
323
+
324
+ for _ in range(max_length - 1):
325
+ if finished_beams.all(): break
326
+ tgt_emb = self.tgt_embedding(beams) * math.sqrt(self.d_model)
327
+ tgt_emb = self.dropout(tgt_emb)
328
+
329
+ # Decoder
330
+ decoder_output = self.decoder(tgt_emb, context=memory, context_mask=context_mask)
331
+ logits = self.final_linear(decoder_output[:, -1, :])
332
+ log_probs = F.log_softmax(logits, dim=-1)
333
+
334
+ # Masking
335
+ log_probs[:, tokenizer.pad_token_id] = -torch.inf
336
+ if finished_beams.any(): log_probs[finished_beams, tokenizer.eos_token_id] = 0
337
+
338
+ # --- BEAM SEARCH LOGIC FIX ---
339
+ if _ == 0:
340
+ # First Step: Expand from the first beam only (since all are identical start tokens)
341
+ # Reshape to (batch, beams, vocab)
342
+ total = (beam_scores.unsqueeze(1) + log_probs).view(batch_size, num_beams, -1)
343
+ # Mask out all beams except the first one (-inf)
344
+ total[:, 1:, :] = -torch.inf
345
+ # Flatten back to (batch, beams*vocab) to pick top k
346
+ total = total.view(batch_size, -1)
347
+ else:
348
+ # Subsequent Steps: Standard Flatten
349
+ total = (beam_scores.unsqueeze(1) + log_probs).view(batch_size, -1)
350
+
351
+ top_scores, top_indices = torch.topk(total, k=num_beams, dim=1)
352
+
353
+ beam_indices = top_indices // log_probs.shape[-1]
354
+ token_indices = top_indices % log_probs.shape[-1]
355
+
356
+ # Now dimensions match: (batch_size, 1) + (batch_size, k)
357
+ effective = (torch.arange(batch_size, device=src.device).unsqueeze(1) * num_beams + beam_indices).view(-1)
358
+ beams = torch.cat([beams[effective], token_indices.view(-1, 1)], dim=1)
359
+ beam_scores = top_scores.view(-1)
360
+ finished_beams = finished_beams | (beams[:, -1] == tokenizer.eos_token_id)
361
+
362
+ final_beams = beams.view(batch_size, num_beams, -1)
363
+ best_beams = final_beams[:, 0, :]
364
+ self.train()
365
+ return best_beams
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46e35e235b03decdc11649a5bb289e8c5bb0c27411d9b02a0bd0618bbbad12aa
3
+ size 488542227
source.spm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bbd1f495eea99c8e21ae086d9146e0fa7b096c3dfdd9ba07ab8b631889df5c9b
3
+ size 796845
special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "eos_token": "</s>",
3
+ "pad_token": "<pad>",
4
+ "unk_token": "<unk>"
5
+ }
target.spm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:678f2a1177d8389f67b66299762dcc4fc567e89b07e212ba91b0c56daecf47ce
3
+ size 768489
tokenizer_config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "</s>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<unk>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "58100": {
20
+ "content": "<pad>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ }
27
+ },
28
+ "clean_up_tokenization_spaces": false,
29
+ "eos_token": "</s>",
30
+ "extra_special_tokens": {},
31
+ "model_max_length": 512,
32
+ "pad_token": "<pad>",
33
+ "separate_vocabs": false,
34
+ "source_lang": "de",
35
+ "sp_model_kwargs": {},
36
+ "target_lang": "en",
37
+ "tokenizer_class": "MarianTokenizer",
38
+ "unk_token": "<unk>"
39
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff