Yujivus commited on
Commit
2cca0cf
·
verified ·
1 Parent(s): 00b5fc3

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ source.spm filter=lfs diff=lfs merge=lfs -text
37
+ target.spm filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - translation
4
+ - prism
5
+ - shimmer
6
+ - pytorch
7
+ datasets:
8
+ - wmt14
9
+ metrics:
10
+ - bleu
11
+ model-index:
12
+ - name: Yujivus/PRISM-Molecule-100k
13
+ results:
14
+ - task:
15
+ type: translation
16
+ name: Translation (de-en)
17
+ dataset:
18
+ name: WMT14
19
+ type: wmt14
20
+ config: de-en
21
+ metrics:
22
+ - name: BLEU
23
+ type: bleu
24
+ value: TBD
25
+ ---
26
+ # PRISM-Shimmer V5 (Experimental)
27
+
28
+ Official checkpoint for the **Shimmer** architecture (PRISM with Complex Embeddings + Intrinsic Phase).
29
+ This model uses **Harmonic Embeddings** instead of standard vector lookup tables to enable spectral alignment.
30
+
31
+ ## Architecture
32
+ - **Encoder:** 6-Layer PRISM (Spectral Gated Harmonic Convolution)
33
+ - **Decoder:** 6-Layer Transformer (RoPE + Flash Attention)
34
+ - **Embedding:** Complex-Valued Shimmer Embeddings (Real+Imag parts learned separately)
35
+
36
+ ## Usage
37
+ ```python
38
+ # Requires modeling_prism.py (included in repo)
39
+ from modeling_prism_gated import PRISMHybrid_RoPE
40
+ # Model definition is self-contained in the repo
41
+ ```
__pycache__/modeling_prism_gated.cpython-312.pyc ADDED
Binary file (18.9 kB). View file
 
config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 58101,
3
+ "d_model": 512,
4
+ "num_heads": 8,
5
+ "dff": 2048,
6
+ "dropout": 0.1,
7
+ "max_length": 128,
8
+ "num_encoder_layers": 6,
9
+ "num_refining_layers": 0,
10
+ "num_decoder_layers": 6,
11
+ "architecture": "PRISM_Molecule"
12
+ }
modeling_prism_gated.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.fft
6
+ import math
7
+ from x_transformers import Decoder
8
+ from transformers import AutoTokenizer
9
+ import os
10
+
11
+ # --- GLOBAL TOKENIZER SETUP ---
12
+ try:
13
+ if os.path.exists("tokenizer_config.json"):
14
+ tokenizer = AutoTokenizer.from_pretrained(".")
15
+ else:
16
+ tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-de-en")
17
+ except Exception as e:
18
+ print(f"Warning: Tokenizer load failed: {e}")
19
+
20
+ # ==================================================================
21
+ # SHIMMER ARCHITECTURE CLASSES
22
+ # ==================================================================
23
+
24
+ class ComplexDropout(nn.Module):
25
+ def __init__(self, p=0.5):
26
+ super().__init__()
27
+ self.p = p
28
+
29
+ def forward(self, z):
30
+ if not self.training or self.p == 0.0:
31
+ return z
32
+ mask = torch.ones_like(z.real)
33
+ mask = F.dropout(mask, self.p, self.training, inplace=False)
34
+ return z * mask
35
+
36
+ class PhasePreservingLayerNorm(nn.Module):
37
+ def __init__(self, d_model, eps=1e-5):
38
+ super().__init__()
39
+ self.layernorm = nn.LayerNorm(d_model, eps=eps)
40
+ self.eps = eps
41
+
42
+ def forward(self, x):
43
+ mag = torch.abs(x)
44
+ mag_norm = self.layernorm(mag)
45
+ return mag_norm.to(x.dtype) * (x / (mag + self.eps))
46
+
47
+ class HarmonicEmbedding(nn.Module):
48
+ def __init__(self, num_embeddings, embedding_dim, max_period=10000.0):
49
+ super().__init__()
50
+ self.embedding_dim = embedding_dim
51
+ self.complex_embedding = nn.Embedding(num_embeddings, embedding_dim * 2)
52
+ freqs = torch.exp(torch.arange(0, embedding_dim, dtype=torch.float32) * -(math.log(max_period) / embedding_dim))
53
+ self.register_buffer('freqs', freqs)
54
+
55
+ def forward(self, input_ids):
56
+ raw_embeds = self.complex_embedding(input_ids)
57
+ real = raw_embeds[..., :self.embedding_dim]
58
+ imag = raw_embeds[..., self.embedding_dim:]
59
+ content_z = torch.complex(real, imag)
60
+ seq_len = input_ids.shape[1]
61
+ positions = torch.arange(seq_len, device=input_ids.device).float()
62
+ angles = torch.outer(positions, self.freqs)
63
+ pos_rotation = torch.polar(torch.ones_like(angles), angles).unsqueeze(0)
64
+ return content_z * pos_rotation
65
+
66
+ class ModReLU(nn.Module):
67
+ def __init__(self, features):
68
+ super().__init__()
69
+ self.b = nn.Parameter(torch.zeros(features))
70
+ def forward(self, z):
71
+ mag = torch.abs(z)
72
+ new_mag = F.relu(mag + self.b)
73
+ phase = z / (mag + 1e-6)
74
+ return new_mag * phase
75
+
76
+ # --- THE CORRECT LAYER (Cartesian Gated) ---
77
+ class PRISMLayer(nn.Module):
78
+ def __init__(self, d_model, max_len=5000, dropout=0.1):
79
+ super().__init__()
80
+ self.d_model = d_model
81
+ self.filter_len = max_len
82
+
83
+ # 1. THE GATE (Data Dependency)
84
+ self.gate_proj = nn.Linear(d_model * 2, d_model * 2)
85
+
86
+ # 2. THE FILTER (Global Pattern)
87
+ self.global_filter = nn.Parameter(torch.randn(d_model, max_len, dtype=torch.cfloat) * 0.02)
88
+
89
+ # 3. INPUT MIXING
90
+ self.mix_real = nn.Linear(d_model, d_model)
91
+ self.mix_imag = nn.Linear(d_model, d_model)
92
+
93
+ # 4. OUTPUT PROJECTION
94
+ self.out_real = nn.Linear(d_model, d_model)
95
+ self.out_imag = nn.Linear(d_model, d_model)
96
+
97
+ self.activation = ModReLU(d_model)
98
+ self.norm = PhasePreservingLayerNorm(d_model)
99
+ self.dropout = ComplexDropout(dropout)
100
+
101
+ def complex_linear(self, x, l_real, l_imag):
102
+ r, i = x.real, x.imag
103
+ new_r = l_real(r) - l_imag(i)
104
+ new_i = l_real(i) + l_imag(r)
105
+ return torch.complex(new_r, new_i)
106
+
107
+ def forward(self, x, src_mask=None):
108
+ if x is None: return None
109
+ residual = x
110
+ x_norm = self.norm(x)
111
+
112
+ if src_mask is not None:
113
+ x_norm = x_norm.masked_fill(src_mask.unsqueeze(-1), 0.0)
114
+
115
+ # A. GATE
116
+ x_cat = torch.cat([x_norm.real, x_norm.imag], dim=-1)
117
+ gates = torch.sigmoid(self.gate_proj(x_cat))
118
+ gate_r, gate_i = gates.chunk(2, dim=-1)
119
+
120
+ # B. FILTER
121
+ B, L, D = x_norm.shape
122
+ x_freq = torch.fft.fft(x_norm, n=self.filter_len, dim=1)
123
+ x_filtered = x_freq * self.global_filter.transpose(-1, -2)
124
+ x_time = torch.fft.ifft(x_filtered, n=self.filter_len, dim=1)
125
+ x_time = x_time[:, :L, :]
126
+
127
+ # C. APPLY GATE
128
+ gated_r = x_time.real * gate_r
129
+ gated_i = x_time.imag * gate_i
130
+ x_gated = torch.complex(gated_r, gated_i)
131
+
132
+ # D. OUT
133
+ x_mixed = self.complex_linear(x_gated, self.mix_real, self.mix_imag)
134
+ x_act = self.activation(x_mixed)
135
+ out = self.complex_linear(x_act, self.out_real, self.out_imag)
136
+ return self.dropout(out) + residual
137
+
138
+ # --- ENCODER MUST BE DEFINED AFTER LAYER ---
139
+ class PRISMEncoder(nn.Module):
140
+ def __init__(self, num_layers, d_model, max_len, dropout=0.1):
141
+ super().__init__()
142
+ self.layers = nn.ModuleList([PRISMLayer(d_model, max_len, dropout) for _ in range(num_layers)])
143
+ self.final_norm = PhasePreservingLayerNorm(d_model)
144
+
145
+ def forward(self, x, src_mask=None):
146
+ for layer in self.layers:
147
+ x = layer(x, src_mask)
148
+ return self.final_norm(x)
149
+
150
+ # --- THE CORRECT BRIDGE (Cartesian) ---
151
+ class ComplexToRealBridge(nn.Module):
152
+ def __init__(self, d_model):
153
+ super().__init__()
154
+ self.proj = nn.Linear(d_model * 2, d_model)
155
+ self.norm = nn.LayerNorm(d_model)
156
+
157
+ def forward(self, x_complex):
158
+ if x_complex is None: raise ValueError("Bridge None")
159
+ cat = torch.cat([x_complex.real, x_complex.imag], dim=-1)
160
+ return self.norm(self.proj(cat))
161
+
162
+ class PRISMHybrid_RoPE(nn.Module):
163
+ def __init__(self, num_encoder_layers, num_refining_layers, num_decoder_layers,
164
+ num_heads, d_model, dff, vocab_size, max_length, dropout):
165
+ super().__init__()
166
+ self.d_model = d_model
167
+ self.harmonic_embedding = HarmonicEmbedding(vocab_size, d_model)
168
+ self.tgt_embedding = nn.Embedding(vocab_size, d_model)
169
+ self.dropout = nn.Dropout(dropout)
170
+
171
+ if num_encoder_layers > 0:
172
+ self.prism_encoder = PRISMEncoder(num_encoder_layers, d_model, max_length, dropout)
173
+ else:
174
+ self.prism_encoder = None
175
+
176
+ self.bridge = ComplexToRealBridge(d_model)
177
+
178
+ if num_refining_layers > 0:
179
+ refining_layer = nn.TransformerEncoderLayer(
180
+ d_model, num_heads, dff, dropout,
181
+ batch_first=True, norm_first=True
182
+ )
183
+ self.reasoning_encoder = nn.TransformerEncoder(refining_layer, num_layers=num_refining_layers)
184
+ else:
185
+ self.reasoning_encoder = None
186
+
187
+ self.decoder = Decoder(
188
+ dim = d_model, depth = num_decoder_layers, heads = num_heads, attn_dim_head = d_model // num_heads,
189
+ ff_mult = dff / d_model, rotary_pos_emb = True, cross_attend = True, attn_flash = True,
190
+ attn_dropout = dropout, ff_dropout = dropout, use_rmsnorm = True
191
+ )
192
+ self.final_linear = nn.Linear(d_model, vocab_size)
193
+ self.final_linear.weight = self.tgt_embedding.weight
194
+
195
+ def create_masks(self, src, tgt):
196
+ src_padding_mask = (src == tokenizer.pad_token_id)
197
+ tgt_padding_mask = (tgt == tokenizer.pad_token_id)
198
+ tgt_mask = nn.Transformer.generate_square_subsequent_mask(sz=tgt.size(1), device=src.device, dtype=torch.bool)
199
+ return src_padding_mask, tgt_padding_mask, src_padding_mask, tgt_mask
200
+
201
+ def forward(self, src, tgt, src_mask, tgt_pad, mem_pad, tgt_mask):
202
+ src_harmonic = self.harmonic_embedding(src)
203
+ if src_mask is not None:
204
+ src_harmonic = src_harmonic.masked_fill(src_mask.unsqueeze(-1), 0.0)
205
+
206
+ if self.prism_encoder is not None:
207
+ if self.training:
208
+ src_harmonic.requires_grad_(True)
209
+ encoded_complex = torch.utils.checkpoint.checkpoint(
210
+ self.prism_encoder.forward, # Safest
211
+ src_harmonic, src_mask, use_reentrant=False
212
+ )
213
+ else:
214
+ encoded_complex = self.prism_encoder(src_harmonic, src_mask)
215
+ else:
216
+ encoded_complex = src_harmonic
217
+
218
+ coarse_memory = self.bridge(encoded_complex)
219
+ if self.reasoning_encoder is not None:
220
+ refined_memory = self.reasoning_encoder(coarse_memory, src_key_padding_mask=mem_pad)
221
+ else:
222
+ refined_memory = coarse_memory
223
+
224
+ tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
225
+ tgt_emb = self.dropout(tgt_emb)
226
+ context_mask = ~mem_pad if mem_pad is not None else None
227
+ decoder_mask = ~tgt_pad if tgt_pad is not None else None
228
+
229
+ if self.training:
230
+ tgt_emb.requires_grad_(True)
231
+ output = torch.utils.checkpoint.checkpoint(
232
+ self.decoder, tgt_emb, context=refined_memory, mask=decoder_mask, context_mask=context_mask, use_reentrant=False
233
+ )
234
+ else:
235
+ output = self.decoder(tgt_emb, context=refined_memory, mask=decoder_mask, context_mask=context_mask)
236
+
237
+ return self.final_linear(output)
238
+
239
+ # ... (generate function remains the same) ...
240
+ @torch.no_grad()
241
+ def generate(self, src, max_length, num_beams=5):
242
+ self.eval()
243
+ src_mask = (src == tokenizer.pad_token_id)
244
+ context_mask = ~src_mask
245
+ src_harmonic = self.harmonic_embedding(src)
246
+ if src_mask is not None:
247
+ src_harmonic = src_harmonic.masked_fill(src_mask.unsqueeze(-1), 0.0)
248
+
249
+ if self.prism_encoder is not None:
250
+ encoded_complex = self.prism_encoder(src_harmonic, src_mask)
251
+ else:
252
+ encoded_complex = src_harmonic
253
+
254
+ coarse_memory = self.bridge(encoded_complex)
255
+
256
+ if self.reasoning_encoder is not None:
257
+ memory = self.reasoning_encoder(coarse_memory, src_key_padding_mask=src_mask)
258
+ else:
259
+ memory = coarse_memory
260
+
261
+ batch_size = src.shape[0]
262
+ memory = memory.repeat_interleave(num_beams, dim=0)
263
+ context_mask = context_mask.repeat_interleave(num_beams, dim=0)
264
+
265
+ beams = torch.full((batch_size * num_beams, 1), tokenizer.pad_token_id, dtype=torch.long, device=src.device)
266
+ beam_scores = torch.zeros(batch_size * num_beams, device=src.device)
267
+ finished_beams = torch.zeros(batch_size * num_beams, dtype=torch.bool, device=src.device)
268
+
269
+ for _ in range(max_length - 1):
270
+ if finished_beams.all(): break
271
+ tgt_emb = self.tgt_embedding(beams) * math.sqrt(self.d_model)
272
+ tgt_emb = self.dropout(tgt_emb)
273
+
274
+ # Decoder
275
+ decoder_output = self.decoder(tgt_emb, context=memory, context_mask=context_mask)
276
+ logits = self.final_linear(decoder_output[:, -1, :])
277
+ log_probs = F.log_softmax(logits, dim=-1)
278
+
279
+ # Masking
280
+ log_probs[:, tokenizer.pad_token_id] = -torch.inf
281
+ if finished_beams.any(): log_probs[finished_beams, tokenizer.eos_token_id] = 0
282
+
283
+ # --- BEAM SEARCH LOGIC FIX ---
284
+ if _ == 0:
285
+ # First Step: Expand from the first beam only (since all are identical start tokens)
286
+ # Reshape to (batch, beams, vocab)
287
+ total = (beam_scores.unsqueeze(1) + log_probs).view(batch_size, num_beams, -1)
288
+ # Mask out all beams except the first one (-inf)
289
+ total[:, 1:, :] = -torch.inf
290
+ # Flatten back to (batch, beams*vocab) to pick top k
291
+ total = total.view(batch_size, -1)
292
+ else:
293
+ # Subsequent Steps: Standard Flatten
294
+ total = (beam_scores.unsqueeze(1) + log_probs).view(batch_size, -1)
295
+
296
+ top_scores, top_indices = torch.topk(total, k=num_beams, dim=1)
297
+
298
+ beam_indices = top_indices // log_probs.shape[-1]
299
+ token_indices = top_indices % log_probs.shape[-1]
300
+
301
+ # Now dimensions match: (batch_size, 1) + (batch_size, k)
302
+ effective = (torch.arange(batch_size, device=src.device).unsqueeze(1) * num_beams + beam_indices).view(-1)
303
+ beams = torch.cat([beams[effective], token_indices.view(-1, 1)], dim=1)
304
+ beam_scores = top_scores.view(-1)
305
+ finished_beams = finished_beams | (beams[:, -1] == tokenizer.eos_token_id)
306
+
307
+ final_beams = beams.view(batch_size, num_beams, -1)
308
+ best_beams = final_beams[:, 0, :]
309
+ self.train()
310
+ return best_beams
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f27ea154d68f14b05bb7e7bb80336c5a33fb5177f566809ab7eb530dab45033
3
+ size 513741579
source.spm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bbd1f495eea99c8e21ae086d9146e0fa7b096c3dfdd9ba07ab8b631889df5c9b
3
+ size 796845
special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "eos_token": "</s>",
3
+ "pad_token": "<pad>",
4
+ "unk_token": "<unk>"
5
+ }
target.spm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:678f2a1177d8389f67b66299762dcc4fc567e89b07e212ba91b0c56daecf47ce
3
+ size 768489
tokenizer_config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "</s>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<unk>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "58100": {
20
+ "content": "<pad>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ }
27
+ },
28
+ "clean_up_tokenization_spaces": false,
29
+ "eos_token": "</s>",
30
+ "extra_special_tokens": {},
31
+ "model_max_length": 512,
32
+ "pad_token": "<pad>",
33
+ "separate_vocabs": false,
34
+ "source_lang": "de",
35
+ "sp_model_kwargs": {},
36
+ "target_lang": "en",
37
+ "tokenizer_class": "MarianTokenizer",
38
+ "unk_token": "<unk>"
39
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff