Add files using upload-large-folder tool
Browse files- .gitattributes +2 -34
- README.md +65 -0
- best_model.pt +3 -0
- config.py +33 -0
- diffusion/__init__.py +0 -0
- diffusion/__pycache__/__init__.cpython-311.pyc +0 -0
- diffusion/__pycache__/__init__.cpython-312.pyc +0 -0
- diffusion/__pycache__/forward_process.cpython-311.pyc +0 -0
- diffusion/__pycache__/forward_process.cpython-312.pyc +0 -0
- diffusion/__pycache__/reverse_process.cpython-311.pyc +0 -0
- diffusion/__pycache__/reverse_process1.cpython-311.pyc +0 -0
- diffusion/__pycache__/reverse_process2.cpython-311.pyc +0 -0
- diffusion/__pycache__/scheduler.cpython-311.pyc +0 -0
- diffusion/__pycache__/scheduler.cpython-312.pyc +0 -0
- diffusion/forward_process.py +21 -0
- diffusion/reverse_process.py +302 -0
- diffusion/reverse_process1.py +154 -0
- diffusion/reverse_process2.py +275 -0
- diffusion/scheduler.py +34 -0
- handler.py +30 -0
- inference.py +122 -0
- inference_api.py +103 -0
- model/__init__.py +0 -0
- model/__pycache__/__init__.cpython-311.pyc +0 -0
- model/__pycache__/__init__.cpython-312.pyc +0 -0
- model/__pycache__/d3pm_model_cross_attention.cpython-311.pyc +0 -0
- model/__pycache__/d3pm_model_cross_attention.cpython-312.pyc +0 -0
- model/__pycache__/d3pm_model_encoder_decoder.cpython-311.pyc +0 -0
- model/__pycache__/sanskrit_model.cpython-311.pyc +0 -0
- model/__pycache__/sanskrit_model.cpython-312.pyc +0 -0
- model/__pycache__/tokenizer.cpython-311.pyc +0 -0
- model/__pycache__/tokenizer.cpython-312.pyc +0 -0
- model/__pycache__/tokenizers.cpython-311.pyc +0 -0
- model/d3pm_model_cross_attention.py +271 -0
- model/d3pm_model_encoder_decoder.py +227 -0
- model/sanskrit_model.py +61 -0
- model/tokenizer.py +222 -0
- model/tokenizers.py +112 -0
- requirements.txt +6 -0
- sanskrit_src_tokenizer.json +0 -0
- sanskrit_tgt_tokenizer.json +0 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,3 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.
|
| 24 |
-
*.
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
language:
|
| 4 |
+
- sa
|
| 5 |
+
- en
|
| 6 |
+
tags:
|
| 7 |
+
- sanskrit
|
| 8 |
+
- paraphrase
|
| 9 |
+
- diffusion
|
| 10 |
+
- d3pm
|
| 11 |
+
- pytorch
|
| 12 |
+
pipeline_tag: text2text-generation
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
# Sanskrit D3PM Paraphrase Model
|
| 16 |
+
|
| 17 |
+
Roman/IAST Sanskrit input to Devanagari output using a D3PM cross-attention model.
|
| 18 |
+
|
| 19 |
+
## Files Included
|
| 20 |
+
|
| 21 |
+
- `best_model.pt` — trained checkpoint
|
| 22 |
+
- `config.py` — runtime config
|
| 23 |
+
- `inference.py` — model loading + generation loop
|
| 24 |
+
- `inference_api.py` — simple Python API (`predict`)
|
| 25 |
+
- `handler.py` — Hugging Face Endpoint handler
|
| 26 |
+
- `model/`, `diffusion/` — architecture modules
|
| 27 |
+
- `sanskrit_src_tokenizer.json`, `sanskrit_tgt_tokenizer.json` — tokenizers
|
| 28 |
+
|
| 29 |
+
## Quick Local Test
|
| 30 |
+
|
| 31 |
+
```python
|
| 32 |
+
from inference_api import predict
|
| 33 |
+
print(predict("dharmo rakṣati rakṣitaḥ")["output"])
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
## Endpoint Payload
|
| 37 |
+
|
| 38 |
+
```json
|
| 39 |
+
{
|
| 40 |
+
"inputs": "yadā mano nivarteta viṣayebhyaḥ svabhāvataḥ",
|
| 41 |
+
"parameters": {
|
| 42 |
+
"temperature": 0.7,
|
| 43 |
+
"top_k": 40,
|
| 44 |
+
"repetition_penalty": 1.2,
|
| 45 |
+
"diversity_penalty": 0.0,
|
| 46 |
+
"num_steps": 64,
|
| 47 |
+
"clean_output": true
|
| 48 |
+
}
|
| 49 |
+
}
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
## Push This Folder To Model Hub
|
| 53 |
+
|
| 54 |
+
```bash
|
| 55 |
+
huggingface-cli login
|
| 56 |
+
huggingface-cli repo create <your-username>/sanskrit-d3pm --type model
|
| 57 |
+
cd hf_model_repo
|
| 58 |
+
git init
|
| 59 |
+
git lfs install
|
| 60 |
+
git remote add origin https://huggingface.co/<your-username>/sanskrit-d3pm
|
| 61 |
+
git add .
|
| 62 |
+
git commit -m "Initial model release"
|
| 63 |
+
git push -u origin main
|
| 64 |
+
```
|
| 65 |
+
|
best_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:280b944be1ed396c93f64deef18b07d258b5dd1c74d59284342864a532c95f8b
|
| 3 |
+
size 1077681643
|
config.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
CONFIG = {
|
| 4 |
+
"model_type": "d3pm_cross_attention",
|
| 5 |
+
"data": {
|
| 6 |
+
"include_negative_examples": True,
|
| 7 |
+
"dataset_size": 60000,
|
| 8 |
+
},
|
| 9 |
+
"diffusion": {
|
| 10 |
+
"mask_token_id": 0,
|
| 11 |
+
},
|
| 12 |
+
"model": {
|
| 13 |
+
"src_vocab_size": 16000,
|
| 14 |
+
"tgt_vocab_size": 16000,
|
| 15 |
+
"d_model": 384,
|
| 16 |
+
"n_heads": 8,
|
| 17 |
+
"d_ff": 1536,
|
| 18 |
+
"n_layers": 6,
|
| 19 |
+
"dropout": 0.1,
|
| 20 |
+
"max_seq_len": 80,
|
| 21 |
+
"diffusion_steps": 64,
|
| 22 |
+
},
|
| 23 |
+
"training": {
|
| 24 |
+
"device": "cuda" if torch.cuda.is_available() else "cpu",
|
| 25 |
+
},
|
| 26 |
+
"inference": {
|
| 27 |
+
"num_steps": 64,
|
| 28 |
+
"temperature": 0.7,
|
| 29 |
+
"top_k": 40,
|
| 30 |
+
"repetition_penalty": 1.2,
|
| 31 |
+
"diversity_penalty": 0.0,
|
| 32 |
+
},
|
| 33 |
+
}
|
diffusion/__init__.py
ADDED
|
File without changes
|
diffusion/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (189 Bytes). View file
|
|
|
diffusion/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (177 Bytes). View file
|
|
|
diffusion/__pycache__/forward_process.cpython-311.pyc
ADDED
|
Binary file (1.75 kB). View file
|
|
|
diffusion/__pycache__/forward_process.cpython-312.pyc
ADDED
|
Binary file (1.66 kB). View file
|
|
|
diffusion/__pycache__/reverse_process.cpython-311.pyc
ADDED
|
Binary file (11.2 kB). View file
|
|
|
diffusion/__pycache__/reverse_process1.cpython-311.pyc
ADDED
|
Binary file (5.37 kB). View file
|
|
|
diffusion/__pycache__/reverse_process2.cpython-311.pyc
ADDED
|
Binary file (12 kB). View file
|
|
|
diffusion/__pycache__/scheduler.cpython-311.pyc
ADDED
|
Binary file (2.93 kB). View file
|
|
|
diffusion/__pycache__/scheduler.cpython-312.pyc
ADDED
|
Binary file (2.75 kB). View file
|
|
|
diffusion/forward_process.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
forward_process.py — Verified Correct (no changes needed)
|
| 3 |
+
===========================================================
|
| 4 |
+
Absorbing (mask) diffusion. PAD never masked. At t=0 alpha=1.0 exactly
|
| 5 |
+
so x_t == x_0 (nothing masked). Works correctly with the fixed scheduler.
|
| 6 |
+
"""
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
class AbsorbingForwardProcess:
|
| 10 |
+
def __init__(self, scheduler, mask_id=0, pad_id=1):
|
| 11 |
+
self.scheduler = scheduler
|
| 12 |
+
self.mask_id = mask_id
|
| 13 |
+
self.pad_id = pad_id
|
| 14 |
+
|
| 15 |
+
def q_sample(self, x_0, t):
|
| 16 |
+
alpha_t = self.scheduler.get_alpha(t).to(x_0.device).view(-1, 1)
|
| 17 |
+
r = torch.rand(x_0.shape, device=x_0.device)
|
| 18 |
+
x_t = x_0.clone()
|
| 19 |
+
x_t[r > alpha_t] = self.mask_id
|
| 20 |
+
x_t[x_0 == self.pad_id] = self.pad_id # PAD stays PAD always
|
| 21 |
+
return x_0, x_t
|
diffusion/reverse_process.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
reverse_process.py — Fixed
|
| 3 |
+
===========================
|
| 4 |
+
Two bugs fixed from the original:
|
| 5 |
+
|
| 6 |
+
BUG 1 (critical): generate_beam() passed x_t (noisy) as `tgt` to model.
|
| 7 |
+
The model does q_sample(tgt, t) internally — so x_t got double-noised.
|
| 8 |
+
Fix: pass x0_estimate (current clean guess) as tgt. Model noises it correctly.
|
| 9 |
+
|
| 10 |
+
BUG 2: apply_diversity_penalty used logits.var(dim=-1) — this adds the
|
| 11 |
+
variance of each position's own distribution back to itself, which is
|
| 12 |
+
mathematically meaningless and just injects noise.
|
| 13 |
+
Fix: penalize tokens that are uniformly high-probability across ALL positions
|
| 14 |
+
(global common tokens). This genuinely promotes diversity.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ReverseDiffusion:
|
| 22 |
+
def __init__(self, scheduler):
|
| 23 |
+
self.scheduler = scheduler
|
| 24 |
+
|
| 25 |
+
def p_sample_step(
|
| 26 |
+
self,
|
| 27 |
+
model,
|
| 28 |
+
x_t,
|
| 29 |
+
t,
|
| 30 |
+
condition,
|
| 31 |
+
beam_width=3,
|
| 32 |
+
temperature=1.0,
|
| 33 |
+
repetition_penalty=1.2,
|
| 34 |
+
diversity_penalty=0.3
|
| 35 |
+
):
|
| 36 |
+
"""
|
| 37 |
+
Single reverse step with temperature + penalties.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
|
| 42 |
+
# ---- Shape safety ----
|
| 43 |
+
if x_t.dim() == 1:
|
| 44 |
+
x_t = x_t.unsqueeze(0)
|
| 45 |
+
|
| 46 |
+
if condition.dim() == 1:
|
| 47 |
+
condition = condition.unsqueeze(0)
|
| 48 |
+
|
| 49 |
+
if t.dim() == 0:
|
| 50 |
+
t = t.unsqueeze(0)
|
| 51 |
+
|
| 52 |
+
if t.shape[0] != x_t.shape[0]:
|
| 53 |
+
t = t.expand(x_t.shape[0])
|
| 54 |
+
|
| 55 |
+
# ---- Model forward ----
|
| 56 |
+
logits, _ = model(condition, x_t, t)
|
| 57 |
+
|
| 58 |
+
# ---- Temperature scaling ----
|
| 59 |
+
logits = logits / temperature
|
| 60 |
+
|
| 61 |
+
# ---- Repetition penalty (FIXED VERSION) ----
|
| 62 |
+
if repetition_penalty != 1.0:
|
| 63 |
+
logits = apply_repetition_penalty(
|
| 64 |
+
logits, x_t, repetition_penalty
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# ---- Diversity penalty ----
|
| 68 |
+
if diversity_penalty > 0:
|
| 69 |
+
logits = apply_diversity_penalty(
|
| 70 |
+
logits, diversity_penalty
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
probs = F.softmax(logits, dim=-1)
|
| 74 |
+
|
| 75 |
+
B, L, V = probs.shape
|
| 76 |
+
|
| 77 |
+
# ---- Top-k beam expansion ----
|
| 78 |
+
topk_probs, topk_ids = torch.topk(
|
| 79 |
+
probs, beam_width, dim=-1
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
candidates = []
|
| 83 |
+
|
| 84 |
+
for k in range(beam_width):
|
| 85 |
+
next_tokens = topk_ids[:, :, k]
|
| 86 |
+
score = torch.log(
|
| 87 |
+
topk_probs[:, :, k] + 1e-9
|
| 88 |
+
).sum()
|
| 89 |
+
candidates.append((next_tokens, score))
|
| 90 |
+
|
| 91 |
+
return candidates
|
| 92 |
+
|
| 93 |
+
def generate_beam(
|
| 94 |
+
self,
|
| 95 |
+
model,
|
| 96 |
+
condition,
|
| 97 |
+
beam_width=3,
|
| 98 |
+
num_steps=None,
|
| 99 |
+
temperature=1.0,
|
| 100 |
+
repetition_penalty=1.2,
|
| 101 |
+
diversity_penalty=0.3
|
| 102 |
+
):
|
| 103 |
+
"""
|
| 104 |
+
Beam-search reverse diffusion with temperature.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
if num_steps is None:
|
| 108 |
+
num_steps = self.scheduler.num_timesteps
|
| 109 |
+
|
| 110 |
+
device = condition.device
|
| 111 |
+
|
| 112 |
+
if condition.dim() == 1:
|
| 113 |
+
condition = condition.unsqueeze(0)
|
| 114 |
+
|
| 115 |
+
B, L = condition.shape
|
| 116 |
+
|
| 117 |
+
# 🔥 Better initialization: start from MASK
|
| 118 |
+
x_init = torch.full(
|
| 119 |
+
(B, L),
|
| 120 |
+
fill_value=model.mask_token_id,
|
| 121 |
+
dtype=torch.long,
|
| 122 |
+
device=device
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
beams = [(x_init, 0.0)]
|
| 126 |
+
|
| 127 |
+
for step in reversed(range(num_steps)):
|
| 128 |
+
|
| 129 |
+
new_beams = []
|
| 130 |
+
|
| 131 |
+
for x_t, score in beams:
|
| 132 |
+
|
| 133 |
+
t_tensor = torch.full(
|
| 134 |
+
(B,),
|
| 135 |
+
step,
|
| 136 |
+
dtype=torch.long,
|
| 137 |
+
device=device
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
candidates = self.p_sample_step(
|
| 141 |
+
model,
|
| 142 |
+
x_t,
|
| 143 |
+
t_tensor,
|
| 144 |
+
condition,
|
| 145 |
+
beam_width,
|
| 146 |
+
temperature,
|
| 147 |
+
repetition_penalty,
|
| 148 |
+
diversity_penalty
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
for tokens, new_score in candidates:
|
| 152 |
+
new_beams.append(
|
| 153 |
+
(tokens, score + new_score)
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# ---- Keep top beams ----
|
| 157 |
+
new_beams = sorted(
|
| 158 |
+
new_beams,
|
| 159 |
+
key=lambda x: x[1],
|
| 160 |
+
reverse=True
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
beams = new_beams[:beam_width]
|
| 164 |
+
|
| 165 |
+
best_tokens, best_score = beams[0]
|
| 166 |
+
return best_tokens
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def generate(
|
| 171 |
+
self,
|
| 172 |
+
model,
|
| 173 |
+
condition,
|
| 174 |
+
num_steps=None,
|
| 175 |
+
temperature=0.8,
|
| 176 |
+
top_k=50,
|
| 177 |
+
repetition_penalty=1.2,
|
| 178 |
+
diversity_penalty=0.0,
|
| 179 |
+
):
|
| 180 |
+
"""
|
| 181 |
+
Correct D3PM iterative refinement.
|
| 182 |
+
|
| 183 |
+
x0_est starts as all [MASK].
|
| 184 |
+
Each step: forward(src=condition, tgt=x0_est, t)
|
| 185 |
+
→ model applies q_sample(x0_est, t) internally
|
| 186 |
+
→ predicts cleaner x0
|
| 187 |
+
→ x0_est updated
|
| 188 |
+
|
| 189 |
+
diversity_penalty: reduces probability of tokens that are
|
| 190 |
+
globally dominant across all sequence positions (not logits.var()).
|
| 191 |
+
"""
|
| 192 |
+
if num_steps is None:
|
| 193 |
+
num_steps = self.scheduler.num_timesteps
|
| 194 |
+
|
| 195 |
+
device = condition.device
|
| 196 |
+
if condition.dim() == 1:
|
| 197 |
+
condition = condition.unsqueeze(0)
|
| 198 |
+
B, L = condition.shape
|
| 199 |
+
|
| 200 |
+
T = self.scheduler.num_timesteps
|
| 201 |
+
step_size = max(1, T // num_steps)
|
| 202 |
+
timesteps = list(range(T - 1, -1, -step_size))
|
| 203 |
+
if timesteps[-1] != 0:
|
| 204 |
+
timesteps.append(0)
|
| 205 |
+
|
| 206 |
+
mask_id = model.mask_token_id
|
| 207 |
+
# Start: know nothing → all MASK is our initial clean estimate
|
| 208 |
+
x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
|
| 209 |
+
hint = None
|
| 210 |
+
|
| 211 |
+
model.eval()
|
| 212 |
+
with torch.no_grad():
|
| 213 |
+
for step_idx, t_val in enumerate(timesteps):
|
| 214 |
+
t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 215 |
+
is_last = (step_idx == len(timesteps) - 1)
|
| 216 |
+
|
| 217 |
+
# KEY: pass x0_est as tgt — model noises it internally
|
| 218 |
+
import inspect
|
| 219 |
+
sig = inspect.signature(model.forward).parameters
|
| 220 |
+
if 'x0_hint' in sig:
|
| 221 |
+
outputs = model(condition, x0_est, t, x0_hint=hint)
|
| 222 |
+
else:
|
| 223 |
+
outputs = model(condition, x0_est, t)
|
| 224 |
+
|
| 225 |
+
logits = outputs[0] if isinstance(outputs, tuple) else outputs
|
| 226 |
+
|
| 227 |
+
# Repetition penalty: down-weight tokens already in sequence
|
| 228 |
+
if repetition_penalty != 1.0:
|
| 229 |
+
logits = apply_repetition_penalty(logits, x0_est, repetition_penalty)
|
| 230 |
+
|
| 231 |
+
# Diversity penalty: reduce globally dominant tokens
|
| 232 |
+
if diversity_penalty > 0.0:
|
| 233 |
+
logits = apply_diversity_penalty(logits, diversity_penalty)
|
| 234 |
+
|
| 235 |
+
# Temperature + top-k
|
| 236 |
+
logits = logits / max(temperature, 1e-5)
|
| 237 |
+
if top_k > 0:
|
| 238 |
+
logits = top_k_filter(logits, top_k)
|
| 239 |
+
|
| 240 |
+
probs = F.softmax(logits, dim=-1)
|
| 241 |
+
|
| 242 |
+
if is_last:
|
| 243 |
+
x0_est = torch.argmax(probs, dim=-1)
|
| 244 |
+
else:
|
| 245 |
+
x0_est = batch_multinomial(probs)
|
| 246 |
+
|
| 247 |
+
hint = x0_est
|
| 248 |
+
|
| 249 |
+
return x0_est
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
# ── Penalty functions ─────────────────────────────────────────────────
|
| 253 |
+
|
| 254 |
+
def apply_repetition_penalty(logits, prev_tokens, penalty=1.2):
|
| 255 |
+
"""
|
| 256 |
+
Down-weight tokens that already appear in the current sequence.
|
| 257 |
+
Prevents मनो मनो मनो repetition loops.
|
| 258 |
+
penalty=1.0 → no effect
|
| 259 |
+
penalty=1.2 → mild suppression of repeated tokens
|
| 260 |
+
penalty=2.0 → strong suppression
|
| 261 |
+
"""
|
| 262 |
+
B, L, V = logits.shape
|
| 263 |
+
for b in range(B):
|
| 264 |
+
for token_id in set(prev_tokens[b].tolist()):
|
| 265 |
+
if token_id > 4: # don't penalize special tokens
|
| 266 |
+
logits[b, :, token_id] = logits[b, :, token_id] / penalty
|
| 267 |
+
return logits
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def apply_diversity_penalty(logits, penalty=0.5):
|
| 271 |
+
"""
|
| 272 |
+
Correct diversity penalty: penalize tokens that are globally dominant
|
| 273 |
+
across ALL sequence positions. This forces the model to use less
|
| 274 |
+
common tokens, increasing output diversity.
|
| 275 |
+
|
| 276 |
+
Method: compute mean probability across positions, subtract penalty
|
| 277 |
+
times that mean. Tokens uniformly high everywhere get suppressed.
|
| 278 |
+
|
| 279 |
+
penalty=0.0 → no diversity enforcement
|
| 280 |
+
penalty=0.5 → moderate diversity
|
| 281 |
+
penalty=1.0 → strong diversity (may hurt coherence)
|
| 282 |
+
"""
|
| 283 |
+
# Mean logit across all positions: [B, V]
|
| 284 |
+
global_mean = logits.mean(dim=1, keepdim=True) # [B, 1, V]
|
| 285 |
+
# Subtract scaled global mean — suppresses globally common tokens
|
| 286 |
+
return logits - penalty * global_mean
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def top_k_filter(logits, k):
|
| 290 |
+
B, L, V = logits.shape
|
| 291 |
+
if k >= V:
|
| 292 |
+
return logits
|
| 293 |
+
topk_vals, _ = torch.topk(logits, k, dim=-1)
|
| 294 |
+
threshold = topk_vals[..., -1].unsqueeze(-1)
|
| 295 |
+
return logits.masked_fill(logits < threshold, float('-inf'))
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def batch_multinomial(probs):
|
| 299 |
+
B, L, V = probs.shape
|
| 300 |
+
flat = probs.view(B * L, V) + 1e-9
|
| 301 |
+
flat = flat / flat.sum(dim=-1, keepdim=True)
|
| 302 |
+
return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
|
diffusion/reverse_process1.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ReverseDiffusion:
|
| 6 |
+
"""
|
| 7 |
+
Stable reverse diffusion with:
|
| 8 |
+
- Beam search
|
| 9 |
+
- Self conditioning
|
| 10 |
+
- Temperature sampling
|
| 11 |
+
- Repetition penalty
|
| 12 |
+
- Diversity penalty
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, scheduler):
|
| 16 |
+
|
| 17 |
+
self.scheduler = scheduler
|
| 18 |
+
|
| 19 |
+
self.temperature = 0.75
|
| 20 |
+
self.repetition_penalty = 1.15
|
| 21 |
+
self.diversity_penalty = 0.0
|
| 22 |
+
self.length_penalty = 1.0
|
| 23 |
+
|
| 24 |
+
# ------------------------------------------------
|
| 25 |
+
# penalties
|
| 26 |
+
# ------------------------------------------------
|
| 27 |
+
|
| 28 |
+
def apply_repetition_penalty(self, logits, tokens):
|
| 29 |
+
|
| 30 |
+
B, L, V = logits.shape
|
| 31 |
+
|
| 32 |
+
for b in range(B):
|
| 33 |
+
|
| 34 |
+
used = set(tokens[b].tolist())
|
| 35 |
+
|
| 36 |
+
for token_id in used:
|
| 37 |
+
logits[b, :, token_id] /= self.repetition_penalty
|
| 38 |
+
|
| 39 |
+
return logits
|
| 40 |
+
|
| 41 |
+
def apply_diversity_penalty(self, logits):
|
| 42 |
+
|
| 43 |
+
if self.diversity_penalty == 0:
|
| 44 |
+
return logits
|
| 45 |
+
|
| 46 |
+
logits_var = logits.var(dim=-1, keepdim=True)
|
| 47 |
+
return logits + self.diversity_penalty * logits_var
|
| 48 |
+
|
| 49 |
+
# ------------------------------------------------
|
| 50 |
+
# single reverse step
|
| 51 |
+
# ------------------------------------------------
|
| 52 |
+
|
| 53 |
+
def p_sample_step(self, model, x_t, t, condition, self_cond=None, beam_width=3):
|
| 54 |
+
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
|
| 57 |
+
logits, hidden = model(condition, x_t, t, self_cond)
|
| 58 |
+
|
| 59 |
+
logits = logits / self.temperature
|
| 60 |
+
|
| 61 |
+
logits = self.apply_repetition_penalty(logits, x_t)
|
| 62 |
+
logits = self.apply_diversity_penalty(logits)
|
| 63 |
+
|
| 64 |
+
probs = F.softmax(logits, dim=-1)
|
| 65 |
+
|
| 66 |
+
B, L, V = probs.shape
|
| 67 |
+
|
| 68 |
+
topk_probs, topk_ids = torch.topk(probs, beam_width, dim=-1)
|
| 69 |
+
|
| 70 |
+
candidates = []
|
| 71 |
+
|
| 72 |
+
for k in range(beam_width):
|
| 73 |
+
|
| 74 |
+
tokens = topk_ids[:, :, k]
|
| 75 |
+
|
| 76 |
+
score = torch.log(topk_probs[:, :, k] + 1e-9).sum()
|
| 77 |
+
|
| 78 |
+
candidates.append((tokens, score))
|
| 79 |
+
|
| 80 |
+
return candidates
|
| 81 |
+
|
| 82 |
+
# ------------------------------------------------
|
| 83 |
+
# beam reverse diffusion
|
| 84 |
+
# ------------------------------------------------
|
| 85 |
+
|
| 86 |
+
def generate_beam(self, model, condition, beam_width=3, num_steps=None):
|
| 87 |
+
|
| 88 |
+
if num_steps is None:
|
| 89 |
+
num_steps = self.scheduler.num_timesteps
|
| 90 |
+
|
| 91 |
+
device = condition.device
|
| 92 |
+
|
| 93 |
+
if condition.dim() == 1:
|
| 94 |
+
condition = condition.unsqueeze(0)
|
| 95 |
+
|
| 96 |
+
B, L = condition.shape
|
| 97 |
+
|
| 98 |
+
# ------------------------------------------------
|
| 99 |
+
# BETTER LATENT INITIALIZATION
|
| 100 |
+
# ------------------------------------------------
|
| 101 |
+
|
| 102 |
+
x_init = condition.clone()
|
| 103 |
+
|
| 104 |
+
mask = torch.rand_like(x_init.float()) < 0.5
|
| 105 |
+
x_init[mask] = model.mask_token_id
|
| 106 |
+
|
| 107 |
+
beams = [(x_init, 0.0)]
|
| 108 |
+
|
| 109 |
+
self_cond = None
|
| 110 |
+
|
| 111 |
+
for step in reversed(range(num_steps)):
|
| 112 |
+
|
| 113 |
+
new_beams = []
|
| 114 |
+
|
| 115 |
+
for x_t, score in beams:
|
| 116 |
+
|
| 117 |
+
t_tensor = torch.full(
|
| 118 |
+
(B,),
|
| 119 |
+
step,
|
| 120 |
+
dtype=torch.long,
|
| 121 |
+
device=device
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
candidates = self.p_sample_step(
|
| 125 |
+
model,
|
| 126 |
+
x_t,
|
| 127 |
+
t_tensor,
|
| 128 |
+
condition,
|
| 129 |
+
self_cond,
|
| 130 |
+
beam_width
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
for tokens, new_score in candidates:
|
| 134 |
+
|
| 135 |
+
length_norm = tokens.shape[1] ** self.length_penalty
|
| 136 |
+
|
| 137 |
+
final_score = (score + new_score) / length_norm
|
| 138 |
+
|
| 139 |
+
new_beams.append((tokens, final_score))
|
| 140 |
+
|
| 141 |
+
new_beams = sorted(
|
| 142 |
+
new_beams,
|
| 143 |
+
key=lambda x: x[1],
|
| 144 |
+
reverse=True
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
beams = new_beams[:beam_width]
|
| 148 |
+
|
| 149 |
+
# self conditioning
|
| 150 |
+
self_cond = beams[0][0]
|
| 151 |
+
|
| 152 |
+
best_tokens, best_score = beams[0]
|
| 153 |
+
|
| 154 |
+
return best_tokens
|
diffusion/reverse_process2.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
reverse_process.py — Final Correct Version
|
| 3 |
+
=============================================
|
| 4 |
+
|
| 5 |
+
KEY PRINCIPLE: generate() must be byte-for-byte identical to run_inference()
|
| 6 |
+
in inference.py, which is what produced BERTScore 0.75 at validation.
|
| 7 |
+
|
| 8 |
+
CRITICAL BUG IN PREVIOUS VERSION:
|
| 9 |
+
We passed inference_mode=True to the model, but the model was NEVER
|
| 10 |
+
called with inference_mode=True during training or validation.
|
| 11 |
+
run_inference() (the validated path) does:
|
| 12 |
+
model(input_ids, x0_est, t, x0_hint=hint)
|
| 13 |
+
→ inference_mode defaults to False.
|
| 14 |
+
|
| 15 |
+
With inference_mode=True the model does two things differently:
|
| 16 |
+
1. tgt_pad_mask = None (training used tgt_pad_mask = tgt==PAD)
|
| 17 |
+
2. Skips q_sample at t=0 (training always called q_sample)
|
| 18 |
+
The model was never trained to handle these conditions → garbage output.
|
| 19 |
+
|
| 20 |
+
Fix: do NOT pass inference_mode. Let it default to False, exactly
|
| 21 |
+
as run_inference() did.
|
| 22 |
+
|
| 23 |
+
BUGS FIXED (vs original reverse_process.py)
|
| 24 |
+
--------------------------------------------
|
| 25 |
+
BUG 1 generate_beam() used for D3PM → all-Ṛ repetition.
|
| 26 |
+
Use generate() (iterative refinement) from app1.py instead.
|
| 27 |
+
BUG 2 apply_diversity_penalty used logits.var() → noise injection.
|
| 28 |
+
Fixed to logits - penalty * logits.mean(dim=1) — global suppression.
|
| 29 |
+
BUG 3 x0_hint (self-conditioning) never passed to model.
|
| 30 |
+
Fixed: generate() passes x0_hint=hint every step.
|
| 31 |
+
BUG 4 params not forwarded from generate_beam() to p_sample_step().
|
| 32 |
+
Fixed in generate_beam() (kept for reference, not for production use).
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
import torch
|
| 36 |
+
import torch.nn.functional as F
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ReverseDiffusion:
|
| 40 |
+
|
| 41 |
+
def __init__(self, scheduler):
|
| 42 |
+
self.scheduler = scheduler
|
| 43 |
+
|
| 44 |
+
# Attribute-style defaults for backward compat with any code
|
| 45 |
+
# that sets reverse_diffusion.temperature = 0.9 etc.
|
| 46 |
+
# generate() prefers explicit kwargs and falls back to these.
|
| 47 |
+
self.temperature = 0.75
|
| 48 |
+
self.repetition_penalty = 1.15
|
| 49 |
+
self.diversity_penalty = 0.0
|
| 50 |
+
self.top_k = 50
|
| 51 |
+
|
| 52 |
+
# ------------------------------------------------------------------ #
|
| 53 |
+
# generate — CORRECT D3PM iterative refinement #
|
| 54 |
+
# Exact equivalent of run_inference() in inference.py #
|
| 55 |
+
# ------------------------------------------------------------------ #
|
| 56 |
+
def generate(
|
| 57 |
+
self,
|
| 58 |
+
model,
|
| 59 |
+
condition,
|
| 60 |
+
num_steps = None,
|
| 61 |
+
temperature = None,
|
| 62 |
+
top_k = None,
|
| 63 |
+
repetition_penalty = None,
|
| 64 |
+
diversity_penalty = None,
|
| 65 |
+
):
|
| 66 |
+
"""
|
| 67 |
+
D3PM iterative refinement — identical to run_inference() in inference.py,
|
| 68 |
+
which is the validated path (BERTScore 0.75).
|
| 69 |
+
|
| 70 |
+
Algorithm:
|
| 71 |
+
x0_est = all [MASK]
|
| 72 |
+
for t = T-1 down to 0:
|
| 73 |
+
logits = model(src, x0_est, t, x0_hint=hint)
|
| 74 |
+
↑ inference_mode NOT passed (defaults to False)
|
| 75 |
+
↑ this exactly matches training/validation
|
| 76 |
+
apply penalties, temperature, top_k
|
| 77 |
+
if t > 0: x0_est = multinomial(softmax(logits)) ← stochastic
|
| 78 |
+
if t = 0: x0_est = argmax(softmax(logits)) ← deterministic
|
| 79 |
+
hint = x0_est
|
| 80 |
+
"""
|
| 81 |
+
# Resolve: explicit kwarg > object attribute
|
| 82 |
+
temperature = temperature if temperature is not None else self.temperature
|
| 83 |
+
top_k = top_k if top_k is not None else self.top_k
|
| 84 |
+
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.repetition_penalty
|
| 85 |
+
diversity_penalty = diversity_penalty if diversity_penalty is not None else self.diversity_penalty
|
| 86 |
+
|
| 87 |
+
if num_steps is None:
|
| 88 |
+
num_steps = self.scheduler.num_timesteps
|
| 89 |
+
|
| 90 |
+
device = condition.device
|
| 91 |
+
if condition.dim() == 1:
|
| 92 |
+
condition = condition.unsqueeze(0)
|
| 93 |
+
B, L = condition.shape
|
| 94 |
+
|
| 95 |
+
T = self.scheduler.num_timesteps
|
| 96 |
+
step_size = max(1, T // num_steps)
|
| 97 |
+
timesteps = list(range(T - 1, -1, -step_size))
|
| 98 |
+
if timesteps[-1] != 0:
|
| 99 |
+
timesteps.append(0)
|
| 100 |
+
|
| 101 |
+
mask_id = model.mask_token_id
|
| 102 |
+
x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
|
| 103 |
+
hint = None
|
| 104 |
+
|
| 105 |
+
model.eval()
|
| 106 |
+
with torch.no_grad():
|
| 107 |
+
for step_idx, t_val in enumerate(timesteps):
|
| 108 |
+
t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 109 |
+
is_last = (step_idx == len(timesteps) - 1)
|
| 110 |
+
|
| 111 |
+
# ── CRITICAL: do NOT pass inference_mode ──────────────────
|
| 112 |
+
# inference_mode defaults to False inside SanskritModel /
|
| 113 |
+
# D3PMCrossAttention. This matches run_inference() exactly.
|
| 114 |
+
# Passing inference_mode=True changes tgt_pad_mask and
|
| 115 |
+
# q_sample behaviour — the model was never trained for that.
|
| 116 |
+
logits, _ = model(condition, x0_est, t, x0_hint=hint)
|
| 117 |
+
|
| 118 |
+
# Repetition penalty
|
| 119 |
+
if repetition_penalty != 1.0:
|
| 120 |
+
logits = apply_repetition_penalty(
|
| 121 |
+
logits, x0_est, repetition_penalty
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Diversity penalty (correct: global mean suppression)
|
| 125 |
+
if diversity_penalty > 0.0:
|
| 126 |
+
logits = apply_diversity_penalty(logits, diversity_penalty)
|
| 127 |
+
|
| 128 |
+
logits = logits / max(temperature, 1e-5)
|
| 129 |
+
|
| 130 |
+
if top_k > 0:
|
| 131 |
+
logits = top_k_filter(logits, top_k)
|
| 132 |
+
|
| 133 |
+
probs = F.softmax(logits, dim=-1)
|
| 134 |
+
|
| 135 |
+
# Stochastic at every step except the last (argmax at t=0)
|
| 136 |
+
if is_last:
|
| 137 |
+
x0_est = torch.argmax(probs, dim=-1)
|
| 138 |
+
else:
|
| 139 |
+
x0_est = batch_multinomial(probs)
|
| 140 |
+
|
| 141 |
+
hint = x0_est
|
| 142 |
+
|
| 143 |
+
return x0_est # (B, L)
|
| 144 |
+
|
| 145 |
+
# ------------------------------------------------------------------ #
|
| 146 |
+
# p_sample_step — used by generate_beam (not for production) #
|
| 147 |
+
# ------------------------------------------------------------------ #
|
| 148 |
+
def p_sample_step(
|
| 149 |
+
self,
|
| 150 |
+
model,
|
| 151 |
+
x_t,
|
| 152 |
+
t,
|
| 153 |
+
condition,
|
| 154 |
+
beam_width = 3,
|
| 155 |
+
temperature = 1.0,
|
| 156 |
+
repetition_penalty = 1.2,
|
| 157 |
+
diversity_penalty = 0.3,
|
| 158 |
+
x0_hint = None,
|
| 159 |
+
):
|
| 160 |
+
with torch.no_grad():
|
| 161 |
+
if x_t.dim() == 1: x_t = x_t.unsqueeze(0)
|
| 162 |
+
if condition.dim() == 1: condition = condition.unsqueeze(0)
|
| 163 |
+
if t.dim() == 0: t = t.unsqueeze(0)
|
| 164 |
+
if t.shape[0] != x_t.shape[0]:
|
| 165 |
+
t = t.expand(x_t.shape[0])
|
| 166 |
+
|
| 167 |
+
# No inference_mode — matches training convention
|
| 168 |
+
logits, _ = model(condition, x_t, t, x0_hint=x0_hint)
|
| 169 |
+
|
| 170 |
+
logits = logits / max(temperature, 1e-5)
|
| 171 |
+
|
| 172 |
+
if repetition_penalty != 1.0:
|
| 173 |
+
logits = apply_repetition_penalty(logits, x_t, repetition_penalty)
|
| 174 |
+
if diversity_penalty > 0.0:
|
| 175 |
+
logits = apply_diversity_penalty(logits, diversity_penalty)
|
| 176 |
+
|
| 177 |
+
probs = F.softmax(logits, dim=-1)
|
| 178 |
+
B, L, V = probs.shape
|
| 179 |
+
|
| 180 |
+
topk_probs, topk_ids = torch.topk(probs, beam_width, dim=-1)
|
| 181 |
+
candidates = []
|
| 182 |
+
for k in range(beam_width):
|
| 183 |
+
next_tokens = topk_ids[:, :, k]
|
| 184 |
+
score = torch.log(topk_probs[:, :, k] + 1e-9).sum()
|
| 185 |
+
candidates.append((next_tokens, score))
|
| 186 |
+
return candidates
|
| 187 |
+
|
| 188 |
+
# ------------------------------------------------------------------ #
|
| 189 |
+
# generate_beam — kept for reference; NOT the correct D3PM method #
|
| 190 |
+
# ------------------------------------------------------------------ #
|
| 191 |
+
def generate_beam(
|
| 192 |
+
self,
|
| 193 |
+
model,
|
| 194 |
+
condition,
|
| 195 |
+
beam_width = 3,
|
| 196 |
+
num_steps = None,
|
| 197 |
+
temperature = None,
|
| 198 |
+
repetition_penalty = None,
|
| 199 |
+
diversity_penalty = None,
|
| 200 |
+
):
|
| 201 |
+
"""
|
| 202 |
+
WARNING: do NOT call this from app1.py for D3PM generation.
|
| 203 |
+
generate_beam() forces every position to the same top-k token
|
| 204 |
+
→ all-Ṛ / all-rud repetition. Use generate() instead.
|
| 205 |
+
Kept only for experimental reference.
|
| 206 |
+
"""
|
| 207 |
+
temperature = temperature if temperature is not None else self.temperature
|
| 208 |
+
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.repetition_penalty
|
| 209 |
+
diversity_penalty = diversity_penalty if diversity_penalty is not None else self.diversity_penalty
|
| 210 |
+
if num_steps is None:
|
| 211 |
+
num_steps = self.scheduler.num_timesteps
|
| 212 |
+
|
| 213 |
+
device = condition.device
|
| 214 |
+
if condition.dim() == 1: condition = condition.unsqueeze(0)
|
| 215 |
+
B, L = condition.shape
|
| 216 |
+
|
| 217 |
+
x_init = torch.full((B, L), fill_value=model.mask_token_id,
|
| 218 |
+
dtype=torch.long, device=device)
|
| 219 |
+
beams = [(x_init, 0.0)]
|
| 220 |
+
best_hint = None
|
| 221 |
+
|
| 222 |
+
for step in reversed(range(num_steps)):
|
| 223 |
+
t_tensor = torch.full((B,), step, dtype=torch.long, device=device)
|
| 224 |
+
new_beams = []
|
| 225 |
+
for x_t, score in beams:
|
| 226 |
+
candidates = self.p_sample_step(
|
| 227 |
+
model, x_t, t_tensor, condition,
|
| 228 |
+
beam_width = beam_width,
|
| 229 |
+
temperature = temperature,
|
| 230 |
+
repetition_penalty = repetition_penalty,
|
| 231 |
+
diversity_penalty = diversity_penalty,
|
| 232 |
+
x0_hint = best_hint,
|
| 233 |
+
)
|
| 234 |
+
for tokens, new_score in candidates:
|
| 235 |
+
new_beams.append((tokens, score + new_score.item()))
|
| 236 |
+
|
| 237 |
+
new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True)
|
| 238 |
+
beams = new_beams[:beam_width]
|
| 239 |
+
best_hint = beams[0][0]
|
| 240 |
+
|
| 241 |
+
return beams[0][0] # (B, L)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
# ── Penalty helpers ────────────────────────────────────────────────────────
|
| 245 |
+
|
| 246 |
+
def apply_repetition_penalty(logits, prev_tokens, penalty=1.2):
|
| 247 |
+
"""Down-weight tokens already present in the sequence."""
|
| 248 |
+
for b in range(logits.shape[0]):
|
| 249 |
+
for token_id in set(prev_tokens[b].tolist()):
|
| 250 |
+
if token_id > 4:
|
| 251 |
+
logits[b, :, token_id] = logits[b, :, token_id] / penalty
|
| 252 |
+
return logits
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def apply_diversity_penalty(logits, penalty=0.3):
|
| 256 |
+
"""
|
| 257 |
+
Correct diversity penalty: suppress globally dominant tokens.
|
| 258 |
+
logits -= penalty * mean(logits, dim=1) [sequence dimension]
|
| 259 |
+
"""
|
| 260 |
+
global_mean = logits.mean(dim=1, keepdim=True) # [B, 1, V]
|
| 261 |
+
return logits - penalty * global_mean
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def top_k_filter(logits, k):
|
| 265 |
+
B, L, V = logits.shape
|
| 266 |
+
if k >= V: return logits
|
| 267 |
+
topk_vals, _ = torch.topk(logits, k, dim=-1)
|
| 268 |
+
return logits.masked_fill(logits < topk_vals[..., -1].unsqueeze(-1), float('-inf'))
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def batch_multinomial(probs):
|
| 272 |
+
B, L, V = probs.shape
|
| 273 |
+
flat = probs.view(B * L, V) + 1e-9
|
| 274 |
+
flat = flat / flat.sum(dim=-1, keepdim=True)
|
| 275 |
+
return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
|
diffusion/scheduler.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
scheduler.py — Fixed & Upgraded
|
| 3 |
+
==================================
|
| 4 |
+
Changes:
|
| 5 |
+
1. T=64 (was 16). More timesteps = richer denoising curriculum per epoch.
|
| 6 |
+
2. alpha at t=0 is EXACTLY 1.0 — fixes Bug 2 (final-step re-noise).
|
| 7 |
+
3. sample_timestep samples [0, T-1] including t=0, so model trains on
|
| 8 |
+
fully-clean inputs (learns the identity at t=0 explicitly).
|
| 9 |
+
"""
|
| 10 |
+
import torch, math
|
| 11 |
+
|
| 12 |
+
class OptimizedCosineScheduler:
|
| 13 |
+
def __init__(self, cfg, device=None):
|
| 14 |
+
self.num_timesteps = cfg['model']['diffusion_steps'] # 64
|
| 15 |
+
self.mask_token_id = cfg['diffusion']['mask_token_id']
|
| 16 |
+
self.device = device or torch.device('cpu')
|
| 17 |
+
self.alphas_cumprod = self._build_schedule().to(self.device)
|
| 18 |
+
|
| 19 |
+
def _build_schedule(self):
|
| 20 |
+
T = self.num_timesteps
|
| 21 |
+
t = torch.arange(T + 1, dtype=torch.float32)
|
| 22 |
+
f_t = torch.cos((t / T + 0.008) / 1.008 * math.pi / 2) ** 2
|
| 23 |
+
alphas_bar = f_t / f_t[0]
|
| 24 |
+
alphas_bar = alphas_bar[1:] # shape [T]
|
| 25 |
+
alphas_bar[0] = 1.0 # FIX: exact 1.0 at t=0
|
| 26 |
+
alphas_bar[-1] = alphas_bar[-1].clamp(max=0.001)
|
| 27 |
+
return alphas_bar
|
| 28 |
+
|
| 29 |
+
def sample_timestep(self, batch_size):
|
| 30 |
+
"""Uniform [0, T-1] — includes t=0 so model sees clean inputs."""
|
| 31 |
+
return torch.randint(0, self.num_timesteps, (batch_size,))
|
| 32 |
+
|
| 33 |
+
def get_alpha(self, t):
|
| 34 |
+
return self.alphas_cumprod[t.to(self.alphas_cumprod.device).long()]
|
handler.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict
|
| 2 |
+
|
| 3 |
+
from inference_api import predict
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class EndpointHandler:
|
| 7 |
+
"""
|
| 8 |
+
Hugging Face Inference Endpoint handler.
|
| 9 |
+
Expects payload:
|
| 10 |
+
{
|
| 11 |
+
"inputs": "dharmo rakṣati rakṣitaḥ",
|
| 12 |
+
"parameters": {"temperature": 0.7, ...}
|
| 13 |
+
}
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, path: str = ""):
|
| 17 |
+
self.path = path
|
| 18 |
+
|
| 19 |
+
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| 20 |
+
text = data.get("inputs", "")
|
| 21 |
+
params = data.get("parameters", {}) or {}
|
| 22 |
+
return predict(
|
| 23 |
+
text=text,
|
| 24 |
+
temperature=params.get("temperature", 0.7),
|
| 25 |
+
top_k=params.get("top_k", 40),
|
| 26 |
+
repetition_penalty=params.get("repetition_penalty", 1.2),
|
| 27 |
+
diversity_penalty=params.get("diversity_penalty", 0.0),
|
| 28 |
+
num_steps=params.get("num_steps", 64),
|
| 29 |
+
clean_output=params.get("clean_output", True),
|
| 30 |
+
)
|
inference.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from config import CONFIG
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _resolve_device(cfg: dict) -> torch.device:
|
| 10 |
+
requested = cfg["training"]["device"]
|
| 11 |
+
if requested == "cuda" and not torch.cuda.is_available():
|
| 12 |
+
requested = "cpu"
|
| 13 |
+
if requested == "mps" and not torch.backends.mps.is_available():
|
| 14 |
+
requested = "cpu"
|
| 15 |
+
cfg["training"]["device"] = requested
|
| 16 |
+
return torch.device(requested)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _build_tokenizers(cfg):
|
| 20 |
+
from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer
|
| 21 |
+
|
| 22 |
+
src_tok = SanskritSourceTokenizer(
|
| 23 |
+
vocab_size=cfg["model"].get("src_vocab_size", 16000),
|
| 24 |
+
max_len=cfg["model"]["max_seq_len"],
|
| 25 |
+
)
|
| 26 |
+
tgt_tok = SanskritTargetTokenizer(
|
| 27 |
+
vocab_size=cfg["model"].get("tgt_vocab_size", 16000),
|
| 28 |
+
max_len=cfg["model"]["max_seq_len"],
|
| 29 |
+
)
|
| 30 |
+
return src_tok, tgt_tok
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def load_model(ckpt_path: str, base_cfg: dict, device: torch.device):
|
| 34 |
+
from model.sanskrit_model import SanskritModel
|
| 35 |
+
|
| 36 |
+
cfg = copy.deepcopy(base_cfg)
|
| 37 |
+
state = torch.load(ckpt_path, map_location="cpu")
|
| 38 |
+
|
| 39 |
+
emb_key = "model.src_embed.token_emb.weight"
|
| 40 |
+
if emb_key in state:
|
| 41 |
+
vocab, d_model = state[emb_key].shape
|
| 42 |
+
cfg["model"]["src_vocab_size"] = vocab
|
| 43 |
+
cfg["model"]["d_model"] = d_model
|
| 44 |
+
cfg["model"]["d_ff"] = d_model * 4
|
| 45 |
+
|
| 46 |
+
layer_ids = {int(k.split(".")[2]) for k in state if k.startswith("model.encoder_blocks.")}
|
| 47 |
+
if layer_ids:
|
| 48 |
+
cfg["model"]["n_layers"] = max(layer_ids) + 1
|
| 49 |
+
|
| 50 |
+
pos_key = "model.src_embed.pos_enc.pe"
|
| 51 |
+
if pos_key in state:
|
| 52 |
+
cfg["model"]["max_seq_len"] = state[pos_key].shape[1]
|
| 53 |
+
|
| 54 |
+
d_model = cfg["model"]["d_model"]
|
| 55 |
+
n_heads = cfg["model"].get("n_heads", 8)
|
| 56 |
+
if d_model % n_heads != 0:
|
| 57 |
+
n_heads = next(h for h in [8, 6, 4, 2, 1] if d_model % h == 0)
|
| 58 |
+
cfg["model"]["n_heads"] = n_heads
|
| 59 |
+
|
| 60 |
+
model = SanskritModel(cfg).to(device)
|
| 61 |
+
model.load_state_dict(torch.load(ckpt_path, map_location=device), strict=False)
|
| 62 |
+
model.eval()
|
| 63 |
+
return model, cfg
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def run_inference(model, input_ids, cfg):
|
| 67 |
+
inf = cfg["inference"]
|
| 68 |
+
device = input_ids.device
|
| 69 |
+
bsz, seqlen = input_ids.shape
|
| 70 |
+
inner = model.model
|
| 71 |
+
|
| 72 |
+
total_steps = inner.scheduler.num_timesteps
|
| 73 |
+
steps = int(inf["num_steps"])
|
| 74 |
+
step_size = max(1, total_steps // max(steps, 1))
|
| 75 |
+
timesteps = list(range(total_steps - 1, -1, -step_size))
|
| 76 |
+
if timesteps[-1] != 0:
|
| 77 |
+
timesteps.append(0)
|
| 78 |
+
|
| 79 |
+
x0_est = torch.full((bsz, seqlen), inner.mask_token_id, dtype=torch.long, device=device)
|
| 80 |
+
hint = None
|
| 81 |
+
|
| 82 |
+
with torch.no_grad():
|
| 83 |
+
for i, t_val in enumerate(timesteps):
|
| 84 |
+
is_last = i == len(timesteps) - 1
|
| 85 |
+
t = torch.full((bsz,), t_val, dtype=torch.long, device=device)
|
| 86 |
+
|
| 87 |
+
logits, _ = model(input_ids, x0_est, t, x0_hint=hint, inference_mode=True)
|
| 88 |
+
|
| 89 |
+
if inf["repetition_penalty"] != 1.0:
|
| 90 |
+
from model.d3pm_model_cross_attention import _apply_repetition_penalty
|
| 91 |
+
|
| 92 |
+
logits = _apply_repetition_penalty(logits, x0_est, float(inf["repetition_penalty"]))
|
| 93 |
+
if inf["diversity_penalty"] > 0.0:
|
| 94 |
+
from model.d3pm_model_cross_attention import _apply_diversity_penalty_fixed
|
| 95 |
+
|
| 96 |
+
logits = _apply_diversity_penalty_fixed(logits, float(inf["diversity_penalty"]))
|
| 97 |
+
|
| 98 |
+
logits = logits / max(float(inf["temperature"]), 1e-5)
|
| 99 |
+
if int(inf["top_k"]) > 0:
|
| 100 |
+
from model.d3pm_model_cross_attention import _top_k_filter
|
| 101 |
+
|
| 102 |
+
logits = _top_k_filter(logits, int(inf["top_k"]))
|
| 103 |
+
|
| 104 |
+
probs = F.softmax(logits, dim=-1)
|
| 105 |
+
if is_last:
|
| 106 |
+
x0_est = torch.argmax(probs, dim=-1)
|
| 107 |
+
else:
|
| 108 |
+
from model.d3pm_model_cross_attention import _batch_multinomial
|
| 109 |
+
|
| 110 |
+
x0_est = _batch_multinomial(probs)
|
| 111 |
+
hint = x0_est
|
| 112 |
+
|
| 113 |
+
return x0_est
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
__all__ = [
|
| 117 |
+
"CONFIG",
|
| 118 |
+
"_resolve_device",
|
| 119 |
+
"_build_tokenizers",
|
| 120 |
+
"load_model",
|
| 121 |
+
"run_inference",
|
| 122 |
+
]
|
inference_api.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from typing import Dict, Any
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from config import CONFIG
|
| 7 |
+
from inference import _build_tokenizers, _resolve_device, load_model, run_inference
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
_STATE = {
|
| 11 |
+
"loaded": False,
|
| 12 |
+
"model": None,
|
| 13 |
+
"cfg": None,
|
| 14 |
+
"device": None,
|
| 15 |
+
"src_tok": None,
|
| 16 |
+
"tgt_tok": None,
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _load_once() -> None:
|
| 21 |
+
if _STATE["loaded"]:
|
| 22 |
+
return
|
| 23 |
+
|
| 24 |
+
cfg = copy.deepcopy(CONFIG)
|
| 25 |
+
cfg["model_type"] = "d3pm_cross_attention"
|
| 26 |
+
cfg["data"]["include_negative_examples"] = True
|
| 27 |
+
device = _resolve_device(cfg)
|
| 28 |
+
|
| 29 |
+
model, cfg = load_model("best_model.pt", cfg, device)
|
| 30 |
+
src_tok, tgt_tok = _build_tokenizers(cfg)
|
| 31 |
+
|
| 32 |
+
_STATE["model"] = model
|
| 33 |
+
_STATE["cfg"] = cfg
|
| 34 |
+
_STATE["device"] = device
|
| 35 |
+
_STATE["src_tok"] = src_tok
|
| 36 |
+
_STATE["tgt_tok"] = tgt_tok
|
| 37 |
+
_STATE["loaded"] = True
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _clean_text(text: str) -> str:
|
| 41 |
+
text = " ".join(text.split())
|
| 42 |
+
if not text:
|
| 43 |
+
return text
|
| 44 |
+
toks = text.split()
|
| 45 |
+
out = []
|
| 46 |
+
prev = None
|
| 47 |
+
run = 0
|
| 48 |
+
for tok in toks:
|
| 49 |
+
if tok == prev:
|
| 50 |
+
run += 1
|
| 51 |
+
else:
|
| 52 |
+
prev = tok
|
| 53 |
+
run = 1
|
| 54 |
+
if run <= 2:
|
| 55 |
+
out.append(tok)
|
| 56 |
+
s = " ".join(out)
|
| 57 |
+
s = s.replace(" ।", "।").replace(" ॥", "॥")
|
| 58 |
+
return " ".join(s.split())
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def predict(
|
| 62 |
+
text: str,
|
| 63 |
+
temperature: float = 0.7,
|
| 64 |
+
top_k: int = 40,
|
| 65 |
+
repetition_penalty: float = 1.2,
|
| 66 |
+
diversity_penalty: float = 0.0,
|
| 67 |
+
num_steps: int = 64,
|
| 68 |
+
clean_output: bool = True,
|
| 69 |
+
) -> Dict[str, Any]:
|
| 70 |
+
_load_once()
|
| 71 |
+
if not text or not text.strip():
|
| 72 |
+
return {"error": "empty input", "output": ""}
|
| 73 |
+
|
| 74 |
+
cfg = copy.deepcopy(_STATE["cfg"])
|
| 75 |
+
cfg["inference"]["temperature"] = float(temperature)
|
| 76 |
+
cfg["inference"]["top_k"] = int(top_k)
|
| 77 |
+
cfg["inference"]["repetition_penalty"] = float(repetition_penalty)
|
| 78 |
+
cfg["inference"]["diversity_penalty"] = float(diversity_penalty)
|
| 79 |
+
cfg["inference"]["num_steps"] = int(num_steps)
|
| 80 |
+
|
| 81 |
+
src_tok = _STATE["src_tok"]
|
| 82 |
+
tgt_tok = _STATE["tgt_tok"]
|
| 83 |
+
device = _STATE["device"]
|
| 84 |
+
|
| 85 |
+
input_ids = torch.tensor([src_tok.encode(text.strip())], dtype=torch.long, device=device)
|
| 86 |
+
out = run_inference(_STATE["model"], input_ids, cfg)
|
| 87 |
+
decoded_ids = [x for x in out[0].tolist() if x > 4]
|
| 88 |
+
raw = tgt_tok.decode(decoded_ids).strip()
|
| 89 |
+
output = _clean_text(raw) if clean_output else raw
|
| 90 |
+
|
| 91 |
+
return {
|
| 92 |
+
"input": text,
|
| 93 |
+
"output": output,
|
| 94 |
+
"raw_output": raw,
|
| 95 |
+
"config": {
|
| 96 |
+
"temperature": float(temperature),
|
| 97 |
+
"top_k": int(top_k),
|
| 98 |
+
"repetition_penalty": float(repetition_penalty),
|
| 99 |
+
"diversity_penalty": float(diversity_penalty),
|
| 100 |
+
"num_steps": int(num_steps),
|
| 101 |
+
"clean_output": bool(clean_output),
|
| 102 |
+
},
|
| 103 |
+
}
|
model/__init__.py
ADDED
|
File without changes
|
model/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (126 Bytes). View file
|
|
|
model/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (173 Bytes). View file
|
|
|
model/__pycache__/d3pm_model_cross_attention.cpython-311.pyc
ADDED
|
Binary file (30.7 kB). View file
|
|
|
model/__pycache__/d3pm_model_cross_attention.cpython-312.pyc
ADDED
|
Binary file (27.2 kB). View file
|
|
|
model/__pycache__/d3pm_model_encoder_decoder.cpython-311.pyc
ADDED
|
Binary file (15.5 kB). View file
|
|
|
model/__pycache__/sanskrit_model.cpython-311.pyc
ADDED
|
Binary file (5.67 kB). View file
|
|
|
model/__pycache__/sanskrit_model.cpython-312.pyc
ADDED
|
Binary file (5.26 kB). View file
|
|
|
model/__pycache__/tokenizer.cpython-311.pyc
ADDED
|
Binary file (15.3 kB). View file
|
|
|
model/__pycache__/tokenizer.cpython-312.pyc
ADDED
|
Binary file (12.9 kB). View file
|
|
|
model/__pycache__/tokenizers.cpython-311.pyc
ADDED
|
Binary file (7.94 kB). View file
|
|
|
model/d3pm_model_cross_attention.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
d3pm_model_cross_attention.py — Cross-Script + Generation-Fixed
|
| 3 |
+
=================================================================
|
| 4 |
+
INPUT : quote_text tokens (Roman script, src_vocab_size)
|
| 5 |
+
OUTPUT : quote_devanagari tokens (Devanagari script, tgt_vocab_size)
|
| 6 |
+
|
| 7 |
+
src_embed uses src_vocab_size (Roman BPE)
|
| 8 |
+
tgt_embed uses tgt_vocab_size (Devanagari BPE)
|
| 9 |
+
head outputs tgt_vocab_size (predict Devanagari tokens)
|
| 10 |
+
Weight tying: head <-> tgt_embed only (NOT src_embed)
|
| 11 |
+
|
| 12 |
+
Generation bugs fixed:
|
| 13 |
+
BUG 1 - tgt_pad_mask suppressed during inference
|
| 14 |
+
BUG 2 - q_sample skipped at t=0
|
| 15 |
+
BUG 3 - time embedding before hint_gate
|
| 16 |
+
BUG 4 - diversity penalty uses global mean not var
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
|
| 23 |
+
from diffusion.scheduler import OptimizedCosineScheduler
|
| 24 |
+
from diffusion.forward_process import AbsorbingForwardProcess
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class SinusoidalPositionalEncoding(nn.Module):
|
| 28 |
+
def __init__(self, d_model, max_len=5000):
|
| 29 |
+
super().__init__()
|
| 30 |
+
pe = torch.zeros(max_len, d_model)
|
| 31 |
+
position = torch.arange(0, max_len).unsqueeze(1).float()
|
| 32 |
+
div_term = torch.exp(
|
| 33 |
+
torch.arange(0, d_model, 2).float() *
|
| 34 |
+
(-torch.log(torch.tensor(10000.0)) / d_model)
|
| 35 |
+
)
|
| 36 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 37 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 38 |
+
self.register_buffer("pe", pe.unsqueeze(0))
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
return x + self.pe[:, :x.size(1), :]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class SanskritEmbeddings(nn.Module):
|
| 45 |
+
def __init__(self, vocab_size, d_model, max_seq_len):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.token_emb = nn.Embedding(vocab_size, d_model)
|
| 48 |
+
self.pos_enc = SinusoidalPositionalEncoding(d_model, max_seq_len)
|
| 49 |
+
self.token_embedding = self.token_emb
|
| 50 |
+
def forward(self, tokens):
|
| 51 |
+
return self.pos_enc(self.token_emb(tokens))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class MultiHeadAttention(nn.Module):
|
| 55 |
+
def __init__(self, d_model, n_heads, dropout=0.1):
|
| 56 |
+
super().__init__()
|
| 57 |
+
assert d_model % n_heads == 0
|
| 58 |
+
self.d_model = d_model
|
| 59 |
+
self.n_heads = n_heads
|
| 60 |
+
self.head_dim = d_model // n_heads
|
| 61 |
+
self.q_proj = nn.Linear(d_model, d_model)
|
| 62 |
+
self.k_proj = nn.Linear(d_model, d_model)
|
| 63 |
+
self.v_proj = nn.Linear(d_model, d_model)
|
| 64 |
+
self.out_proj = nn.Linear(d_model, d_model)
|
| 65 |
+
self.dropout = nn.Dropout(dropout)
|
| 66 |
+
|
| 67 |
+
def forward(self, q, k, v, mask=None):
|
| 68 |
+
B, Lq, _ = q.size()
|
| 69 |
+
Lk = k.size(1)
|
| 70 |
+
Q = self.q_proj(q).view(B, Lq, self.n_heads, self.head_dim).transpose(1, 2)
|
| 71 |
+
K = self.k_proj(k).view(B, Lk, self.n_heads, self.head_dim).transpose(1, 2)
|
| 72 |
+
V = self.v_proj(v).view(B, Lk, self.n_heads, self.head_dim).transpose(1, 2)
|
| 73 |
+
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
|
| 74 |
+
if mask is not None:
|
| 75 |
+
scores = scores.masked_fill(mask.unsqueeze(1).unsqueeze(2), float('-inf'))
|
| 76 |
+
attn = self.dropout(torch.softmax(scores, dim=-1))
|
| 77 |
+
out = torch.matmul(attn, V).transpose(1, 2).contiguous().view(B, Lq, self.d_model)
|
| 78 |
+
return self.out_proj(out)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class EncoderBlock(nn.Module):
|
| 82 |
+
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.mha = MultiHeadAttention(d_model, n_heads, dropout)
|
| 85 |
+
self.ff = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout),
|
| 86 |
+
nn.Linear(d_ff, d_model), nn.Dropout(dropout))
|
| 87 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 88 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 89 |
+
def forward(self, x, pad_mask=None):
|
| 90 |
+
x = self.norm1(x + self.mha(x, x, x, mask=pad_mask))
|
| 91 |
+
return self.norm2(x + self.ff(x))
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class DecoderBlock(nn.Module):
|
| 95 |
+
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
|
| 98 |
+
self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout)
|
| 99 |
+
self.ff = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout),
|
| 100 |
+
nn.Linear(d_ff, d_model), nn.Dropout(dropout))
|
| 101 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 102 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 103 |
+
self.norm3 = nn.LayerNorm(d_model)
|
| 104 |
+
def forward(self, x, memory, tgt_pad_mask=None, src_pad_mask=None):
|
| 105 |
+
x = self.norm1(x + self.self_attn(x, x, x, mask=tgt_pad_mask))
|
| 106 |
+
x = self.norm2(x + self.cross_attn(x, memory, memory, mask=src_pad_mask))
|
| 107 |
+
return self.norm3(x + self.ff(x))
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class D3PMCrossAttention(nn.Module):
|
| 111 |
+
def __init__(self, cfg):
|
| 112 |
+
super().__init__()
|
| 113 |
+
self.cfg = cfg
|
| 114 |
+
self.mask_token_id = cfg['diffusion']['mask_token_id']
|
| 115 |
+
d = cfg['model']['d_model']
|
| 116 |
+
nhead = cfg['model']['n_heads']
|
| 117 |
+
d_ff = cfg['model']['d_ff']
|
| 118 |
+
drop = cfg['model']['dropout']
|
| 119 |
+
seqlen = cfg['model']['max_seq_len']
|
| 120 |
+
nlayer = cfg['model']['n_layers']
|
| 121 |
+
src_vocab = cfg['model'].get('src_vocab_size', cfg['model']['vocab_size'])
|
| 122 |
+
tgt_vocab = cfg['model'].get('tgt_vocab_size', cfg['model']['vocab_size'])
|
| 123 |
+
|
| 124 |
+
# Separate embeddings: Roman src, Devanagari tgt
|
| 125 |
+
self.src_embed = SanskritEmbeddings(src_vocab, d, seqlen)
|
| 126 |
+
self.tgt_embed = SanskritEmbeddings(tgt_vocab, d, seqlen)
|
| 127 |
+
|
| 128 |
+
self.scheduler = OptimizedCosineScheduler(cfg)
|
| 129 |
+
self.forward_process = AbsorbingForwardProcess(self.scheduler)
|
| 130 |
+
|
| 131 |
+
self.encoder_blocks = nn.ModuleList([EncoderBlock(d, nhead, d_ff, drop) for _ in range(nlayer)])
|
| 132 |
+
self.decoder_blocks = nn.ModuleList([DecoderBlock(d, nhead, d_ff, drop) for _ in range(nlayer)])
|
| 133 |
+
|
| 134 |
+
self.time_mlp = nn.Sequential(nn.Linear(1, d//4), nn.SiLU(), nn.Linear(d//4, d))
|
| 135 |
+
self.hint_gate = nn.Sequential(nn.Linear(d, d), nn.Sigmoid())
|
| 136 |
+
|
| 137 |
+
# Output head: predict Devanagari tokens, tied to tgt_embed
|
| 138 |
+
self.head = nn.Linear(d, tgt_vocab, bias=False)
|
| 139 |
+
self.head.weight = self.tgt_embed.token_embedding.weight
|
| 140 |
+
|
| 141 |
+
def forward(self, src, tgt, t, x0_hint=None, inference_mode=False):
|
| 142 |
+
PAD = 1
|
| 143 |
+
src_pad_mask = (src == PAD)
|
| 144 |
+
# BUG 1 FIX: no tgt mask during inference
|
| 145 |
+
tgt_pad_mask = None if inference_mode else (tgt == PAD)
|
| 146 |
+
|
| 147 |
+
# Encode Roman source
|
| 148 |
+
memory = self.src_embed(src)
|
| 149 |
+
for block in self.encoder_blocks:
|
| 150 |
+
memory = block(memory, pad_mask=src_pad_mask)
|
| 151 |
+
|
| 152 |
+
# BUG 2 FIX: skip q_sample at final step t=0
|
| 153 |
+
if inference_mode and (t == 0).all():
|
| 154 |
+
x_t_ids = tgt
|
| 155 |
+
else:
|
| 156 |
+
_, x_t_ids = self.forward_process.q_sample(tgt, t)
|
| 157 |
+
|
| 158 |
+
x = self.tgt_embed(x_t_ids)
|
| 159 |
+
|
| 160 |
+
# BUG 3 FIX: time embedding BEFORE hint gate
|
| 161 |
+
t_norm = t.float() / self.scheduler.num_timesteps
|
| 162 |
+
t_emb = self.time_mlp(t_norm.unsqueeze(-1))
|
| 163 |
+
x = x + t_emb.unsqueeze(1)
|
| 164 |
+
|
| 165 |
+
if x0_hint is not None:
|
| 166 |
+
hint_emb = self.tgt_embed(x0_hint)
|
| 167 |
+
gate = self.hint_gate(x) # time-aware gate
|
| 168 |
+
x = x + gate * hint_emb
|
| 169 |
+
|
| 170 |
+
for block in self.decoder_blocks:
|
| 171 |
+
x = block(x, memory, tgt_pad_mask=tgt_pad_mask, src_pad_mask=src_pad_mask)
|
| 172 |
+
|
| 173 |
+
return self.head(x), None
|
| 174 |
+
|
| 175 |
+
@torch.no_grad()
|
| 176 |
+
def generate(self, src, num_steps=None, temperature=0.8, top_k=50,
|
| 177 |
+
repetition_penalty=1.2, diversity_penalty=0.0):
|
| 178 |
+
if src.dim() == 1:
|
| 179 |
+
src = src.unsqueeze(0)
|
| 180 |
+
device = src.device
|
| 181 |
+
B, L = src.shape
|
| 182 |
+
T = self.scheduler.num_timesteps
|
| 183 |
+
steps = num_steps or T
|
| 184 |
+
step_size = max(1, T // steps)
|
| 185 |
+
timesteps = list(range(T - 1, -1, -step_size))
|
| 186 |
+
if timesteps[-1] != 0:
|
| 187 |
+
timesteps.append(0)
|
| 188 |
+
|
| 189 |
+
mask_id = self.mask_token_id
|
| 190 |
+
x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
|
| 191 |
+
hint = None
|
| 192 |
+
|
| 193 |
+
self.eval()
|
| 194 |
+
with torch.no_grad():
|
| 195 |
+
for step_idx, t_val in enumerate(timesteps):
|
| 196 |
+
t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 197 |
+
is_last = (step_idx == len(timesteps) - 1)
|
| 198 |
+
logits, _ = self.forward(src, x0_est, t, x0_hint=hint, inference_mode=True)
|
| 199 |
+
if repetition_penalty != 1.0:
|
| 200 |
+
logits = _apply_repetition_penalty(logits, x0_est, repetition_penalty)
|
| 201 |
+
if diversity_penalty > 0.0:
|
| 202 |
+
logits = _apply_diversity_penalty_fixed(logits, diversity_penalty) # BUG 4 FIX
|
| 203 |
+
logits = logits / max(temperature, 1e-5)
|
| 204 |
+
if top_k > 0:
|
| 205 |
+
logits = _top_k_filter(logits, top_k)
|
| 206 |
+
probs = F.softmax(logits, dim=-1)
|
| 207 |
+
x0_est = torch.argmax(probs, dim=-1) if is_last else _batch_multinomial(probs)
|
| 208 |
+
hint = x0_est
|
| 209 |
+
return x0_est
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class BaselineCrossAttention(nn.Module):
|
| 213 |
+
def __init__(self, cfg):
|
| 214 |
+
super().__init__()
|
| 215 |
+
d = cfg['model']['d_model']; nhead = cfg['model']['n_heads']
|
| 216 |
+
d_ff = cfg['model']['d_ff']; drop = cfg['model']['dropout']
|
| 217 |
+
seqlen = cfg['model']['max_seq_len']; nlayer = cfg['model']['n_layers']
|
| 218 |
+
src_vocab = cfg['model'].get('src_vocab_size', cfg['model']['vocab_size'])
|
| 219 |
+
tgt_vocab = cfg['model'].get('tgt_vocab_size', cfg['model']['vocab_size'])
|
| 220 |
+
self.src_embed = SanskritEmbeddings(src_vocab, d, seqlen)
|
| 221 |
+
self.tgt_embed = SanskritEmbeddings(tgt_vocab, d, seqlen)
|
| 222 |
+
self.encoder_blocks = nn.ModuleList([EncoderBlock(d, nhead, d_ff, drop) for _ in range(nlayer)])
|
| 223 |
+
self.decoder_blocks = nn.ModuleList([DecoderBlock(d, nhead, d_ff, drop) for _ in range(nlayer)])
|
| 224 |
+
self.head = nn.Linear(d, tgt_vocab, bias=False)
|
| 225 |
+
self.head.weight = self.tgt_embed.token_embedding.weight
|
| 226 |
+
|
| 227 |
+
def forward(self, src, tgt, t=None, x0_hint=None):
|
| 228 |
+
PAD = 1
|
| 229 |
+
memory = self.src_embed(src)
|
| 230 |
+
for b in self.encoder_blocks: memory = b(memory, pad_mask=(src==PAD))
|
| 231 |
+
x = self.tgt_embed(tgt)
|
| 232 |
+
for b in self.decoder_blocks: x = b(x, memory, tgt_pad_mask=(tgt==PAD), src_pad_mask=(src==PAD))
|
| 233 |
+
return (self.head(x),)
|
| 234 |
+
|
| 235 |
+
@torch.no_grad()
|
| 236 |
+
def generate(self, src, max_len=None, start_token_id=2, **kwargs):
|
| 237 |
+
if max_len is None: max_len = src.size(1)
|
| 238 |
+
B, device = src.size(0), src.device
|
| 239 |
+
memory = self.src_embed(src)
|
| 240 |
+
for b in self.encoder_blocks: memory = b(memory, pad_mask=(src==1))
|
| 241 |
+
ys = torch.full((B, 1), start_token_id, dtype=torch.long, device=device)
|
| 242 |
+
for _ in range(max_len):
|
| 243 |
+
x = self.tgt_embed(ys)
|
| 244 |
+
for b in self.decoder_blocks: x = b(x, memory, tgt_pad_mask=None, src_pad_mask=(src==1))
|
| 245 |
+
ys = torch.cat([ys, torch.argmax(self.head(x)[:,-1,:], dim=-1, keepdim=True)], dim=1)
|
| 246 |
+
return ys[:, 1:max_len+1]
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# helpers
|
| 250 |
+
def _top_k_filter(logits, k):
|
| 251 |
+
B, L, V = logits.shape
|
| 252 |
+
if k >= V: return logits
|
| 253 |
+
topk_vals, _ = torch.topk(logits, k, dim=-1)
|
| 254 |
+
return logits.masked_fill(logits < topk_vals[..., -1].unsqueeze(-1), float('-inf'))
|
| 255 |
+
|
| 256 |
+
def _batch_multinomial(probs):
|
| 257 |
+
B, L, V = probs.shape
|
| 258 |
+
flat = probs.view(B*L, V) + 1e-9
|
| 259 |
+
return torch.multinomial(flat/flat.sum(-1,keepdim=True), 1).squeeze(-1).view(B, L)
|
| 260 |
+
|
| 261 |
+
def _apply_repetition_penalty(logits, prev_tokens, penalty):
|
| 262 |
+
for b in range(logits.shape[0]):
|
| 263 |
+
for tid in set(prev_tokens[b].tolist()):
|
| 264 |
+
if tid > 4: logits[b, :, tid] = logits[b, :, tid] / penalty
|
| 265 |
+
return logits
|
| 266 |
+
|
| 267 |
+
def _apply_diversity_penalty(logits, penalty): # legacy wrong version
|
| 268 |
+
return logits + penalty * logits.var(dim=-1, keepdim=True)
|
| 269 |
+
|
| 270 |
+
def _apply_diversity_penalty_fixed(logits, penalty): # correct version
|
| 271 |
+
return logits - penalty * logits.mean(dim=1, keepdim=True)
|
model/d3pm_model_encoder_decoder.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from diffusion.scheduler import OptimizedCosineScheduler
|
| 4 |
+
from diffusion.forward_process import AbsorbingForwardProcess
|
| 5 |
+
# Import shared classes to guarantee identical architectures
|
| 6 |
+
from model.d3pm_model_cross_attention import SanskritEmbeddings, EncoderBlock, MultiHeadAttention
|
| 7 |
+
class DecoderBlock(nn.Module):
|
| 8 |
+
def __init__(self, d_model, n_heads, d_ff, dropout=0.15):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
|
| 11 |
+
self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout) # ← restored
|
| 12 |
+
self.ff = nn.Sequential(
|
| 13 |
+
nn.Linear(d_model, d_ff),
|
| 14 |
+
nn.ReLU(),
|
| 15 |
+
nn.Dropout(dropout),
|
| 16 |
+
nn.Linear(d_ff, d_model),
|
| 17 |
+
nn.Dropout(dropout),
|
| 18 |
+
)
|
| 19 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 20 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 21 |
+
self.norm3 = nn.LayerNorm(d_model) # ← restored (for cross-attn residual)
|
| 22 |
+
|
| 23 |
+
def forward(self, x, memory, tgt_pad_mask=None):
|
| 24 |
+
# 1. Masked self-attention on target
|
| 25 |
+
x = self.norm1(x + self.self_attn(x, x, x, mask=tgt_pad_mask))
|
| 26 |
+
# 2. Cross-attention: queries from decoder, keys/values from encoder memory
|
| 27 |
+
x = self.norm2(x + self.cross_attn(x, memory, memory))
|
| 28 |
+
# 3. Feed-forward
|
| 29 |
+
return self.norm3(x + self.ff(x))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class DecoderBlockNoCrossAttn(nn.Module):
|
| 33 |
+
"""Kept for reference — NOT used by D3PMEncoderDecoder."""
|
| 34 |
+
def __init__(self, d_model, n_heads, d_ff, dropout=0.15):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
|
| 37 |
+
self.ff = nn.Sequential(
|
| 38 |
+
nn.Linear(d_model, d_ff), nn.ReLU(), nn.Dropout(dropout),
|
| 39 |
+
nn.Linear(d_ff, d_model), nn.Dropout(dropout),
|
| 40 |
+
)
|
| 41 |
+
self.norm1, self.norm2 = nn.LayerNorm(d_model), nn.LayerNorm(d_model)
|
| 42 |
+
|
| 43 |
+
def forward(self, x, tgt_pad_mask=None, causal_mask=None):
|
| 44 |
+
combined_mask = None
|
| 45 |
+
if tgt_pad_mask is not None and causal_mask is not None:
|
| 46 |
+
combined_mask = tgt_pad_mask | causal_mask
|
| 47 |
+
elif causal_mask is not None:
|
| 48 |
+
combined_mask = causal_mask
|
| 49 |
+
elif tgt_pad_mask is not None:
|
| 50 |
+
combined_mask = tgt_pad_mask
|
| 51 |
+
x = self.norm1(x + self.self_attn(x, x, x, mask=combined_mask))
|
| 52 |
+
return self.norm2(x + self.ff(x))
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# ============================================================
|
| 56 |
+
# 1. D3PM Encoder-Decoder Model
|
| 57 |
+
# ============================================================
|
| 58 |
+
class D3PMEncoderDecoder(nn.Module):
|
| 59 |
+
def __init__(self, cfg):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.cfg = cfg
|
| 62 |
+
self.mask_token_id = cfg['diffusion']['mask_token_id']
|
| 63 |
+
|
| 64 |
+
src_vocab = cfg['model'].get('src_vocab_size', cfg['model']['vocab_size'])
|
| 65 |
+
tgt_vocab = cfg['model'].get('tgt_vocab_size', cfg['model']['vocab_size'])
|
| 66 |
+
d_model = cfg['model']['d_model']
|
| 67 |
+
n_heads = cfg['model']['n_heads']
|
| 68 |
+
d_ff = cfg['model']['d_ff']
|
| 69 |
+
dropout = cfg['model']['dropout']
|
| 70 |
+
n_layers = cfg['model']['n_layers']
|
| 71 |
+
max_len = cfg['model']['max_seq_len']
|
| 72 |
+
|
| 73 |
+
self.src_embed = SanskritEmbeddings(src_vocab, d_model, max_len)
|
| 74 |
+
self.tgt_embed = SanskritEmbeddings(tgt_vocab, d_model, max_len)
|
| 75 |
+
|
| 76 |
+
self.scheduler = OptimizedCosineScheduler(cfg)
|
| 77 |
+
self.forward_process = AbsorbingForwardProcess(self.scheduler)
|
| 78 |
+
|
| 79 |
+
self.encoder_blocks = nn.ModuleList([
|
| 80 |
+
EncoderBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
|
| 81 |
+
])
|
| 82 |
+
# DecoderBlock now has cross-attention — matches saved checkpoint
|
| 83 |
+
self.decoder_blocks = nn.ModuleList([
|
| 84 |
+
DecoderBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
|
| 85 |
+
])
|
| 86 |
+
|
| 87 |
+
self.time_mlp = nn.Sequential(
|
| 88 |
+
nn.Linear(1, d_model // 4), nn.SiLU(),
|
| 89 |
+
nn.Linear(d_model // 4, d_model),
|
| 90 |
+
)
|
| 91 |
+
self.head = nn.Linear(d_model, tgt_vocab)
|
| 92 |
+
self.head.weight = self.tgt_embed.token_embedding.weight
|
| 93 |
+
|
| 94 |
+
def forward(self, src, tgt, t, x0_hint=None):
|
| 95 |
+
src_pad_mask = (src == 1)
|
| 96 |
+
tgt_pad_mask = (tgt == 1)
|
| 97 |
+
|
| 98 |
+
# Encode source (Roman IAST)
|
| 99 |
+
memory = self.src_embed(src)
|
| 100 |
+
for block in self.encoder_blocks:
|
| 101 |
+
memory = block(memory, pad_mask=src_pad_mask)
|
| 102 |
+
|
| 103 |
+
# Corrupt target with forward diffusion
|
| 104 |
+
_, x_t_ids = self.forward_process.q_sample(tgt, t)
|
| 105 |
+
|
| 106 |
+
# Optionally blend in x0_hint (self-conditioning)
|
| 107 |
+
if x0_hint is not None:
|
| 108 |
+
hint_prob = 0.5
|
| 109 |
+
blend_mask = (torch.rand(x_t_ids.shape, device=x_t_ids.device) < hint_prob)
|
| 110 |
+
still_mask = (x_t_ids == self.mask_token_id)
|
| 111 |
+
x_t_ids = torch.where(blend_mask & still_mask, x0_hint, x_t_ids)
|
| 112 |
+
|
| 113 |
+
x = self.tgt_embed(x_t_ids)
|
| 114 |
+
t_emb = self.time_mlp(t.float().unsqueeze(-1)).unsqueeze(1)
|
| 115 |
+
x = x + t_emb.expand(-1, tgt.shape[1], -1)
|
| 116 |
+
|
| 117 |
+
# Decode with cross-attention over encoder memory
|
| 118 |
+
for block in self.decoder_blocks:
|
| 119 |
+
x = block(x, memory, tgt_pad_mask=tgt_pad_mask)
|
| 120 |
+
|
| 121 |
+
return self.head(x), None
|
| 122 |
+
|
| 123 |
+
@torch.no_grad()
|
| 124 |
+
def generate(
|
| 125 |
+
self,
|
| 126 |
+
src,
|
| 127 |
+
num_steps = None,
|
| 128 |
+
temperature = 0.75,
|
| 129 |
+
top_k = 50,
|
| 130 |
+
repetition_penalty = 1.15,
|
| 131 |
+
diversity_penalty = 0.0,
|
| 132 |
+
):
|
| 133 |
+
"""
|
| 134 |
+
Iterative D3PM reverse diffusion — same signature as
|
| 135 |
+
D3PMCrossAttention.generate() so SanskritModel.generate() works
|
| 136 |
+
identically for both model types.
|
| 137 |
+
"""
|
| 138 |
+
device = src.device
|
| 139 |
+
B, L = src.shape[0], self.cfg['model']['max_seq_len']
|
| 140 |
+
T = num_steps or self.scheduler.num_timesteps
|
| 141 |
+
mask_id = self.mask_token_id
|
| 142 |
+
pad_id = 1
|
| 143 |
+
|
| 144 |
+
x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
|
| 145 |
+
|
| 146 |
+
for step in range(T - 1, -1, -1):
|
| 147 |
+
t_tensor = torch.full((B,), step, dtype=torch.long, device=device)
|
| 148 |
+
hint = x0_est.clone()
|
| 149 |
+
|
| 150 |
+
logits, _ = self.forward(src, x0_est, t_tensor, x0_hint=hint)
|
| 151 |
+
|
| 152 |
+
# Repetition penalty
|
| 153 |
+
if repetition_penalty != 1.0:
|
| 154 |
+
for b in range(B):
|
| 155 |
+
for tok in set(x0_est[b].tolist()):
|
| 156 |
+
if tok > pad_id:
|
| 157 |
+
logits[b, :, tok] /= repetition_penalty
|
| 158 |
+
|
| 159 |
+
# Diversity penalty (suppress common tokens)
|
| 160 |
+
if diversity_penalty > 0.0:
|
| 161 |
+
logits = logits - diversity_penalty * logits.mean(dim=1, keepdim=True)
|
| 162 |
+
|
| 163 |
+
# Temperature + top-k sampling
|
| 164 |
+
logits = logits / max(temperature, 1e-8)
|
| 165 |
+
if top_k > 0:
|
| 166 |
+
vals, _ = torch.topk(logits, top_k, dim=-1)
|
| 167 |
+
logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
|
| 168 |
+
|
| 169 |
+
probs = torch.softmax(logits, dim=-1)
|
| 170 |
+
# Only update positions that are still masked
|
| 171 |
+
still = (x0_est == mask_id)
|
| 172 |
+
sample = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(B, L)
|
| 173 |
+
x0_est = torch.where(still, sample, x0_est)
|
| 174 |
+
|
| 175 |
+
return x0_est
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# ============================================================
|
| 179 |
+
# 2. Baseline Encoder-Decoder Model (unchanged)
|
| 180 |
+
# ============================================================
|
| 181 |
+
class BaselineEncoderDecoder(nn.Module):
|
| 182 |
+
def __init__(self, cfg):
|
| 183 |
+
super().__init__()
|
| 184 |
+
self.cfg = cfg
|
| 185 |
+
self.src_embed = SanskritEmbeddings(cfg['model']['vocab_size'], cfg['model']['d_model'],
|
| 186 |
+
cfg['model']['max_seq_len'])
|
| 187 |
+
self.tgt_embed = SanskritEmbeddings(cfg['model']['vocab_size'], cfg['model']['d_model'],
|
| 188 |
+
cfg['model']['max_seq_len'])
|
| 189 |
+
self.encoder_blocks = nn.ModuleList([
|
| 190 |
+
EncoderBlock(cfg['model']['d_model'], cfg['model']['n_heads'],
|
| 191 |
+
cfg['model']['d_ff'], cfg['model']['dropout'])
|
| 192 |
+
for _ in range(cfg['model']['n_layers'])
|
| 193 |
+
])
|
| 194 |
+
self.decoder_blocks = nn.ModuleList([
|
| 195 |
+
DecoderBlock(cfg['model']['d_model'], cfg['model']['n_heads'],
|
| 196 |
+
cfg['model']['d_ff'], cfg['model']['dropout'])
|
| 197 |
+
for _ in range(cfg['model']['n_layers'])
|
| 198 |
+
])
|
| 199 |
+
self.head = nn.Linear(cfg['model']['d_model'], cfg['model']['vocab_size'])
|
| 200 |
+
self.head.weight = self.tgt_embed.token_embedding.weight
|
| 201 |
+
|
| 202 |
+
def forward(self, src, tgt):
|
| 203 |
+
src_pad_mask, tgt_pad_mask = (src == 1), (tgt == 1)
|
| 204 |
+
memory = self.src_embed(src)
|
| 205 |
+
for block in self.encoder_blocks:
|
| 206 |
+
memory = block(memory, pad_mask=src_pad_mask)
|
| 207 |
+
x = self.tgt_embed(tgt)
|
| 208 |
+
for block in self.decoder_blocks:
|
| 209 |
+
x = block(x, memory, tgt_pad_mask=tgt_pad_mask)
|
| 210 |
+
return self.head(x)
|
| 211 |
+
|
| 212 |
+
@torch.no_grad()
|
| 213 |
+
def generate(self, src, max_len=80, start_token_id=2):
|
| 214 |
+
batch_size, device = src.size(0), src.device
|
| 215 |
+
src_pad_mask = (src == 1)
|
| 216 |
+
memory = self.src_embed(src)
|
| 217 |
+
for block in self.encoder_blocks:
|
| 218 |
+
memory = block(memory, pad_mask=src_pad_mask)
|
| 219 |
+
ys = torch.ones(batch_size, 1, dtype=torch.long, device=device) * start_token_id
|
| 220 |
+
for _ in range(max_len):
|
| 221 |
+
x = self.tgt_embed(ys)
|
| 222 |
+
for block in self.decoder_blocks:
|
| 223 |
+
x = block(x, memory, tgt_pad_mask=None)
|
| 224 |
+
logits = self.head(x)
|
| 225 |
+
next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
|
| 226 |
+
ys = torch.cat([ys, next_token], dim=1)
|
| 227 |
+
return ys[:, 1:]
|
model/sanskrit_model.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
sanskrit_model.py — Fixed
|
| 3 |
+
===========================
|
| 4 |
+
Added inference_mode parameter to forward() so reverse_process.py can
|
| 5 |
+
pass inference_mode=True without a TypeError.
|
| 6 |
+
|
| 7 |
+
The wrapper introspects each inner model's signature and only passes
|
| 8 |
+
kwargs that model actually accepts — safe across all four architectures.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import inspect
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SanskritModel(nn.Module):
|
| 17 |
+
def __init__(self, cfg):
|
| 18 |
+
super().__init__()
|
| 19 |
+
model_type = cfg['model_type']
|
| 20 |
+
|
| 21 |
+
if model_type == 'd3pm_cross_attention':
|
| 22 |
+
from model.d3pm_model_cross_attention import D3PMCrossAttention
|
| 23 |
+
self.model = D3PMCrossAttention(cfg)
|
| 24 |
+
|
| 25 |
+
elif model_type == 'd3pm_encoder_decoder':
|
| 26 |
+
from model.d3pm_model_encoder_decoder import D3PMEncoderDecoder
|
| 27 |
+
self.model = D3PMEncoderDecoder(cfg)
|
| 28 |
+
|
| 29 |
+
elif model_type == 'baseline_cross_attention':
|
| 30 |
+
from model.d3pm_model_cross_attention import BaselineCrossAttention
|
| 31 |
+
self.model = BaselineCrossAttention(cfg)
|
| 32 |
+
|
| 33 |
+
elif model_type == 'baseline_encoder_decoder':
|
| 34 |
+
from model.d3pm_model_encoder_decoder import BaselineEncoderDecoder
|
| 35 |
+
self.model = BaselineEncoderDecoder(cfg)
|
| 36 |
+
|
| 37 |
+
else:
|
| 38 |
+
raise ValueError(f"Unknown model_type: {model_type}")
|
| 39 |
+
|
| 40 |
+
def forward(self, input_ids, target_ids, t, x0_hint=None, inference_mode=False):
|
| 41 |
+
"""
|
| 42 |
+
Forward pass. Introspects the inner model's signature so only
|
| 43 |
+
supported kwargs are passed — works with all four architectures.
|
| 44 |
+
"""
|
| 45 |
+
sig = inspect.signature(self.model.forward).parameters
|
| 46 |
+
kwargs = {}
|
| 47 |
+
if 'x0_hint' in sig:
|
| 48 |
+
kwargs['x0_hint'] = x0_hint
|
| 49 |
+
if 'inference_mode' in sig:
|
| 50 |
+
kwargs['inference_mode'] = inference_mode
|
| 51 |
+
|
| 52 |
+
if 't' in sig:
|
| 53 |
+
return self.model(input_ids, target_ids, t, **kwargs)
|
| 54 |
+
else:
|
| 55 |
+
return self.model(input_ids, target_ids, **kwargs)
|
| 56 |
+
|
| 57 |
+
@torch.no_grad()
|
| 58 |
+
def generate(self, src, **kwargs):
|
| 59 |
+
sig = inspect.signature(self.model.generate).parameters
|
| 60 |
+
filtered = {k: v for k, v in kwargs.items() if k in sig}
|
| 61 |
+
return self.model.generate(src, **filtered)
|
model/tokenizer.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
tokenizer.py — Dual Tokenizer Fix
|
| 3 |
+
====================================
|
| 4 |
+
Two separate BPE tokenizers:
|
| 5 |
+
|
| 6 |
+
SanskritSourceTokenizer — trained on quote_text (Roman/IAST script)
|
| 7 |
+
SanskritTargetTokenizer — trained on quote_devanagari (Devanagari script)
|
| 8 |
+
|
| 9 |
+
WHY SEPARATE?
|
| 10 |
+
Roman Sanskrit and Devanagari are fundamentally different character sets.
|
| 11 |
+
Roman uses a-z + diacritics (~60 unique chars), Devanagari uses ā-ह + matras
|
| 12 |
+
(~100+ unique chars). A shared BPE tokenizer wastes half its vocab on
|
| 13 |
+
character combos that never cross scripts, and forces the embedding table
|
| 14 |
+
to encode both scripts in one space — confusing the model's cross-attention.
|
| 15 |
+
|
| 16 |
+
With separate tokenizers:
|
| 17 |
+
- src vocab captures Roman subwords cleanly (ā, ś, ṭ, ṃ etc.)
|
| 18 |
+
- tgt vocab captures Devanagari akshara clusters cleanly (क्ष, त्र, etc.)
|
| 19 |
+
- The model learns a true cross-script mapping in its cross-attention
|
| 20 |
+
|
| 21 |
+
SPECIAL TOKENS (same IDs in both):
|
| 22 |
+
[MASK] = 0 ← required by absorbing diffusion
|
| 23 |
+
[PAD] = 1
|
| 24 |
+
[UNK] = 2
|
| 25 |
+
[CLS] = 3
|
| 26 |
+
[SEP] = 4
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
from tokenizers import Tokenizer
|
| 30 |
+
from tokenizers.models import BPE
|
| 31 |
+
from tokenizers.trainers import BpeTrainer
|
| 32 |
+
from tokenizers.pre_tokenizers import Whitespace
|
| 33 |
+
from datasets import load_dataset
|
| 34 |
+
from pathlib import Path
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
SPECIAL_TOKENS = ["[MASK]", "[PAD]", "[UNK]", "[CLS]", "[SEP]"]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _build_bpe(texts, vocab_size):
|
| 41 |
+
"""Build a BPE tokenizer from an iterator of strings."""
|
| 42 |
+
tok = Tokenizer(BPE(unk_token="[UNK]"))
|
| 43 |
+
tok.pre_tokenizer = Whitespace()
|
| 44 |
+
trainer = BpeTrainer(
|
| 45 |
+
vocab_size=vocab_size,
|
| 46 |
+
special_tokens=SPECIAL_TOKENS, # [MASK] MUST be first → id=0
|
| 47 |
+
min_frequency=2,
|
| 48 |
+
)
|
| 49 |
+
tok.train_from_iterator(texts, trainer)
|
| 50 |
+
return tok
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _validate(tok, name):
|
| 54 |
+
mask_id = tok.token_to_id("[MASK]")
|
| 55 |
+
pad_id = tok.token_to_id("[PAD]")
|
| 56 |
+
assert mask_id == 0, f"{name}: [MASK] must be id=0, got {mask_id}"
|
| 57 |
+
assert pad_id == 1, f"{name}: [PAD] must be id=1, got {pad_id}"
|
| 58 |
+
print(f"✅ {name}: [MASK]=0, [PAD]=1 confirmed. Vocab size={tok.get_vocab_size()}")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# ── Source tokenizer (Roman/IAST Sanskrit) ────────────────────────────
|
| 62 |
+
|
| 63 |
+
class SanskritSourceTokenizer:
|
| 64 |
+
"""
|
| 65 |
+
Tokenizer for quote_text — Roman transliteration of Sanskrit.
|
| 66 |
+
Examples: "dharmo rakṣati rakṣitaḥ", "yatra nāryastu pūjyante"
|
| 67 |
+
"""
|
| 68 |
+
MODEL_PATH = "sanskrit_src_tokenizer.json"
|
| 69 |
+
|
| 70 |
+
def __init__(self, vocab_size=8000, max_len=80, n_train_samples=50000):
|
| 71 |
+
self.vocab_size = vocab_size
|
| 72 |
+
self.max_len = max_len
|
| 73 |
+
self.mask_token_id = 0
|
| 74 |
+
|
| 75 |
+
if Path(self.MODEL_PATH).exists():
|
| 76 |
+
print(f"📖 Loading source tokenizer from {self.MODEL_PATH} …")
|
| 77 |
+
self.tokenizer = Tokenizer.from_file(self.MODEL_PATH)
|
| 78 |
+
else:
|
| 79 |
+
print("🎓 Training source tokenizer on quote_text …")
|
| 80 |
+
self._train(vocab_size, n_train_samples)
|
| 81 |
+
|
| 82 |
+
_validate(self.tokenizer, "SrcTokenizer")
|
| 83 |
+
|
| 84 |
+
def _train(self, vocab_size, n_samples):
|
| 85 |
+
dataset = load_dataset("paws/sanskrit-verses-gretil", split="train")
|
| 86 |
+
n = min(n_samples, len(dataset))
|
| 87 |
+
texts = [s["quote_text"] for s in dataset.select(range(n))
|
| 88 |
+
if s["quote_text"].strip()]
|
| 89 |
+
self.tokenizer = _build_bpe(texts, vocab_size)
|
| 90 |
+
self.tokenizer.save(self.MODEL_PATH)
|
| 91 |
+
print(f"✅ Source tokenizer trained on {len(texts)} Roman texts.")
|
| 92 |
+
|
| 93 |
+
def encode(self, text):
|
| 94 |
+
ids = self.tokenizer.encode(text).ids[:self.max_len]
|
| 95 |
+
pad = self.tokenizer.token_to_id("[PAD]")
|
| 96 |
+
ids += [pad] * max(0, self.max_len - len(ids))
|
| 97 |
+
return ids[:self.max_len]
|
| 98 |
+
|
| 99 |
+
def decode(self, ids):
|
| 100 |
+
clean = [i for i in ids if i > 4] # skip special tokens
|
| 101 |
+
return self.tokenizer.decode(clean)
|
| 102 |
+
|
| 103 |
+
def __len__(self):
|
| 104 |
+
return self.vocab_size
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# ── Target tokenizer (Devanagari Sanskrit) ───────────────────────────
|
| 108 |
+
|
| 109 |
+
class SanskritTargetTokenizer:
|
| 110 |
+
"""
|
| 111 |
+
Tokenizer for quote_devanagari — Devanagari script.
|
| 112 |
+
Examples: "धर्मो रक्षति रक्षितः", "यत्र नार्यस्तु पूज्यन्ते"
|
| 113 |
+
"""
|
| 114 |
+
MODEL_PATH = "sanskrit_tgt_tokenizer.json"
|
| 115 |
+
|
| 116 |
+
def __init__(self, vocab_size=8000, max_len=80, n_train_samples=50000):
|
| 117 |
+
self.vocab_size = vocab_size
|
| 118 |
+
self.max_len = max_len
|
| 119 |
+
self.mask_token_id = 0
|
| 120 |
+
|
| 121 |
+
if Path(self.MODEL_PATH).exists():
|
| 122 |
+
print(f"📖 Loading target tokenizer from {self.MODEL_PATH} …")
|
| 123 |
+
self.tokenizer = Tokenizer.from_file(self.MODEL_PATH)
|
| 124 |
+
else:
|
| 125 |
+
print("🎓 Training target tokenizer on quote_devanagari …")
|
| 126 |
+
self._train(vocab_size, n_train_samples)
|
| 127 |
+
|
| 128 |
+
_validate(self.tokenizer, "TgtTokenizer")
|
| 129 |
+
|
| 130 |
+
def _train(self, vocab_size, n_samples):
|
| 131 |
+
dataset = load_dataset("paws/sanskrit-verses-gretil", split="train")
|
| 132 |
+
n = min(n_samples, len(dataset))
|
| 133 |
+
texts = [s["quote_devanagari"] for s in dataset.select(range(n))
|
| 134 |
+
if s["quote_devanagari"].strip()]
|
| 135 |
+
self.tokenizer = _build_bpe(texts, vocab_size)
|
| 136 |
+
self.tokenizer.save(self.MODEL_PATH)
|
| 137 |
+
print(f"✅ Target tokenizer trained on {len(texts)} Devanagari texts.")
|
| 138 |
+
|
| 139 |
+
def encode(self, text):
|
| 140 |
+
ids = self.tokenizer.encode(text).ids[:self.max_len]
|
| 141 |
+
pad = self.tokenizer.token_to_id("[PAD]")
|
| 142 |
+
ids += [pad] * max(0, self.max_len - len(ids))
|
| 143 |
+
return ids[:self.max_len]
|
| 144 |
+
|
| 145 |
+
def decode(self, ids):
|
| 146 |
+
clean = [i for i in ids if i > 4]
|
| 147 |
+
return self.tokenizer.decode(clean)
|
| 148 |
+
|
| 149 |
+
# Methods required by BERTScore
|
| 150 |
+
def build_inputs_with_special_tokens(self, token_ids):
|
| 151 |
+
return list(token_ids)
|
| 152 |
+
|
| 153 |
+
def get_vocab(self):
|
| 154 |
+
return {str(i): i for i in range(self.vocab_size)}
|
| 155 |
+
|
| 156 |
+
def convert_ids_to_tokens(self, ids):
|
| 157 |
+
return [str(i) for i in ids]
|
| 158 |
+
|
| 159 |
+
def __len__(self):
|
| 160 |
+
return self.vocab_size
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# ── Legacy shared tokenizer (kept for backward compat) ───────────────
|
| 164 |
+
|
| 165 |
+
class SanskritTokenizer:
|
| 166 |
+
"""
|
| 167 |
+
LEGACY: single shared tokenizer trained on BOTH scripts.
|
| 168 |
+
Still works but suboptimal — use SanskritSourceTokenizer +
|
| 169 |
+
SanskritTargetTokenizer for the quote_text → quote_devanagari task.
|
| 170 |
+
"""
|
| 171 |
+
MODEL_PATH = "sanskrit_tokenizer_m4pro.json"
|
| 172 |
+
|
| 173 |
+
def __init__(self, vocab_size=16000, max_len=80):
|
| 174 |
+
self.vocab_size = vocab_size
|
| 175 |
+
self.max_len = max_len
|
| 176 |
+
self.mask_token_id = 0
|
| 177 |
+
|
| 178 |
+
if Path(self.MODEL_PATH).exists():
|
| 179 |
+
print("📖 Loading shared tokenizer …")
|
| 180 |
+
self.tokenizer = Tokenizer.from_file(self.MODEL_PATH)
|
| 181 |
+
else:
|
| 182 |
+
print("🎓 Training shared tokenizer on both scripts …")
|
| 183 |
+
self._train(vocab_size)
|
| 184 |
+
|
| 185 |
+
_validate(self.tokenizer, "SharedTokenizer")
|
| 186 |
+
|
| 187 |
+
def _train(self, vocab_size):
|
| 188 |
+
dataset = load_dataset("paws/sanskrit-verses-gretil", split="train")
|
| 189 |
+
n = min(50000, len(dataset))
|
| 190 |
+
texts = []
|
| 191 |
+
for s in dataset.select(range(n)):
|
| 192 |
+
if s["quote_text"].strip():
|
| 193 |
+
texts.append(s["quote_text"])
|
| 194 |
+
if s["quote_devanagari"].strip():
|
| 195 |
+
texts.append(s["quote_devanagari"])
|
| 196 |
+
self.tokenizer = _build_bpe(texts, vocab_size)
|
| 197 |
+
self.tokenizer.save(self.MODEL_PATH)
|
| 198 |
+
print(f"✅ Shared tokenizer trained ({len(texts)} texts).")
|
| 199 |
+
|
| 200 |
+
def encode(self, text):
|
| 201 |
+
ids = self.tokenizer.encode(text).ids[:self.max_len]
|
| 202 |
+
pad = self.tokenizer.token_to_id("[PAD]")
|
| 203 |
+
ids += [pad] * max(0, self.max_len - len(ids))
|
| 204 |
+
return ids[:self.max_len]
|
| 205 |
+
|
| 206 |
+
def decode(self, ids):
|
| 207 |
+
if ids and isinstance(ids[0], list):
|
| 208 |
+
raise TypeError("decode() got 2D list — pass a 1D list.")
|
| 209 |
+
clean = [i for i in ids if i > 4]
|
| 210 |
+
return self.tokenizer.decode(clean)
|
| 211 |
+
|
| 212 |
+
def build_inputs_with_special_tokens(self, token_ids):
|
| 213 |
+
return list(token_ids)
|
| 214 |
+
|
| 215 |
+
def get_vocab(self):
|
| 216 |
+
return {str(i): i for i in range(self.vocab_size)}
|
| 217 |
+
|
| 218 |
+
def convert_ids_to_tokens(self, ids):
|
| 219 |
+
return [str(i) for i in ids]
|
| 220 |
+
|
| 221 |
+
def __len__(self):
|
| 222 |
+
return self.vocab_size
|
model/tokenizers.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
tokenizer.py — FINAL
|
| 3 |
+
=====================
|
| 4 |
+
Uses the original sanskrit_tokenizer_m4pro.json — the exact one the model
|
| 5 |
+
was trained with. Hard-coded absolute path as primary, with fallbacks.
|
| 6 |
+
|
| 7 |
+
This tokenizer has NO </w> end-of-word markers and NO decoder set.
|
| 8 |
+
decode() returns space-separated BPE pieces — this is the format the
|
| 9 |
+
model was trained and evaluated on (BERTScore 0.71). Do NOT add a decoder
|
| 10 |
+
or retrain: that would break alignment with the checkpoint.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from tokenizers import Tokenizer
|
| 14 |
+
from tokenizers.models import BPE
|
| 15 |
+
from tokenizers.trainers import BpeTrainer
|
| 16 |
+
from tokenizers.pre_tokenizers import Whitespace
|
| 17 |
+
from datasets import load_dataset
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
import os
|
| 20 |
+
|
| 21 |
+
# Hard-coded absolute path — update if you move the project
|
| 22 |
+
TOKENIZER_PATH = "/Users/bhsingh/Documents/Final_Paraphrase/sanskrit_tokenizer_m4pro.json"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def build_tokenizer(texts, vocab_size=16000):
|
| 26 |
+
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
|
| 27 |
+
tokenizer.pre_tokenizer = Whitespace()
|
| 28 |
+
trainer = BpeTrainer(
|
| 29 |
+
vocab_size=vocab_size,
|
| 30 |
+
special_tokens=["[MASK]", "[PAD]", "[UNK]", "[CLS]", "[SEP]"],
|
| 31 |
+
min_frequency=2,
|
| 32 |
+
)
|
| 33 |
+
tokenizer.train_from_iterator(texts, trainer)
|
| 34 |
+
return tokenizer
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class SanskritTokenizer:
|
| 38 |
+
def __init__(self, vocab_size=16000, max_len=80):
|
| 39 |
+
self.vocab_size = vocab_size
|
| 40 |
+
self.max_len = max_len
|
| 41 |
+
self.mask_token_id = 0
|
| 42 |
+
|
| 43 |
+
script_dir = Path(__file__).resolve().parent
|
| 44 |
+
candidates = [
|
| 45 |
+
os.environ.get("SANSKRIT_TOKENIZER_PATH", ""),
|
| 46 |
+
TOKENIZER_PATH,
|
| 47 |
+
str(script_dir.parent / "sanskrit_tokenizer_m4pro.json"),
|
| 48 |
+
str(script_dir / "sanskrit_tokenizer_m4pro.json"),
|
| 49 |
+
str(Path.cwd() / "sanskrit_tokenizer_m4pro.json"),
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
self.model_path = None
|
| 53 |
+
for c in candidates:
|
| 54 |
+
if c and Path(c).exists():
|
| 55 |
+
self.model_path = c
|
| 56 |
+
break
|
| 57 |
+
|
| 58 |
+
if self.model_path:
|
| 59 |
+
print(f"📖 Loading tokenizer from: {self.model_path}")
|
| 60 |
+
self.tokenizer = Tokenizer.from_file(self.model_path)
|
| 61 |
+
self._validate_mask_token()
|
| 62 |
+
else:
|
| 63 |
+
print(f"⚠️ Tokenizer not found at any candidate path.")
|
| 64 |
+
print(f" Expected: {TOKENIZER_PATH}")
|
| 65 |
+
print(" Retraining — WARNING: output will not match existing checkpoint!")
|
| 66 |
+
self.model_path = TOKENIZER_PATH
|
| 67 |
+
self._train_tokenizer()
|
| 68 |
+
|
| 69 |
+
def _validate_mask_token(self):
|
| 70 |
+
mask_id = self.tokenizer.token_to_id("[MASK]")
|
| 71 |
+
assert mask_id == 0, f"[MASK] must be ID 0, got {mask_id}"
|
| 72 |
+
print("✅ [MASK] token confirmed at ID=0")
|
| 73 |
+
|
| 74 |
+
def _train_tokenizer(self):
|
| 75 |
+
dataset = load_dataset("paws/sanskrit-verses-gretil", split="train")
|
| 76 |
+
texts = []
|
| 77 |
+
for sample in dataset.select(range(50000)):
|
| 78 |
+
texts.extend([sample["quote_text"], sample["quote_devanagari"]])
|
| 79 |
+
tokenizer = build_tokenizer(texts, self.vocab_size)
|
| 80 |
+
tokenizer.save(self.model_path)
|
| 81 |
+
self.tokenizer = tokenizer
|
| 82 |
+
self._validate_mask_token()
|
| 83 |
+
print(f"✅ Tokenizer saved to: {self.model_path}")
|
| 84 |
+
|
| 85 |
+
def encode(self, text):
|
| 86 |
+
encoded = self.tokenizer.encode(text)
|
| 87 |
+
token_ids = encoded.ids[:self.max_len]
|
| 88 |
+
pad_id = self.tokenizer.token_to_id("[PAD]")
|
| 89 |
+
if len(token_ids) < self.max_len:
|
| 90 |
+
token_ids += [pad_id] * (self.max_len - len(token_ids))
|
| 91 |
+
return token_ids[:self.max_len]
|
| 92 |
+
|
| 93 |
+
def decode(self, ids):
|
| 94 |
+
if isinstance(ids, list) and len(ids) > 0 and isinstance(ids[0], list):
|
| 95 |
+
raise TypeError("decode() expects 1D list of IDs, not 2D.")
|
| 96 |
+
# Filter special tokens: 0=MASK 1=PAD 2=UNK 3=CLS 4=SEP
|
| 97 |
+
clean = [i for i in ids if isinstance(i, int) and i > 4]
|
| 98 |
+
if not clean:
|
| 99 |
+
return ""
|
| 100 |
+
return self.tokenizer.decode(clean, skip_special_tokens=True).strip()
|
| 101 |
+
|
| 102 |
+
def build_inputs_with_special_tokens(self, token_ids):
|
| 103 |
+
return list(token_ids)
|
| 104 |
+
|
| 105 |
+
def get_vocab(self):
|
| 106 |
+
return {str(i): i for i in range(self.vocab_size)}
|
| 107 |
+
|
| 108 |
+
def convert_ids_to_tokens(self, ids):
|
| 109 |
+
return [str(i) for i in ids]
|
| 110 |
+
|
| 111 |
+
def __len__(self):
|
| 112 |
+
return self.vocab_size
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.2
|
| 2 |
+
numpy>=1.24
|
| 3 |
+
tqdm>=4.66
|
| 4 |
+
datasets>=2.19
|
| 5 |
+
tokenizers>=0.15
|
| 6 |
+
scikit-learn>=1.3
|
sanskrit_src_tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sanskrit_tgt_tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|