bhsinghgrid commited on
Commit
7d6a683
·
verified ·
1 Parent(s): bb4b181

Add files using upload-large-folder tool

Browse files
Files changed (41) hide show
  1. .gitattributes +2 -34
  2. README.md +65 -0
  3. best_model.pt +3 -0
  4. config.py +33 -0
  5. diffusion/__init__.py +0 -0
  6. diffusion/__pycache__/__init__.cpython-311.pyc +0 -0
  7. diffusion/__pycache__/__init__.cpython-312.pyc +0 -0
  8. diffusion/__pycache__/forward_process.cpython-311.pyc +0 -0
  9. diffusion/__pycache__/forward_process.cpython-312.pyc +0 -0
  10. diffusion/__pycache__/reverse_process.cpython-311.pyc +0 -0
  11. diffusion/__pycache__/reverse_process1.cpython-311.pyc +0 -0
  12. diffusion/__pycache__/reverse_process2.cpython-311.pyc +0 -0
  13. diffusion/__pycache__/scheduler.cpython-311.pyc +0 -0
  14. diffusion/__pycache__/scheduler.cpython-312.pyc +0 -0
  15. diffusion/forward_process.py +21 -0
  16. diffusion/reverse_process.py +302 -0
  17. diffusion/reverse_process1.py +154 -0
  18. diffusion/reverse_process2.py +275 -0
  19. diffusion/scheduler.py +34 -0
  20. handler.py +30 -0
  21. inference.py +122 -0
  22. inference_api.py +103 -0
  23. model/__init__.py +0 -0
  24. model/__pycache__/__init__.cpython-311.pyc +0 -0
  25. model/__pycache__/__init__.cpython-312.pyc +0 -0
  26. model/__pycache__/d3pm_model_cross_attention.cpython-311.pyc +0 -0
  27. model/__pycache__/d3pm_model_cross_attention.cpython-312.pyc +0 -0
  28. model/__pycache__/d3pm_model_encoder_decoder.cpython-311.pyc +0 -0
  29. model/__pycache__/sanskrit_model.cpython-311.pyc +0 -0
  30. model/__pycache__/sanskrit_model.cpython-312.pyc +0 -0
  31. model/__pycache__/tokenizer.cpython-311.pyc +0 -0
  32. model/__pycache__/tokenizer.cpython-312.pyc +0 -0
  33. model/__pycache__/tokenizers.cpython-311.pyc +0 -0
  34. model/d3pm_model_cross_attention.py +271 -0
  35. model/d3pm_model_encoder_decoder.py +227 -0
  36. model/sanskrit_model.py +61 -0
  37. model/tokenizer.py +222 -0
  38. model/tokenizers.py +112 -0
  39. requirements.txt +6 -0
  40. sanskrit_src_tokenizer.json +0 -0
  41. sanskrit_tgt_tokenizer.json +0 -0
.gitattributes CHANGED
@@ -1,35 +1,3 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.pt filter=lfs diff=lfs merge=lfs -text
2
+ *.bin filter=lfs diff=lfs merge=lfs -text
3
+ *.arrow filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
README.md ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - sa
5
+ - en
6
+ tags:
7
+ - sanskrit
8
+ - paraphrase
9
+ - diffusion
10
+ - d3pm
11
+ - pytorch
12
+ pipeline_tag: text2text-generation
13
+ ---
14
+
15
+ # Sanskrit D3PM Paraphrase Model
16
+
17
+ Roman/IAST Sanskrit input to Devanagari output using a D3PM cross-attention model.
18
+
19
+ ## Files Included
20
+
21
+ - `best_model.pt` — trained checkpoint
22
+ - `config.py` — runtime config
23
+ - `inference.py` — model loading + generation loop
24
+ - `inference_api.py` — simple Python API (`predict`)
25
+ - `handler.py` — Hugging Face Endpoint handler
26
+ - `model/`, `diffusion/` — architecture modules
27
+ - `sanskrit_src_tokenizer.json`, `sanskrit_tgt_tokenizer.json` — tokenizers
28
+
29
+ ## Quick Local Test
30
+
31
+ ```python
32
+ from inference_api import predict
33
+ print(predict("dharmo rakṣati rakṣitaḥ")["output"])
34
+ ```
35
+
36
+ ## Endpoint Payload
37
+
38
+ ```json
39
+ {
40
+ "inputs": "yadā mano nivarteta viṣayebhyaḥ svabhāvataḥ",
41
+ "parameters": {
42
+ "temperature": 0.7,
43
+ "top_k": 40,
44
+ "repetition_penalty": 1.2,
45
+ "diversity_penalty": 0.0,
46
+ "num_steps": 64,
47
+ "clean_output": true
48
+ }
49
+ }
50
+ ```
51
+
52
+ ## Push This Folder To Model Hub
53
+
54
+ ```bash
55
+ huggingface-cli login
56
+ huggingface-cli repo create <your-username>/sanskrit-d3pm --type model
57
+ cd hf_model_repo
58
+ git init
59
+ git lfs install
60
+ git remote add origin https://huggingface.co/<your-username>/sanskrit-d3pm
61
+ git add .
62
+ git commit -m "Initial model release"
63
+ git push -u origin main
64
+ ```
65
+
best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:280b944be1ed396c93f64deef18b07d258b5dd1c74d59284342864a532c95f8b
3
+ size 1077681643
config.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ CONFIG = {
4
+ "model_type": "d3pm_cross_attention",
5
+ "data": {
6
+ "include_negative_examples": True,
7
+ "dataset_size": 60000,
8
+ },
9
+ "diffusion": {
10
+ "mask_token_id": 0,
11
+ },
12
+ "model": {
13
+ "src_vocab_size": 16000,
14
+ "tgt_vocab_size": 16000,
15
+ "d_model": 384,
16
+ "n_heads": 8,
17
+ "d_ff": 1536,
18
+ "n_layers": 6,
19
+ "dropout": 0.1,
20
+ "max_seq_len": 80,
21
+ "diffusion_steps": 64,
22
+ },
23
+ "training": {
24
+ "device": "cuda" if torch.cuda.is_available() else "cpu",
25
+ },
26
+ "inference": {
27
+ "num_steps": 64,
28
+ "temperature": 0.7,
29
+ "top_k": 40,
30
+ "repetition_penalty": 1.2,
31
+ "diversity_penalty": 0.0,
32
+ },
33
+ }
diffusion/__init__.py ADDED
File without changes
diffusion/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (189 Bytes). View file
 
diffusion/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (177 Bytes). View file
 
diffusion/__pycache__/forward_process.cpython-311.pyc ADDED
Binary file (1.75 kB). View file
 
diffusion/__pycache__/forward_process.cpython-312.pyc ADDED
Binary file (1.66 kB). View file
 
diffusion/__pycache__/reverse_process.cpython-311.pyc ADDED
Binary file (11.2 kB). View file
 
diffusion/__pycache__/reverse_process1.cpython-311.pyc ADDED
Binary file (5.37 kB). View file
 
diffusion/__pycache__/reverse_process2.cpython-311.pyc ADDED
Binary file (12 kB). View file
 
diffusion/__pycache__/scheduler.cpython-311.pyc ADDED
Binary file (2.93 kB). View file
 
diffusion/__pycache__/scheduler.cpython-312.pyc ADDED
Binary file (2.75 kB). View file
 
diffusion/forward_process.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ forward_process.py — Verified Correct (no changes needed)
3
+ ===========================================================
4
+ Absorbing (mask) diffusion. PAD never masked. At t=0 alpha=1.0 exactly
5
+ so x_t == x_0 (nothing masked). Works correctly with the fixed scheduler.
6
+ """
7
+ import torch
8
+
9
+ class AbsorbingForwardProcess:
10
+ def __init__(self, scheduler, mask_id=0, pad_id=1):
11
+ self.scheduler = scheduler
12
+ self.mask_id = mask_id
13
+ self.pad_id = pad_id
14
+
15
+ def q_sample(self, x_0, t):
16
+ alpha_t = self.scheduler.get_alpha(t).to(x_0.device).view(-1, 1)
17
+ r = torch.rand(x_0.shape, device=x_0.device)
18
+ x_t = x_0.clone()
19
+ x_t[r > alpha_t] = self.mask_id
20
+ x_t[x_0 == self.pad_id] = self.pad_id # PAD stays PAD always
21
+ return x_0, x_t
diffusion/reverse_process.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ reverse_process.py — Fixed
3
+ ===========================
4
+ Two bugs fixed from the original:
5
+
6
+ BUG 1 (critical): generate_beam() passed x_t (noisy) as `tgt` to model.
7
+ The model does q_sample(tgt, t) internally — so x_t got double-noised.
8
+ Fix: pass x0_estimate (current clean guess) as tgt. Model noises it correctly.
9
+
10
+ BUG 2: apply_diversity_penalty used logits.var(dim=-1) — this adds the
11
+ variance of each position's own distribution back to itself, which is
12
+ mathematically meaningless and just injects noise.
13
+ Fix: penalize tokens that are uniformly high-probability across ALL positions
14
+ (global common tokens). This genuinely promotes diversity.
15
+ """
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+
20
+
21
+ class ReverseDiffusion:
22
+ def __init__(self, scheduler):
23
+ self.scheduler = scheduler
24
+
25
+ def p_sample_step(
26
+ self,
27
+ model,
28
+ x_t,
29
+ t,
30
+ condition,
31
+ beam_width=3,
32
+ temperature=1.0,
33
+ repetition_penalty=1.2,
34
+ diversity_penalty=0.3
35
+ ):
36
+ """
37
+ Single reverse step with temperature + penalties.
38
+ """
39
+
40
+ with torch.no_grad():
41
+
42
+ # ---- Shape safety ----
43
+ if x_t.dim() == 1:
44
+ x_t = x_t.unsqueeze(0)
45
+
46
+ if condition.dim() == 1:
47
+ condition = condition.unsqueeze(0)
48
+
49
+ if t.dim() == 0:
50
+ t = t.unsqueeze(0)
51
+
52
+ if t.shape[0] != x_t.shape[0]:
53
+ t = t.expand(x_t.shape[0])
54
+
55
+ # ---- Model forward ----
56
+ logits, _ = model(condition, x_t, t)
57
+
58
+ # ---- Temperature scaling ----
59
+ logits = logits / temperature
60
+
61
+ # ---- Repetition penalty (FIXED VERSION) ----
62
+ if repetition_penalty != 1.0:
63
+ logits = apply_repetition_penalty(
64
+ logits, x_t, repetition_penalty
65
+ )
66
+
67
+ # ---- Diversity penalty ----
68
+ if diversity_penalty > 0:
69
+ logits = apply_diversity_penalty(
70
+ logits, diversity_penalty
71
+ )
72
+
73
+ probs = F.softmax(logits, dim=-1)
74
+
75
+ B, L, V = probs.shape
76
+
77
+ # ---- Top-k beam expansion ----
78
+ topk_probs, topk_ids = torch.topk(
79
+ probs, beam_width, dim=-1
80
+ )
81
+
82
+ candidates = []
83
+
84
+ for k in range(beam_width):
85
+ next_tokens = topk_ids[:, :, k]
86
+ score = torch.log(
87
+ topk_probs[:, :, k] + 1e-9
88
+ ).sum()
89
+ candidates.append((next_tokens, score))
90
+
91
+ return candidates
92
+
93
+ def generate_beam(
94
+ self,
95
+ model,
96
+ condition,
97
+ beam_width=3,
98
+ num_steps=None,
99
+ temperature=1.0,
100
+ repetition_penalty=1.2,
101
+ diversity_penalty=0.3
102
+ ):
103
+ """
104
+ Beam-search reverse diffusion with temperature.
105
+ """
106
+
107
+ if num_steps is None:
108
+ num_steps = self.scheduler.num_timesteps
109
+
110
+ device = condition.device
111
+
112
+ if condition.dim() == 1:
113
+ condition = condition.unsqueeze(0)
114
+
115
+ B, L = condition.shape
116
+
117
+ # 🔥 Better initialization: start from MASK
118
+ x_init = torch.full(
119
+ (B, L),
120
+ fill_value=model.mask_token_id,
121
+ dtype=torch.long,
122
+ device=device
123
+ )
124
+
125
+ beams = [(x_init, 0.0)]
126
+
127
+ for step in reversed(range(num_steps)):
128
+
129
+ new_beams = []
130
+
131
+ for x_t, score in beams:
132
+
133
+ t_tensor = torch.full(
134
+ (B,),
135
+ step,
136
+ dtype=torch.long,
137
+ device=device
138
+ )
139
+
140
+ candidates = self.p_sample_step(
141
+ model,
142
+ x_t,
143
+ t_tensor,
144
+ condition,
145
+ beam_width,
146
+ temperature,
147
+ repetition_penalty,
148
+ diversity_penalty
149
+ )
150
+
151
+ for tokens, new_score in candidates:
152
+ new_beams.append(
153
+ (tokens, score + new_score)
154
+ )
155
+
156
+ # ---- Keep top beams ----
157
+ new_beams = sorted(
158
+ new_beams,
159
+ key=lambda x: x[1],
160
+ reverse=True
161
+ )
162
+
163
+ beams = new_beams[:beam_width]
164
+
165
+ best_tokens, best_score = beams[0]
166
+ return best_tokens
167
+
168
+
169
+
170
+ def generate(
171
+ self,
172
+ model,
173
+ condition,
174
+ num_steps=None,
175
+ temperature=0.8,
176
+ top_k=50,
177
+ repetition_penalty=1.2,
178
+ diversity_penalty=0.0,
179
+ ):
180
+ """
181
+ Correct D3PM iterative refinement.
182
+
183
+ x0_est starts as all [MASK].
184
+ Each step: forward(src=condition, tgt=x0_est, t)
185
+ → model applies q_sample(x0_est, t) internally
186
+ → predicts cleaner x0
187
+ → x0_est updated
188
+
189
+ diversity_penalty: reduces probability of tokens that are
190
+ globally dominant across all sequence positions (not logits.var()).
191
+ """
192
+ if num_steps is None:
193
+ num_steps = self.scheduler.num_timesteps
194
+
195
+ device = condition.device
196
+ if condition.dim() == 1:
197
+ condition = condition.unsqueeze(0)
198
+ B, L = condition.shape
199
+
200
+ T = self.scheduler.num_timesteps
201
+ step_size = max(1, T // num_steps)
202
+ timesteps = list(range(T - 1, -1, -step_size))
203
+ if timesteps[-1] != 0:
204
+ timesteps.append(0)
205
+
206
+ mask_id = model.mask_token_id
207
+ # Start: know nothing → all MASK is our initial clean estimate
208
+ x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
209
+ hint = None
210
+
211
+ model.eval()
212
+ with torch.no_grad():
213
+ for step_idx, t_val in enumerate(timesteps):
214
+ t = torch.full((B,), t_val, dtype=torch.long, device=device)
215
+ is_last = (step_idx == len(timesteps) - 1)
216
+
217
+ # KEY: pass x0_est as tgt — model noises it internally
218
+ import inspect
219
+ sig = inspect.signature(model.forward).parameters
220
+ if 'x0_hint' in sig:
221
+ outputs = model(condition, x0_est, t, x0_hint=hint)
222
+ else:
223
+ outputs = model(condition, x0_est, t)
224
+
225
+ logits = outputs[0] if isinstance(outputs, tuple) else outputs
226
+
227
+ # Repetition penalty: down-weight tokens already in sequence
228
+ if repetition_penalty != 1.0:
229
+ logits = apply_repetition_penalty(logits, x0_est, repetition_penalty)
230
+
231
+ # Diversity penalty: reduce globally dominant tokens
232
+ if diversity_penalty > 0.0:
233
+ logits = apply_diversity_penalty(logits, diversity_penalty)
234
+
235
+ # Temperature + top-k
236
+ logits = logits / max(temperature, 1e-5)
237
+ if top_k > 0:
238
+ logits = top_k_filter(logits, top_k)
239
+
240
+ probs = F.softmax(logits, dim=-1)
241
+
242
+ if is_last:
243
+ x0_est = torch.argmax(probs, dim=-1)
244
+ else:
245
+ x0_est = batch_multinomial(probs)
246
+
247
+ hint = x0_est
248
+
249
+ return x0_est
250
+
251
+
252
+ # ── Penalty functions ─────────────────────────────────────────────────
253
+
254
+ def apply_repetition_penalty(logits, prev_tokens, penalty=1.2):
255
+ """
256
+ Down-weight tokens that already appear in the current sequence.
257
+ Prevents मनो मनो मनो repetition loops.
258
+ penalty=1.0 → no effect
259
+ penalty=1.2 → mild suppression of repeated tokens
260
+ penalty=2.0 → strong suppression
261
+ """
262
+ B, L, V = logits.shape
263
+ for b in range(B):
264
+ for token_id in set(prev_tokens[b].tolist()):
265
+ if token_id > 4: # don't penalize special tokens
266
+ logits[b, :, token_id] = logits[b, :, token_id] / penalty
267
+ return logits
268
+
269
+
270
+ def apply_diversity_penalty(logits, penalty=0.5):
271
+ """
272
+ Correct diversity penalty: penalize tokens that are globally dominant
273
+ across ALL sequence positions. This forces the model to use less
274
+ common tokens, increasing output diversity.
275
+
276
+ Method: compute mean probability across positions, subtract penalty
277
+ times that mean. Tokens uniformly high everywhere get suppressed.
278
+
279
+ penalty=0.0 → no diversity enforcement
280
+ penalty=0.5 → moderate diversity
281
+ penalty=1.0 → strong diversity (may hurt coherence)
282
+ """
283
+ # Mean logit across all positions: [B, V]
284
+ global_mean = logits.mean(dim=1, keepdim=True) # [B, 1, V]
285
+ # Subtract scaled global mean — suppresses globally common tokens
286
+ return logits - penalty * global_mean
287
+
288
+
289
+ def top_k_filter(logits, k):
290
+ B, L, V = logits.shape
291
+ if k >= V:
292
+ return logits
293
+ topk_vals, _ = torch.topk(logits, k, dim=-1)
294
+ threshold = topk_vals[..., -1].unsqueeze(-1)
295
+ return logits.masked_fill(logits < threshold, float('-inf'))
296
+
297
+
298
+ def batch_multinomial(probs):
299
+ B, L, V = probs.shape
300
+ flat = probs.view(B * L, V) + 1e-9
301
+ flat = flat / flat.sum(dim=-1, keepdim=True)
302
+ return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
diffusion/reverse_process1.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ class ReverseDiffusion:
6
+ """
7
+ Stable reverse diffusion with:
8
+ - Beam search
9
+ - Self conditioning
10
+ - Temperature sampling
11
+ - Repetition penalty
12
+ - Diversity penalty
13
+ """
14
+
15
+ def __init__(self, scheduler):
16
+
17
+ self.scheduler = scheduler
18
+
19
+ self.temperature = 0.75
20
+ self.repetition_penalty = 1.15
21
+ self.diversity_penalty = 0.0
22
+ self.length_penalty = 1.0
23
+
24
+ # ------------------------------------------------
25
+ # penalties
26
+ # ------------------------------------------------
27
+
28
+ def apply_repetition_penalty(self, logits, tokens):
29
+
30
+ B, L, V = logits.shape
31
+
32
+ for b in range(B):
33
+
34
+ used = set(tokens[b].tolist())
35
+
36
+ for token_id in used:
37
+ logits[b, :, token_id] /= self.repetition_penalty
38
+
39
+ return logits
40
+
41
+ def apply_diversity_penalty(self, logits):
42
+
43
+ if self.diversity_penalty == 0:
44
+ return logits
45
+
46
+ logits_var = logits.var(dim=-1, keepdim=True)
47
+ return logits + self.diversity_penalty * logits_var
48
+
49
+ # ------------------------------------------------
50
+ # single reverse step
51
+ # ------------------------------------------------
52
+
53
+ def p_sample_step(self, model, x_t, t, condition, self_cond=None, beam_width=3):
54
+
55
+ with torch.no_grad():
56
+
57
+ logits, hidden = model(condition, x_t, t, self_cond)
58
+
59
+ logits = logits / self.temperature
60
+
61
+ logits = self.apply_repetition_penalty(logits, x_t)
62
+ logits = self.apply_diversity_penalty(logits)
63
+
64
+ probs = F.softmax(logits, dim=-1)
65
+
66
+ B, L, V = probs.shape
67
+
68
+ topk_probs, topk_ids = torch.topk(probs, beam_width, dim=-1)
69
+
70
+ candidates = []
71
+
72
+ for k in range(beam_width):
73
+
74
+ tokens = topk_ids[:, :, k]
75
+
76
+ score = torch.log(topk_probs[:, :, k] + 1e-9).sum()
77
+
78
+ candidates.append((tokens, score))
79
+
80
+ return candidates
81
+
82
+ # ------------------------------------------------
83
+ # beam reverse diffusion
84
+ # ------------------------------------------------
85
+
86
+ def generate_beam(self, model, condition, beam_width=3, num_steps=None):
87
+
88
+ if num_steps is None:
89
+ num_steps = self.scheduler.num_timesteps
90
+
91
+ device = condition.device
92
+
93
+ if condition.dim() == 1:
94
+ condition = condition.unsqueeze(0)
95
+
96
+ B, L = condition.shape
97
+
98
+ # ------------------------------------------------
99
+ # BETTER LATENT INITIALIZATION
100
+ # ------------------------------------------------
101
+
102
+ x_init = condition.clone()
103
+
104
+ mask = torch.rand_like(x_init.float()) < 0.5
105
+ x_init[mask] = model.mask_token_id
106
+
107
+ beams = [(x_init, 0.0)]
108
+
109
+ self_cond = None
110
+
111
+ for step in reversed(range(num_steps)):
112
+
113
+ new_beams = []
114
+
115
+ for x_t, score in beams:
116
+
117
+ t_tensor = torch.full(
118
+ (B,),
119
+ step,
120
+ dtype=torch.long,
121
+ device=device
122
+ )
123
+
124
+ candidates = self.p_sample_step(
125
+ model,
126
+ x_t,
127
+ t_tensor,
128
+ condition,
129
+ self_cond,
130
+ beam_width
131
+ )
132
+
133
+ for tokens, new_score in candidates:
134
+
135
+ length_norm = tokens.shape[1] ** self.length_penalty
136
+
137
+ final_score = (score + new_score) / length_norm
138
+
139
+ new_beams.append((tokens, final_score))
140
+
141
+ new_beams = sorted(
142
+ new_beams,
143
+ key=lambda x: x[1],
144
+ reverse=True
145
+ )
146
+
147
+ beams = new_beams[:beam_width]
148
+
149
+ # self conditioning
150
+ self_cond = beams[0][0]
151
+
152
+ best_tokens, best_score = beams[0]
153
+
154
+ return best_tokens
diffusion/reverse_process2.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ reverse_process.py — Final Correct Version
3
+ =============================================
4
+
5
+ KEY PRINCIPLE: generate() must be byte-for-byte identical to run_inference()
6
+ in inference.py, which is what produced BERTScore 0.75 at validation.
7
+
8
+ CRITICAL BUG IN PREVIOUS VERSION:
9
+ We passed inference_mode=True to the model, but the model was NEVER
10
+ called with inference_mode=True during training or validation.
11
+ run_inference() (the validated path) does:
12
+ model(input_ids, x0_est, t, x0_hint=hint)
13
+ → inference_mode defaults to False.
14
+
15
+ With inference_mode=True the model does two things differently:
16
+ 1. tgt_pad_mask = None (training used tgt_pad_mask = tgt==PAD)
17
+ 2. Skips q_sample at t=0 (training always called q_sample)
18
+ The model was never trained to handle these conditions → garbage output.
19
+
20
+ Fix: do NOT pass inference_mode. Let it default to False, exactly
21
+ as run_inference() did.
22
+
23
+ BUGS FIXED (vs original reverse_process.py)
24
+ --------------------------------------------
25
+ BUG 1 generate_beam() used for D3PM → all-Ṛ repetition.
26
+ Use generate() (iterative refinement) from app1.py instead.
27
+ BUG 2 apply_diversity_penalty used logits.var() → noise injection.
28
+ Fixed to logits - penalty * logits.mean(dim=1) — global suppression.
29
+ BUG 3 x0_hint (self-conditioning) never passed to model.
30
+ Fixed: generate() passes x0_hint=hint every step.
31
+ BUG 4 params not forwarded from generate_beam() to p_sample_step().
32
+ Fixed in generate_beam() (kept for reference, not for production use).
33
+ """
34
+
35
+ import torch
36
+ import torch.nn.functional as F
37
+
38
+
39
+ class ReverseDiffusion:
40
+
41
+ def __init__(self, scheduler):
42
+ self.scheduler = scheduler
43
+
44
+ # Attribute-style defaults for backward compat with any code
45
+ # that sets reverse_diffusion.temperature = 0.9 etc.
46
+ # generate() prefers explicit kwargs and falls back to these.
47
+ self.temperature = 0.75
48
+ self.repetition_penalty = 1.15
49
+ self.diversity_penalty = 0.0
50
+ self.top_k = 50
51
+
52
+ # ------------------------------------------------------------------ #
53
+ # generate — CORRECT D3PM iterative refinement #
54
+ # Exact equivalent of run_inference() in inference.py #
55
+ # ------------------------------------------------------------------ #
56
+ def generate(
57
+ self,
58
+ model,
59
+ condition,
60
+ num_steps = None,
61
+ temperature = None,
62
+ top_k = None,
63
+ repetition_penalty = None,
64
+ diversity_penalty = None,
65
+ ):
66
+ """
67
+ D3PM iterative refinement — identical to run_inference() in inference.py,
68
+ which is the validated path (BERTScore 0.75).
69
+
70
+ Algorithm:
71
+ x0_est = all [MASK]
72
+ for t = T-1 down to 0:
73
+ logits = model(src, x0_est, t, x0_hint=hint)
74
+ ↑ inference_mode NOT passed (defaults to False)
75
+ ↑ this exactly matches training/validation
76
+ apply penalties, temperature, top_k
77
+ if t > 0: x0_est = multinomial(softmax(logits)) ← stochastic
78
+ if t = 0: x0_est = argmax(softmax(logits)) ← deterministic
79
+ hint = x0_est
80
+ """
81
+ # Resolve: explicit kwarg > object attribute
82
+ temperature = temperature if temperature is not None else self.temperature
83
+ top_k = top_k if top_k is not None else self.top_k
84
+ repetition_penalty = repetition_penalty if repetition_penalty is not None else self.repetition_penalty
85
+ diversity_penalty = diversity_penalty if diversity_penalty is not None else self.diversity_penalty
86
+
87
+ if num_steps is None:
88
+ num_steps = self.scheduler.num_timesteps
89
+
90
+ device = condition.device
91
+ if condition.dim() == 1:
92
+ condition = condition.unsqueeze(0)
93
+ B, L = condition.shape
94
+
95
+ T = self.scheduler.num_timesteps
96
+ step_size = max(1, T // num_steps)
97
+ timesteps = list(range(T - 1, -1, -step_size))
98
+ if timesteps[-1] != 0:
99
+ timesteps.append(0)
100
+
101
+ mask_id = model.mask_token_id
102
+ x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
103
+ hint = None
104
+
105
+ model.eval()
106
+ with torch.no_grad():
107
+ for step_idx, t_val in enumerate(timesteps):
108
+ t = torch.full((B,), t_val, dtype=torch.long, device=device)
109
+ is_last = (step_idx == len(timesteps) - 1)
110
+
111
+ # ── CRITICAL: do NOT pass inference_mode ──────────────────
112
+ # inference_mode defaults to False inside SanskritModel /
113
+ # D3PMCrossAttention. This matches run_inference() exactly.
114
+ # Passing inference_mode=True changes tgt_pad_mask and
115
+ # q_sample behaviour — the model was never trained for that.
116
+ logits, _ = model(condition, x0_est, t, x0_hint=hint)
117
+
118
+ # Repetition penalty
119
+ if repetition_penalty != 1.0:
120
+ logits = apply_repetition_penalty(
121
+ logits, x0_est, repetition_penalty
122
+ )
123
+
124
+ # Diversity penalty (correct: global mean suppression)
125
+ if diversity_penalty > 0.0:
126
+ logits = apply_diversity_penalty(logits, diversity_penalty)
127
+
128
+ logits = logits / max(temperature, 1e-5)
129
+
130
+ if top_k > 0:
131
+ logits = top_k_filter(logits, top_k)
132
+
133
+ probs = F.softmax(logits, dim=-1)
134
+
135
+ # Stochastic at every step except the last (argmax at t=0)
136
+ if is_last:
137
+ x0_est = torch.argmax(probs, dim=-1)
138
+ else:
139
+ x0_est = batch_multinomial(probs)
140
+
141
+ hint = x0_est
142
+
143
+ return x0_est # (B, L)
144
+
145
+ # ------------------------------------------------------------------ #
146
+ # p_sample_step — used by generate_beam (not for production) #
147
+ # ------------------------------------------------------------------ #
148
+ def p_sample_step(
149
+ self,
150
+ model,
151
+ x_t,
152
+ t,
153
+ condition,
154
+ beam_width = 3,
155
+ temperature = 1.0,
156
+ repetition_penalty = 1.2,
157
+ diversity_penalty = 0.3,
158
+ x0_hint = None,
159
+ ):
160
+ with torch.no_grad():
161
+ if x_t.dim() == 1: x_t = x_t.unsqueeze(0)
162
+ if condition.dim() == 1: condition = condition.unsqueeze(0)
163
+ if t.dim() == 0: t = t.unsqueeze(0)
164
+ if t.shape[0] != x_t.shape[0]:
165
+ t = t.expand(x_t.shape[0])
166
+
167
+ # No inference_mode — matches training convention
168
+ logits, _ = model(condition, x_t, t, x0_hint=x0_hint)
169
+
170
+ logits = logits / max(temperature, 1e-5)
171
+
172
+ if repetition_penalty != 1.0:
173
+ logits = apply_repetition_penalty(logits, x_t, repetition_penalty)
174
+ if diversity_penalty > 0.0:
175
+ logits = apply_diversity_penalty(logits, diversity_penalty)
176
+
177
+ probs = F.softmax(logits, dim=-1)
178
+ B, L, V = probs.shape
179
+
180
+ topk_probs, topk_ids = torch.topk(probs, beam_width, dim=-1)
181
+ candidates = []
182
+ for k in range(beam_width):
183
+ next_tokens = topk_ids[:, :, k]
184
+ score = torch.log(topk_probs[:, :, k] + 1e-9).sum()
185
+ candidates.append((next_tokens, score))
186
+ return candidates
187
+
188
+ # ------------------------------------------------------------------ #
189
+ # generate_beam — kept for reference; NOT the correct D3PM method #
190
+ # ------------------------------------------------------------------ #
191
+ def generate_beam(
192
+ self,
193
+ model,
194
+ condition,
195
+ beam_width = 3,
196
+ num_steps = None,
197
+ temperature = None,
198
+ repetition_penalty = None,
199
+ diversity_penalty = None,
200
+ ):
201
+ """
202
+ WARNING: do NOT call this from app1.py for D3PM generation.
203
+ generate_beam() forces every position to the same top-k token
204
+ → all-Ṛ / all-rud repetition. Use generate() instead.
205
+ Kept only for experimental reference.
206
+ """
207
+ temperature = temperature if temperature is not None else self.temperature
208
+ repetition_penalty = repetition_penalty if repetition_penalty is not None else self.repetition_penalty
209
+ diversity_penalty = diversity_penalty if diversity_penalty is not None else self.diversity_penalty
210
+ if num_steps is None:
211
+ num_steps = self.scheduler.num_timesteps
212
+
213
+ device = condition.device
214
+ if condition.dim() == 1: condition = condition.unsqueeze(0)
215
+ B, L = condition.shape
216
+
217
+ x_init = torch.full((B, L), fill_value=model.mask_token_id,
218
+ dtype=torch.long, device=device)
219
+ beams = [(x_init, 0.0)]
220
+ best_hint = None
221
+
222
+ for step in reversed(range(num_steps)):
223
+ t_tensor = torch.full((B,), step, dtype=torch.long, device=device)
224
+ new_beams = []
225
+ for x_t, score in beams:
226
+ candidates = self.p_sample_step(
227
+ model, x_t, t_tensor, condition,
228
+ beam_width = beam_width,
229
+ temperature = temperature,
230
+ repetition_penalty = repetition_penalty,
231
+ diversity_penalty = diversity_penalty,
232
+ x0_hint = best_hint,
233
+ )
234
+ for tokens, new_score in candidates:
235
+ new_beams.append((tokens, score + new_score.item()))
236
+
237
+ new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True)
238
+ beams = new_beams[:beam_width]
239
+ best_hint = beams[0][0]
240
+
241
+ return beams[0][0] # (B, L)
242
+
243
+
244
+ # ── Penalty helpers ────────────────────────────────────────────────────────
245
+
246
+ def apply_repetition_penalty(logits, prev_tokens, penalty=1.2):
247
+ """Down-weight tokens already present in the sequence."""
248
+ for b in range(logits.shape[0]):
249
+ for token_id in set(prev_tokens[b].tolist()):
250
+ if token_id > 4:
251
+ logits[b, :, token_id] = logits[b, :, token_id] / penalty
252
+ return logits
253
+
254
+
255
+ def apply_diversity_penalty(logits, penalty=0.3):
256
+ """
257
+ Correct diversity penalty: suppress globally dominant tokens.
258
+ logits -= penalty * mean(logits, dim=1) [sequence dimension]
259
+ """
260
+ global_mean = logits.mean(dim=1, keepdim=True) # [B, 1, V]
261
+ return logits - penalty * global_mean
262
+
263
+
264
+ def top_k_filter(logits, k):
265
+ B, L, V = logits.shape
266
+ if k >= V: return logits
267
+ topk_vals, _ = torch.topk(logits, k, dim=-1)
268
+ return logits.masked_fill(logits < topk_vals[..., -1].unsqueeze(-1), float('-inf'))
269
+
270
+
271
+ def batch_multinomial(probs):
272
+ B, L, V = probs.shape
273
+ flat = probs.view(B * L, V) + 1e-9
274
+ flat = flat / flat.sum(dim=-1, keepdim=True)
275
+ return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
diffusion/scheduler.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ scheduler.py — Fixed & Upgraded
3
+ ==================================
4
+ Changes:
5
+ 1. T=64 (was 16). More timesteps = richer denoising curriculum per epoch.
6
+ 2. alpha at t=0 is EXACTLY 1.0 — fixes Bug 2 (final-step re-noise).
7
+ 3. sample_timestep samples [0, T-1] including t=0, so model trains on
8
+ fully-clean inputs (learns the identity at t=0 explicitly).
9
+ """
10
+ import torch, math
11
+
12
+ class OptimizedCosineScheduler:
13
+ def __init__(self, cfg, device=None):
14
+ self.num_timesteps = cfg['model']['diffusion_steps'] # 64
15
+ self.mask_token_id = cfg['diffusion']['mask_token_id']
16
+ self.device = device or torch.device('cpu')
17
+ self.alphas_cumprod = self._build_schedule().to(self.device)
18
+
19
+ def _build_schedule(self):
20
+ T = self.num_timesteps
21
+ t = torch.arange(T + 1, dtype=torch.float32)
22
+ f_t = torch.cos((t / T + 0.008) / 1.008 * math.pi / 2) ** 2
23
+ alphas_bar = f_t / f_t[0]
24
+ alphas_bar = alphas_bar[1:] # shape [T]
25
+ alphas_bar[0] = 1.0 # FIX: exact 1.0 at t=0
26
+ alphas_bar[-1] = alphas_bar[-1].clamp(max=0.001)
27
+ return alphas_bar
28
+
29
+ def sample_timestep(self, batch_size):
30
+ """Uniform [0, T-1] — includes t=0 so model sees clean inputs."""
31
+ return torch.randint(0, self.num_timesteps, (batch_size,))
32
+
33
+ def get_alpha(self, t):
34
+ return self.alphas_cumprod[t.to(self.alphas_cumprod.device).long()]
handler.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+
3
+ from inference_api import predict
4
+
5
+
6
+ class EndpointHandler:
7
+ """
8
+ Hugging Face Inference Endpoint handler.
9
+ Expects payload:
10
+ {
11
+ "inputs": "dharmo rakṣati rakṣitaḥ",
12
+ "parameters": {"temperature": 0.7, ...}
13
+ }
14
+ """
15
+
16
+ def __init__(self, path: str = ""):
17
+ self.path = path
18
+
19
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
20
+ text = data.get("inputs", "")
21
+ params = data.get("parameters", {}) or {}
22
+ return predict(
23
+ text=text,
24
+ temperature=params.get("temperature", 0.7),
25
+ top_k=params.get("top_k", 40),
26
+ repetition_penalty=params.get("repetition_penalty", 1.2),
27
+ diversity_penalty=params.get("diversity_penalty", 0.0),
28
+ num_steps=params.get("num_steps", 64),
29
+ clean_output=params.get("clean_output", True),
30
+ )
inference.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from config import CONFIG
7
+
8
+
9
+ def _resolve_device(cfg: dict) -> torch.device:
10
+ requested = cfg["training"]["device"]
11
+ if requested == "cuda" and not torch.cuda.is_available():
12
+ requested = "cpu"
13
+ if requested == "mps" and not torch.backends.mps.is_available():
14
+ requested = "cpu"
15
+ cfg["training"]["device"] = requested
16
+ return torch.device(requested)
17
+
18
+
19
+ def _build_tokenizers(cfg):
20
+ from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer
21
+
22
+ src_tok = SanskritSourceTokenizer(
23
+ vocab_size=cfg["model"].get("src_vocab_size", 16000),
24
+ max_len=cfg["model"]["max_seq_len"],
25
+ )
26
+ tgt_tok = SanskritTargetTokenizer(
27
+ vocab_size=cfg["model"].get("tgt_vocab_size", 16000),
28
+ max_len=cfg["model"]["max_seq_len"],
29
+ )
30
+ return src_tok, tgt_tok
31
+
32
+
33
+ def load_model(ckpt_path: str, base_cfg: dict, device: torch.device):
34
+ from model.sanskrit_model import SanskritModel
35
+
36
+ cfg = copy.deepcopy(base_cfg)
37
+ state = torch.load(ckpt_path, map_location="cpu")
38
+
39
+ emb_key = "model.src_embed.token_emb.weight"
40
+ if emb_key in state:
41
+ vocab, d_model = state[emb_key].shape
42
+ cfg["model"]["src_vocab_size"] = vocab
43
+ cfg["model"]["d_model"] = d_model
44
+ cfg["model"]["d_ff"] = d_model * 4
45
+
46
+ layer_ids = {int(k.split(".")[2]) for k in state if k.startswith("model.encoder_blocks.")}
47
+ if layer_ids:
48
+ cfg["model"]["n_layers"] = max(layer_ids) + 1
49
+
50
+ pos_key = "model.src_embed.pos_enc.pe"
51
+ if pos_key in state:
52
+ cfg["model"]["max_seq_len"] = state[pos_key].shape[1]
53
+
54
+ d_model = cfg["model"]["d_model"]
55
+ n_heads = cfg["model"].get("n_heads", 8)
56
+ if d_model % n_heads != 0:
57
+ n_heads = next(h for h in [8, 6, 4, 2, 1] if d_model % h == 0)
58
+ cfg["model"]["n_heads"] = n_heads
59
+
60
+ model = SanskritModel(cfg).to(device)
61
+ model.load_state_dict(torch.load(ckpt_path, map_location=device), strict=False)
62
+ model.eval()
63
+ return model, cfg
64
+
65
+
66
+ def run_inference(model, input_ids, cfg):
67
+ inf = cfg["inference"]
68
+ device = input_ids.device
69
+ bsz, seqlen = input_ids.shape
70
+ inner = model.model
71
+
72
+ total_steps = inner.scheduler.num_timesteps
73
+ steps = int(inf["num_steps"])
74
+ step_size = max(1, total_steps // max(steps, 1))
75
+ timesteps = list(range(total_steps - 1, -1, -step_size))
76
+ if timesteps[-1] != 0:
77
+ timesteps.append(0)
78
+
79
+ x0_est = torch.full((bsz, seqlen), inner.mask_token_id, dtype=torch.long, device=device)
80
+ hint = None
81
+
82
+ with torch.no_grad():
83
+ for i, t_val in enumerate(timesteps):
84
+ is_last = i == len(timesteps) - 1
85
+ t = torch.full((bsz,), t_val, dtype=torch.long, device=device)
86
+
87
+ logits, _ = model(input_ids, x0_est, t, x0_hint=hint, inference_mode=True)
88
+
89
+ if inf["repetition_penalty"] != 1.0:
90
+ from model.d3pm_model_cross_attention import _apply_repetition_penalty
91
+
92
+ logits = _apply_repetition_penalty(logits, x0_est, float(inf["repetition_penalty"]))
93
+ if inf["diversity_penalty"] > 0.0:
94
+ from model.d3pm_model_cross_attention import _apply_diversity_penalty_fixed
95
+
96
+ logits = _apply_diversity_penalty_fixed(logits, float(inf["diversity_penalty"]))
97
+
98
+ logits = logits / max(float(inf["temperature"]), 1e-5)
99
+ if int(inf["top_k"]) > 0:
100
+ from model.d3pm_model_cross_attention import _top_k_filter
101
+
102
+ logits = _top_k_filter(logits, int(inf["top_k"]))
103
+
104
+ probs = F.softmax(logits, dim=-1)
105
+ if is_last:
106
+ x0_est = torch.argmax(probs, dim=-1)
107
+ else:
108
+ from model.d3pm_model_cross_attention import _batch_multinomial
109
+
110
+ x0_est = _batch_multinomial(probs)
111
+ hint = x0_est
112
+
113
+ return x0_est
114
+
115
+
116
+ __all__ = [
117
+ "CONFIG",
118
+ "_resolve_device",
119
+ "_build_tokenizers",
120
+ "load_model",
121
+ "run_inference",
122
+ ]
inference_api.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Dict, Any
3
+
4
+ import torch
5
+
6
+ from config import CONFIG
7
+ from inference import _build_tokenizers, _resolve_device, load_model, run_inference
8
+
9
+
10
+ _STATE = {
11
+ "loaded": False,
12
+ "model": None,
13
+ "cfg": None,
14
+ "device": None,
15
+ "src_tok": None,
16
+ "tgt_tok": None,
17
+ }
18
+
19
+
20
+ def _load_once() -> None:
21
+ if _STATE["loaded"]:
22
+ return
23
+
24
+ cfg = copy.deepcopy(CONFIG)
25
+ cfg["model_type"] = "d3pm_cross_attention"
26
+ cfg["data"]["include_negative_examples"] = True
27
+ device = _resolve_device(cfg)
28
+
29
+ model, cfg = load_model("best_model.pt", cfg, device)
30
+ src_tok, tgt_tok = _build_tokenizers(cfg)
31
+
32
+ _STATE["model"] = model
33
+ _STATE["cfg"] = cfg
34
+ _STATE["device"] = device
35
+ _STATE["src_tok"] = src_tok
36
+ _STATE["tgt_tok"] = tgt_tok
37
+ _STATE["loaded"] = True
38
+
39
+
40
+ def _clean_text(text: str) -> str:
41
+ text = " ".join(text.split())
42
+ if not text:
43
+ return text
44
+ toks = text.split()
45
+ out = []
46
+ prev = None
47
+ run = 0
48
+ for tok in toks:
49
+ if tok == prev:
50
+ run += 1
51
+ else:
52
+ prev = tok
53
+ run = 1
54
+ if run <= 2:
55
+ out.append(tok)
56
+ s = " ".join(out)
57
+ s = s.replace(" ।", "।").replace(" ॥", "॥")
58
+ return " ".join(s.split())
59
+
60
+
61
+ def predict(
62
+ text: str,
63
+ temperature: float = 0.7,
64
+ top_k: int = 40,
65
+ repetition_penalty: float = 1.2,
66
+ diversity_penalty: float = 0.0,
67
+ num_steps: int = 64,
68
+ clean_output: bool = True,
69
+ ) -> Dict[str, Any]:
70
+ _load_once()
71
+ if not text or not text.strip():
72
+ return {"error": "empty input", "output": ""}
73
+
74
+ cfg = copy.deepcopy(_STATE["cfg"])
75
+ cfg["inference"]["temperature"] = float(temperature)
76
+ cfg["inference"]["top_k"] = int(top_k)
77
+ cfg["inference"]["repetition_penalty"] = float(repetition_penalty)
78
+ cfg["inference"]["diversity_penalty"] = float(diversity_penalty)
79
+ cfg["inference"]["num_steps"] = int(num_steps)
80
+
81
+ src_tok = _STATE["src_tok"]
82
+ tgt_tok = _STATE["tgt_tok"]
83
+ device = _STATE["device"]
84
+
85
+ input_ids = torch.tensor([src_tok.encode(text.strip())], dtype=torch.long, device=device)
86
+ out = run_inference(_STATE["model"], input_ids, cfg)
87
+ decoded_ids = [x for x in out[0].tolist() if x > 4]
88
+ raw = tgt_tok.decode(decoded_ids).strip()
89
+ output = _clean_text(raw) if clean_output else raw
90
+
91
+ return {
92
+ "input": text,
93
+ "output": output,
94
+ "raw_output": raw,
95
+ "config": {
96
+ "temperature": float(temperature),
97
+ "top_k": int(top_k),
98
+ "repetition_penalty": float(repetition_penalty),
99
+ "diversity_penalty": float(diversity_penalty),
100
+ "num_steps": int(num_steps),
101
+ "clean_output": bool(clean_output),
102
+ },
103
+ }
model/__init__.py ADDED
File without changes
model/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (126 Bytes). View file
 
model/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (173 Bytes). View file
 
model/__pycache__/d3pm_model_cross_attention.cpython-311.pyc ADDED
Binary file (30.7 kB). View file
 
model/__pycache__/d3pm_model_cross_attention.cpython-312.pyc ADDED
Binary file (27.2 kB). View file
 
model/__pycache__/d3pm_model_encoder_decoder.cpython-311.pyc ADDED
Binary file (15.5 kB). View file
 
model/__pycache__/sanskrit_model.cpython-311.pyc ADDED
Binary file (5.67 kB). View file
 
model/__pycache__/sanskrit_model.cpython-312.pyc ADDED
Binary file (5.26 kB). View file
 
model/__pycache__/tokenizer.cpython-311.pyc ADDED
Binary file (15.3 kB). View file
 
model/__pycache__/tokenizer.cpython-312.pyc ADDED
Binary file (12.9 kB). View file
 
model/__pycache__/tokenizers.cpython-311.pyc ADDED
Binary file (7.94 kB). View file
 
model/d3pm_model_cross_attention.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ d3pm_model_cross_attention.py — Cross-Script + Generation-Fixed
3
+ =================================================================
4
+ INPUT : quote_text tokens (Roman script, src_vocab_size)
5
+ OUTPUT : quote_devanagari tokens (Devanagari script, tgt_vocab_size)
6
+
7
+ src_embed uses src_vocab_size (Roman BPE)
8
+ tgt_embed uses tgt_vocab_size (Devanagari BPE)
9
+ head outputs tgt_vocab_size (predict Devanagari tokens)
10
+ Weight tying: head <-> tgt_embed only (NOT src_embed)
11
+
12
+ Generation bugs fixed:
13
+ BUG 1 - tgt_pad_mask suppressed during inference
14
+ BUG 2 - q_sample skipped at t=0
15
+ BUG 3 - time embedding before hint_gate
16
+ BUG 4 - diversity penalty uses global mean not var
17
+ """
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from diffusion.scheduler import OptimizedCosineScheduler
24
+ from diffusion.forward_process import AbsorbingForwardProcess
25
+
26
+
27
+ class SinusoidalPositionalEncoding(nn.Module):
28
+ def __init__(self, d_model, max_len=5000):
29
+ super().__init__()
30
+ pe = torch.zeros(max_len, d_model)
31
+ position = torch.arange(0, max_len).unsqueeze(1).float()
32
+ div_term = torch.exp(
33
+ torch.arange(0, d_model, 2).float() *
34
+ (-torch.log(torch.tensor(10000.0)) / d_model)
35
+ )
36
+ pe[:, 0::2] = torch.sin(position * div_term)
37
+ pe[:, 1::2] = torch.cos(position * div_term)
38
+ self.register_buffer("pe", pe.unsqueeze(0))
39
+
40
+ def forward(self, x):
41
+ return x + self.pe[:, :x.size(1), :]
42
+
43
+
44
+ class SanskritEmbeddings(nn.Module):
45
+ def __init__(self, vocab_size, d_model, max_seq_len):
46
+ super().__init__()
47
+ self.token_emb = nn.Embedding(vocab_size, d_model)
48
+ self.pos_enc = SinusoidalPositionalEncoding(d_model, max_seq_len)
49
+ self.token_embedding = self.token_emb
50
+ def forward(self, tokens):
51
+ return self.pos_enc(self.token_emb(tokens))
52
+
53
+
54
+ class MultiHeadAttention(nn.Module):
55
+ def __init__(self, d_model, n_heads, dropout=0.1):
56
+ super().__init__()
57
+ assert d_model % n_heads == 0
58
+ self.d_model = d_model
59
+ self.n_heads = n_heads
60
+ self.head_dim = d_model // n_heads
61
+ self.q_proj = nn.Linear(d_model, d_model)
62
+ self.k_proj = nn.Linear(d_model, d_model)
63
+ self.v_proj = nn.Linear(d_model, d_model)
64
+ self.out_proj = nn.Linear(d_model, d_model)
65
+ self.dropout = nn.Dropout(dropout)
66
+
67
+ def forward(self, q, k, v, mask=None):
68
+ B, Lq, _ = q.size()
69
+ Lk = k.size(1)
70
+ Q = self.q_proj(q).view(B, Lq, self.n_heads, self.head_dim).transpose(1, 2)
71
+ K = self.k_proj(k).view(B, Lk, self.n_heads, self.head_dim).transpose(1, 2)
72
+ V = self.v_proj(v).view(B, Lk, self.n_heads, self.head_dim).transpose(1, 2)
73
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
74
+ if mask is not None:
75
+ scores = scores.masked_fill(mask.unsqueeze(1).unsqueeze(2), float('-inf'))
76
+ attn = self.dropout(torch.softmax(scores, dim=-1))
77
+ out = torch.matmul(attn, V).transpose(1, 2).contiguous().view(B, Lq, self.d_model)
78
+ return self.out_proj(out)
79
+
80
+
81
+ class EncoderBlock(nn.Module):
82
+ def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
83
+ super().__init__()
84
+ self.mha = MultiHeadAttention(d_model, n_heads, dropout)
85
+ self.ff = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout),
86
+ nn.Linear(d_ff, d_model), nn.Dropout(dropout))
87
+ self.norm1 = nn.LayerNorm(d_model)
88
+ self.norm2 = nn.LayerNorm(d_model)
89
+ def forward(self, x, pad_mask=None):
90
+ x = self.norm1(x + self.mha(x, x, x, mask=pad_mask))
91
+ return self.norm2(x + self.ff(x))
92
+
93
+
94
+ class DecoderBlock(nn.Module):
95
+ def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
96
+ super().__init__()
97
+ self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
98
+ self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout)
99
+ self.ff = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout),
100
+ nn.Linear(d_ff, d_model), nn.Dropout(dropout))
101
+ self.norm1 = nn.LayerNorm(d_model)
102
+ self.norm2 = nn.LayerNorm(d_model)
103
+ self.norm3 = nn.LayerNorm(d_model)
104
+ def forward(self, x, memory, tgt_pad_mask=None, src_pad_mask=None):
105
+ x = self.norm1(x + self.self_attn(x, x, x, mask=tgt_pad_mask))
106
+ x = self.norm2(x + self.cross_attn(x, memory, memory, mask=src_pad_mask))
107
+ return self.norm3(x + self.ff(x))
108
+
109
+
110
+ class D3PMCrossAttention(nn.Module):
111
+ def __init__(self, cfg):
112
+ super().__init__()
113
+ self.cfg = cfg
114
+ self.mask_token_id = cfg['diffusion']['mask_token_id']
115
+ d = cfg['model']['d_model']
116
+ nhead = cfg['model']['n_heads']
117
+ d_ff = cfg['model']['d_ff']
118
+ drop = cfg['model']['dropout']
119
+ seqlen = cfg['model']['max_seq_len']
120
+ nlayer = cfg['model']['n_layers']
121
+ src_vocab = cfg['model'].get('src_vocab_size', cfg['model']['vocab_size'])
122
+ tgt_vocab = cfg['model'].get('tgt_vocab_size', cfg['model']['vocab_size'])
123
+
124
+ # Separate embeddings: Roman src, Devanagari tgt
125
+ self.src_embed = SanskritEmbeddings(src_vocab, d, seqlen)
126
+ self.tgt_embed = SanskritEmbeddings(tgt_vocab, d, seqlen)
127
+
128
+ self.scheduler = OptimizedCosineScheduler(cfg)
129
+ self.forward_process = AbsorbingForwardProcess(self.scheduler)
130
+
131
+ self.encoder_blocks = nn.ModuleList([EncoderBlock(d, nhead, d_ff, drop) for _ in range(nlayer)])
132
+ self.decoder_blocks = nn.ModuleList([DecoderBlock(d, nhead, d_ff, drop) for _ in range(nlayer)])
133
+
134
+ self.time_mlp = nn.Sequential(nn.Linear(1, d//4), nn.SiLU(), nn.Linear(d//4, d))
135
+ self.hint_gate = nn.Sequential(nn.Linear(d, d), nn.Sigmoid())
136
+
137
+ # Output head: predict Devanagari tokens, tied to tgt_embed
138
+ self.head = nn.Linear(d, tgt_vocab, bias=False)
139
+ self.head.weight = self.tgt_embed.token_embedding.weight
140
+
141
+ def forward(self, src, tgt, t, x0_hint=None, inference_mode=False):
142
+ PAD = 1
143
+ src_pad_mask = (src == PAD)
144
+ # BUG 1 FIX: no tgt mask during inference
145
+ tgt_pad_mask = None if inference_mode else (tgt == PAD)
146
+
147
+ # Encode Roman source
148
+ memory = self.src_embed(src)
149
+ for block in self.encoder_blocks:
150
+ memory = block(memory, pad_mask=src_pad_mask)
151
+
152
+ # BUG 2 FIX: skip q_sample at final step t=0
153
+ if inference_mode and (t == 0).all():
154
+ x_t_ids = tgt
155
+ else:
156
+ _, x_t_ids = self.forward_process.q_sample(tgt, t)
157
+
158
+ x = self.tgt_embed(x_t_ids)
159
+
160
+ # BUG 3 FIX: time embedding BEFORE hint gate
161
+ t_norm = t.float() / self.scheduler.num_timesteps
162
+ t_emb = self.time_mlp(t_norm.unsqueeze(-1))
163
+ x = x + t_emb.unsqueeze(1)
164
+
165
+ if x0_hint is not None:
166
+ hint_emb = self.tgt_embed(x0_hint)
167
+ gate = self.hint_gate(x) # time-aware gate
168
+ x = x + gate * hint_emb
169
+
170
+ for block in self.decoder_blocks:
171
+ x = block(x, memory, tgt_pad_mask=tgt_pad_mask, src_pad_mask=src_pad_mask)
172
+
173
+ return self.head(x), None
174
+
175
+ @torch.no_grad()
176
+ def generate(self, src, num_steps=None, temperature=0.8, top_k=50,
177
+ repetition_penalty=1.2, diversity_penalty=0.0):
178
+ if src.dim() == 1:
179
+ src = src.unsqueeze(0)
180
+ device = src.device
181
+ B, L = src.shape
182
+ T = self.scheduler.num_timesteps
183
+ steps = num_steps or T
184
+ step_size = max(1, T // steps)
185
+ timesteps = list(range(T - 1, -1, -step_size))
186
+ if timesteps[-1] != 0:
187
+ timesteps.append(0)
188
+
189
+ mask_id = self.mask_token_id
190
+ x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
191
+ hint = None
192
+
193
+ self.eval()
194
+ with torch.no_grad():
195
+ for step_idx, t_val in enumerate(timesteps):
196
+ t = torch.full((B,), t_val, dtype=torch.long, device=device)
197
+ is_last = (step_idx == len(timesteps) - 1)
198
+ logits, _ = self.forward(src, x0_est, t, x0_hint=hint, inference_mode=True)
199
+ if repetition_penalty != 1.0:
200
+ logits = _apply_repetition_penalty(logits, x0_est, repetition_penalty)
201
+ if diversity_penalty > 0.0:
202
+ logits = _apply_diversity_penalty_fixed(logits, diversity_penalty) # BUG 4 FIX
203
+ logits = logits / max(temperature, 1e-5)
204
+ if top_k > 0:
205
+ logits = _top_k_filter(logits, top_k)
206
+ probs = F.softmax(logits, dim=-1)
207
+ x0_est = torch.argmax(probs, dim=-1) if is_last else _batch_multinomial(probs)
208
+ hint = x0_est
209
+ return x0_est
210
+
211
+
212
+ class BaselineCrossAttention(nn.Module):
213
+ def __init__(self, cfg):
214
+ super().__init__()
215
+ d = cfg['model']['d_model']; nhead = cfg['model']['n_heads']
216
+ d_ff = cfg['model']['d_ff']; drop = cfg['model']['dropout']
217
+ seqlen = cfg['model']['max_seq_len']; nlayer = cfg['model']['n_layers']
218
+ src_vocab = cfg['model'].get('src_vocab_size', cfg['model']['vocab_size'])
219
+ tgt_vocab = cfg['model'].get('tgt_vocab_size', cfg['model']['vocab_size'])
220
+ self.src_embed = SanskritEmbeddings(src_vocab, d, seqlen)
221
+ self.tgt_embed = SanskritEmbeddings(tgt_vocab, d, seqlen)
222
+ self.encoder_blocks = nn.ModuleList([EncoderBlock(d, nhead, d_ff, drop) for _ in range(nlayer)])
223
+ self.decoder_blocks = nn.ModuleList([DecoderBlock(d, nhead, d_ff, drop) for _ in range(nlayer)])
224
+ self.head = nn.Linear(d, tgt_vocab, bias=False)
225
+ self.head.weight = self.tgt_embed.token_embedding.weight
226
+
227
+ def forward(self, src, tgt, t=None, x0_hint=None):
228
+ PAD = 1
229
+ memory = self.src_embed(src)
230
+ for b in self.encoder_blocks: memory = b(memory, pad_mask=(src==PAD))
231
+ x = self.tgt_embed(tgt)
232
+ for b in self.decoder_blocks: x = b(x, memory, tgt_pad_mask=(tgt==PAD), src_pad_mask=(src==PAD))
233
+ return (self.head(x),)
234
+
235
+ @torch.no_grad()
236
+ def generate(self, src, max_len=None, start_token_id=2, **kwargs):
237
+ if max_len is None: max_len = src.size(1)
238
+ B, device = src.size(0), src.device
239
+ memory = self.src_embed(src)
240
+ for b in self.encoder_blocks: memory = b(memory, pad_mask=(src==1))
241
+ ys = torch.full((B, 1), start_token_id, dtype=torch.long, device=device)
242
+ for _ in range(max_len):
243
+ x = self.tgt_embed(ys)
244
+ for b in self.decoder_blocks: x = b(x, memory, tgt_pad_mask=None, src_pad_mask=(src==1))
245
+ ys = torch.cat([ys, torch.argmax(self.head(x)[:,-1,:], dim=-1, keepdim=True)], dim=1)
246
+ return ys[:, 1:max_len+1]
247
+
248
+
249
+ # helpers
250
+ def _top_k_filter(logits, k):
251
+ B, L, V = logits.shape
252
+ if k >= V: return logits
253
+ topk_vals, _ = torch.topk(logits, k, dim=-1)
254
+ return logits.masked_fill(logits < topk_vals[..., -1].unsqueeze(-1), float('-inf'))
255
+
256
+ def _batch_multinomial(probs):
257
+ B, L, V = probs.shape
258
+ flat = probs.view(B*L, V) + 1e-9
259
+ return torch.multinomial(flat/flat.sum(-1,keepdim=True), 1).squeeze(-1).view(B, L)
260
+
261
+ def _apply_repetition_penalty(logits, prev_tokens, penalty):
262
+ for b in range(logits.shape[0]):
263
+ for tid in set(prev_tokens[b].tolist()):
264
+ if tid > 4: logits[b, :, tid] = logits[b, :, tid] / penalty
265
+ return logits
266
+
267
+ def _apply_diversity_penalty(logits, penalty): # legacy wrong version
268
+ return logits + penalty * logits.var(dim=-1, keepdim=True)
269
+
270
+ def _apply_diversity_penalty_fixed(logits, penalty): # correct version
271
+ return logits - penalty * logits.mean(dim=1, keepdim=True)
model/d3pm_model_encoder_decoder.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from diffusion.scheduler import OptimizedCosineScheduler
4
+ from diffusion.forward_process import AbsorbingForwardProcess
5
+ # Import shared classes to guarantee identical architectures
6
+ from model.d3pm_model_cross_attention import SanskritEmbeddings, EncoderBlock, MultiHeadAttention
7
+ class DecoderBlock(nn.Module):
8
+ def __init__(self, d_model, n_heads, d_ff, dropout=0.15):
9
+ super().__init__()
10
+ self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
11
+ self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout) # ← restored
12
+ self.ff = nn.Sequential(
13
+ nn.Linear(d_model, d_ff),
14
+ nn.ReLU(),
15
+ nn.Dropout(dropout),
16
+ nn.Linear(d_ff, d_model),
17
+ nn.Dropout(dropout),
18
+ )
19
+ self.norm1 = nn.LayerNorm(d_model)
20
+ self.norm2 = nn.LayerNorm(d_model)
21
+ self.norm3 = nn.LayerNorm(d_model) # ← restored (for cross-attn residual)
22
+
23
+ def forward(self, x, memory, tgt_pad_mask=None):
24
+ # 1. Masked self-attention on target
25
+ x = self.norm1(x + self.self_attn(x, x, x, mask=tgt_pad_mask))
26
+ # 2. Cross-attention: queries from decoder, keys/values from encoder memory
27
+ x = self.norm2(x + self.cross_attn(x, memory, memory))
28
+ # 3. Feed-forward
29
+ return self.norm3(x + self.ff(x))
30
+
31
+
32
+ class DecoderBlockNoCrossAttn(nn.Module):
33
+ """Kept for reference — NOT used by D3PMEncoderDecoder."""
34
+ def __init__(self, d_model, n_heads, d_ff, dropout=0.15):
35
+ super().__init__()
36
+ self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
37
+ self.ff = nn.Sequential(
38
+ nn.Linear(d_model, d_ff), nn.ReLU(), nn.Dropout(dropout),
39
+ nn.Linear(d_ff, d_model), nn.Dropout(dropout),
40
+ )
41
+ self.norm1, self.norm2 = nn.LayerNorm(d_model), nn.LayerNorm(d_model)
42
+
43
+ def forward(self, x, tgt_pad_mask=None, causal_mask=None):
44
+ combined_mask = None
45
+ if tgt_pad_mask is not None and causal_mask is not None:
46
+ combined_mask = tgt_pad_mask | causal_mask
47
+ elif causal_mask is not None:
48
+ combined_mask = causal_mask
49
+ elif tgt_pad_mask is not None:
50
+ combined_mask = tgt_pad_mask
51
+ x = self.norm1(x + self.self_attn(x, x, x, mask=combined_mask))
52
+ return self.norm2(x + self.ff(x))
53
+
54
+
55
+ # ============================================================
56
+ # 1. D3PM Encoder-Decoder Model
57
+ # ============================================================
58
+ class D3PMEncoderDecoder(nn.Module):
59
+ def __init__(self, cfg):
60
+ super().__init__()
61
+ self.cfg = cfg
62
+ self.mask_token_id = cfg['diffusion']['mask_token_id']
63
+
64
+ src_vocab = cfg['model'].get('src_vocab_size', cfg['model']['vocab_size'])
65
+ tgt_vocab = cfg['model'].get('tgt_vocab_size', cfg['model']['vocab_size'])
66
+ d_model = cfg['model']['d_model']
67
+ n_heads = cfg['model']['n_heads']
68
+ d_ff = cfg['model']['d_ff']
69
+ dropout = cfg['model']['dropout']
70
+ n_layers = cfg['model']['n_layers']
71
+ max_len = cfg['model']['max_seq_len']
72
+
73
+ self.src_embed = SanskritEmbeddings(src_vocab, d_model, max_len)
74
+ self.tgt_embed = SanskritEmbeddings(tgt_vocab, d_model, max_len)
75
+
76
+ self.scheduler = OptimizedCosineScheduler(cfg)
77
+ self.forward_process = AbsorbingForwardProcess(self.scheduler)
78
+
79
+ self.encoder_blocks = nn.ModuleList([
80
+ EncoderBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
81
+ ])
82
+ # DecoderBlock now has cross-attention — matches saved checkpoint
83
+ self.decoder_blocks = nn.ModuleList([
84
+ DecoderBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
85
+ ])
86
+
87
+ self.time_mlp = nn.Sequential(
88
+ nn.Linear(1, d_model // 4), nn.SiLU(),
89
+ nn.Linear(d_model // 4, d_model),
90
+ )
91
+ self.head = nn.Linear(d_model, tgt_vocab)
92
+ self.head.weight = self.tgt_embed.token_embedding.weight
93
+
94
+ def forward(self, src, tgt, t, x0_hint=None):
95
+ src_pad_mask = (src == 1)
96
+ tgt_pad_mask = (tgt == 1)
97
+
98
+ # Encode source (Roman IAST)
99
+ memory = self.src_embed(src)
100
+ for block in self.encoder_blocks:
101
+ memory = block(memory, pad_mask=src_pad_mask)
102
+
103
+ # Corrupt target with forward diffusion
104
+ _, x_t_ids = self.forward_process.q_sample(tgt, t)
105
+
106
+ # Optionally blend in x0_hint (self-conditioning)
107
+ if x0_hint is not None:
108
+ hint_prob = 0.5
109
+ blend_mask = (torch.rand(x_t_ids.shape, device=x_t_ids.device) < hint_prob)
110
+ still_mask = (x_t_ids == self.mask_token_id)
111
+ x_t_ids = torch.where(blend_mask & still_mask, x0_hint, x_t_ids)
112
+
113
+ x = self.tgt_embed(x_t_ids)
114
+ t_emb = self.time_mlp(t.float().unsqueeze(-1)).unsqueeze(1)
115
+ x = x + t_emb.expand(-1, tgt.shape[1], -1)
116
+
117
+ # Decode with cross-attention over encoder memory
118
+ for block in self.decoder_blocks:
119
+ x = block(x, memory, tgt_pad_mask=tgt_pad_mask)
120
+
121
+ return self.head(x), None
122
+
123
+ @torch.no_grad()
124
+ def generate(
125
+ self,
126
+ src,
127
+ num_steps = None,
128
+ temperature = 0.75,
129
+ top_k = 50,
130
+ repetition_penalty = 1.15,
131
+ diversity_penalty = 0.0,
132
+ ):
133
+ """
134
+ Iterative D3PM reverse diffusion — same signature as
135
+ D3PMCrossAttention.generate() so SanskritModel.generate() works
136
+ identically for both model types.
137
+ """
138
+ device = src.device
139
+ B, L = src.shape[0], self.cfg['model']['max_seq_len']
140
+ T = num_steps or self.scheduler.num_timesteps
141
+ mask_id = self.mask_token_id
142
+ pad_id = 1
143
+
144
+ x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
145
+
146
+ for step in range(T - 1, -1, -1):
147
+ t_tensor = torch.full((B,), step, dtype=torch.long, device=device)
148
+ hint = x0_est.clone()
149
+
150
+ logits, _ = self.forward(src, x0_est, t_tensor, x0_hint=hint)
151
+
152
+ # Repetition penalty
153
+ if repetition_penalty != 1.0:
154
+ for b in range(B):
155
+ for tok in set(x0_est[b].tolist()):
156
+ if tok > pad_id:
157
+ logits[b, :, tok] /= repetition_penalty
158
+
159
+ # Diversity penalty (suppress common tokens)
160
+ if diversity_penalty > 0.0:
161
+ logits = logits - diversity_penalty * logits.mean(dim=1, keepdim=True)
162
+
163
+ # Temperature + top-k sampling
164
+ logits = logits / max(temperature, 1e-8)
165
+ if top_k > 0:
166
+ vals, _ = torch.topk(logits, top_k, dim=-1)
167
+ logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
168
+
169
+ probs = torch.softmax(logits, dim=-1)
170
+ # Only update positions that are still masked
171
+ still = (x0_est == mask_id)
172
+ sample = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(B, L)
173
+ x0_est = torch.where(still, sample, x0_est)
174
+
175
+ return x0_est
176
+
177
+
178
+ # ============================================================
179
+ # 2. Baseline Encoder-Decoder Model (unchanged)
180
+ # ============================================================
181
+ class BaselineEncoderDecoder(nn.Module):
182
+ def __init__(self, cfg):
183
+ super().__init__()
184
+ self.cfg = cfg
185
+ self.src_embed = SanskritEmbeddings(cfg['model']['vocab_size'], cfg['model']['d_model'],
186
+ cfg['model']['max_seq_len'])
187
+ self.tgt_embed = SanskritEmbeddings(cfg['model']['vocab_size'], cfg['model']['d_model'],
188
+ cfg['model']['max_seq_len'])
189
+ self.encoder_blocks = nn.ModuleList([
190
+ EncoderBlock(cfg['model']['d_model'], cfg['model']['n_heads'],
191
+ cfg['model']['d_ff'], cfg['model']['dropout'])
192
+ for _ in range(cfg['model']['n_layers'])
193
+ ])
194
+ self.decoder_blocks = nn.ModuleList([
195
+ DecoderBlock(cfg['model']['d_model'], cfg['model']['n_heads'],
196
+ cfg['model']['d_ff'], cfg['model']['dropout'])
197
+ for _ in range(cfg['model']['n_layers'])
198
+ ])
199
+ self.head = nn.Linear(cfg['model']['d_model'], cfg['model']['vocab_size'])
200
+ self.head.weight = self.tgt_embed.token_embedding.weight
201
+
202
+ def forward(self, src, tgt):
203
+ src_pad_mask, tgt_pad_mask = (src == 1), (tgt == 1)
204
+ memory = self.src_embed(src)
205
+ for block in self.encoder_blocks:
206
+ memory = block(memory, pad_mask=src_pad_mask)
207
+ x = self.tgt_embed(tgt)
208
+ for block in self.decoder_blocks:
209
+ x = block(x, memory, tgt_pad_mask=tgt_pad_mask)
210
+ return self.head(x)
211
+
212
+ @torch.no_grad()
213
+ def generate(self, src, max_len=80, start_token_id=2):
214
+ batch_size, device = src.size(0), src.device
215
+ src_pad_mask = (src == 1)
216
+ memory = self.src_embed(src)
217
+ for block in self.encoder_blocks:
218
+ memory = block(memory, pad_mask=src_pad_mask)
219
+ ys = torch.ones(batch_size, 1, dtype=torch.long, device=device) * start_token_id
220
+ for _ in range(max_len):
221
+ x = self.tgt_embed(ys)
222
+ for block in self.decoder_blocks:
223
+ x = block(x, memory, tgt_pad_mask=None)
224
+ logits = self.head(x)
225
+ next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
226
+ ys = torch.cat([ys, next_token], dim=1)
227
+ return ys[:, 1:]
model/sanskrit_model.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ sanskrit_model.py — Fixed
3
+ ===========================
4
+ Added inference_mode parameter to forward() so reverse_process.py can
5
+ pass inference_mode=True without a TypeError.
6
+
7
+ The wrapper introspects each inner model's signature and only passes
8
+ kwargs that model actually accepts — safe across all four architectures.
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import inspect
14
+
15
+
16
+ class SanskritModel(nn.Module):
17
+ def __init__(self, cfg):
18
+ super().__init__()
19
+ model_type = cfg['model_type']
20
+
21
+ if model_type == 'd3pm_cross_attention':
22
+ from model.d3pm_model_cross_attention import D3PMCrossAttention
23
+ self.model = D3PMCrossAttention(cfg)
24
+
25
+ elif model_type == 'd3pm_encoder_decoder':
26
+ from model.d3pm_model_encoder_decoder import D3PMEncoderDecoder
27
+ self.model = D3PMEncoderDecoder(cfg)
28
+
29
+ elif model_type == 'baseline_cross_attention':
30
+ from model.d3pm_model_cross_attention import BaselineCrossAttention
31
+ self.model = BaselineCrossAttention(cfg)
32
+
33
+ elif model_type == 'baseline_encoder_decoder':
34
+ from model.d3pm_model_encoder_decoder import BaselineEncoderDecoder
35
+ self.model = BaselineEncoderDecoder(cfg)
36
+
37
+ else:
38
+ raise ValueError(f"Unknown model_type: {model_type}")
39
+
40
+ def forward(self, input_ids, target_ids, t, x0_hint=None, inference_mode=False):
41
+ """
42
+ Forward pass. Introspects the inner model's signature so only
43
+ supported kwargs are passed — works with all four architectures.
44
+ """
45
+ sig = inspect.signature(self.model.forward).parameters
46
+ kwargs = {}
47
+ if 'x0_hint' in sig:
48
+ kwargs['x0_hint'] = x0_hint
49
+ if 'inference_mode' in sig:
50
+ kwargs['inference_mode'] = inference_mode
51
+
52
+ if 't' in sig:
53
+ return self.model(input_ids, target_ids, t, **kwargs)
54
+ else:
55
+ return self.model(input_ids, target_ids, **kwargs)
56
+
57
+ @torch.no_grad()
58
+ def generate(self, src, **kwargs):
59
+ sig = inspect.signature(self.model.generate).parameters
60
+ filtered = {k: v for k, v in kwargs.items() if k in sig}
61
+ return self.model.generate(src, **filtered)
model/tokenizer.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ tokenizer.py — Dual Tokenizer Fix
3
+ ====================================
4
+ Two separate BPE tokenizers:
5
+
6
+ SanskritSourceTokenizer — trained on quote_text (Roman/IAST script)
7
+ SanskritTargetTokenizer — trained on quote_devanagari (Devanagari script)
8
+
9
+ WHY SEPARATE?
10
+ Roman Sanskrit and Devanagari are fundamentally different character sets.
11
+ Roman uses a-z + diacritics (~60 unique chars), Devanagari uses ā-ह + matras
12
+ (~100+ unique chars). A shared BPE tokenizer wastes half its vocab on
13
+ character combos that never cross scripts, and forces the embedding table
14
+ to encode both scripts in one space — confusing the model's cross-attention.
15
+
16
+ With separate tokenizers:
17
+ - src vocab captures Roman subwords cleanly (ā, ś, ṭ, ṃ etc.)
18
+ - tgt vocab captures Devanagari akshara clusters cleanly (क्ष, त्र, etc.)
19
+ - The model learns a true cross-script mapping in its cross-attention
20
+
21
+ SPECIAL TOKENS (same IDs in both):
22
+ [MASK] = 0 ← required by absorbing diffusion
23
+ [PAD] = 1
24
+ [UNK] = 2
25
+ [CLS] = 3
26
+ [SEP] = 4
27
+ """
28
+
29
+ from tokenizers import Tokenizer
30
+ from tokenizers.models import BPE
31
+ from tokenizers.trainers import BpeTrainer
32
+ from tokenizers.pre_tokenizers import Whitespace
33
+ from datasets import load_dataset
34
+ from pathlib import Path
35
+
36
+
37
+ SPECIAL_TOKENS = ["[MASK]", "[PAD]", "[UNK]", "[CLS]", "[SEP]"]
38
+
39
+
40
+ def _build_bpe(texts, vocab_size):
41
+ """Build a BPE tokenizer from an iterator of strings."""
42
+ tok = Tokenizer(BPE(unk_token="[UNK]"))
43
+ tok.pre_tokenizer = Whitespace()
44
+ trainer = BpeTrainer(
45
+ vocab_size=vocab_size,
46
+ special_tokens=SPECIAL_TOKENS, # [MASK] MUST be first → id=0
47
+ min_frequency=2,
48
+ )
49
+ tok.train_from_iterator(texts, trainer)
50
+ return tok
51
+
52
+
53
+ def _validate(tok, name):
54
+ mask_id = tok.token_to_id("[MASK]")
55
+ pad_id = tok.token_to_id("[PAD]")
56
+ assert mask_id == 0, f"{name}: [MASK] must be id=0, got {mask_id}"
57
+ assert pad_id == 1, f"{name}: [PAD] must be id=1, got {pad_id}"
58
+ print(f"✅ {name}: [MASK]=0, [PAD]=1 confirmed. Vocab size={tok.get_vocab_size()}")
59
+
60
+
61
+ # ── Source tokenizer (Roman/IAST Sanskrit) ────────────────────────────
62
+
63
+ class SanskritSourceTokenizer:
64
+ """
65
+ Tokenizer for quote_text — Roman transliteration of Sanskrit.
66
+ Examples: "dharmo rakṣati rakṣitaḥ", "yatra nāryastu pūjyante"
67
+ """
68
+ MODEL_PATH = "sanskrit_src_tokenizer.json"
69
+
70
+ def __init__(self, vocab_size=8000, max_len=80, n_train_samples=50000):
71
+ self.vocab_size = vocab_size
72
+ self.max_len = max_len
73
+ self.mask_token_id = 0
74
+
75
+ if Path(self.MODEL_PATH).exists():
76
+ print(f"📖 Loading source tokenizer from {self.MODEL_PATH} …")
77
+ self.tokenizer = Tokenizer.from_file(self.MODEL_PATH)
78
+ else:
79
+ print("🎓 Training source tokenizer on quote_text …")
80
+ self._train(vocab_size, n_train_samples)
81
+
82
+ _validate(self.tokenizer, "SrcTokenizer")
83
+
84
+ def _train(self, vocab_size, n_samples):
85
+ dataset = load_dataset("paws/sanskrit-verses-gretil", split="train")
86
+ n = min(n_samples, len(dataset))
87
+ texts = [s["quote_text"] for s in dataset.select(range(n))
88
+ if s["quote_text"].strip()]
89
+ self.tokenizer = _build_bpe(texts, vocab_size)
90
+ self.tokenizer.save(self.MODEL_PATH)
91
+ print(f"✅ Source tokenizer trained on {len(texts)} Roman texts.")
92
+
93
+ def encode(self, text):
94
+ ids = self.tokenizer.encode(text).ids[:self.max_len]
95
+ pad = self.tokenizer.token_to_id("[PAD]")
96
+ ids += [pad] * max(0, self.max_len - len(ids))
97
+ return ids[:self.max_len]
98
+
99
+ def decode(self, ids):
100
+ clean = [i for i in ids if i > 4] # skip special tokens
101
+ return self.tokenizer.decode(clean)
102
+
103
+ def __len__(self):
104
+ return self.vocab_size
105
+
106
+
107
+ # ── Target tokenizer (Devanagari Sanskrit) ───────────────────────────
108
+
109
+ class SanskritTargetTokenizer:
110
+ """
111
+ Tokenizer for quote_devanagari — Devanagari script.
112
+ Examples: "धर्मो रक्षति रक्षितः", "यत्र नार्यस्तु पूज्यन्ते"
113
+ """
114
+ MODEL_PATH = "sanskrit_tgt_tokenizer.json"
115
+
116
+ def __init__(self, vocab_size=8000, max_len=80, n_train_samples=50000):
117
+ self.vocab_size = vocab_size
118
+ self.max_len = max_len
119
+ self.mask_token_id = 0
120
+
121
+ if Path(self.MODEL_PATH).exists():
122
+ print(f"📖 Loading target tokenizer from {self.MODEL_PATH} …")
123
+ self.tokenizer = Tokenizer.from_file(self.MODEL_PATH)
124
+ else:
125
+ print("🎓 Training target tokenizer on quote_devanagari …")
126
+ self._train(vocab_size, n_train_samples)
127
+
128
+ _validate(self.tokenizer, "TgtTokenizer")
129
+
130
+ def _train(self, vocab_size, n_samples):
131
+ dataset = load_dataset("paws/sanskrit-verses-gretil", split="train")
132
+ n = min(n_samples, len(dataset))
133
+ texts = [s["quote_devanagari"] for s in dataset.select(range(n))
134
+ if s["quote_devanagari"].strip()]
135
+ self.tokenizer = _build_bpe(texts, vocab_size)
136
+ self.tokenizer.save(self.MODEL_PATH)
137
+ print(f"✅ Target tokenizer trained on {len(texts)} Devanagari texts.")
138
+
139
+ def encode(self, text):
140
+ ids = self.tokenizer.encode(text).ids[:self.max_len]
141
+ pad = self.tokenizer.token_to_id("[PAD]")
142
+ ids += [pad] * max(0, self.max_len - len(ids))
143
+ return ids[:self.max_len]
144
+
145
+ def decode(self, ids):
146
+ clean = [i for i in ids if i > 4]
147
+ return self.tokenizer.decode(clean)
148
+
149
+ # Methods required by BERTScore
150
+ def build_inputs_with_special_tokens(self, token_ids):
151
+ return list(token_ids)
152
+
153
+ def get_vocab(self):
154
+ return {str(i): i for i in range(self.vocab_size)}
155
+
156
+ def convert_ids_to_tokens(self, ids):
157
+ return [str(i) for i in ids]
158
+
159
+ def __len__(self):
160
+ return self.vocab_size
161
+
162
+
163
+ # ── Legacy shared tokenizer (kept for backward compat) ───────────────
164
+
165
+ class SanskritTokenizer:
166
+ """
167
+ LEGACY: single shared tokenizer trained on BOTH scripts.
168
+ Still works but suboptimal — use SanskritSourceTokenizer +
169
+ SanskritTargetTokenizer for the quote_text → quote_devanagari task.
170
+ """
171
+ MODEL_PATH = "sanskrit_tokenizer_m4pro.json"
172
+
173
+ def __init__(self, vocab_size=16000, max_len=80):
174
+ self.vocab_size = vocab_size
175
+ self.max_len = max_len
176
+ self.mask_token_id = 0
177
+
178
+ if Path(self.MODEL_PATH).exists():
179
+ print("📖 Loading shared tokenizer …")
180
+ self.tokenizer = Tokenizer.from_file(self.MODEL_PATH)
181
+ else:
182
+ print("🎓 Training shared tokenizer on both scripts …")
183
+ self._train(vocab_size)
184
+
185
+ _validate(self.tokenizer, "SharedTokenizer")
186
+
187
+ def _train(self, vocab_size):
188
+ dataset = load_dataset("paws/sanskrit-verses-gretil", split="train")
189
+ n = min(50000, len(dataset))
190
+ texts = []
191
+ for s in dataset.select(range(n)):
192
+ if s["quote_text"].strip():
193
+ texts.append(s["quote_text"])
194
+ if s["quote_devanagari"].strip():
195
+ texts.append(s["quote_devanagari"])
196
+ self.tokenizer = _build_bpe(texts, vocab_size)
197
+ self.tokenizer.save(self.MODEL_PATH)
198
+ print(f"✅ Shared tokenizer trained ({len(texts)} texts).")
199
+
200
+ def encode(self, text):
201
+ ids = self.tokenizer.encode(text).ids[:self.max_len]
202
+ pad = self.tokenizer.token_to_id("[PAD]")
203
+ ids += [pad] * max(0, self.max_len - len(ids))
204
+ return ids[:self.max_len]
205
+
206
+ def decode(self, ids):
207
+ if ids and isinstance(ids[0], list):
208
+ raise TypeError("decode() got 2D list — pass a 1D list.")
209
+ clean = [i for i in ids if i > 4]
210
+ return self.tokenizer.decode(clean)
211
+
212
+ def build_inputs_with_special_tokens(self, token_ids):
213
+ return list(token_ids)
214
+
215
+ def get_vocab(self):
216
+ return {str(i): i for i in range(self.vocab_size)}
217
+
218
+ def convert_ids_to_tokens(self, ids):
219
+ return [str(i) for i in ids]
220
+
221
+ def __len__(self):
222
+ return self.vocab_size
model/tokenizers.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ tokenizer.py — FINAL
3
+ =====================
4
+ Uses the original sanskrit_tokenizer_m4pro.json — the exact one the model
5
+ was trained with. Hard-coded absolute path as primary, with fallbacks.
6
+
7
+ This tokenizer has NO </w> end-of-word markers and NO decoder set.
8
+ decode() returns space-separated BPE pieces — this is the format the
9
+ model was trained and evaluated on (BERTScore 0.71). Do NOT add a decoder
10
+ or retrain: that would break alignment with the checkpoint.
11
+ """
12
+
13
+ from tokenizers import Tokenizer
14
+ from tokenizers.models import BPE
15
+ from tokenizers.trainers import BpeTrainer
16
+ from tokenizers.pre_tokenizers import Whitespace
17
+ from datasets import load_dataset
18
+ from pathlib import Path
19
+ import os
20
+
21
+ # Hard-coded absolute path — update if you move the project
22
+ TOKENIZER_PATH = "/Users/bhsingh/Documents/Final_Paraphrase/sanskrit_tokenizer_m4pro.json"
23
+
24
+
25
+ def build_tokenizer(texts, vocab_size=16000):
26
+ tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
27
+ tokenizer.pre_tokenizer = Whitespace()
28
+ trainer = BpeTrainer(
29
+ vocab_size=vocab_size,
30
+ special_tokens=["[MASK]", "[PAD]", "[UNK]", "[CLS]", "[SEP]"],
31
+ min_frequency=2,
32
+ )
33
+ tokenizer.train_from_iterator(texts, trainer)
34
+ return tokenizer
35
+
36
+
37
+ class SanskritTokenizer:
38
+ def __init__(self, vocab_size=16000, max_len=80):
39
+ self.vocab_size = vocab_size
40
+ self.max_len = max_len
41
+ self.mask_token_id = 0
42
+
43
+ script_dir = Path(__file__).resolve().parent
44
+ candidates = [
45
+ os.environ.get("SANSKRIT_TOKENIZER_PATH", ""),
46
+ TOKENIZER_PATH,
47
+ str(script_dir.parent / "sanskrit_tokenizer_m4pro.json"),
48
+ str(script_dir / "sanskrit_tokenizer_m4pro.json"),
49
+ str(Path.cwd() / "sanskrit_tokenizer_m4pro.json"),
50
+ ]
51
+
52
+ self.model_path = None
53
+ for c in candidates:
54
+ if c and Path(c).exists():
55
+ self.model_path = c
56
+ break
57
+
58
+ if self.model_path:
59
+ print(f"📖 Loading tokenizer from: {self.model_path}")
60
+ self.tokenizer = Tokenizer.from_file(self.model_path)
61
+ self._validate_mask_token()
62
+ else:
63
+ print(f"⚠️ Tokenizer not found at any candidate path.")
64
+ print(f" Expected: {TOKENIZER_PATH}")
65
+ print(" Retraining — WARNING: output will not match existing checkpoint!")
66
+ self.model_path = TOKENIZER_PATH
67
+ self._train_tokenizer()
68
+
69
+ def _validate_mask_token(self):
70
+ mask_id = self.tokenizer.token_to_id("[MASK]")
71
+ assert mask_id == 0, f"[MASK] must be ID 0, got {mask_id}"
72
+ print("✅ [MASK] token confirmed at ID=0")
73
+
74
+ def _train_tokenizer(self):
75
+ dataset = load_dataset("paws/sanskrit-verses-gretil", split="train")
76
+ texts = []
77
+ for sample in dataset.select(range(50000)):
78
+ texts.extend([sample["quote_text"], sample["quote_devanagari"]])
79
+ tokenizer = build_tokenizer(texts, self.vocab_size)
80
+ tokenizer.save(self.model_path)
81
+ self.tokenizer = tokenizer
82
+ self._validate_mask_token()
83
+ print(f"✅ Tokenizer saved to: {self.model_path}")
84
+
85
+ def encode(self, text):
86
+ encoded = self.tokenizer.encode(text)
87
+ token_ids = encoded.ids[:self.max_len]
88
+ pad_id = self.tokenizer.token_to_id("[PAD]")
89
+ if len(token_ids) < self.max_len:
90
+ token_ids += [pad_id] * (self.max_len - len(token_ids))
91
+ return token_ids[:self.max_len]
92
+
93
+ def decode(self, ids):
94
+ if isinstance(ids, list) and len(ids) > 0 and isinstance(ids[0], list):
95
+ raise TypeError("decode() expects 1D list of IDs, not 2D.")
96
+ # Filter special tokens: 0=MASK 1=PAD 2=UNK 3=CLS 4=SEP
97
+ clean = [i for i in ids if isinstance(i, int) and i > 4]
98
+ if not clean:
99
+ return ""
100
+ return self.tokenizer.decode(clean, skip_special_tokens=True).strip()
101
+
102
+ def build_inputs_with_special_tokens(self, token_ids):
103
+ return list(token_ids)
104
+
105
+ def get_vocab(self):
106
+ return {str(i): i for i in range(self.vocab_size)}
107
+
108
+ def convert_ids_to_tokens(self, ids):
109
+ return [str(i) for i in ids]
110
+
111
+ def __len__(self):
112
+ return self.vocab_size
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=2.2
2
+ numpy>=1.24
3
+ tqdm>=4.66
4
+ datasets>=2.19
5
+ tokenizers>=0.15
6
+ scikit-learn>=1.3
sanskrit_src_tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
sanskrit_tgt_tokenizer.json ADDED
The diff for this file is too large to render. See raw diff