Yujivus commited on
Commit
a39b612
·
verified ·
1 Parent(s): 19dc73b

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ source.spm filter=lfs diff=lfs merge=lfs -text
37
+ target.spm filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - translation
4
+ - baseline
5
+ - pytorch
6
+ model-index:
7
+ - name: Yujivus/PRISM-Baseline-6-6
8
+ results:
9
+ - task:
10
+ type: translation
11
+ name: Translation
12
+ metrics:
13
+ - name: BLEU
14
+ type: bleu
15
+ value: Unknown
16
+ ---
17
+ # Baseline-6-6
18
+ Standard RoPE Transformer Baseline.
19
+ - Encoder Layers: 6
20
+ - Decoder Layers: 6
config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 58101,
3
+ "d_model": 512,
4
+ "num_heads": 8,
5
+ "dff": 2048,
6
+ "dropout": 0.1,
7
+ "max_length": 128,
8
+ "num_encoder_layers": 6,
9
+ "num_decoder_layers": 6,
10
+ "architecture": "RoPETransformer"
11
+ }
modeling_baseline.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import math
6
+ import os
7
+ from x_transformers import Encoder, Decoder
8
+ from transformers import AutoTokenizer
9
+
10
+ # --- SMART TOKENIZER SETUP ---
11
+ try:
12
+ if os.path.exists("tokenizer_config.json"):
13
+ tokenizer = AutoTokenizer.from_pretrained(".")
14
+ else:
15
+ tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-de-en")
16
+ except Exception as e:
17
+ print(f"Warning: Tokenizer load failed: {e}")
18
+ # -----------------------------
19
+
20
+
21
+ class RoPETransformer(nn.Module):
22
+ def __init__(self, num_encoder_layers, num_decoder_layers, num_heads, d_model, dff, vocab_size, max_length, dropout):
23
+ super().__init__()
24
+ self.d_model = d_model
25
+ self.embedding = nn.Embedding(vocab_size, d_model)
26
+
27
+ # We REMOVE self.pos_encoder (RoPE handles position internally)
28
+ self.dropout_layer = nn.Dropout(dropout)
29
+
30
+ # --- x-transformers Encoder ---
31
+ self.encoder = Encoder(
32
+ dim = d_model,
33
+ depth = num_encoder_layers,
34
+ heads = num_heads,
35
+ attn_dim_head = d_model // num_heads,
36
+ ff_mult = dff / d_model,
37
+ rotary_pos_emb = True,
38
+ attn_flash = True,
39
+ attn_dropout = dropout,
40
+ ff_dropout = dropout,
41
+ use_rmsnorm = True
42
+ )
43
+
44
+ # --- x-transformers Decoder ---
45
+ self.decoder = Decoder(
46
+ dim = d_model,
47
+ depth = num_decoder_layers,
48
+ heads = num_heads,
49
+ attn_dim_head = d_model // num_heads,
50
+ ff_mult = dff / d_model,
51
+ rotary_pos_emb = True,
52
+ cross_attend = True,
53
+ attn_flash = True,
54
+ attn_dropout = dropout,
55
+ ff_dropout = dropout,
56
+ use_rmsnorm = True
57
+ )
58
+
59
+ self.final_linear = nn.Linear(d_model, vocab_size)
60
+ self.final_linear.weight = self.embedding.weight
61
+
62
+ def forward(self, src, tgt, src_padding_mask, tgt_padding_mask, memory_key_padding_mask, tgt_mask):
63
+ # 1. Embeddings (No Absolute Positional Encoding added!)
64
+ src_emb = self.embedding(src) * math.sqrt(self.d_model)
65
+ src_emb = self.dropout_layer(src_emb)
66
+
67
+ tgt_emb = self.embedding(tgt) * math.sqrt(self.d_model)
68
+ tgt_emb = self.dropout_layer(tgt_emb)
69
+
70
+ # 2. Mask Conversion
71
+ # User provides True=PAD. x-transformers wants True=KEEP.
72
+ # We invert the boolean mask using ~
73
+ enc_mask = ~src_padding_mask if src_padding_mask is not None else None
74
+ dec_mask = ~tgt_padding_mask if tgt_padding_mask is not None else None
75
+
76
+ # Note: 'tgt_mask' (causal mask) is handled automatically by x-transformers Decoder!
77
+ # We do NOT pass the square causal mask manually.
78
+
79
+ # 3. Encoder
80
+ # x-transformers takes embeddings directly
81
+ memory = self.encoder(src_emb, mask=enc_mask)
82
+
83
+ # 4. Decoder
84
+ # context = memory (from encoder)
85
+ # context_mask = mask for memory (encoder mask)
86
+ decoder_output = self.decoder(
87
+ tgt_emb,
88
+ context=memory,
89
+ mask=dec_mask,
90
+ context_mask=enc_mask
91
+ )
92
+
93
+ return self.final_linear(decoder_output)
94
+
95
+ # Keep your existing create_masks (used for Data Processing mostly)
96
+ def create_masks(self, src, tgt):
97
+ src_padding_mask = (src == tokenizer.pad_token_id)
98
+ tgt_padding_mask = (tgt == tokenizer.pad_token_id)
99
+ # We still generate this for compatibility, though x-transformers handles causality internally
100
+ tgt_mask = nn.Transformer.generate_square_subsequent_mask(
101
+ sz=tgt.size(1), device=src.device, dtype=torch.bool
102
+ )
103
+ return src_padding_mask, tgt_padding_mask, src_padding_mask, tgt_mask
104
+
105
+ @torch.no_grad()
106
+ def generate(self, src: torch.Tensor, max_length: int, num_beams: int = 5) -> torch.Tensor:
107
+ self.eval()
108
+ # Create Mask (True=PAD)
109
+ src_padding_mask = (src == tokenizer.pad_token_id)
110
+ # Invert for x-transformers (True=KEEP)
111
+ enc_mask = ~src_padding_mask
112
+
113
+ # Encode
114
+ src_emb = self.embedding(src) * math.sqrt(self.d_model)
115
+ # No Pos Encoder
116
+ memory = self.encoder(self.dropout_layer(src_emb), mask=enc_mask)
117
+
118
+ batch_size = src.shape[0]
119
+ # Expand for beams
120
+ memory = memory.repeat_interleave(num_beams, dim=0)
121
+ enc_mask = enc_mask.repeat_interleave(num_beams, dim=0)
122
+
123
+ initial_token = tokenizer.pad_token_id
124
+ beams = torch.full((batch_size * num_beams, 1), initial_token, dtype=torch.long, device=src.device)
125
+ beam_scores = torch.zeros(batch_size * num_beams, device=src.device)
126
+ finished_beams = torch.zeros(batch_size * num_beams, dtype=torch.bool, device=src.device)
127
+
128
+ for _ in range(max_length - 1):
129
+ if finished_beams.all(): break
130
+
131
+ # Embed beams
132
+ tgt_emb = self.embedding(beams) * math.sqrt(self.d_model)
133
+ # No Pos Encoder
134
+
135
+ # Decode
136
+ # x-transformers automatically handles the causal masking for the sequence length of tgt_emb
137
+ decoder_output = self.decoder(
138
+ self.dropout_layer(tgt_emb),
139
+ context=memory,
140
+ context_mask=enc_mask
141
+ )
142
+
143
+ logits = self.final_linear(decoder_output[:, -1, :])
144
+ log_probs = F.log_softmax(logits, dim=-1)
145
+
146
+ # ... (Rest of your Beam Search Logic remains identical) ...
147
+ log_probs[:, tokenizer.pad_token_id] = -torch.inf
148
+ if finished_beams.any(): log_probs[finished_beams, tokenizer.eos_token_id] = 0
149
+
150
+ total_scores = beam_scores.unsqueeze(1) + log_probs
151
+ if _ == 0:
152
+ total_scores = total_scores.view(batch_size, num_beams, -1)
153
+ total_scores[:, 1:, :] = -torch.inf
154
+ total_scores = total_scores.view(batch_size * num_beams, -1)
155
+ else:
156
+ total_scores = beam_scores.unsqueeze(1) + log_probs
157
+
158
+ total_scores = total_scores.view(batch_size, -1)
159
+ top_scores, top_indices = torch.topk(total_scores, k=num_beams, dim=1)
160
+
161
+ beam_indices = top_indices // log_probs.shape[-1]
162
+ token_indices = top_indices % log_probs.shape[-1]
163
+
164
+ batch_indices = torch.arange(batch_size, device=src.device).unsqueeze(1)
165
+ effective_indices = (batch_indices * num_beams + beam_indices).view(-1)
166
+
167
+ beams = beams[effective_indices]
168
+ beams = torch.cat([beams, token_indices.view(-1, 1)], dim=1)
169
+ beam_scores = top_scores.view(-1)
170
+ finished_beams = finished_beams | (beams[:, -1] == tokenizer.eos_token_id)
171
+
172
+ final_beams = beams.view(batch_size, num_beams, -1)
173
+ final_scores = beam_scores.view(batch_size, num_beams)
174
+ normalized_scores = final_scores / (final_beams != tokenizer.pad_token_id).sum(-1).float().clamp(min=1)
175
+ best_beams = final_beams[torch.arange(batch_size), normalized_scores.argmax(1), :]
176
+ self.train()
177
+ return best_beams
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:962cbd6495aea167bee37a50cb05dab32f1df2c9ea19282b3c14d8176eabc847
3
+ size 295642003
source.spm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bbd1f495eea99c8e21ae086d9146e0fa7b096c3dfdd9ba07ab8b631889df5c9b
3
+ size 796845
special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "eos_token": "</s>",
3
+ "pad_token": "<pad>",
4
+ "unk_token": "<unk>"
5
+ }
target.spm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:678f2a1177d8389f67b66299762dcc4fc567e89b07e212ba91b0c56daecf47ce
3
+ size 768489
tokenizer_config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "</s>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<unk>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "58100": {
20
+ "content": "<pad>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ }
27
+ },
28
+ "clean_up_tokenization_spaces": false,
29
+ "eos_token": "</s>",
30
+ "extra_special_tokens": {},
31
+ "model_max_length": 512,
32
+ "pad_token": "<pad>",
33
+ "separate_vocabs": false,
34
+ "source_lang": "de",
35
+ "sp_model_kwargs": {},
36
+ "target_lang": "en",
37
+ "tokenizer_class": "MarianTokenizer",
38
+ "unk_token": "<unk>"
39
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff