bhsinghgrid commited on
Commit
6d34b0d
·
verified ·
1 Parent(s): 2a01299

Upload folder using huggingface_hub

Browse files
Files changed (48) hide show
  1. .gitattributes +2 -34
  2. .gitignore +2 -0
  3. README.md +122 -0
  4. analysis_reports/outputs_all_models_20260325/T16/task1_kv_cache.txt +15 -0
  5. analysis_reports/outputs_all_models_20260325/T16/task2_report.txt +35 -0
  6. analysis_reports/outputs_all_models_20260325/T16/task3_report.txt +21 -0
  7. analysis_reports/outputs_all_models_20260325/T16/task4_report.txt +14 -0
  8. analysis_reports/outputs_all_models_20260325/T16/task5_report.txt +15 -0
  9. analysis_reports/outputs_all_models_20260325/T32/task1_kv_cache.txt +15 -0
  10. analysis_reports/outputs_all_models_20260325/T32/task2_report.txt +35 -0
  11. analysis_reports/outputs_all_models_20260325/T32/task3_report.txt +21 -0
  12. analysis_reports/outputs_all_models_20260325/T32/task4_report.txt +14 -0
  13. analysis_reports/outputs_all_models_20260325/T32/task5_report.txt +15 -0
  14. analysis_reports/outputs_all_models_20260325/T4/task1_kv_cache.txt +15 -0
  15. analysis_reports/outputs_all_models_20260325/T4/task2_report.txt +29 -0
  16. analysis_reports/outputs_all_models_20260325/T4/task3_report.txt +21 -0
  17. analysis_reports/outputs_all_models_20260325/T4/task4_report.txt +14 -0
  18. analysis_reports/outputs_all_models_20260325/T4/task5_report.txt +15 -0
  19. analysis_reports/outputs_all_models_20260325/T64/task1_kv_cache.txt +15 -0
  20. analysis_reports/outputs_all_models_20260325/T64/task2_report.txt +35 -0
  21. analysis_reports/outputs_all_models_20260325/T64/task3_report.txt +21 -0
  22. analysis_reports/outputs_all_models_20260325/T64/task4_report.txt +14 -0
  23. analysis_reports/outputs_all_models_20260325/T64/task5_report.txt +15 -0
  24. analysis_reports/outputs_all_models_20260325/T8/task1_kv_cache.txt +15 -0
  25. analysis_reports/outputs_all_models_20260325/T8/task2_report.txt +33 -0
  26. analysis_reports/outputs_all_models_20260325/T8/task3_report.txt +21 -0
  27. analysis_reports/outputs_all_models_20260325/T8/task4_report.txt +14 -0
  28. analysis_reports/outputs_all_models_20260325/T8/task5_report.txt +15 -0
  29. config.py +33 -0
  30. diffusion/__init__.py +0 -0
  31. diffusion/forward_process.py +21 -0
  32. diffusion/reverse_process.py +302 -0
  33. diffusion/reverse_process1.py +154 -0
  34. diffusion/reverse_process2.py +275 -0
  35. diffusion/scheduler.py +34 -0
  36. handler.py +30 -0
  37. inference.py +554 -0
  38. inference_api.py +131 -0
  39. model/__init__.py +0 -0
  40. model/d3pm_model_cross_attention.py +271 -0
  41. model/d3pm_model_encoder_decoder.py +227 -0
  42. model/sanskrit_model.py +61 -0
  43. model/tokenizer.py +222 -0
  44. model/tokenizers.py +112 -0
  45. model_settings.json +5 -0
  46. requirements.txt +6 -0
  47. sanskrit_src_tokenizer.json +0 -0
  48. sanskrit_tgt_tokenizer.json +0 -0
.gitattributes CHANGED
@@ -1,35 +1,3 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.pt filter=lfs diff=lfs merge=lfs -text
2
+ *.bin filter=lfs diff=lfs merge=lfs -text
3
+ *.arrow filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__/
2
+ *.pyc
README.md ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - sa
5
+ - en
6
+ tags:
7
+ - sanskrit
8
+ - paraphrase
9
+ - diffusion
10
+ - d3pm
11
+ - pytorch
12
+ pipeline_tag: text2text-generation
13
+ ---
14
+
15
+ # Sanskrit D3PM Encoder-Decoder Model
16
+
17
+ Roman/IAST Sanskrit input to Devanagari output using a custom D3PM checkpoint.
18
+ This package is configured for the `d3pm_encoder_decoder` checkpoint stored in
19
+ `best_model.pt`.
20
+ Hugging Face model repo: `bhsinghgrid/devflow2`
21
+
22
+ ## Files Included
23
+
24
+ - `best_model.pt` — trained checkpoint
25
+ - `model_settings.json` — packaged runtime metadata
26
+ - `config.py` — runtime config
27
+ - `inference.py` — model loading + generation loop
28
+ - `inference_api.py` — simple Python API (`predict`)
29
+ - `handler.py` — Hugging Face Endpoint handler
30
+ - `model/`, `diffusion/` — architecture modules
31
+ - `sanskrit_src_tokenizer.json`, `sanskrit_tgt_tokenizer.json` — tokenizers
32
+
33
+ ## Quick Local Test
34
+
35
+ ```python
36
+ from inference_api import predict
37
+ print(predict("dharmo rakṣati rakṣitaḥ")["output"])
38
+ ```
39
+
40
+ ## Runtime Settings
41
+
42
+ For local/API usage, the runtime first reads `model_settings.json`, then allows
43
+ optional environment variable overrides:
44
+
45
+ - `HF_MODEL_TYPE` = `d3pm_cross_attention` or `d3pm_encoder_decoder`
46
+ - `HF_INCLUDE_NEG` = `true` or `false`
47
+ - `HF_NUM_STEPS` = diffusion step count for the packaged checkpoint
48
+
49
+ Packaged settings for this repo:
50
+
51
+ ```bash
52
+ export HF_MODEL_TYPE=d3pm_encoder_decoder
53
+ export HF_INCLUDE_NEG=false
54
+ export HF_NUM_STEPS=4
55
+ ```
56
+
57
+ ## Use This Model In A Hugging Face Space
58
+
59
+ In your Space settings, set:
60
+
61
+ - `HF_CHECKPOINT_REPO=bhsinghgrid/devflow2`
62
+ - `HF_CHECKPOINT_FILE=best_model.pt`
63
+
64
+ If your Space reads model metadata automatically, no extra model-type variables
65
+ are required. If it does not, also set:
66
+
67
+ ```bash
68
+ HF_DEFAULT_MODEL_TYPE=d3pm_encoder_decoder
69
+ HF_DEFAULT_INCLUDE_NEG=false
70
+ HF_DEFAULT_NUM_STEPS=4
71
+ ```
72
+
73
+ ## Transformer-Style Usage (Custom Runtime)
74
+
75
+ This checkpoint is a custom D3PM architecture (`.pt`), not a native `transformers`
76
+ `AutoModel` format. Use it via the provided runtime:
77
+
78
+ ```python
79
+ import torch
80
+ from config import CONFIG
81
+ from inference import load_model, run_inference, _decode_clean
82
+ from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer
83
+
84
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
85
+ model, cfg = load_model("best_model.pt", CONFIG, device)
86
+
87
+ src_tok = SanskritSourceTokenizer(vocab_size=16000, max_len=cfg["model"]["max_seq_len"])
88
+ tgt_tok = SanskritTargetTokenizer(vocab_size=16000, max_len=cfg["model"]["max_seq_len"])
89
+
90
+ text = "dharmo rakṣati rakṣitaḥ"
91
+ ids = torch.tensor([src_tok.encode(text)], dtype=torch.long, device=device)
92
+ out = run_inference(model, ids, cfg)
93
+ print(_decode_clean(tgt_tok, out[0].tolist()))
94
+ ```
95
+
96
+ If you need full `transformers` compatibility (`AutoModel.from_pretrained`),
97
+ export weights to a Hugging Face Transformers model format first.
98
+
99
+ ## Endpoint Payload
100
+
101
+ ```json
102
+ {
103
+ "inputs": "yadā mano nivarteta viṣayebhyaḥ svabhāvataḥ",
104
+ "parameters": {
105
+ "temperature": 0.7,
106
+ "top_k": 40,
107
+ "repetition_penalty": 1.2,
108
+ "diversity_penalty": 0.0,
109
+ "num_steps": 4,
110
+ "clean_output": true
111
+ }
112
+ }
113
+ ```
114
+
115
+ ## Push This Folder To Model Hub
116
+
117
+ ```bash
118
+ cd hf_model_repo_encoder_decoder
119
+ git add .
120
+ git commit -m "Add encoder-decoder T4 model package"
121
+ git push -u hf main
122
+ ```
analysis_reports/outputs_all_models_20260325/T16/task1_kv_cache.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 1 — KV CACHE BENCHMARK
2
+ ========================================
3
+
4
+ has_generate_cached=True
5
+ memory_profile=Torch CPU mem-event reduction: 30.4% @ src_len=64 (std=2143.0MB, cache=1492.1MB)
6
+
7
+ src_len standard(s) cached(s) speedup encoder%
8
+ 16 0.893 0.571 1.56x 40.0%
9
+ 32 0.751 0.509 1.48x 42.3%
10
+ 64 1.141 0.822 1.39x 40.7%
11
+
12
+ Saved graphs:
13
+ - task1_time_comparison.png
14
+ - task1_speedup.png
15
+ - task1_encoder_cost.png
analysis_reports/outputs_all_models_20260325/T16/task2_report.txt ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 2 — ATTENTION + DRIFT REPORT
2
+ ==================================================
3
+
4
+ Input : dharmo rakṣati rakṣitaḥ
5
+ Output: धर्मो रक्षति रक्षितः
6
+
7
+ Captured steps: 16
8
+ Analysis quality: WEAK
9
+ Final output uniq-ratio: 1.000
10
+ Degenerate output: NO
11
+ Multi-sample semantic score (n<=8): 0.1471
12
+ Lock-in step (CER<=0.05): t=0
13
+ Locked tokens: 38 Flexible tokens: 42
14
+ TF-IDF vs attention stability corr: 0.9294
15
+ TF-IDF status: OK
16
+
17
+ Saved graphs:
18
+ - task2_attn_t*.png / task2_all_layers_t0.png
19
+ - task2_attn_evolution.png
20
+ - task2_semantic_drift.png
21
+ - task2_source_alignment.png
22
+ - task2_tfidf_vs_attention.png
23
+
24
+ Step trajectory (first 10 rows)
25
+ ------------------------------------------------------------
26
+ t= 15 bert=0.0475 drift=0.9525 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
27
+ t= 14 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
28
+ t= 13 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
29
+ t= 12 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
30
+ t= 11 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
31
+ t= 10 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
32
+ t= 9 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
33
+ t= 8 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
34
+ t= 7 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
35
+ t= 6 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
analysis_reports/outputs_all_models_20260325/T16/task3_report.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 3 — CONCEPT VECTORS + PCA STEERING
2
+ ==================================================
3
+
4
+ PCA: 50 components, 74.8% variance
5
+ Diversity PC: 0 (|r|=0.325 with diversity proxy)
6
+
7
+ Direction validity: WEAK
8
+ Spectrum unique ratio (mean over 5 seeds): 1.000
9
+ Spectrum semantic stability (mean over 5 seeds): 0.312
10
+
11
+ Saved graphs:
12
+ - task3_concept_space.png
13
+ - task3_pca_explained_variance.png
14
+ - task3_diversity_curve.png
15
+
16
+ Diversity spectrum:
17
+ alpha=-2.0 → बले वेध विवर् धान वीर्य वीर्य धिं सिंहा भि̱ सन वस्तु वेध वै वेध वस्तु सन सन सिंहा सिंहा वीर्य वीर्य वस्तु सन रुते प्रभवति मन वेध बले बले र्वृ प्रपूजयेत् युगा मलि धान तुल वीर्य वीर्य वीर्य वीर्य वीर्य वीर्य धान तुल कालेन युगा वेध बले वेध वेध च्छे ष्मस् यस्या काष्ठा ज्ञप्त अर्णव धिं धिं वस्तु धिं सन तया सन सन देवाः देवाः स्वातन्त्र अर्णव मह वस्तु मुष् सन धिं धिं धिं विक्र त्र मह हस्ते च्छे मह
18
+ alpha=-1.0 → बले र् अ तुल वीर्य वीर्य गुरु सिंहा सन सन विलेप वै वै वै गतस्य वेध सन सिंहा सिंहा स्य स्य । सन वै वै वै बले बले बले बले र् अ अ तुल तुल वीर्य वीर्य वीर्य वीर्य वीर्य वीर्य तुल तुल तुल ् बले वेध दिव्यां मान वै अप्सु सन ॥ ॥ वस्तु सिंहा सन सन विक्र सन स काष्ठा सन सन सन कार सन सन सन सन भ बल ु सिंहा सन सिंहा सन म् म् सन
19
+ alpha=+0.0 → बले र् अ तुल वीर्य वीर्य स्य सिंहा सन सन पितो वै वै वै दक्षिणां सन सन सिंहा सिंहा स्य स्य स्य सन गतस्य वै वै ॥ बले बले र् र् अ अ । तुल वीर्य वीर्य वीर्य वीर्य वीर्य तुल तुल तुल तुल अ स बले बले वै वै ॥ ॥ ॥ सन सन सिंहा स सन सन सन सन सन सन सन सन सन सन ॥ ॥ सन सन शतैः ॥ सिंहा सिंहा द सिंहा सन त् सन
20
+ alpha=+1.0 → बले र् अ अ विशुद्धं स्य स्य सिंहा सिंहा सन गतस्य वै वै वै वेत्ति सन सन सिंहा स्य स्य स्य स्य सन वै वै स मल बले बले र् र् व अ अ तुल वीर्य वीर्य वीर्य स्य वीर्य स्य तुल ानु अ अ । र् व ॥ वै वै सन द ॥ ॥ सिंहा सिंहा ॥ सं सन ॥ ॥ व ॥ ॥ हेम सन सन व ॥ ै ॥ वै भ न न ॥ मित्रो सिंहा सन
21
+ alpha=+2.0 → आविश र् अ किंचिद् वर स्य स्य सिंहा सं निमे ञ् सं वै वै ञ् सन कृपा सिंहा स्य स्य स्य स्य फणा ञ् वै ौ जिह्व बले मानाः र् र् वराय अ माने वर विशुद्धं स्य स्य स्य – वर विशुद्धं व वर अ कृपा ॥ परम् ॥ कश्चि वै ॥ ञ् ञ् सं स्य स्य तम् व प्रवर्तन्ते कर्मसु परम् वर ते ॥ व ञ् ॥ ॥ सं द ॥ ॥ वर न्द ̱व ॥ व व ै
analysis_reports/outputs_all_models_20260325/T16/task4_report.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 4 — SEMANTIC ROBUSTNESS ABLATION
2
+ ==================================================
3
+
4
+ Optimal diffusion steps = 16
5
+
6
+ T BERT-F1 SEM_SIM BLEU sec/sample
7
+ --------------------------------------------------------
8
+ 16 0.2574 0.0580 0.0007 0.9068
9
+
10
+ Marginal gains (BERT-F1):
11
+
12
+ Saved plots/files:
13
+ - task4_3d.png
14
+ - task4_raw_results.json
analysis_reports/outputs_all_models_20260325/T16/task5_report.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 5 — CLASSIFIER-FREE GUIDANCE
2
+ ==================================================
3
+
4
+ Classifier params: 139521
5
+ Training samples : 40
6
+
7
+ Guidance scale sweep:
8
+ λ CER diversity d2 sBLEU
9
+ ----------------------------------------------------
10
+ 0.0 0.8336 0.808 0.624 0.007 ← optimal
11
+ 0.5 0.8362 0.800 0.606 0.007
12
+ 1.0 0.8390 0.798 0.601 0.005
13
+ 1.5 0.8458 0.813 0.631 0.004
14
+ 2.0 0.8531 0.828 0.660 0.004
15
+ 3.0 0.8773 0.830 0.669 0.009
analysis_reports/outputs_all_models_20260325/T32/task1_kv_cache.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 1 — KV CACHE BENCHMARK
2
+ ========================================
3
+
4
+ has_generate_cached=True
5
+ memory_profile=Torch CPU mem-event reduction: 31.1% @ src_len=64 (std=4287.2MB, cache=2953.9MB)
6
+
7
+ src_len standard(s) cached(s) speedup encoder%
8
+ 16 1.914 1.165 1.64x 39.6%
9
+ 32 1.542 0.891 1.73x 42.1%
10
+ 64 2.096 1.475 1.42x 42.7%
11
+
12
+ Saved graphs:
13
+ - task1_time_comparison.png
14
+ - task1_speedup.png
15
+ - task1_encoder_cost.png
analysis_reports/outputs_all_models_20260325/T32/task2_report.txt ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 2 — ATTENTION + DRIFT REPORT
2
+ ==================================================
3
+
4
+ Input : dharmo rakṣati rakṣitaḥ
5
+ Output: धर्मो रक्षति रक्षितः
6
+
7
+ Captured steps: 32
8
+ Analysis quality: WEAK
9
+ Final output uniq-ratio: 1.000
10
+ Degenerate output: NO
11
+ Multi-sample semantic score (n<=8): 0.0627
12
+ Lock-in step (CER<=0.05): t=0
13
+ Locked tokens: 75 Flexible tokens: 5
14
+ TF-IDF vs attention stability corr: -0.0869
15
+ TF-IDF status: WEAK
16
+
17
+ Saved graphs:
18
+ - task2_attn_t*.png / task2_all_layers_t0.png
19
+ - task2_attn_evolution.png
20
+ - task2_semantic_drift.png
21
+ - task2_source_alignment.png
22
+ - task2_tfidf_vs_attention.png
23
+
24
+ Step trajectory (first 10 rows)
25
+ ------------------------------------------------------------
26
+ t= 31 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
27
+ t= 30 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
28
+ t= 29 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
29
+ t= 28 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
30
+ t= 27 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
31
+ t= 26 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
32
+ t= 25 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
33
+ t= 24 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
34
+ t= 23 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
35
+ t= 22 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
analysis_reports/outputs_all_models_20260325/T32/task3_report.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 3 — CONCEPT VECTORS + PCA STEERING
2
+ ==================================================
3
+
4
+ PCA: 50 components, 94.6% variance
5
+ Diversity PC: 0 (|r|=-0.530 with diversity proxy)
6
+
7
+ Direction validity: WEAK
8
+ Spectrum unique ratio (mean over 5 seeds): 0.840
9
+ Spectrum semantic stability (mean over 5 seeds): 0.234
10
+
11
+ Saved graphs:
12
+ - task3_concept_space.png
13
+ - task3_pca_explained_variance.png
14
+ - task3_diversity_curve.png
15
+
16
+ Diversity spectrum:
17
+ alpha=-2.0 → ेन श्रे श्रे ेन श्रे अण्ड व्याः श्रे तन्त्रा ॥ ॥ ॥ व्याः व्याः व्याः तद्वद् तद्वद् तद्वद् ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ तद्वद् ॥ ॥ ॥ ॥ ॥ व्याः व्याः व्याः ॥ ॥ राजन्य व्याः व्याः व्याः ॥ व्याः व्याः ॥ ॥ काम्य ॥ ॥ ॥ व्याः ॥ तद्वद् ॥ ॥ ॥ ॥ ॥ तन्त्रा तन्त्रा ॥ ॥ ॥ ॥ व्याः ॥ ॥ ॥ ॥ ॥ युधम् तद्वद् युधम् ॥
18
+ alpha=-1.0 → श्रे श्रे श्रे ेन श्रे श्रे श्रे श्रे अण्ड तन्त्रा व्याः ॥ अण्ड अण्ड तन्त्रा व्याः तद्वद् ॥ व्याः ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ अण्ड ॥ ॥ ॥ व्याः ॥ व्याः नो̍ ॥ ॥ ॥ ॥ ॥ व्याः व्याः अण्ड ॥ ॥ तन्त्रा ॥ ॥ तद्वद् युधम् रोमा शम्भु ॥ धूमं तन्त्रा ॥ तन्त्रा ॥ व्याः ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥
19
+ alpha=+0.0 → अण्ड श्रे करः श्रे तन्त्रा करः करः तन्त्रा श्रे अण्ड अण्ड अण्ड ॥ श्रे तद्वद् अण्ड ॥ ॥ अण्ड ॥ ॥ ॥ ॥ ॥ ॥ ॥ अण्ड ॥ ॥ ॥ ॥ अण्ड ॥ ॥ ॥ ॥ ॥ ॥ राजन्य तन्त्रा नो̍ ॥ ॥ ॥ ॥ ॥ व्याः ॥ अण्ड ॥ काम्य ॥ ॥ ॥ ॥ ॥ शम्भु धूमं तन्त्रा तन्त्रा ेन ॥ काम्य ॥ ॥ करः तन्त्रा ॥ अण्ड ॥ अण्ड ॥ विनिर्जित्य ॥ ॥ ॥ तन्त्रा अण्ड तद्वद् करः
20
+ alpha=+1.0 → माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण
21
+ alpha=+2.0 → माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण
analysis_reports/outputs_all_models_20260325/T32/task4_report.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 4 — SEMANTIC ROBUSTNESS ABLATION
2
+ ==================================================
3
+
4
+ Optimal diffusion steps = 32
5
+
6
+ T BERT-F1 SEM_SIM BLEU sec/sample
7
+ --------------------------------------------------------
8
+ 32 0.0422 0.0012 0.0000 1.8451
9
+
10
+ Marginal gains (BERT-F1):
11
+
12
+ Saved plots/files:
13
+ - task4_3d.png
14
+ - task4_raw_results.json
analysis_reports/outputs_all_models_20260325/T32/task5_report.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 5 — CLASSIFIER-FREE GUIDANCE
2
+ ==================================================
3
+
4
+ Classifier params: 139521
5
+ Training samples : 40
6
+
7
+ Guidance scale sweep:
8
+ λ CER diversity d2 sBLEU
9
+ ----------------------------------------------------
10
+ 0.0 0.9357 0.239 0.011 0.533 ← optimal
11
+ 0.5 0.9372 0.251 0.015 0.512
12
+ 1.0 0.9467 0.164 0.018 0.690
13
+ 1.5 0.9528 0.137 0.017 0.743
14
+ 2.0 0.9525 0.144 0.013 0.725
15
+ 3.0 0.9496 0.181 0.018 0.656
analysis_reports/outputs_all_models_20260325/T4/task1_kv_cache.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 1 — KV CACHE BENCHMARK
2
+ ========================================
3
+
4
+ has_generate_cached=True
5
+ memory_profile=Torch CPU mem-event reduction: 24.6% @ src_len=64 (std=525.8MB, cache=396.4MB)
6
+
7
+ src_len standard(s) cached(s) speedup encoder%
8
+ 16 0.267 0.173 1.54x 43.2%
9
+ 32 0.197 0.153 1.29x 40.7%
10
+ 64 0.353 0.265 1.33x 42.0%
11
+
12
+ Saved graphs:
13
+ - task1_time_comparison.png
14
+ - task1_speedup.png
15
+ - task1_encoder_cost.png
analysis_reports/outputs_all_models_20260325/T4/task2_report.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 2 — ATTENTION + DRIFT REPORT
2
+ ==================================================
3
+
4
+ Input : dharmo rakṣati rakṣitaḥ
5
+ Output: धर्मो रक्षति रक्षितः
6
+
7
+ Captured steps: 4
8
+ Analysis quality: VALID
9
+ Final output uniq-ratio: 1.000
10
+ Degenerate output: NO
11
+ Multi-sample semantic score (n<=8): 0.1568
12
+ Lock-in step (CER<=0.05): t=0
13
+ Locked tokens: 79 Flexible tokens: 1
14
+ TF-IDF vs attention stability corr: 0.9472
15
+ TF-IDF status: OK
16
+
17
+ Saved graphs:
18
+ - task2_attn_t*.png / task2_all_layers_t0.png
19
+ - task2_attn_evolution.png
20
+ - task2_semantic_drift.png
21
+ - task2_source_alignment.png
22
+ - task2_tfidf_vs_attention.png
23
+
24
+ Step trajectory (first 10 rows)
25
+ ------------------------------------------------------------
26
+ t= 3 bert=0.0603 drift=0.9397 text=ति ति ति रक्षि तः तः तः तः तः धर्मो धर्मो धर्मो धर्मो धर्मो
27
+ t= 2 bert=0.0597 drift=0.9403 text=ति ति ति रक्षि तः तः तः तः तः धर्मो धर्मो धर्मो धर्मो धर्मो
28
+ t= 1 bert=0.0597 drift=0.9403 text=ति ति ति रक्षि तः तः तः तः तः धर्मो धर्मो धर्मो धर्मो धर्मो
29
+ t= 0 bert=0.0597 drift=0.9403 text=ति ति ति रक्षि तः तः तः तः तः धर्मो धर्मो धर्मो धर्मो धर्मो
analysis_reports/outputs_all_models_20260325/T4/task3_report.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 3 — CONCEPT VECTORS + PCA STEERING
2
+ ==================================================
3
+
4
+ PCA: 50 components, 72.0% variance
5
+ Diversity PC: 0 (|r|=-0.349 with diversity proxy)
6
+
7
+ Direction validity: WEAK
8
+ Spectrum unique ratio (mean over 5 seeds): 1.000
9
+ Spectrum semantic stability (mean over 5 seeds): 0.325
10
+
11
+ Saved graphs:
12
+ - task3_concept_space.png
13
+ - task3_pca_explained_variance.png
14
+ - task3_diversity_curve.png
15
+
16
+ Diversity spectrum:
17
+ alpha=-2.0 → बले र् अपश्य येहि ऌ वीर्य ऌ सिंहा सन सन ̍त̱ ज्ज्वा माम् वै वै महर्द्धि महर्द्धि ऌ सिंहा कू दिक्षु ऌ दश्य वै क्रमं बले र् दश्य स्वस्थ तुल तुल वीर्य वीर्य वी ऌ सिंहा राज कू वीर्य वीर्य वीर्य वीर्य ऌ वी निरुद्धा ̍त̱ बले बले साध्व उपशान्त वी वी दाक्षि हतः महर्द्धि साध्व तु वी वी ऌ दिक्षु दिक्षु पूष माम् पुरं ऌ दिक्षु वी पूष ̍त̱ ोद् दिक्षु पुरं स्त्रं मनोरथ अस्मा ऌ वाहि राजान वी
18
+ alpha=-1.0 → बले बले अ तुल तुल वीर्य स्य सिंहा सन सन गतस्य गतस्य वै वै वै गतस्य सन पाता सिंहा दिता । ज्ज्वा वै वै बले बले र् अ अ तुल तुल वीर्य वीर्य स्य सिंहा सिंहा ध्रा स्य वीर्य वीर्य वीर्य तुल तुल ̍त̱ अ र् र् बले दिक्षु वै वै । वै संस्थिता रतं सन गतस्य पूष । वक्त्र सन सन सन सन सन व गतस्य व सन ॥ ति मनो हतः मातु ̍त̱ व कू कू सन सन
19
+ alpha=+0.0 → बले र् अ तुल तुल वीर्य स्य सिंहा सन सन गतस्य गतस्य वै वै वै गतस्य सन सन सिंहा सिंहा । व वै वै त्वम् बले र् अ अ तुल तुल वीर्य वीर्य स्य स्य स्य सिंहा स्य स्य वीर्य वीर्य तुल तुल तुल अ र् र् र् त्ते वै वै गतस्य सन सन सन सन सन गतस्य निःसृ गतस्य सन गतस्य सन सन सन सन सन वि सन वि स्रव सिंहा सन सन सन सन सन सन सन गतस्य
20
+ alpha=+1.0 → बले र् अ अ तुल वीर्य स्य सिंहा सन सन गतस्य गतस्य वै वै वै गतस्य सन सन सिंहा सिंहा षण् स्य ै वै बले बले र् अ अ तुल कान्ते षण् वीर्य स्य स्य सिंहा स्य स्य स्य वीर्य षण् वीर्य षण् अ अ र् र् र् षण् ेष गतस्य गतस्य गतस्य गतस्य सन सन षण् षण् गतस्य सन गतस्य गतस्य सन सन सन सन ष्णु गतस्य नो षण् नो - सन सन सन सन सन सन सन सन
21
+ alpha=+2.0 → षण् र् अ षण् षण् षण् स्य षण् षण् षण् षण् षण् वै षण् षण् गतस्य षण् षण् षण् षण् षण् षण् षण् स षण् षण् र् षण् अ षण् षण् षण् षण् स्य स्य षण् स्य स्य स्य षण् षण् षण् षण् षण् अ र् षण् षण् षण् षण् षण् षण् षण् षण् षण् षण् षण् षण् षण् षण् षण् षण् षण् गतस्य षण् षण् षण् षण् षण् षण् षण् षण् षण् षण् षण् षण् षण् षण् षण् षण्
analysis_reports/outputs_all_models_20260325/T4/task4_report.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 4 — SEMANTIC ROBUSTNESS ABLATION
2
+ ==================================================
3
+
4
+ Optimal diffusion steps = 4
5
+
6
+ T BERT-F1 SEM_SIM BLEU sec/sample
7
+ --------------------------------------------------------
8
+ 4 0.2644 0.0574 0.0000 0.2782
9
+
10
+ Marginal gains (BERT-F1):
11
+
12
+ Saved plots/files:
13
+ - task4_3d.png
14
+ - task4_raw_results.json
analysis_reports/outputs_all_models_20260325/T4/task5_report.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 5 — CLASSIFIER-FREE GUIDANCE
2
+ ==================================================
3
+
4
+ Classifier params: 139521
5
+ Training samples : 40
6
+
7
+ Guidance scale sweep:
8
+ λ CER diversity d2 sBLEU
9
+ ----------------------------------------------------
10
+ 0.0 0.8366 0.815 0.635 0.005
11
+ 0.5 0.8356 0.797 0.599 0.004 ← optimal
12
+ 1.0 0.8369 0.791 0.588 0.006
13
+ 1.5 0.8367 0.783 0.571 0.006
14
+ 2.0 0.8367 0.774 0.553 0.005
15
+ 3.0 0.8363 0.769 0.542 0.005
analysis_reports/outputs_all_models_20260325/T64/task1_kv_cache.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 1 — KV CACHE BENCHMARK
2
+ ========================================
3
+
4
+ has_generate_cached=True
5
+ memory_profile=Torch CPU mem-event reduction: 31.4% @ src_len=64 (std=8592.3MB, cache=5890.5MB)
6
+
7
+ src_len standard(s) cached(s) speedup encoder%
8
+ 16 4.206 3.584 1.17x 74.4%
9
+ 32 4.647 3.371 1.38x 37.6%
10
+ 64 8.403 4.593 1.83x 49.6%
11
+
12
+ Saved graphs:
13
+ - task1_time_comparison.png
14
+ - task1_speedup.png
15
+ - task1_encoder_cost.png
analysis_reports/outputs_all_models_20260325/T64/task2_report.txt ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 2 — ATTENTION + DRIFT REPORT
2
+ ==================================================
3
+
4
+ Input : dharmo rakṣati rakṣitaḥ
5
+ Output: धर्मो रक्षति रक्षितः
6
+
7
+ Captured steps: 64
8
+ Analysis quality: VALID
9
+ Final output uniq-ratio: 1.000
10
+ Degenerate output: NO
11
+ Multi-sample semantic score (n<=8): 0.1490
12
+ Lock-in step (CER<=0.05): t=0
13
+ Locked tokens: 59 Flexible tokens: 21
14
+ TF-IDF vs attention stability corr: 0.7804
15
+ TF-IDF status: OK
16
+
17
+ Saved graphs:
18
+ - task2_attn_t*.png / task2_all_layers_t0.png
19
+ - task2_attn_evolution.png
20
+ - task2_semantic_drift.png
21
+ - task2_source_alignment.png
22
+ - task2_tfidf_vs_attention.png
23
+
24
+ Step trajectory (first 10 rows)
25
+ ------------------------------------------------------------
26
+ t= 63 bert=0.0552 drift=0.9448 text=धर्मो ति काम्य तः तः तः तः तः तः धर्मो धर्मो धर्मो धर्मो धर्
27
+ t= 62 bert=0.0548 drift=0.9452 text=धर्मो ति काम्य तः तः तः तः तः तः धर्मो धर्मो धर्मो धर्मो धर्
28
+ t= 61 bert=0.0548 drift=0.9452 text=धर्मो ति काम्य तः तः तः तः तः तः धर्मो धर्मो धर्मो धर्मो धर्
29
+ t= 60 bert=0.0548 drift=0.9452 text=धर्मो ति काम्य तः तः तः तः तः तः धर्मो धर्मो धर्मो धर्मो धर्
30
+ t= 59 bert=0.0548 drift=0.9452 text=धर्मो ति काम्य तः तः तः तः तः तः धर्मो धर्मो धर्मो धर्मो धर्
31
+ t= 58 bert=0.0548 drift=0.9452 text=धर्मो ति काम्य तः तः तः तः तः तः धर्मो धर्मो धर्मो धर्मो धर्
32
+ t= 57 bert=0.0548 drift=0.9452 text=धर्मो ति काम्य तः तः तः तः तः तः धर्मो धर्मो धर्मो धर्मो धर्
33
+ t= 56 bert=0.0546 drift=0.9454 text=धर्मो ति काम्य तः तः तः तः तः तः धर्मो धर्मो धर्मो धर्मो धर्
34
+ t= 55 bert=0.0546 drift=0.9454 text=धर्मो ति काम्य तः तः तः तः तः तः धर्मो धर्मो धर्मो धर्मो धर्
35
+ t= 54 bert=0.0546 drift=0.9454 text=धर्मो ति काम्य तः तः तः तः तः तः धर्मो धर्मो धर्मो धर्मो धर्
analysis_reports/outputs_all_models_20260325/T64/task3_report.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 3 — CONCEPT VECTORS + PCA STEERING
2
+ ==================================================
3
+
4
+ PCA: 39 components, 100.0% variance
5
+ Diversity PC: 0 (|r|=0.314 with diversity proxy)
6
+
7
+ Direction validity: WEAK
8
+ Spectrum unique ratio (mean over 5 seeds): 1.000
9
+ Spectrum semantic stability (mean over 5 seeds): 0.302
10
+
11
+ Saved graphs:
12
+ - task3_concept_space.png
13
+ - task3_pca_explained_variance.png
14
+ - task3_diversity_curve.png
15
+
16
+ Diversity spectrum:
17
+ alpha=-2.0 → बले बले अ तुल तुल वीर्य अ̱स्या भूयः सन ान्ते लब्ध्वा अर्थ वै भूयः त्ति ान्त सन भूयः भूयः तुल अ अ वीरो बले बले बले अ̱स्या भूयः ॥ ॥ अ अ अ̱स्या ॥ वे̱ध सन सन सन सन ह् ह् ान ह् स्य ह् ानु ह् यो बो दादि ह् मतां ह् सान्त्व ह् ( मतां ॥ धीमान् भूयः ॥ अ̱स्या पान होमयेत् सारथि ( ॥ भूयः ॥ ॥ ॥ ॥ ॥ ॥ गोप लज्ज अ̱स्या मतां लज्ज यो
18
+ alpha=-1.0 → बले बले अ तुल तुल वीर्य स्य सिंहा सिंहा सन त्ति वै वै वै द् सन सिंहा स्य स्य तुल तुल अ प्रयाति र् बले बले बले ॥ ॥ ॥ अ ॥ पि महा सन सन सन सन सन सन सिंहा स्रव सन स्य मीं गोप स्य स्य स्य स्य भूयः तुल सान्त्व यो अ ह् ान ान तव वेग ( यो भूषणम् ( ानु ॥ ॥ ॥ ॥ ॥ ॥ अ̱स्या ॥ ॥ ॥ ॥ यो पि ॥ म
19
+ alpha=+0.0 → बले र् अ तुल तुल वीर्य स्य सिंहा सिंहा सन ध्या वै वै वै गतस्य भ सिंहा भ स्य तुल तुल अ अ र् बले बले बले ॥ बले र् । । ॥ वै वै सन सन सन सन सिंहा सिंहा सिंहा स्य स्य स्य स्य भ स्य स्य स्य स्य ानु स्य ् ता यो स्य फल ॥ म तुल च सि ॥ ् ॥ ॥ न् ॥ ॥ ॥ ॥ ॥ ॥ ॥ महा सन सन क्ष ॥
20
+ alpha=+1.0 → बले र् । तुल यु वीर्य स्य सिंहा सिंहा सन ध्या वै वै वै । सन सिंहा स्य स्य तुल तुल । र् र् बले बले बले ॥ ॥ र् अ स् सन ते सन ीं सन सन त्र सिंहा यु सिंहा स्थल स्य स्य रौद्र स्य स्य न्दा ता यु स्य यु त्र क्ष ।। ीं स्य म्र कल्प यत् स् क्ष क्ष ॥ स्य यु मण्डलं यु ॥ ॥ ीं ॥ ॥ भ्यः ीं ीं ॥ ॥ ॥
21
+ alpha=+2.0 → र् र् तुरङ्ग आहुः ितो । स्य सिंहा सिंहा सिंहा ब्रह्मा वै वै & ते तस् तुरङ्ग नो स्तम्भ ीं यु संच र् र् बले ीं स्तम्भ ते । तस् न्तं मण्डलं यु । स्तम्भ स्तम्भ सन आहुः सिंहा यु सिंहा सिंहा स्य मण्डलं यु स्य स्य स्य एव कल्प स्तम्भ ̱र स्तम्भ अमु आहुः यु ̱र कल्प यु तुरङ्ग यु तुरङ्ग ̱र तुरङ्ग रणम् मण्डलं यु ीं मण्डलं दिनं ̱र यु ॥ तुरङ्ग ितः आहुः ॥ मण्डलं आहुः क्षमः
analysis_reports/outputs_all_models_20260325/T64/task4_report.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 4 — SEMANTIC ROBUSTNESS ABLATION
2
+ ==================================================
3
+
4
+ Optimal diffusion steps = 64
5
+
6
+ T BERT-F1 SEM_SIM BLEU sec/sample
7
+ --------------------------------------------------------
8
+ 64 0.2482 0.0580 0.0007 5.6116
9
+
10
+ Marginal gains (BERT-F1):
11
+
12
+ Saved plots/files:
13
+ - task4_3d.png
14
+ - task4_raw_results.json
analysis_reports/outputs_all_models_20260325/T64/task5_report.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 5 — CLASSIFIER-FREE GUIDANCE
2
+ ==================================================
3
+
4
+ Classifier params: 139521
5
+ Training samples : 20
6
+
7
+ Guidance scale sweep:
8
+ λ CER diversity d2 sBLEU
9
+ ----------------------------------------------------
10
+ 0.0 0.8451 0.838 0.689 0.013 ← optimal
11
+ 0.5 0.8490 0.818 0.650 0.013
12
+ 1.0 0.8509 0.838 0.684 0.007
13
+ 1.5 0.8622 0.857 0.720 0.005
14
+ 2.0 0.8761 0.869 0.744 0.005
15
+ 3.0 0.9056 0.814 0.642 0.013
analysis_reports/outputs_all_models_20260325/T8/task1_kv_cache.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 1 — KV CACHE BENCHMARK
2
+ ========================================
3
+
4
+ has_generate_cached=True
5
+ memory_profile=Torch CPU mem-event reduction: 25.8% @ src_len=64 (std=1168.9MB, cache=866.9MB)
6
+
7
+ src_len standard(s) cached(s) speedup encoder%
8
+ 16 0.582 0.400 1.45x 35.9%
9
+ 32 0.511 0.402 1.27x 37.7%
10
+ 64 0.666 0.490 1.36x 35.6%
11
+
12
+ Saved graphs:
13
+ - task1_time_comparison.png
14
+ - task1_speedup.png
15
+ - task1_encoder_cost.png
analysis_reports/outputs_all_models_20260325/T8/task2_report.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 2 — ATTENTION + DRIFT REPORT
2
+ ==================================================
3
+
4
+ Input : dharmo rakṣati rakṣitaḥ
5
+ Output: धर्मो रक्षति रक्षितः
6
+
7
+ Captured steps: 8
8
+ Analysis quality: WEAK
9
+ Final output uniq-ratio: 1.000
10
+ Degenerate output: NO
11
+ Multi-sample semantic score (n<=8): 0.0915
12
+ Lock-in step (CER<=0.05): t=0
13
+ Locked tokens: 79 Flexible tokens: 1
14
+ TF-IDF vs attention stability corr: 0.8905
15
+ TF-IDF status: OK
16
+
17
+ Saved graphs:
18
+ - task2_attn_t*.png / task2_all_layers_t0.png
19
+ - task2_attn_evolution.png
20
+ - task2_semantic_drift.png
21
+ - task2_source_alignment.png
22
+ - task2_tfidf_vs_attention.png
23
+
24
+ Step trajectory (first 10 rows)
25
+ ------------------------------------------------------------
26
+ t= 7 bert=0.0219 drift=0.9781 text=ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः
27
+ t= 6 bert=0.0225 drift=0.9775 text=ं ं ं ं ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः
28
+ t= 5 bert=0.0225 drift=0.9775 text=ं ं ं ं ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः
29
+ t= 4 bert=0.0225 drift=0.9775 text=ं ं ं ं ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः
30
+ t= 3 bert=0.0225 drift=0.9775 text=ं ं ं ं ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः
31
+ t= 2 bert=0.0227 drift=0.9773 text=ं ं ं ं ं ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ित
32
+ t= 1 bert=0.0228 drift=0.9772 text=ं ं ं ं ं ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ित
33
+ t= 0 bert=0.0228 drift=0.9772 text=ं ं ं ं ं ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ित
analysis_reports/outputs_all_models_20260325/T8/task3_report.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 3 — CONCEPT VECTORS + PCA STEERING
2
+ ==================================================
3
+
4
+ PCA: 50 components, 75.9% variance
5
+ Diversity PC: 0 (|r|=-0.344 with diversity proxy)
6
+
7
+ Direction validity: WEAK
8
+ Spectrum unique ratio (mean over 5 seeds): 1.000
9
+ Spectrum semantic stability (mean over 5 seeds): 0.341
10
+
11
+ Saved graphs:
12
+ - task3_concept_space.png
13
+ - task3_pca_explained_variance.png
14
+ - task3_diversity_curve.png
15
+
16
+ Diversity spectrum:
17
+ alpha=-2.0 → मनसः श्चक्र स्य स्य स्य अ स्य तैः तैः तैः स्य श्चक्र तैः गतभी स्य स्य स्य श्चक्र तैः तैः श्चक्र स्य तैः स्त्वं श्चक्र श्चक्र स्त्र तैः तैः कुण्ठ तैः तैः स्य तैः तैः तैः स्य तैः तैः गतभी तैः तैः णि̍ स्य तैः तैः तैः अ तैः तैः ह्वये मनसः ॥ तैः तैः गतभी ॥ श्चक्र तैः तैः तैः तैः तैः तैः तैः तैः स्य तैः करिष्या तैः स्त्वं तैः तैः श्चक्र तैः तैः श्चक्र तैः तैः ह्वये
18
+ alpha=-1.0 → स्य अ तैः वै तैः अ वेद् मनसः स्य । । तैः तैः स्य गतभी स्य अ स्त्वं सीद् स्य तैः स्य तैः सु̱म् सीद् र्ध कृतानि गतभी गतभी तैः तैः स्य तैः तैः तैः मनसः तैः कृतानि तैः तैः सु̱म् अ तैः मनसः अ मनसः स्य अ तैः ॥ ॥ स्य गतभी गतभी ॥ ॥ वै तैः तैः मनसः तैः अ तैः तैः च वर स्य तैः या वात् स्य तैः सीद् तैः स्य तैः स्य अ तैः तैः
19
+ alpha=+0.0 → अ अ वै ज्ञ स्य अ ज्ञ गतभी वर द शिख मन्त्र गतभी सु̱म् । द द स्य मन्त्र वा यो सीद् ज्ञ वै अ स्य स्य मन्त्र स्य मन्त्र स्य गतभी । गतभी गतभी तैः गतभी कृत तैः स्य तैः ॥ वै तैः ॥ वै अ कृतानि स्य वर वै ॥ ॥ वै ॥ अ ॥ स्य ॥ वै स्य ज्ञ ॥ स्य तैः तैः वै स्य स्य अ स्य तैः वै स्य तैः प्रण तीरे स्य । सीद्
20
+ alpha=+1.0 → पम वै तुल्य शत्रू पम शिख वर अ णाः परा णाः स्य कृत प्रिय । भिन् णाः ज्ञ वै विराज वै गणो वै ्या अ वै पम ्या भिन् वै लब्ध शोभ स्य च श वर वै ॥ वै क्षिप्य शिख भिर् ॥ सन वा मन्त्र मृ ॥ ॥ ॥ वै ॥ मन्त्र ॥ पम सङ् वर शोभ क्षिप्य भिर् स्य क्षिप्य वै सन वर शिख वै शिख वर दर्श शिख कलं पम ौ वर कलं भिर् कलं वै शिख
21
+ alpha=+2.0 → पम पम लब्ध पम शोभ पम परे भिर् अन्य णाः रसा लब्ध पम शोभ लब्ध शोभ पम शत्रू शिख भिन् पम पम पम णाः शोभ णाः पम शोभ शोभ परे णाः णाः पम पम पम पम शोभ पम शोभ शोभ पम डा णाः शोभ वै पम वै ॥ पम ॥ पम णाः ॥ ॥ परे अन्य णाः पम शोभ शोभ शोभ शोभ णाः वै परे कलं परे वै पम णाः पम शोभ णाः णाः कलं परे शोभ णाः कलं शत्रू
analysis_reports/outputs_all_models_20260325/T8/task4_report.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 4 — SEMANTIC ROBUSTNESS ABLATION
2
+ ==================================================
3
+
4
+ Optimal diffusion steps = 8
5
+
6
+ T BERT-F1 SEM_SIM BLEU sec/sample
7
+ --------------------------------------------------------
8
+ 8 0.1210 0.0400 0.0000 0.6194
9
+
10
+ Marginal gains (BERT-F1):
11
+
12
+ Saved plots/files:
13
+ - task4_3d.png
14
+ - task4_raw_results.json
analysis_reports/outputs_all_models_20260325/T8/task5_report.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 5 — CLASSIFIER-FREE GUIDANCE
2
+ ==================================================
3
+
4
+ Classifier params: 139521
5
+ Training samples : 40
6
+
7
+ Guidance scale sweep:
8
+ λ CER diversity d2 sBLEU
9
+ ----------------------------------------------------
10
+ 0.0 0.8834 0.796 0.596 0.004 ← optimal
11
+ 0.5 0.8881 0.781 0.568 0.005
12
+ 1.0 0.8876 0.767 0.540 0.007
13
+ 1.5 0.8921 0.757 0.517 0.004
14
+ 2.0 0.8929 0.734 0.474 0.006
15
+ 3.0 0.8970 0.724 0.453 0.005
config.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ CONFIG = {
4
+ "model_type": "d3pm_cross_attention",
5
+ "data": {
6
+ "include_negative_examples": True,
7
+ "dataset_size": 60000,
8
+ },
9
+ "diffusion": {
10
+ "mask_token_id": 0,
11
+ },
12
+ "model": {
13
+ "src_vocab_size": 16000,
14
+ "tgt_vocab_size": 16000,
15
+ "d_model": 384,
16
+ "n_heads": 8,
17
+ "d_ff": 1536,
18
+ "n_layers": 6,
19
+ "dropout": 0.1,
20
+ "max_seq_len": 80,
21
+ "diffusion_steps": 64,
22
+ },
23
+ "training": {
24
+ "device": "cuda" if torch.cuda.is_available() else "cpu",
25
+ },
26
+ "inference": {
27
+ "num_steps": 64,
28
+ "temperature": 0.7,
29
+ "top_k": 40,
30
+ "repetition_penalty": 1.2,
31
+ "diversity_penalty": 0.0,
32
+ },
33
+ }
diffusion/__init__.py ADDED
File without changes
diffusion/forward_process.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ forward_process.py — Verified Correct (no changes needed)
3
+ ===========================================================
4
+ Absorbing (mask) diffusion. PAD never masked. At t=0 alpha=1.0 exactly
5
+ so x_t == x_0 (nothing masked). Works correctly with the fixed scheduler.
6
+ """
7
+ import torch
8
+
9
+ class AbsorbingForwardProcess:
10
+ def __init__(self, scheduler, mask_id=0, pad_id=1):
11
+ self.scheduler = scheduler
12
+ self.mask_id = mask_id
13
+ self.pad_id = pad_id
14
+
15
+ def q_sample(self, x_0, t):
16
+ alpha_t = self.scheduler.get_alpha(t).to(x_0.device).view(-1, 1)
17
+ r = torch.rand(x_0.shape, device=x_0.device)
18
+ x_t = x_0.clone()
19
+ x_t[r > alpha_t] = self.mask_id
20
+ x_t[x_0 == self.pad_id] = self.pad_id # PAD stays PAD always
21
+ return x_0, x_t
diffusion/reverse_process.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ reverse_process.py — Fixed
3
+ ===========================
4
+ Two bugs fixed from the original:
5
+
6
+ BUG 1 (critical): generate_beam() passed x_t (noisy) as `tgt` to model.
7
+ The model does q_sample(tgt, t) internally — so x_t got double-noised.
8
+ Fix: pass x0_estimate (current clean guess) as tgt. Model noises it correctly.
9
+
10
+ BUG 2: apply_diversity_penalty used logits.var(dim=-1) — this adds the
11
+ variance of each position's own distribution back to itself, which is
12
+ mathematically meaningless and just injects noise.
13
+ Fix: penalize tokens that are uniformly high-probability across ALL positions
14
+ (global common tokens). This genuinely promotes diversity.
15
+ """
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+
20
+
21
+ class ReverseDiffusion:
22
+ def __init__(self, scheduler):
23
+ self.scheduler = scheduler
24
+
25
+ def p_sample_step(
26
+ self,
27
+ model,
28
+ x_t,
29
+ t,
30
+ condition,
31
+ beam_width=3,
32
+ temperature=1.0,
33
+ repetition_penalty=1.2,
34
+ diversity_penalty=0.3
35
+ ):
36
+ """
37
+ Single reverse step with temperature + penalties.
38
+ """
39
+
40
+ with torch.no_grad():
41
+
42
+ # ---- Shape safety ----
43
+ if x_t.dim() == 1:
44
+ x_t = x_t.unsqueeze(0)
45
+
46
+ if condition.dim() == 1:
47
+ condition = condition.unsqueeze(0)
48
+
49
+ if t.dim() == 0:
50
+ t = t.unsqueeze(0)
51
+
52
+ if t.shape[0] != x_t.shape[0]:
53
+ t = t.expand(x_t.shape[0])
54
+
55
+ # ---- Model forward ----
56
+ logits, _ = model(condition, x_t, t)
57
+
58
+ # ---- Temperature scaling ----
59
+ logits = logits / temperature
60
+
61
+ # ---- Repetition penalty (FIXED VERSION) ----
62
+ if repetition_penalty != 1.0:
63
+ logits = apply_repetition_penalty(
64
+ logits, x_t, repetition_penalty
65
+ )
66
+
67
+ # ---- Diversity penalty ----
68
+ if diversity_penalty > 0:
69
+ logits = apply_diversity_penalty(
70
+ logits, diversity_penalty
71
+ )
72
+
73
+ probs = F.softmax(logits, dim=-1)
74
+
75
+ B, L, V = probs.shape
76
+
77
+ # ---- Top-k beam expansion ----
78
+ topk_probs, topk_ids = torch.topk(
79
+ probs, beam_width, dim=-1
80
+ )
81
+
82
+ candidates = []
83
+
84
+ for k in range(beam_width):
85
+ next_tokens = topk_ids[:, :, k]
86
+ score = torch.log(
87
+ topk_probs[:, :, k] + 1e-9
88
+ ).sum()
89
+ candidates.append((next_tokens, score))
90
+
91
+ return candidates
92
+
93
+ def generate_beam(
94
+ self,
95
+ model,
96
+ condition,
97
+ beam_width=3,
98
+ num_steps=None,
99
+ temperature=1.0,
100
+ repetition_penalty=1.2,
101
+ diversity_penalty=0.3
102
+ ):
103
+ """
104
+ Beam-search reverse diffusion with temperature.
105
+ """
106
+
107
+ if num_steps is None:
108
+ num_steps = self.scheduler.num_timesteps
109
+
110
+ device = condition.device
111
+
112
+ if condition.dim() == 1:
113
+ condition = condition.unsqueeze(0)
114
+
115
+ B, L = condition.shape
116
+
117
+ # 🔥 Better initialization: start from MASK
118
+ x_init = torch.full(
119
+ (B, L),
120
+ fill_value=model.mask_token_id,
121
+ dtype=torch.long,
122
+ device=device
123
+ )
124
+
125
+ beams = [(x_init, 0.0)]
126
+
127
+ for step in reversed(range(num_steps)):
128
+
129
+ new_beams = []
130
+
131
+ for x_t, score in beams:
132
+
133
+ t_tensor = torch.full(
134
+ (B,),
135
+ step,
136
+ dtype=torch.long,
137
+ device=device
138
+ )
139
+
140
+ candidates = self.p_sample_step(
141
+ model,
142
+ x_t,
143
+ t_tensor,
144
+ condition,
145
+ beam_width,
146
+ temperature,
147
+ repetition_penalty,
148
+ diversity_penalty
149
+ )
150
+
151
+ for tokens, new_score in candidates:
152
+ new_beams.append(
153
+ (tokens, score + new_score)
154
+ )
155
+
156
+ # ---- Keep top beams ----
157
+ new_beams = sorted(
158
+ new_beams,
159
+ key=lambda x: x[1],
160
+ reverse=True
161
+ )
162
+
163
+ beams = new_beams[:beam_width]
164
+
165
+ best_tokens, best_score = beams[0]
166
+ return best_tokens
167
+
168
+
169
+
170
+ def generate(
171
+ self,
172
+ model,
173
+ condition,
174
+ num_steps=None,
175
+ temperature=0.8,
176
+ top_k=50,
177
+ repetition_penalty=1.2,
178
+ diversity_penalty=0.0,
179
+ ):
180
+ """
181
+ Correct D3PM iterative refinement.
182
+
183
+ x0_est starts as all [MASK].
184
+ Each step: forward(src=condition, tgt=x0_est, t)
185
+ → model applies q_sample(x0_est, t) internally
186
+ → predicts cleaner x0
187
+ → x0_est updated
188
+
189
+ diversity_penalty: reduces probability of tokens that are
190
+ globally dominant across all sequence positions (not logits.var()).
191
+ """
192
+ if num_steps is None:
193
+ num_steps = self.scheduler.num_timesteps
194
+
195
+ device = condition.device
196
+ if condition.dim() == 1:
197
+ condition = condition.unsqueeze(0)
198
+ B, L = condition.shape
199
+
200
+ T = self.scheduler.num_timesteps
201
+ step_size = max(1, T // num_steps)
202
+ timesteps = list(range(T - 1, -1, -step_size))
203
+ if timesteps[-1] != 0:
204
+ timesteps.append(0)
205
+
206
+ mask_id = model.mask_token_id
207
+ # Start: know nothing → all MASK is our initial clean estimate
208
+ x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
209
+ hint = None
210
+
211
+ model.eval()
212
+ with torch.no_grad():
213
+ for step_idx, t_val in enumerate(timesteps):
214
+ t = torch.full((B,), t_val, dtype=torch.long, device=device)
215
+ is_last = (step_idx == len(timesteps) - 1)
216
+
217
+ # KEY: pass x0_est as tgt — model noises it internally
218
+ import inspect
219
+ sig = inspect.signature(model.forward).parameters
220
+ if 'x0_hint' in sig:
221
+ outputs = model(condition, x0_est, t, x0_hint=hint)
222
+ else:
223
+ outputs = model(condition, x0_est, t)
224
+
225
+ logits = outputs[0] if isinstance(outputs, tuple) else outputs
226
+
227
+ # Repetition penalty: down-weight tokens already in sequence
228
+ if repetition_penalty != 1.0:
229
+ logits = apply_repetition_penalty(logits, x0_est, repetition_penalty)
230
+
231
+ # Diversity penalty: reduce globally dominant tokens
232
+ if diversity_penalty > 0.0:
233
+ logits = apply_diversity_penalty(logits, diversity_penalty)
234
+
235
+ # Temperature + top-k
236
+ logits = logits / max(temperature, 1e-5)
237
+ if top_k > 0:
238
+ logits = top_k_filter(logits, top_k)
239
+
240
+ probs = F.softmax(logits, dim=-1)
241
+
242
+ if is_last:
243
+ x0_est = torch.argmax(probs, dim=-1)
244
+ else:
245
+ x0_est = batch_multinomial(probs)
246
+
247
+ hint = x0_est
248
+
249
+ return x0_est
250
+
251
+
252
+ # ── Penalty functions ─────────────────────────────────────────────────
253
+
254
+ def apply_repetition_penalty(logits, prev_tokens, penalty=1.2):
255
+ """
256
+ Down-weight tokens that already appear in the current sequence.
257
+ Prevents मनो मनो मनो repetition loops.
258
+ penalty=1.0 → no effect
259
+ penalty=1.2 → mild suppression of repeated tokens
260
+ penalty=2.0 → strong suppression
261
+ """
262
+ B, L, V = logits.shape
263
+ for b in range(B):
264
+ for token_id in set(prev_tokens[b].tolist()):
265
+ if token_id > 4: # don't penalize special tokens
266
+ logits[b, :, token_id] = logits[b, :, token_id] / penalty
267
+ return logits
268
+
269
+
270
+ def apply_diversity_penalty(logits, penalty=0.5):
271
+ """
272
+ Correct diversity penalty: penalize tokens that are globally dominant
273
+ across ALL sequence positions. This forces the model to use less
274
+ common tokens, increasing output diversity.
275
+
276
+ Method: compute mean probability across positions, subtract penalty
277
+ times that mean. Tokens uniformly high everywhere get suppressed.
278
+
279
+ penalty=0.0 → no diversity enforcement
280
+ penalty=0.5 → moderate diversity
281
+ penalty=1.0 → strong diversity (may hurt coherence)
282
+ """
283
+ # Mean logit across all positions: [B, V]
284
+ global_mean = logits.mean(dim=1, keepdim=True) # [B, 1, V]
285
+ # Subtract scaled global mean — suppresses globally common tokens
286
+ return logits - penalty * global_mean
287
+
288
+
289
+ def top_k_filter(logits, k):
290
+ B, L, V = logits.shape
291
+ if k >= V:
292
+ return logits
293
+ topk_vals, _ = torch.topk(logits, k, dim=-1)
294
+ threshold = topk_vals[..., -1].unsqueeze(-1)
295
+ return logits.masked_fill(logits < threshold, float('-inf'))
296
+
297
+
298
+ def batch_multinomial(probs):
299
+ B, L, V = probs.shape
300
+ flat = probs.view(B * L, V) + 1e-9
301
+ flat = flat / flat.sum(dim=-1, keepdim=True)
302
+ return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
diffusion/reverse_process1.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ class ReverseDiffusion:
6
+ """
7
+ Stable reverse diffusion with:
8
+ - Beam search
9
+ - Self conditioning
10
+ - Temperature sampling
11
+ - Repetition penalty
12
+ - Diversity penalty
13
+ """
14
+
15
+ def __init__(self, scheduler):
16
+
17
+ self.scheduler = scheduler
18
+
19
+ self.temperature = 0.75
20
+ self.repetition_penalty = 1.15
21
+ self.diversity_penalty = 0.0
22
+ self.length_penalty = 1.0
23
+
24
+ # ------------------------------------------------
25
+ # penalties
26
+ # ------------------------------------------------
27
+
28
+ def apply_repetition_penalty(self, logits, tokens):
29
+
30
+ B, L, V = logits.shape
31
+
32
+ for b in range(B):
33
+
34
+ used = set(tokens[b].tolist())
35
+
36
+ for token_id in used:
37
+ logits[b, :, token_id] /= self.repetition_penalty
38
+
39
+ return logits
40
+
41
+ def apply_diversity_penalty(self, logits):
42
+
43
+ if self.diversity_penalty == 0:
44
+ return logits
45
+
46
+ logits_var = logits.var(dim=-1, keepdim=True)
47
+ return logits + self.diversity_penalty * logits_var
48
+
49
+ # ------------------------------------------------
50
+ # single reverse step
51
+ # ------------------------------------------------
52
+
53
+ def p_sample_step(self, model, x_t, t, condition, self_cond=None, beam_width=3):
54
+
55
+ with torch.no_grad():
56
+
57
+ logits, hidden = model(condition, x_t, t, self_cond)
58
+
59
+ logits = logits / self.temperature
60
+
61
+ logits = self.apply_repetition_penalty(logits, x_t)
62
+ logits = self.apply_diversity_penalty(logits)
63
+
64
+ probs = F.softmax(logits, dim=-1)
65
+
66
+ B, L, V = probs.shape
67
+
68
+ topk_probs, topk_ids = torch.topk(probs, beam_width, dim=-1)
69
+
70
+ candidates = []
71
+
72
+ for k in range(beam_width):
73
+
74
+ tokens = topk_ids[:, :, k]
75
+
76
+ score = torch.log(topk_probs[:, :, k] + 1e-9).sum()
77
+
78
+ candidates.append((tokens, score))
79
+
80
+ return candidates
81
+
82
+ # ------------------------------------------------
83
+ # beam reverse diffusion
84
+ # ------------------------------------------------
85
+
86
+ def generate_beam(self, model, condition, beam_width=3, num_steps=None):
87
+
88
+ if num_steps is None:
89
+ num_steps = self.scheduler.num_timesteps
90
+
91
+ device = condition.device
92
+
93
+ if condition.dim() == 1:
94
+ condition = condition.unsqueeze(0)
95
+
96
+ B, L = condition.shape
97
+
98
+ # ------------------------------------------------
99
+ # BETTER LATENT INITIALIZATION
100
+ # ------------------------------------------------
101
+
102
+ x_init = condition.clone()
103
+
104
+ mask = torch.rand_like(x_init.float()) < 0.5
105
+ x_init[mask] = model.mask_token_id
106
+
107
+ beams = [(x_init, 0.0)]
108
+
109
+ self_cond = None
110
+
111
+ for step in reversed(range(num_steps)):
112
+
113
+ new_beams = []
114
+
115
+ for x_t, score in beams:
116
+
117
+ t_tensor = torch.full(
118
+ (B,),
119
+ step,
120
+ dtype=torch.long,
121
+ device=device
122
+ )
123
+
124
+ candidates = self.p_sample_step(
125
+ model,
126
+ x_t,
127
+ t_tensor,
128
+ condition,
129
+ self_cond,
130
+ beam_width
131
+ )
132
+
133
+ for tokens, new_score in candidates:
134
+
135
+ length_norm = tokens.shape[1] ** self.length_penalty
136
+
137
+ final_score = (score + new_score) / length_norm
138
+
139
+ new_beams.append((tokens, final_score))
140
+
141
+ new_beams = sorted(
142
+ new_beams,
143
+ key=lambda x: x[1],
144
+ reverse=True
145
+ )
146
+
147
+ beams = new_beams[:beam_width]
148
+
149
+ # self conditioning
150
+ self_cond = beams[0][0]
151
+
152
+ best_tokens, best_score = beams[0]
153
+
154
+ return best_tokens
diffusion/reverse_process2.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ reverse_process.py — Final Correct Version
3
+ =============================================
4
+
5
+ KEY PRINCIPLE: generate() must be byte-for-byte identical to run_inference()
6
+ in inference.py, which is what produced BERTScore 0.75 at validation.
7
+
8
+ CRITICAL BUG IN PREVIOUS VERSION:
9
+ We passed inference_mode=True to the model, but the model was NEVER
10
+ called with inference_mode=True during training or validation.
11
+ run_inference() (the validated path) does:
12
+ model(input_ids, x0_est, t, x0_hint=hint)
13
+ → inference_mode defaults to False.
14
+
15
+ With inference_mode=True the model does two things differently:
16
+ 1. tgt_pad_mask = None (training used tgt_pad_mask = tgt==PAD)
17
+ 2. Skips q_sample at t=0 (training always called q_sample)
18
+ The model was never trained to handle these conditions → garbage output.
19
+
20
+ Fix: do NOT pass inference_mode. Let it default to False, exactly
21
+ as run_inference() did.
22
+
23
+ BUGS FIXED (vs original reverse_process.py)
24
+ --------------------------------------------
25
+ BUG 1 generate_beam() used for D3PM → all-Ṛ repetition.
26
+ Use generate() (iterative refinement) from app1.py instead.
27
+ BUG 2 apply_diversity_penalty used logits.var() → noise injection.
28
+ Fixed to logits - penalty * logits.mean(dim=1) — global suppression.
29
+ BUG 3 x0_hint (self-conditioning) never passed to model.
30
+ Fixed: generate() passes x0_hint=hint every step.
31
+ BUG 4 params not forwarded from generate_beam() to p_sample_step().
32
+ Fixed in generate_beam() (kept for reference, not for production use).
33
+ """
34
+
35
+ import torch
36
+ import torch.nn.functional as F
37
+
38
+
39
+ class ReverseDiffusion:
40
+
41
+ def __init__(self, scheduler):
42
+ self.scheduler = scheduler
43
+
44
+ # Attribute-style defaults for backward compat with any code
45
+ # that sets reverse_diffusion.temperature = 0.9 etc.
46
+ # generate() prefers explicit kwargs and falls back to these.
47
+ self.temperature = 0.75
48
+ self.repetition_penalty = 1.15
49
+ self.diversity_penalty = 0.0
50
+ self.top_k = 50
51
+
52
+ # ------------------------------------------------------------------ #
53
+ # generate — CORRECT D3PM iterative refinement #
54
+ # Exact equivalent of run_inference() in inference.py #
55
+ # ------------------------------------------------------------------ #
56
+ def generate(
57
+ self,
58
+ model,
59
+ condition,
60
+ num_steps = None,
61
+ temperature = None,
62
+ top_k = None,
63
+ repetition_penalty = None,
64
+ diversity_penalty = None,
65
+ ):
66
+ """
67
+ D3PM iterative refinement — identical to run_inference() in inference.py,
68
+ which is the validated path (BERTScore 0.75).
69
+
70
+ Algorithm:
71
+ x0_est = all [MASK]
72
+ for t = T-1 down to 0:
73
+ logits = model(src, x0_est, t, x0_hint=hint)
74
+ ↑ inference_mode NOT passed (defaults to False)
75
+ ↑ this exactly matches training/validation
76
+ apply penalties, temperature, top_k
77
+ if t > 0: x0_est = multinomial(softmax(logits)) ← stochastic
78
+ if t = 0: x0_est = argmax(softmax(logits)) ← deterministic
79
+ hint = x0_est
80
+ """
81
+ # Resolve: explicit kwarg > object attribute
82
+ temperature = temperature if temperature is not None else self.temperature
83
+ top_k = top_k if top_k is not None else self.top_k
84
+ repetition_penalty = repetition_penalty if repetition_penalty is not None else self.repetition_penalty
85
+ diversity_penalty = diversity_penalty if diversity_penalty is not None else self.diversity_penalty
86
+
87
+ if num_steps is None:
88
+ num_steps = self.scheduler.num_timesteps
89
+
90
+ device = condition.device
91
+ if condition.dim() == 1:
92
+ condition = condition.unsqueeze(0)
93
+ B, L = condition.shape
94
+
95
+ T = self.scheduler.num_timesteps
96
+ step_size = max(1, T // num_steps)
97
+ timesteps = list(range(T - 1, -1, -step_size))
98
+ if timesteps[-1] != 0:
99
+ timesteps.append(0)
100
+
101
+ mask_id = model.mask_token_id
102
+ x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
103
+ hint = None
104
+
105
+ model.eval()
106
+ with torch.no_grad():
107
+ for step_idx, t_val in enumerate(timesteps):
108
+ t = torch.full((B,), t_val, dtype=torch.long, device=device)
109
+ is_last = (step_idx == len(timesteps) - 1)
110
+
111
+ # ── CRITICAL: do NOT pass inference_mode ──────────────────
112
+ # inference_mode defaults to False inside SanskritModel /
113
+ # D3PMCrossAttention. This matches run_inference() exactly.
114
+ # Passing inference_mode=True changes tgt_pad_mask and
115
+ # q_sample behaviour — the model was never trained for that.
116
+ logits, _ = model(condition, x0_est, t, x0_hint=hint)
117
+
118
+ # Repetition penalty
119
+ if repetition_penalty != 1.0:
120
+ logits = apply_repetition_penalty(
121
+ logits, x0_est, repetition_penalty
122
+ )
123
+
124
+ # Diversity penalty (correct: global mean suppression)
125
+ if diversity_penalty > 0.0:
126
+ logits = apply_diversity_penalty(logits, diversity_penalty)
127
+
128
+ logits = logits / max(temperature, 1e-5)
129
+
130
+ if top_k > 0:
131
+ logits = top_k_filter(logits, top_k)
132
+
133
+ probs = F.softmax(logits, dim=-1)
134
+
135
+ # Stochastic at every step except the last (argmax at t=0)
136
+ if is_last:
137
+ x0_est = torch.argmax(probs, dim=-1)
138
+ else:
139
+ x0_est = batch_multinomial(probs)
140
+
141
+ hint = x0_est
142
+
143
+ return x0_est # (B, L)
144
+
145
+ # ------------------------------------------------------------------ #
146
+ # p_sample_step — used by generate_beam (not for production) #
147
+ # ------------------------------------------------------------------ #
148
+ def p_sample_step(
149
+ self,
150
+ model,
151
+ x_t,
152
+ t,
153
+ condition,
154
+ beam_width = 3,
155
+ temperature = 1.0,
156
+ repetition_penalty = 1.2,
157
+ diversity_penalty = 0.3,
158
+ x0_hint = None,
159
+ ):
160
+ with torch.no_grad():
161
+ if x_t.dim() == 1: x_t = x_t.unsqueeze(0)
162
+ if condition.dim() == 1: condition = condition.unsqueeze(0)
163
+ if t.dim() == 0: t = t.unsqueeze(0)
164
+ if t.shape[0] != x_t.shape[0]:
165
+ t = t.expand(x_t.shape[0])
166
+
167
+ # No inference_mode — matches training convention
168
+ logits, _ = model(condition, x_t, t, x0_hint=x0_hint)
169
+
170
+ logits = logits / max(temperature, 1e-5)
171
+
172
+ if repetition_penalty != 1.0:
173
+ logits = apply_repetition_penalty(logits, x_t, repetition_penalty)
174
+ if diversity_penalty > 0.0:
175
+ logits = apply_diversity_penalty(logits, diversity_penalty)
176
+
177
+ probs = F.softmax(logits, dim=-1)
178
+ B, L, V = probs.shape
179
+
180
+ topk_probs, topk_ids = torch.topk(probs, beam_width, dim=-1)
181
+ candidates = []
182
+ for k in range(beam_width):
183
+ next_tokens = topk_ids[:, :, k]
184
+ score = torch.log(topk_probs[:, :, k] + 1e-9).sum()
185
+ candidates.append((next_tokens, score))
186
+ return candidates
187
+
188
+ # ------------------------------------------------------------------ #
189
+ # generate_beam — kept for reference; NOT the correct D3PM method #
190
+ # ------------------------------------------------------------------ #
191
+ def generate_beam(
192
+ self,
193
+ model,
194
+ condition,
195
+ beam_width = 3,
196
+ num_steps = None,
197
+ temperature = None,
198
+ repetition_penalty = None,
199
+ diversity_penalty = None,
200
+ ):
201
+ """
202
+ WARNING: do NOT call this from app1.py for D3PM generation.
203
+ generate_beam() forces every position to the same top-k token
204
+ → all-Ṛ / all-rud repetition. Use generate() instead.
205
+ Kept only for experimental reference.
206
+ """
207
+ temperature = temperature if temperature is not None else self.temperature
208
+ repetition_penalty = repetition_penalty if repetition_penalty is not None else self.repetition_penalty
209
+ diversity_penalty = diversity_penalty if diversity_penalty is not None else self.diversity_penalty
210
+ if num_steps is None:
211
+ num_steps = self.scheduler.num_timesteps
212
+
213
+ device = condition.device
214
+ if condition.dim() == 1: condition = condition.unsqueeze(0)
215
+ B, L = condition.shape
216
+
217
+ x_init = torch.full((B, L), fill_value=model.mask_token_id,
218
+ dtype=torch.long, device=device)
219
+ beams = [(x_init, 0.0)]
220
+ best_hint = None
221
+
222
+ for step in reversed(range(num_steps)):
223
+ t_tensor = torch.full((B,), step, dtype=torch.long, device=device)
224
+ new_beams = []
225
+ for x_t, score in beams:
226
+ candidates = self.p_sample_step(
227
+ model, x_t, t_tensor, condition,
228
+ beam_width = beam_width,
229
+ temperature = temperature,
230
+ repetition_penalty = repetition_penalty,
231
+ diversity_penalty = diversity_penalty,
232
+ x0_hint = best_hint,
233
+ )
234
+ for tokens, new_score in candidates:
235
+ new_beams.append((tokens, score + new_score.item()))
236
+
237
+ new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True)
238
+ beams = new_beams[:beam_width]
239
+ best_hint = beams[0][0]
240
+
241
+ return beams[0][0] # (B, L)
242
+
243
+
244
+ # ── Penalty helpers ────────────────────────────────────────────────────────
245
+
246
+ def apply_repetition_penalty(logits, prev_tokens, penalty=1.2):
247
+ """Down-weight tokens already present in the sequence."""
248
+ for b in range(logits.shape[0]):
249
+ for token_id in set(prev_tokens[b].tolist()):
250
+ if token_id > 4:
251
+ logits[b, :, token_id] = logits[b, :, token_id] / penalty
252
+ return logits
253
+
254
+
255
+ def apply_diversity_penalty(logits, penalty=0.3):
256
+ """
257
+ Correct diversity penalty: suppress globally dominant tokens.
258
+ logits -= penalty * mean(logits, dim=1) [sequence dimension]
259
+ """
260
+ global_mean = logits.mean(dim=1, keepdim=True) # [B, 1, V]
261
+ return logits - penalty * global_mean
262
+
263
+
264
+ def top_k_filter(logits, k):
265
+ B, L, V = logits.shape
266
+ if k >= V: return logits
267
+ topk_vals, _ = torch.topk(logits, k, dim=-1)
268
+ return logits.masked_fill(logits < topk_vals[..., -1].unsqueeze(-1), float('-inf'))
269
+
270
+
271
+ def batch_multinomial(probs):
272
+ B, L, V = probs.shape
273
+ flat = probs.view(B * L, V) + 1e-9
274
+ flat = flat / flat.sum(dim=-1, keepdim=True)
275
+ return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
diffusion/scheduler.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ scheduler.py — Fixed & Upgraded
3
+ ==================================
4
+ Changes:
5
+ 1. T=64 (was 16). More timesteps = richer denoising curriculum per epoch.
6
+ 2. alpha at t=0 is EXACTLY 1.0 — fixes Bug 2 (final-step re-noise).
7
+ 3. sample_timestep samples [0, T-1] including t=0, so model trains on
8
+ fully-clean inputs (learns the identity at t=0 explicitly).
9
+ """
10
+ import torch, math
11
+
12
+ class OptimizedCosineScheduler:
13
+ def __init__(self, cfg, device=None):
14
+ self.num_timesteps = cfg['model']['diffusion_steps'] # 64
15
+ self.mask_token_id = cfg['diffusion']['mask_token_id']
16
+ self.device = device or torch.device('cpu')
17
+ self.alphas_cumprod = self._build_schedule().to(self.device)
18
+
19
+ def _build_schedule(self):
20
+ T = self.num_timesteps
21
+ t = torch.arange(T + 1, dtype=torch.float32)
22
+ f_t = torch.cos((t / T + 0.008) / 1.008 * math.pi / 2) ** 2
23
+ alphas_bar = f_t / f_t[0]
24
+ alphas_bar = alphas_bar[1:] # shape [T]
25
+ alphas_bar[0] = 1.0 # FIX: exact 1.0 at t=0
26
+ alphas_bar[-1] = alphas_bar[-1].clamp(max=0.001)
27
+ return alphas_bar
28
+
29
+ def sample_timestep(self, batch_size):
30
+ """Uniform [0, T-1] — includes t=0 so model sees clean inputs."""
31
+ return torch.randint(0, self.num_timesteps, (batch_size,))
32
+
33
+ def get_alpha(self, t):
34
+ return self.alphas_cumprod[t.to(self.alphas_cumprod.device).long()]
handler.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+
3
+ from inference_api import predict
4
+
5
+
6
+ class EndpointHandler:
7
+ """
8
+ Hugging Face Inference Endpoint handler.
9
+ Expects payload:
10
+ {
11
+ "inputs": "dharmo rakṣati rakṣitaḥ",
12
+ "parameters": {"temperature": 0.7, ...}
13
+ }
14
+ """
15
+
16
+ def __init__(self, path: str = ""):
17
+ self.path = path
18
+
19
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
20
+ text = data.get("inputs", "")
21
+ params = data.get("parameters", {}) or {}
22
+ return predict(
23
+ text=text,
24
+ temperature=params.get("temperature", 0.7),
25
+ top_k=params.get("top_k", 40),
26
+ repetition_penalty=params.get("repetition_penalty", 1.2),
27
+ diversity_penalty=params.get("diversity_penalty", 0.0),
28
+ num_steps=params.get("num_steps", 64),
29
+ clean_output=params.get("clean_output", True),
30
+ )
inference.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inference.py
3
+ ============
4
+ Correct D3PM inference for Sanskrit paraphrase generation.
5
+
6
+ The model's forward() takes CLEAN tgt and noises it internally.
7
+ So inference passes x0_estimate (starting all-[MASK]) as tgt each step,
8
+ letting the model noise it and then predict a cleaner version.
9
+
10
+ Also includes: robust checkpoint loading (auto-detects architecture
11
+ from saved weights — no CONFIG mismatch crashes).
12
+ """
13
+
14
+ import json
15
+ import torch
16
+ import os, sys
17
+ import re
18
+ from tqdm import tqdm
19
+ from torch.utils.data import DataLoader, Subset
20
+
21
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
22
+ from config import CONFIG
23
+
24
+
25
+ # ── Checkpoint loader ─────────────────────────────────────────────────
26
+
27
+ def _resolve_device(cfg_device: str) -> torch.device:
28
+ cfg_device = (cfg_device or "").lower()
29
+ if cfg_device == "cuda" and torch.cuda.is_available():
30
+ return torch.device("cuda")
31
+ if cfg_device == "mps" and torch.backends.mps.is_available():
32
+ return torch.device("mps")
33
+ if cfg_device in {"cpu", "cuda", "mps"}:
34
+ return torch.device("cpu")
35
+ if torch.cuda.is_available():
36
+ return torch.device("cuda")
37
+ if torch.backends.mps.is_available():
38
+ return torch.device("mps")
39
+ return torch.device("cpu")
40
+
41
+ def load_model(ckpt_path: str, base_cfg: dict, device: torch.device):
42
+ """
43
+ Auto-detect architecture from checkpoint weight shapes,
44
+ then load. Never fails due to CONFIG vs checkpoint mismatch.
45
+ """
46
+ import copy
47
+ from model.sanskrit_model import SanskritModel
48
+
49
+ cfg = copy.deepcopy(base_cfg)
50
+ state = torch.load(ckpt_path, map_location='cpu')
51
+
52
+ # d_model + vocab_size
53
+ ek = 'model.src_embed.token_emb.weight'
54
+ if ek in state:
55
+ vocab, d = state[ek].shape
56
+ cfg['model']['vocab_size'] = vocab
57
+ cfg['model']['d_model'] = d
58
+ cfg['model']['d_ff'] = d * 4
59
+
60
+ # n_layers
61
+ ids = {int(k.split('.')[2]) for k in state if k.startswith('model.encoder_blocks.')}
62
+ if ids:
63
+ cfg['model']['n_layers'] = max(ids) + 1
64
+
65
+ # max_seq_len
66
+ pk = 'model.src_embed.pos_enc.pe'
67
+ if pk in state:
68
+ cfg['model']['max_seq_len'] = state[pk].shape[1]
69
+
70
+ # n_heads
71
+ d = cfg['model']['d_model']
72
+ h = cfg['model'].get('n_heads', 6)
73
+ if d % h != 0:
74
+ h = next(x for x in [8, 6, 4, 2, 1] if d % x == 0)
75
+ cfg['model']['n_heads'] = h
76
+
77
+ print(f"🔍 Detected: d_model={cfg['model']['d_model']}, "
78
+ f"n_layers={cfg['model']['n_layers']}, "
79
+ f"max_seq_len={cfg['model']['max_seq_len']}, "
80
+ f"n_heads={cfg['model']['n_heads']}")
81
+
82
+ model = SanskritModel(cfg).to(device)
83
+ raw_state = torch.load(ckpt_path, map_location=device)
84
+ model_state = model.state_dict()
85
+ filtered_state = {}
86
+ skipped_mismatch = []
87
+ for k, v in raw_state.items():
88
+ if k in model_state and hasattr(v, "shape") and hasattr(model_state[k], "shape"):
89
+ if tuple(v.shape) != tuple(model_state[k].shape):
90
+ skipped_mismatch.append((k, tuple(v.shape), tuple(model_state[k].shape)))
91
+ continue
92
+ filtered_state[k] = v
93
+
94
+ missing, unexpected = model.load_state_dict(filtered_state, strict=False)
95
+
96
+ # hint_gate may be absent in older checkpoints — initialise safely
97
+ allowed = {'model.hint_gate.0.weight', 'model.hint_gate.0.bias'}
98
+ real_missing = [k for k in missing if k not in allowed]
99
+ if real_missing:
100
+ print(f"⚠️ Missing keys: {real_missing[:3]} …")
101
+ if unexpected:
102
+ print(f"⚠️ Unexpected keys: {unexpected[:3]} …")
103
+ if skipped_mismatch:
104
+ print(f"⚠️ Shape-mismatched keys skipped: {len(skipped_mismatch)}")
105
+
106
+ # Enable compact-attention branch only when checkpoint actually provides it.
107
+ has_compact = any(".compact_out_proj.weight" in k for k in filtered_state.keys())
108
+ if has_compact and hasattr(model, "model") and hasattr(model.model, "decoder_blocks"):
109
+ for block in model.model.decoder_blocks:
110
+ if hasattr(block, "cross_attn") and hasattr(block.cross_attn, "use_compact"):
111
+ block.cross_attn.use_compact = True
112
+ print("ℹ️ Compact cross-attention branch enabled from checkpoint.")
113
+ if hasattr(model.model, 'hint_gate') and 'model.hint_gate.0.weight' in missing:
114
+ with torch.no_grad():
115
+ w = model.model.hint_gate[0].weight
116
+ torch.nn.init.zeros_(model.model.hint_gate[0].bias)
117
+ torch.nn.init.eye_(w) if w.shape[0] == w.shape[1] \
118
+ else torch.nn.init.xavier_uniform_(w)
119
+ print("ℹ️ hint_gate initialised to identity (not in checkpoint).")
120
+
121
+ print("✅ Model loaded.")
122
+ return model, cfg
123
+
124
+
125
+ # ── Core inference function (same path as validation) ────────────────
126
+
127
+ @torch.no_grad()
128
+ def run_inference(model, input_ids, cfg):
129
+ """
130
+ Reverse diffusion sampling (clean path).
131
+ Uses cached reverse diffusion when available, otherwise model.generate().
132
+ """
133
+ inf = cfg['inference']
134
+ model.eval()
135
+ kwargs = dict(
136
+ num_steps=inf['num_steps'],
137
+ temperature=inf['temperature'],
138
+ top_k=inf['top_k'],
139
+ repetition_penalty=inf.get('repetition_penalty', 1.2),
140
+ diversity_penalty=inf.get('diversity_penalty', 0.0),
141
+ )
142
+ if hasattr(model, "generate_cached"):
143
+ out = model.generate_cached(input_ids, **kwargs)
144
+ else:
145
+ out = model.generate(input_ids, **kwargs)
146
+
147
+ # Optional retry with stronger anti-repetition settings.
148
+ if inf.get("auto_retry_on_repetition", True):
149
+ repeat_threshold = float(inf.get("repeat_ratio_threshold", 0.40))
150
+ max_repeat_run = int(inf.get("max_repeat_run", 4))
151
+ if _mean_repeat_ratio(out) >= repeat_threshold:
152
+ retry_kwargs = dict(kwargs)
153
+ retry_kwargs["temperature"] = max(0.6, float(kwargs["temperature"]) - 0.1)
154
+ retry_kwargs["top_k"] = max(20, int(kwargs["top_k"]) - 10)
155
+ retry_kwargs["repetition_penalty"] = max(float(kwargs["repetition_penalty"]), 1.6)
156
+ retry_kwargs["diversity_penalty"] = max(float(kwargs["diversity_penalty"]), 0.3)
157
+ if hasattr(model, "generate_cached"):
158
+ retry = model.generate_cached(input_ids, **retry_kwargs)
159
+ else:
160
+ retry = model.generate(input_ids, **retry_kwargs)
161
+ if _mean_repeat_ratio(retry) < _mean_repeat_ratio(out):
162
+ out = retry
163
+ out = _dedup_repeated_ids(out, max_repeat_run=max_repeat_run)
164
+
165
+ return out
166
+
167
+
168
+ def _mean_repeat_ratio(ids_tensor: torch.Tensor) -> float:
169
+ if ids_tensor is None or ids_tensor.numel() == 0:
170
+ return 0.0
171
+ ratios = []
172
+ for row in ids_tensor:
173
+ ids = [int(x) for x in row.tolist() if int(x) > 4]
174
+ if len(ids) < 2:
175
+ ratios.append(0.0)
176
+ continue
177
+ repeats = sum(1 for i in range(1, len(ids)) if ids[i] == ids[i - 1])
178
+ ratios.append(repeats / max(1, len(ids) - 1))
179
+ return float(sum(ratios) / max(1, len(ratios)))
180
+
181
+
182
+ def _dedup_repeated_ids(ids_tensor: torch.Tensor, max_repeat_run: int = 4) -> torch.Tensor:
183
+ """
184
+ Keep generation path unchanged, but clean extreme run-on token loops in final output ids.
185
+ """
186
+ if ids_tensor is None or ids_tensor.numel() == 0:
187
+ return ids_tensor
188
+ cleaned_rows = []
189
+ for row in ids_tensor.tolist():
190
+ out = []
191
+ prev = None
192
+ run = 0
193
+ for tok in row:
194
+ if tok <= 4:
195
+ out.append(tok)
196
+ prev = tok
197
+ run = 1
198
+ continue
199
+ if tok == prev:
200
+ run += 1
201
+ if run > max_repeat_run:
202
+ continue
203
+ else:
204
+ run = 1
205
+ out.append(tok)
206
+ prev = tok
207
+ # Preserve original length for downstream decode assumptions.
208
+ if len(out) < len(row):
209
+ out.extend([1] * (len(row) - len(out)))
210
+ else:
211
+ out = out[:len(row)]
212
+ cleaned_rows.append(out)
213
+ return torch.tensor(cleaned_rows, dtype=ids_tensor.dtype, device=ids_tensor.device)
214
+
215
+
216
+ def _decode_clean(tgt_tok, ids):
217
+ out = []
218
+ for x in ids:
219
+ if x in (1, 4) and out:
220
+ break
221
+ if x > 4:
222
+ out.append(x)
223
+ text = tgt_tok.decode(out).strip()
224
+ return _clean_repetition_text(text)
225
+
226
+
227
+ def _clean_repetition_text(text: str, max_repeat_run: int = 3) -> str:
228
+ words = [w for w in text.split() if w.strip()]
229
+ if not words:
230
+ return text.strip()
231
+ cleaned = []
232
+ prev = None
233
+ run = 0
234
+ for w in words:
235
+ if w == prev:
236
+ run += 1
237
+ if run > max_repeat_run:
238
+ continue
239
+ else:
240
+ run = 1
241
+ cleaned.append(w)
242
+ prev = w
243
+ return " ".join(cleaned).strip()
244
+
245
+
246
+ # ── Cleanup heuristics from UI inference pipeline ─────────────────────
247
+
248
+ _IAST_VOWELS = [
249
+ ("ai", "ऐ"), ("au", "औ"),
250
+ ("ā", "आ"), ("ī", "ई"), ("ū", "ऊ"),
251
+ ("ṛ", "ऋ"), ("ṝ", "ॠ"), ("ḷ", "ऌ"), ("ḹ", "ॡ"),
252
+ ("a", "अ"), ("i", "इ"), ("u", "उ"),
253
+ ("e", "ए"), ("o", "ओ"),
254
+ ]
255
+ _IAST_MATRAS = [
256
+ ("ai", "ै"), ("au", "ौ"),
257
+ ("ā", "ा"), ("ī", "ी"), ("ū", "ू"),
258
+ ("ṛ", "ृ"), ("ṝ", "ॄ"), ("ḷ", "ॢ"), ("ḹ", "ॣ"),
259
+ ("a", ""), ("i", "ि"), ("u", "ु"),
260
+ ("e", "े"), ("o", "ो"),
261
+ ]
262
+ _IAST_CONS = [
263
+ ("kṣ", "क्ष"), ("jñ", "ज्ञ"), ("tr", "त्र"),
264
+ ("kh", "ख"), ("gh", "घ"), ("ch", "छ"), ("jh", "झ"),
265
+ ("ṭh", "ठ"), ("ḍh", "ढ"), ("th", "थ"), ("dh", "ध"),
266
+ ("ph", "फ"), ("bh", "भ"),
267
+ ("ṅ", "ङ"), ("ñ", "ञ"), ("ṭ", "ट"), ("ḍ", "ड"),
268
+ ("ṇ", "ण"), ("ś", "श"), ("ṣ", "ष"), ("ḥ", "ः"),
269
+ ("ṃ", "ं"), ("ṁ", "ं"),
270
+ ("y", "���"), ("r", "र"), ("l", "ल"), ("v", "व"),
271
+ ("s", "स"), ("h", "ह"),
272
+ ("k", "क"), ("g", "ग"), ("c", "च"), ("j", "ज"),
273
+ ("t", "त"), ("d", "द"), ("n", "न"),
274
+ ("p", "प"), ("b", "ब"), ("m", "म"),
275
+ ]
276
+ _PUNCT = {".": "।", "|": "।", "||": "॥", ",": ",", "?": "?", "!": "!"}
277
+
278
+
279
+ def _iast_to_deva(text: str) -> str:
280
+ s = (text or "").lower()
281
+ out = []
282
+ i = 0
283
+ pending_consonant = False
284
+
285
+ def _match_any(pairs, pos):
286
+ for k, v in pairs:
287
+ if s.startswith(k, pos):
288
+ return k, v
289
+ return None, None
290
+
291
+ while i < len(s):
292
+ if s[i].isspace():
293
+ pending_consonant = False
294
+ out.append(s[i])
295
+ i += 1
296
+ continue
297
+ if s[i:i+2] == "||":
298
+ pending_consonant = False
299
+ out.append(_PUNCT["||"])
300
+ i += 2
301
+ continue
302
+ if s[i] in _PUNCT:
303
+ pending_consonant = False
304
+ out.append(_PUNCT[s[i]])
305
+ i += 1
306
+ continue
307
+
308
+ v_key, v_deva = _match_any(_IAST_VOWELS, i)
309
+ if v_key:
310
+ if pending_consonant:
311
+ _, v_matra = _match_any(_IAST_MATRAS, i)
312
+ out[-1] = out[-1] + (v_matra or "")
313
+ pending_consonant = False
314
+ else:
315
+ out.append(v_deva)
316
+ i += len(v_key)
317
+ continue
318
+
319
+ c_key, c_deva = _match_any(_IAST_CONS, i)
320
+ if c_key:
321
+ if pending_consonant:
322
+ out[-1] = out[-1] + "्"
323
+ out.append(c_deva)
324
+ pending_consonant = True
325
+ i += len(c_key)
326
+ continue
327
+
328
+ out.append(s[i])
329
+ pending_consonant = False
330
+ i += 1
331
+
332
+ return "".join(out).strip()
333
+
334
+
335
+ def _compute_cer(pred: str, ref: str) -> float:
336
+ if pred == ref:
337
+ return 0.0
338
+ if not pred or not ref:
339
+ return 1.0
340
+ m, n = len(pred), len(ref)
341
+ dp = list(range(n + 1))
342
+ for i in range(1, m + 1):
343
+ prev = dp[0]
344
+ dp[0] = i
345
+ for j in range(1, n + 1):
346
+ temp = dp[j]
347
+ cost = 0 if pred[i - 1] == ref[j - 1] else 1
348
+ dp[j] = min(dp[j] + 1, dp[j - 1] + 1, prev + cost)
349
+ prev = temp
350
+ return dp[n] / max(m, n)
351
+
352
+
353
+ def _cleanup_thresholds(temperature: float, top_k: int):
354
+ temp = float(temperature)
355
+ k = max(1, int(top_k))
356
+ t_norm = max(0.0, min((temp - 0.4) / 0.6, 1.0))
357
+ k_norm = max(0.0, min((k - 20) / 80.0, 1.0))
358
+ diversity = 0.6 * t_norm + 0.4 * k_norm
359
+ cer_threshold = 0.10 + 0.18 * diversity
360
+ deva_ratio_threshold = 0.60 - 0.20 * diversity
361
+ return cer_threshold, deva_ratio_threshold
362
+
363
+
364
+ def _decode_with_cleanup(tgt_tok, ids, src_text: str, inf_cfg: dict):
365
+ model_out = _decode_clean(tgt_tok, ids)
366
+ rule_out = _iast_to_deva(src_text.strip())
367
+ deva_chars = sum(1 for ch in model_out if "\u0900" <= ch <= "\u097F")
368
+ deva_ratio = deva_chars / max(1, len(model_out))
369
+ cer = _compute_cer(model_out, rule_out)
370
+ cer_thr, ratio_thr = _cleanup_thresholds(
371
+ inf_cfg.get("temperature", 0.8),
372
+ inf_cfg.get("top_k", 40),
373
+ )
374
+ if deva_ratio < ratio_thr or len(model_out) > 2.0 * max(1, len(rule_out)) or cer > cer_thr:
375
+ return rule_out
376
+ return model_out
377
+
378
+
379
+ # ── Interactive demo ──────────────────────────────────────────────────
380
+
381
+ def interactive_demo(checkpoint=None, single_text=None):
382
+ from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer
383
+
384
+ cfg = CONFIG
385
+ device = _resolve_device(cfg['training'].get('device', 'cpu'))
386
+
387
+ model_name = cfg['model_type']
388
+ has_neg = cfg['data']['include_negative_examples']
389
+ ckpt = checkpoint or f"results/{model_name}_neg_{has_neg}/best_model.pt"
390
+
391
+ if not os.path.exists(ckpt):
392
+ raise FileNotFoundError(f"No checkpoint at {ckpt} — train first.")
393
+
394
+ model, cfg = load_model(ckpt, cfg, device)
395
+ model.eval()
396
+
397
+ src_tok = SanskritSourceTokenizer(
398
+ vocab_size=cfg['model'].get('src_vocab_size', 16000),
399
+ max_len=cfg['model']['max_seq_len'],
400
+ )
401
+ tgt_tok = SanskritTargetTokenizer(
402
+ vocab_size=cfg['model'].get('tgt_vocab_size', 16000),
403
+ max_len=cfg['model']['max_seq_len'],
404
+ )
405
+
406
+ print("\n" + "="*55)
407
+ print("Sanskrit D3PM Paraphrase — type verse, get paraphrase")
408
+ print("="*55 + "\n")
409
+
410
+ while True:
411
+ try:
412
+ text = (single_text if single_text is not None else input("INPUT > ")).strip()
413
+ except (EOFError, KeyboardInterrupt):
414
+ break
415
+ if not text or text.lower() in ('quit', 'exit', 'q'):
416
+ break
417
+
418
+ ids = torch.tensor(
419
+ [src_tok.encode(text)[:cfg['model']['max_seq_len']]],
420
+ dtype=torch.long, device=device
421
+ )
422
+ out = run_inference(model, ids, cfg)
423
+ cleaned = _decode_with_cleanup(tgt_tok, out[0].tolist(), text, cfg["inference"])
424
+ print(f"PARAPHRASE → {cleaned}\n")
425
+ if single_text is not None:
426
+ break
427
+
428
+
429
+ # ── Batch evaluation ──────────────────────────────────────────────────
430
+
431
+ def batch_evaluate(sample_size=500, checkpoint=None):
432
+ from data.dataset import OptimizedSanskritDataset
433
+ from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer
434
+
435
+ cfg = CONFIG
436
+ device = _resolve_device(cfg['training'].get('device', 'cpu'))
437
+
438
+ model_name = cfg['model_type']
439
+ has_neg = cfg['data']['include_negative_examples']
440
+ exp_dir = f"results/{model_name}_neg_{has_neg}"
441
+ ckpt = checkpoint or f"{exp_dir}/best_model.pt"
442
+
443
+ if not os.path.exists(ckpt):
444
+ raise FileNotFoundError(f"No checkpoint at {ckpt}")
445
+
446
+ model, cfg = load_model(ckpt, cfg, device)
447
+ model.eval()
448
+
449
+ src_tok = SanskritSourceTokenizer(
450
+ vocab_size=cfg['model'].get('src_vocab_size', 16000),
451
+ max_len=cfg['model']['max_seq_len'],
452
+ )
453
+ tgt_tok = SanskritTargetTokenizer(
454
+ vocab_size=cfg['model'].get('tgt_vocab_size', 16000),
455
+ max_len=cfg['model']['max_seq_len'],
456
+ )
457
+
458
+ def collate(batch):
459
+ return {
460
+ 'input_ids': torch.stack([b['input_ids'].long() for b in batch]),
461
+ 'target_text': [b['target_text'] for b in batch],
462
+ 'input_text': [b['input_text'] for b in batch],
463
+ }
464
+
465
+ dataset = OptimizedSanskritDataset(
466
+ split='test',
467
+ max_len=cfg['model']['max_seq_len'],
468
+ cfg=cfg,
469
+ src_tokenizer=src_tok,
470
+ tgt_tokenizer=tgt_tok,
471
+ )
472
+ indices = list(range(min(sample_size, len(dataset))))
473
+ loader = DataLoader(
474
+ Subset(dataset, indices),
475
+ batch_size=cfg['training']['batch_size'],
476
+ shuffle=False, collate_fn=collate
477
+ )
478
+
479
+ all_preds, all_refs, all_inputs = [], [], []
480
+ print(f"⏳ Generating {len(indices)} paraphrases …")
481
+
482
+ for batch in tqdm(loader):
483
+ ids = batch['input_ids'].to(device)
484
+ out = run_inference(model, ids, cfg)
485
+ for i in range(out.size(0)):
486
+ all_preds.append(_decode_with_cleanup(
487
+ tgt_tok, out[i].tolist(), batch['input_text'][i], cfg["inference"]
488
+ ))
489
+ all_refs.append(batch['target_text'][i].strip())
490
+ all_inputs.append(batch['input_text'][i].strip())
491
+
492
+ # Metrics
493
+ bleu_score, bert_f1 = 0.0, 0.0
494
+ try:
495
+ from nltk.translate.bleu_score import corpus_bleu
496
+ bleu_score = corpus_bleu(
497
+ [[r.split()] for r in all_refs],
498
+ [p.split() for p in all_preds]
499
+ )
500
+ except Exception:
501
+ pass
502
+
503
+ try:
504
+ import evaluate as hf_eval
505
+ res = hf_eval.load('bertscore').compute(
506
+ predictions=all_preds, references=all_refs, lang='hi'
507
+ )
508
+ bert_f1 = sum(res['f1']) / len(res['f1'])
509
+ except Exception:
510
+ pass
511
+
512
+ # Save
513
+ out_path = f"{exp_dir}/evaluation_results.txt"
514
+ pred_path = f"{exp_dir}/evaluation_predictions.jsonl"
515
+ with open(out_path, 'w', encoding='utf-8') as f:
516
+ f.write(f"Model : {model_name}\n")
517
+ f.write(f"Negatives: {has_neg}\n")
518
+ f.write(f"Steps : {cfg['inference']['num_steps']}\n")
519
+ f.write(f"Temp : {cfg['inference']['temperature']}\n")
520
+ f.write(f"RepPen : {cfg['inference']['repetition_penalty']}\n")
521
+ f.write(f"DivPen : {cfg['inference']['diversity_penalty']}\n")
522
+ f.write(f"BLEU : {bleu_score:.4f}\n")
523
+ f.write(f"BERTScore: {bert_f1:.4f}\n\n")
524
+ f.write("=== SAMPLES ===\n")
525
+ for i in range(min(20, len(all_preds))):
526
+ f.write(f"IN : {all_inputs[i]}\n")
527
+ f.write(f"REF : {all_refs[i]}\n")
528
+ f.write(f"PRED: {all_preds[i]}\n")
529
+ f.write("-" * 60 + "\n")
530
+
531
+ with open(pred_path, 'w', encoding='utf-8') as f:
532
+ for src, ref, pred in zip(all_inputs, all_refs, all_preds):
533
+ row = {"input": src, "reference": ref, "prediction": pred}
534
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
535
+
536
+ print(f"\n✅ Results → {out_path}")
537
+ print(f"🗂️ Saved predictions → {pred_path}")
538
+ print(f"📊 BLEU: {bleu_score:.4f} | BERTScore: {bert_f1:.4f}")
539
+ return all_preds, all_refs
540
+
541
+
542
+ if __name__ == '__main__':
543
+ import argparse
544
+ p = argparse.ArgumentParser()
545
+ p.add_argument('--mode', choices=['demo', 'eval'], default='demo')
546
+ p.add_argument('--samples', type=int, default=500)
547
+ p.add_argument('--checkpoint', type=str, default=None)
548
+ p.add_argument('--text', type=str, default=None, help='Run one-shot demo input and exit')
549
+ args = p.parse_args()
550
+
551
+ if args.mode == 'demo':
552
+ interactive_demo(checkpoint=args.checkpoint, single_text=args.text)
553
+ else:
554
+ batch_evaluate(args.samples, checkpoint=args.checkpoint)
inference_api.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import os
4
+ from typing import Dict, Any
5
+
6
+ import torch
7
+
8
+ from config import CONFIG
9
+ from inference import _build_tokenizers, _resolve_device, load_model, run_inference
10
+
11
+
12
+ _STATE = {
13
+ "loaded": False,
14
+ "model": None,
15
+ "cfg": None,
16
+ "device": None,
17
+ "src_tok": None,
18
+ "tgt_tok": None,
19
+ }
20
+
21
+
22
+ def _read_model_settings() -> Dict[str, Any]:
23
+ if not os.path.exists("model_settings.json"):
24
+ return {}
25
+ try:
26
+ with open("model_settings.json", "r", encoding="utf-8") as f:
27
+ data = json.load(f)
28
+ return data if isinstance(data, dict) else {}
29
+ except Exception:
30
+ return {}
31
+
32
+
33
+ def _load_once() -> None:
34
+ if _STATE["loaded"]:
35
+ return
36
+
37
+ settings = _read_model_settings()
38
+ cfg = copy.deepcopy(CONFIG)
39
+ cfg["model_type"] = os.environ.get(
40
+ "HF_MODEL_TYPE",
41
+ settings.get("model_type", "d3pm_cross_attention"),
42
+ )
43
+ cfg["data"]["include_negative_examples"] = (
44
+ os.environ.get(
45
+ "HF_INCLUDE_NEG",
46
+ str(settings.get("include_negative_examples", True)).lower(),
47
+ ).lower()
48
+ == "true"
49
+ )
50
+ num_steps_raw = os.environ.get("HF_NUM_STEPS", settings.get("num_steps"))
51
+ if num_steps_raw is not None:
52
+ num_steps = int(num_steps_raw)
53
+ cfg["model"]["diffusion_steps"] = num_steps
54
+ cfg["inference"]["num_steps"] = num_steps
55
+ device = _resolve_device(cfg)
56
+
57
+ model, cfg = load_model("best_model.pt", cfg, device)
58
+ src_tok, tgt_tok = _build_tokenizers(cfg)
59
+
60
+ _STATE["model"] = model
61
+ _STATE["cfg"] = cfg
62
+ _STATE["device"] = device
63
+ _STATE["src_tok"] = src_tok
64
+ _STATE["tgt_tok"] = tgt_tok
65
+ _STATE["loaded"] = True
66
+
67
+
68
+ def _clean_text(text: str) -> str:
69
+ text = " ".join(text.split())
70
+ if not text:
71
+ return text
72
+ toks = text.split()
73
+ out = []
74
+ prev = None
75
+ run = 0
76
+ for tok in toks:
77
+ if tok == prev:
78
+ run += 1
79
+ else:
80
+ prev = tok
81
+ run = 1
82
+ if run <= 2:
83
+ out.append(tok)
84
+ s = " ".join(out)
85
+ s = s.replace(" ।", "।").replace(" ॥", "॥")
86
+ return " ".join(s.split())
87
+
88
+
89
+ def predict(
90
+ text: str,
91
+ temperature: float = 0.7,
92
+ top_k: int = 40,
93
+ repetition_penalty: float = 1.2,
94
+ diversity_penalty: float = 0.0,
95
+ num_steps: int = 64,
96
+ clean_output: bool = True,
97
+ ) -> Dict[str, Any]:
98
+ _load_once()
99
+ if not text or not text.strip():
100
+ return {"error": "empty input", "output": ""}
101
+
102
+ cfg = copy.deepcopy(_STATE["cfg"])
103
+ cfg["inference"]["temperature"] = float(temperature)
104
+ cfg["inference"]["top_k"] = int(top_k)
105
+ cfg["inference"]["repetition_penalty"] = float(repetition_penalty)
106
+ cfg["inference"]["diversity_penalty"] = float(diversity_penalty)
107
+ cfg["inference"]["num_steps"] = int(num_steps)
108
+
109
+ src_tok = _STATE["src_tok"]
110
+ tgt_tok = _STATE["tgt_tok"]
111
+ device = _STATE["device"]
112
+
113
+ input_ids = torch.tensor([src_tok.encode(text.strip())], dtype=torch.long, device=device)
114
+ out = run_inference(_STATE["model"], input_ids, cfg)
115
+ decoded_ids = [x for x in out[0].tolist() if x > 4]
116
+ raw = tgt_tok.decode(decoded_ids).strip()
117
+ output = _clean_text(raw) if clean_output else raw
118
+
119
+ return {
120
+ "input": text,
121
+ "output": output,
122
+ "raw_output": raw,
123
+ "config": {
124
+ "temperature": float(temperature),
125
+ "top_k": int(top_k),
126
+ "repetition_penalty": float(repetition_penalty),
127
+ "diversity_penalty": float(diversity_penalty),
128
+ "num_steps": int(num_steps),
129
+ "clean_output": bool(clean_output),
130
+ },
131
+ }
model/__init__.py ADDED
File without changes
model/d3pm_model_cross_attention.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ d3pm_model_cross_attention.py — Cross-Script + Generation-Fixed
3
+ =================================================================
4
+ INPUT : quote_text tokens (Roman script, src_vocab_size)
5
+ OUTPUT : quote_devanagari tokens (Devanagari script, tgt_vocab_size)
6
+
7
+ src_embed uses src_vocab_size (Roman BPE)
8
+ tgt_embed uses tgt_vocab_size (Devanagari BPE)
9
+ head outputs tgt_vocab_size (predict Devanagari tokens)
10
+ Weight tying: head <-> tgt_embed only (NOT src_embed)
11
+
12
+ Generation bugs fixed:
13
+ BUG 1 - tgt_pad_mask suppressed during inference
14
+ BUG 2 - q_sample skipped at t=0
15
+ BUG 3 - time embedding before hint_gate
16
+ BUG 4 - diversity penalty uses global mean not var
17
+ """
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from diffusion.scheduler import OptimizedCosineScheduler
24
+ from diffusion.forward_process import AbsorbingForwardProcess
25
+
26
+
27
+ class SinusoidalPositionalEncoding(nn.Module):
28
+ def __init__(self, d_model, max_len=5000):
29
+ super().__init__()
30
+ pe = torch.zeros(max_len, d_model)
31
+ position = torch.arange(0, max_len).unsqueeze(1).float()
32
+ div_term = torch.exp(
33
+ torch.arange(0, d_model, 2).float() *
34
+ (-torch.log(torch.tensor(10000.0)) / d_model)
35
+ )
36
+ pe[:, 0::2] = torch.sin(position * div_term)
37
+ pe[:, 1::2] = torch.cos(position * div_term)
38
+ self.register_buffer("pe", pe.unsqueeze(0))
39
+
40
+ def forward(self, x):
41
+ return x + self.pe[:, :x.size(1), :]
42
+
43
+
44
+ class SanskritEmbeddings(nn.Module):
45
+ def __init__(self, vocab_size, d_model, max_seq_len):
46
+ super().__init__()
47
+ self.token_emb = nn.Embedding(vocab_size, d_model)
48
+ self.pos_enc = SinusoidalPositionalEncoding(d_model, max_seq_len)
49
+ self.token_embedding = self.token_emb
50
+ def forward(self, tokens):
51
+ return self.pos_enc(self.token_emb(tokens))
52
+
53
+
54
+ class MultiHeadAttention(nn.Module):
55
+ def __init__(self, d_model, n_heads, dropout=0.1):
56
+ super().__init__()
57
+ assert d_model % n_heads == 0
58
+ self.d_model = d_model
59
+ self.n_heads = n_heads
60
+ self.head_dim = d_model // n_heads
61
+ self.q_proj = nn.Linear(d_model, d_model)
62
+ self.k_proj = nn.Linear(d_model, d_model)
63
+ self.v_proj = nn.Linear(d_model, d_model)
64
+ self.out_proj = nn.Linear(d_model, d_model)
65
+ self.dropout = nn.Dropout(dropout)
66
+
67
+ def forward(self, q, k, v, mask=None):
68
+ B, Lq, _ = q.size()
69
+ Lk = k.size(1)
70
+ Q = self.q_proj(q).view(B, Lq, self.n_heads, self.head_dim).transpose(1, 2)
71
+ K = self.k_proj(k).view(B, Lk, self.n_heads, self.head_dim).transpose(1, 2)
72
+ V = self.v_proj(v).view(B, Lk, self.n_heads, self.head_dim).transpose(1, 2)
73
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
74
+ if mask is not None:
75
+ scores = scores.masked_fill(mask.unsqueeze(1).unsqueeze(2), float('-inf'))
76
+ attn = self.dropout(torch.softmax(scores, dim=-1))
77
+ out = torch.matmul(attn, V).transpose(1, 2).contiguous().view(B, Lq, self.d_model)
78
+ return self.out_proj(out)
79
+
80
+
81
+ class EncoderBlock(nn.Module):
82
+ def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
83
+ super().__init__()
84
+ self.mha = MultiHeadAttention(d_model, n_heads, dropout)
85
+ self.ff = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout),
86
+ nn.Linear(d_ff, d_model), nn.Dropout(dropout))
87
+ self.norm1 = nn.LayerNorm(d_model)
88
+ self.norm2 = nn.LayerNorm(d_model)
89
+ def forward(self, x, pad_mask=None):
90
+ x = self.norm1(x + self.mha(x, x, x, mask=pad_mask))
91
+ return self.norm2(x + self.ff(x))
92
+
93
+
94
+ class DecoderBlock(nn.Module):
95
+ def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
96
+ super().__init__()
97
+ self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
98
+ self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout)
99
+ self.ff = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout),
100
+ nn.Linear(d_ff, d_model), nn.Dropout(dropout))
101
+ self.norm1 = nn.LayerNorm(d_model)
102
+ self.norm2 = nn.LayerNorm(d_model)
103
+ self.norm3 = nn.LayerNorm(d_model)
104
+ def forward(self, x, memory, tgt_pad_mask=None, src_pad_mask=None):
105
+ x = self.norm1(x + self.self_attn(x, x, x, mask=tgt_pad_mask))
106
+ x = self.norm2(x + self.cross_attn(x, memory, memory, mask=src_pad_mask))
107
+ return self.norm3(x + self.ff(x))
108
+
109
+
110
+ class D3PMCrossAttention(nn.Module):
111
+ def __init__(self, cfg):
112
+ super().__init__()
113
+ self.cfg = cfg
114
+ self.mask_token_id = cfg['diffusion']['mask_token_id']
115
+ d = cfg['model']['d_model']
116
+ nhead = cfg['model']['n_heads']
117
+ d_ff = cfg['model']['d_ff']
118
+ drop = cfg['model']['dropout']
119
+ seqlen = cfg['model']['max_seq_len']
120
+ nlayer = cfg['model']['n_layers']
121
+ src_vocab = cfg['model'].get('src_vocab_size', cfg['model']['vocab_size'])
122
+ tgt_vocab = cfg['model'].get('tgt_vocab_size', cfg['model']['vocab_size'])
123
+
124
+ # Separate embeddings: Roman src, Devanagari tgt
125
+ self.src_embed = SanskritEmbeddings(src_vocab, d, seqlen)
126
+ self.tgt_embed = SanskritEmbeddings(tgt_vocab, d, seqlen)
127
+
128
+ self.scheduler = OptimizedCosineScheduler(cfg)
129
+ self.forward_process = AbsorbingForwardProcess(self.scheduler)
130
+
131
+ self.encoder_blocks = nn.ModuleList([EncoderBlock(d, nhead, d_ff, drop) for _ in range(nlayer)])
132
+ self.decoder_blocks = nn.ModuleList([DecoderBlock(d, nhead, d_ff, drop) for _ in range(nlayer)])
133
+
134
+ self.time_mlp = nn.Sequential(nn.Linear(1, d//4), nn.SiLU(), nn.Linear(d//4, d))
135
+ self.hint_gate = nn.Sequential(nn.Linear(d, d), nn.Sigmoid())
136
+
137
+ # Output head: predict Devanagari tokens, tied to tgt_embed
138
+ self.head = nn.Linear(d, tgt_vocab, bias=False)
139
+ self.head.weight = self.tgt_embed.token_embedding.weight
140
+
141
+ def forward(self, src, tgt, t, x0_hint=None, inference_mode=False):
142
+ PAD = 1
143
+ src_pad_mask = (src == PAD)
144
+ # BUG 1 FIX: no tgt mask during inference
145
+ tgt_pad_mask = None if inference_mode else (tgt == PAD)
146
+
147
+ # Encode Roman source
148
+ memory = self.src_embed(src)
149
+ for block in self.encoder_blocks:
150
+ memory = block(memory, pad_mask=src_pad_mask)
151
+
152
+ # BUG 2 FIX: skip q_sample at final step t=0
153
+ if inference_mode and (t == 0).all():
154
+ x_t_ids = tgt
155
+ else:
156
+ _, x_t_ids = self.forward_process.q_sample(tgt, t)
157
+
158
+ x = self.tgt_embed(x_t_ids)
159
+
160
+ # BUG 3 FIX: time embedding BEFORE hint gate
161
+ t_norm = t.float() / self.scheduler.num_timesteps
162
+ t_emb = self.time_mlp(t_norm.unsqueeze(-1))
163
+ x = x + t_emb.unsqueeze(1)
164
+
165
+ if x0_hint is not None:
166
+ hint_emb = self.tgt_embed(x0_hint)
167
+ gate = self.hint_gate(x) # time-aware gate
168
+ x = x + gate * hint_emb
169
+
170
+ for block in self.decoder_blocks:
171
+ x = block(x, memory, tgt_pad_mask=tgt_pad_mask, src_pad_mask=src_pad_mask)
172
+
173
+ return self.head(x), None
174
+
175
+ @torch.no_grad()
176
+ def generate(self, src, num_steps=None, temperature=0.8, top_k=50,
177
+ repetition_penalty=1.2, diversity_penalty=0.0):
178
+ if src.dim() == 1:
179
+ src = src.unsqueeze(0)
180
+ device = src.device
181
+ B, L = src.shape
182
+ T = self.scheduler.num_timesteps
183
+ steps = num_steps or T
184
+ step_size = max(1, T // steps)
185
+ timesteps = list(range(T - 1, -1, -step_size))
186
+ if timesteps[-1] != 0:
187
+ timesteps.append(0)
188
+
189
+ mask_id = self.mask_token_id
190
+ x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
191
+ hint = None
192
+
193
+ self.eval()
194
+ with torch.no_grad():
195
+ for step_idx, t_val in enumerate(timesteps):
196
+ t = torch.full((B,), t_val, dtype=torch.long, device=device)
197
+ is_last = (step_idx == len(timesteps) - 1)
198
+ logits, _ = self.forward(src, x0_est, t, x0_hint=hint, inference_mode=True)
199
+ if repetition_penalty != 1.0:
200
+ logits = _apply_repetition_penalty(logits, x0_est, repetition_penalty)
201
+ if diversity_penalty > 0.0:
202
+ logits = _apply_diversity_penalty_fixed(logits, diversity_penalty) # BUG 4 FIX
203
+ logits = logits / max(temperature, 1e-5)
204
+ if top_k > 0:
205
+ logits = _top_k_filter(logits, top_k)
206
+ probs = F.softmax(logits, dim=-1)
207
+ x0_est = torch.argmax(probs, dim=-1) if is_last else _batch_multinomial(probs)
208
+ hint = x0_est
209
+ return x0_est
210
+
211
+
212
+ class BaselineCrossAttention(nn.Module):
213
+ def __init__(self, cfg):
214
+ super().__init__()
215
+ d = cfg['model']['d_model']; nhead = cfg['model']['n_heads']
216
+ d_ff = cfg['model']['d_ff']; drop = cfg['model']['dropout']
217
+ seqlen = cfg['model']['max_seq_len']; nlayer = cfg['model']['n_layers']
218
+ src_vocab = cfg['model'].get('src_vocab_size', cfg['model']['vocab_size'])
219
+ tgt_vocab = cfg['model'].get('tgt_vocab_size', cfg['model']['vocab_size'])
220
+ self.src_embed = SanskritEmbeddings(src_vocab, d, seqlen)
221
+ self.tgt_embed = SanskritEmbeddings(tgt_vocab, d, seqlen)
222
+ self.encoder_blocks = nn.ModuleList([EncoderBlock(d, nhead, d_ff, drop) for _ in range(nlayer)])
223
+ self.decoder_blocks = nn.ModuleList([DecoderBlock(d, nhead, d_ff, drop) for _ in range(nlayer)])
224
+ self.head = nn.Linear(d, tgt_vocab, bias=False)
225
+ self.head.weight = self.tgt_embed.token_embedding.weight
226
+
227
+ def forward(self, src, tgt, t=None, x0_hint=None):
228
+ PAD = 1
229
+ memory = self.src_embed(src)
230
+ for b in self.encoder_blocks: memory = b(memory, pad_mask=(src==PAD))
231
+ x = self.tgt_embed(tgt)
232
+ for b in self.decoder_blocks: x = b(x, memory, tgt_pad_mask=(tgt==PAD), src_pad_mask=(src==PAD))
233
+ return (self.head(x),)
234
+
235
+ @torch.no_grad()
236
+ def generate(self, src, max_len=None, start_token_id=2, **kwargs):
237
+ if max_len is None: max_len = src.size(1)
238
+ B, device = src.size(0), src.device
239
+ memory = self.src_embed(src)
240
+ for b in self.encoder_blocks: memory = b(memory, pad_mask=(src==1))
241
+ ys = torch.full((B, 1), start_token_id, dtype=torch.long, device=device)
242
+ for _ in range(max_len):
243
+ x = self.tgt_embed(ys)
244
+ for b in self.decoder_blocks: x = b(x, memory, tgt_pad_mask=None, src_pad_mask=(src==1))
245
+ ys = torch.cat([ys, torch.argmax(self.head(x)[:,-1,:], dim=-1, keepdim=True)], dim=1)
246
+ return ys[:, 1:max_len+1]
247
+
248
+
249
+ # helpers
250
+ def _top_k_filter(logits, k):
251
+ B, L, V = logits.shape
252
+ if k >= V: return logits
253
+ topk_vals, _ = torch.topk(logits, k, dim=-1)
254
+ return logits.masked_fill(logits < topk_vals[..., -1].unsqueeze(-1), float('-inf'))
255
+
256
+ def _batch_multinomial(probs):
257
+ B, L, V = probs.shape
258
+ flat = probs.view(B*L, V) + 1e-9
259
+ return torch.multinomial(flat/flat.sum(-1,keepdim=True), 1).squeeze(-1).view(B, L)
260
+
261
+ def _apply_repetition_penalty(logits, prev_tokens, penalty):
262
+ for b in range(logits.shape[0]):
263
+ for tid in set(prev_tokens[b].tolist()):
264
+ if tid > 4: logits[b, :, tid] = logits[b, :, tid] / penalty
265
+ return logits
266
+
267
+ def _apply_diversity_penalty(logits, penalty): # legacy wrong version
268
+ return logits + penalty * logits.var(dim=-1, keepdim=True)
269
+
270
+ def _apply_diversity_penalty_fixed(logits, penalty): # correct version
271
+ return logits - penalty * logits.mean(dim=1, keepdim=True)
model/d3pm_model_encoder_decoder.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from diffusion.scheduler import OptimizedCosineScheduler
4
+ from diffusion.forward_process import AbsorbingForwardProcess
5
+ # Import shared classes to guarantee identical architectures
6
+ from model.d3pm_model_cross_attention import SanskritEmbeddings, EncoderBlock, MultiHeadAttention
7
+ class DecoderBlock(nn.Module):
8
+ def __init__(self, d_model, n_heads, d_ff, dropout=0.15):
9
+ super().__init__()
10
+ self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
11
+ self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout) # ← restored
12
+ self.ff = nn.Sequential(
13
+ nn.Linear(d_model, d_ff),
14
+ nn.ReLU(),
15
+ nn.Dropout(dropout),
16
+ nn.Linear(d_ff, d_model),
17
+ nn.Dropout(dropout),
18
+ )
19
+ self.norm1 = nn.LayerNorm(d_model)
20
+ self.norm2 = nn.LayerNorm(d_model)
21
+ self.norm3 = nn.LayerNorm(d_model) # ← restored (for cross-attn residual)
22
+
23
+ def forward(self, x, memory, tgt_pad_mask=None):
24
+ # 1. Masked self-attention on target
25
+ x = self.norm1(x + self.self_attn(x, x, x, mask=tgt_pad_mask))
26
+ # 2. Cross-attention: queries from decoder, keys/values from encoder memory
27
+ x = self.norm2(x + self.cross_attn(x, memory, memory))
28
+ # 3. Feed-forward
29
+ return self.norm3(x + self.ff(x))
30
+
31
+
32
+ class DecoderBlockNoCrossAttn(nn.Module):
33
+ """Kept for reference — NOT used by D3PMEncoderDecoder."""
34
+ def __init__(self, d_model, n_heads, d_ff, dropout=0.15):
35
+ super().__init__()
36
+ self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
37
+ self.ff = nn.Sequential(
38
+ nn.Linear(d_model, d_ff), nn.ReLU(), nn.Dropout(dropout),
39
+ nn.Linear(d_ff, d_model), nn.Dropout(dropout),
40
+ )
41
+ self.norm1, self.norm2 = nn.LayerNorm(d_model), nn.LayerNorm(d_model)
42
+
43
+ def forward(self, x, tgt_pad_mask=None, causal_mask=None):
44
+ combined_mask = None
45
+ if tgt_pad_mask is not None and causal_mask is not None:
46
+ combined_mask = tgt_pad_mask | causal_mask
47
+ elif causal_mask is not None:
48
+ combined_mask = causal_mask
49
+ elif tgt_pad_mask is not None:
50
+ combined_mask = tgt_pad_mask
51
+ x = self.norm1(x + self.self_attn(x, x, x, mask=combined_mask))
52
+ return self.norm2(x + self.ff(x))
53
+
54
+
55
+ # ============================================================
56
+ # 1. D3PM Encoder-Decoder Model
57
+ # ============================================================
58
+ class D3PMEncoderDecoder(nn.Module):
59
+ def __init__(self, cfg):
60
+ super().__init__()
61
+ self.cfg = cfg
62
+ self.mask_token_id = cfg['diffusion']['mask_token_id']
63
+
64
+ src_vocab = cfg['model'].get('src_vocab_size', cfg['model']['vocab_size'])
65
+ tgt_vocab = cfg['model'].get('tgt_vocab_size', cfg['model']['vocab_size'])
66
+ d_model = cfg['model']['d_model']
67
+ n_heads = cfg['model']['n_heads']
68
+ d_ff = cfg['model']['d_ff']
69
+ dropout = cfg['model']['dropout']
70
+ n_layers = cfg['model']['n_layers']
71
+ max_len = cfg['model']['max_seq_len']
72
+
73
+ self.src_embed = SanskritEmbeddings(src_vocab, d_model, max_len)
74
+ self.tgt_embed = SanskritEmbeddings(tgt_vocab, d_model, max_len)
75
+
76
+ self.scheduler = OptimizedCosineScheduler(cfg)
77
+ self.forward_process = AbsorbingForwardProcess(self.scheduler)
78
+
79
+ self.encoder_blocks = nn.ModuleList([
80
+ EncoderBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
81
+ ])
82
+ # DecoderBlock now has cross-attention — matches saved checkpoint
83
+ self.decoder_blocks = nn.ModuleList([
84
+ DecoderBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
85
+ ])
86
+
87
+ self.time_mlp = nn.Sequential(
88
+ nn.Linear(1, d_model // 4), nn.SiLU(),
89
+ nn.Linear(d_model // 4, d_model),
90
+ )
91
+ self.head = nn.Linear(d_model, tgt_vocab)
92
+ self.head.weight = self.tgt_embed.token_embedding.weight
93
+
94
+ def forward(self, src, tgt, t, x0_hint=None):
95
+ src_pad_mask = (src == 1)
96
+ tgt_pad_mask = (tgt == 1)
97
+
98
+ # Encode source (Roman IAST)
99
+ memory = self.src_embed(src)
100
+ for block in self.encoder_blocks:
101
+ memory = block(memory, pad_mask=src_pad_mask)
102
+
103
+ # Corrupt target with forward diffusion
104
+ _, x_t_ids = self.forward_process.q_sample(tgt, t)
105
+
106
+ # Optionally blend in x0_hint (self-conditioning)
107
+ if x0_hint is not None:
108
+ hint_prob = 0.5
109
+ blend_mask = (torch.rand(x_t_ids.shape, device=x_t_ids.device) < hint_prob)
110
+ still_mask = (x_t_ids == self.mask_token_id)
111
+ x_t_ids = torch.where(blend_mask & still_mask, x0_hint, x_t_ids)
112
+
113
+ x = self.tgt_embed(x_t_ids)
114
+ t_emb = self.time_mlp(t.float().unsqueeze(-1)).unsqueeze(1)
115
+ x = x + t_emb.expand(-1, tgt.shape[1], -1)
116
+
117
+ # Decode with cross-attention over encoder memory
118
+ for block in self.decoder_blocks:
119
+ x = block(x, memory, tgt_pad_mask=tgt_pad_mask)
120
+
121
+ return self.head(x), None
122
+
123
+ @torch.no_grad()
124
+ def generate(
125
+ self,
126
+ src,
127
+ num_steps = None,
128
+ temperature = 0.75,
129
+ top_k = 50,
130
+ repetition_penalty = 1.15,
131
+ diversity_penalty = 0.0,
132
+ ):
133
+ """
134
+ Iterative D3PM reverse diffusion — same signature as
135
+ D3PMCrossAttention.generate() so SanskritModel.generate() works
136
+ identically for both model types.
137
+ """
138
+ device = src.device
139
+ B, L = src.shape[0], self.cfg['model']['max_seq_len']
140
+ T = num_steps or self.scheduler.num_timesteps
141
+ mask_id = self.mask_token_id
142
+ pad_id = 1
143
+
144
+ x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
145
+
146
+ for step in range(T - 1, -1, -1):
147
+ t_tensor = torch.full((B,), step, dtype=torch.long, device=device)
148
+ hint = x0_est.clone()
149
+
150
+ logits, _ = self.forward(src, x0_est, t_tensor, x0_hint=hint)
151
+
152
+ # Repetition penalty
153
+ if repetition_penalty != 1.0:
154
+ for b in range(B):
155
+ for tok in set(x0_est[b].tolist()):
156
+ if tok > pad_id:
157
+ logits[b, :, tok] /= repetition_penalty
158
+
159
+ # Diversity penalty (suppress common tokens)
160
+ if diversity_penalty > 0.0:
161
+ logits = logits - diversity_penalty * logits.mean(dim=1, keepdim=True)
162
+
163
+ # Temperature + top-k sampling
164
+ logits = logits / max(temperature, 1e-8)
165
+ if top_k > 0:
166
+ vals, _ = torch.topk(logits, top_k, dim=-1)
167
+ logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
168
+
169
+ probs = torch.softmax(logits, dim=-1)
170
+ # Only update positions that are still masked
171
+ still = (x0_est == mask_id)
172
+ sample = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(B, L)
173
+ x0_est = torch.where(still, sample, x0_est)
174
+
175
+ return x0_est
176
+
177
+
178
+ # ============================================================
179
+ # 2. Baseline Encoder-Decoder Model (unchanged)
180
+ # ============================================================
181
+ class BaselineEncoderDecoder(nn.Module):
182
+ def __init__(self, cfg):
183
+ super().__init__()
184
+ self.cfg = cfg
185
+ self.src_embed = SanskritEmbeddings(cfg['model']['vocab_size'], cfg['model']['d_model'],
186
+ cfg['model']['max_seq_len'])
187
+ self.tgt_embed = SanskritEmbeddings(cfg['model']['vocab_size'], cfg['model']['d_model'],
188
+ cfg['model']['max_seq_len'])
189
+ self.encoder_blocks = nn.ModuleList([
190
+ EncoderBlock(cfg['model']['d_model'], cfg['model']['n_heads'],
191
+ cfg['model']['d_ff'], cfg['model']['dropout'])
192
+ for _ in range(cfg['model']['n_layers'])
193
+ ])
194
+ self.decoder_blocks = nn.ModuleList([
195
+ DecoderBlock(cfg['model']['d_model'], cfg['model']['n_heads'],
196
+ cfg['model']['d_ff'], cfg['model']['dropout'])
197
+ for _ in range(cfg['model']['n_layers'])
198
+ ])
199
+ self.head = nn.Linear(cfg['model']['d_model'], cfg['model']['vocab_size'])
200
+ self.head.weight = self.tgt_embed.token_embedding.weight
201
+
202
+ def forward(self, src, tgt):
203
+ src_pad_mask, tgt_pad_mask = (src == 1), (tgt == 1)
204
+ memory = self.src_embed(src)
205
+ for block in self.encoder_blocks:
206
+ memory = block(memory, pad_mask=src_pad_mask)
207
+ x = self.tgt_embed(tgt)
208
+ for block in self.decoder_blocks:
209
+ x = block(x, memory, tgt_pad_mask=tgt_pad_mask)
210
+ return self.head(x)
211
+
212
+ @torch.no_grad()
213
+ def generate(self, src, max_len=80, start_token_id=2):
214
+ batch_size, device = src.size(0), src.device
215
+ src_pad_mask = (src == 1)
216
+ memory = self.src_embed(src)
217
+ for block in self.encoder_blocks:
218
+ memory = block(memory, pad_mask=src_pad_mask)
219
+ ys = torch.ones(batch_size, 1, dtype=torch.long, device=device) * start_token_id
220
+ for _ in range(max_len):
221
+ x = self.tgt_embed(ys)
222
+ for block in self.decoder_blocks:
223
+ x = block(x, memory, tgt_pad_mask=None)
224
+ logits = self.head(x)
225
+ next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
226
+ ys = torch.cat([ys, next_token], dim=1)
227
+ return ys[:, 1:]
model/sanskrit_model.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ sanskrit_model.py — Fixed
3
+ ===========================
4
+ Added inference_mode parameter to forward() so reverse_process.py can
5
+ pass inference_mode=True without a TypeError.
6
+
7
+ The wrapper introspects each inner model's signature and only passes
8
+ kwargs that model actually accepts — safe across all four architectures.
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import inspect
14
+
15
+
16
+ class SanskritModel(nn.Module):
17
+ def __init__(self, cfg):
18
+ super().__init__()
19
+ model_type = cfg['model_type']
20
+
21
+ if model_type == 'd3pm_cross_attention':
22
+ from model.d3pm_model_cross_attention import D3PMCrossAttention
23
+ self.model = D3PMCrossAttention(cfg)
24
+
25
+ elif model_type == 'd3pm_encoder_decoder':
26
+ from model.d3pm_model_encoder_decoder import D3PMEncoderDecoder
27
+ self.model = D3PMEncoderDecoder(cfg)
28
+
29
+ elif model_type == 'baseline_cross_attention':
30
+ from model.d3pm_model_cross_attention import BaselineCrossAttention
31
+ self.model = BaselineCrossAttention(cfg)
32
+
33
+ elif model_type == 'baseline_encoder_decoder':
34
+ from model.d3pm_model_encoder_decoder import BaselineEncoderDecoder
35
+ self.model = BaselineEncoderDecoder(cfg)
36
+
37
+ else:
38
+ raise ValueError(f"Unknown model_type: {model_type}")
39
+
40
+ def forward(self, input_ids, target_ids, t, x0_hint=None, inference_mode=False):
41
+ """
42
+ Forward pass. Introspects the inner model's signature so only
43
+ supported kwargs are passed — works with all four architectures.
44
+ """
45
+ sig = inspect.signature(self.model.forward).parameters
46
+ kwargs = {}
47
+ if 'x0_hint' in sig:
48
+ kwargs['x0_hint'] = x0_hint
49
+ if 'inference_mode' in sig:
50
+ kwargs['inference_mode'] = inference_mode
51
+
52
+ if 't' in sig:
53
+ return self.model(input_ids, target_ids, t, **kwargs)
54
+ else:
55
+ return self.model(input_ids, target_ids, **kwargs)
56
+
57
+ @torch.no_grad()
58
+ def generate(self, src, **kwargs):
59
+ sig = inspect.signature(self.model.generate).parameters
60
+ filtered = {k: v for k, v in kwargs.items() if k in sig}
61
+ return self.model.generate(src, **filtered)
model/tokenizer.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ tokenizer.py — Dual Tokenizer Fix
3
+ ====================================
4
+ Two separate BPE tokenizers:
5
+
6
+ SanskritSourceTokenizer — trained on quote_text (Roman/IAST script)
7
+ SanskritTargetTokenizer — trained on quote_devanagari (Devanagari script)
8
+
9
+ WHY SEPARATE?
10
+ Roman Sanskrit and Devanagari are fundamentally different character sets.
11
+ Roman uses a-z + diacritics (~60 unique chars), Devanagari uses ā-ह + matras
12
+ (~100+ unique chars). A shared BPE tokenizer wastes half its vocab on
13
+ character combos that never cross scripts, and forces the embedding table
14
+ to encode both scripts in one space — confusing the model's cross-attention.
15
+
16
+ With separate tokenizers:
17
+ - src vocab captures Roman subwords cleanly (ā, ś, ṭ, ṃ etc.)
18
+ - tgt vocab captures Devanagari akshara clusters cleanly (क्ष, त्र, etc.)
19
+ - The model learns a true cross-script mapping in its cross-attention
20
+
21
+ SPECIAL TOKENS (same IDs in both):
22
+ [MASK] = 0 ← required by absorbing diffusion
23
+ [PAD] = 1
24
+ [UNK] = 2
25
+ [CLS] = 3
26
+ [SEP] = 4
27
+ """
28
+
29
+ from tokenizers import Tokenizer
30
+ from tokenizers.models import BPE
31
+ from tokenizers.trainers import BpeTrainer
32
+ from tokenizers.pre_tokenizers import Whitespace
33
+ from datasets import load_dataset
34
+ from pathlib import Path
35
+
36
+
37
+ SPECIAL_TOKENS = ["[MASK]", "[PAD]", "[UNK]", "[CLS]", "[SEP]"]
38
+
39
+
40
+ def _build_bpe(texts, vocab_size):
41
+ """Build a BPE tokenizer from an iterator of strings."""
42
+ tok = Tokenizer(BPE(unk_token="[UNK]"))
43
+ tok.pre_tokenizer = Whitespace()
44
+ trainer = BpeTrainer(
45
+ vocab_size=vocab_size,
46
+ special_tokens=SPECIAL_TOKENS, # [MASK] MUST be first → id=0
47
+ min_frequency=2,
48
+ )
49
+ tok.train_from_iterator(texts, trainer)
50
+ return tok
51
+
52
+
53
+ def _validate(tok, name):
54
+ mask_id = tok.token_to_id("[MASK]")
55
+ pad_id = tok.token_to_id("[PAD]")
56
+ assert mask_id == 0, f"{name}: [MASK] must be id=0, got {mask_id}"
57
+ assert pad_id == 1, f"{name}: [PAD] must be id=1, got {pad_id}"
58
+ print(f"✅ {name}: [MASK]=0, [PAD]=1 confirmed. Vocab size={tok.get_vocab_size()}")
59
+
60
+
61
+ # ── Source tokenizer (Roman/IAST Sanskrit) ────────────────────────────
62
+
63
+ class SanskritSourceTokenizer:
64
+ """
65
+ Tokenizer for quote_text — Roman transliteration of Sanskrit.
66
+ Examples: "dharmo rakṣati rakṣitaḥ", "yatra nāryastu pūjyante"
67
+ """
68
+ MODEL_PATH = "sanskrit_src_tokenizer.json"
69
+
70
+ def __init__(self, vocab_size=8000, max_len=80, n_train_samples=50000):
71
+ self.vocab_size = vocab_size
72
+ self.max_len = max_len
73
+ self.mask_token_id = 0
74
+
75
+ if Path(self.MODEL_PATH).exists():
76
+ print(f"📖 Loading source tokenizer from {self.MODEL_PATH} …")
77
+ self.tokenizer = Tokenizer.from_file(self.MODEL_PATH)
78
+ else:
79
+ print("🎓 Training source tokenizer on quote_text …")
80
+ self._train(vocab_size, n_train_samples)
81
+
82
+ _validate(self.tokenizer, "SrcTokenizer")
83
+
84
+ def _train(self, vocab_size, n_samples):
85
+ dataset = load_dataset("paws/sanskrit-verses-gretil", split="train")
86
+ n = min(n_samples, len(dataset))
87
+ texts = [s["quote_text"] for s in dataset.select(range(n))
88
+ if s["quote_text"].strip()]
89
+ self.tokenizer = _build_bpe(texts, vocab_size)
90
+ self.tokenizer.save(self.MODEL_PATH)
91
+ print(f"✅ Source tokenizer trained on {len(texts)} Roman texts.")
92
+
93
+ def encode(self, text):
94
+ ids = self.tokenizer.encode(text).ids[:self.max_len]
95
+ pad = self.tokenizer.token_to_id("[PAD]")
96
+ ids += [pad] * max(0, self.max_len - len(ids))
97
+ return ids[:self.max_len]
98
+
99
+ def decode(self, ids):
100
+ clean = [i for i in ids if i > 4] # skip special tokens
101
+ return self.tokenizer.decode(clean)
102
+
103
+ def __len__(self):
104
+ return self.vocab_size
105
+
106
+
107
+ # ── Target tokenizer (Devanagari Sanskrit) ───────────────────────────
108
+
109
+ class SanskritTargetTokenizer:
110
+ """
111
+ Tokenizer for quote_devanagari — Devanagari script.
112
+ Examples: "धर्मो रक्षति रक्षितः", "यत्र नार्यस्तु पूज्यन्ते"
113
+ """
114
+ MODEL_PATH = "sanskrit_tgt_tokenizer.json"
115
+
116
+ def __init__(self, vocab_size=8000, max_len=80, n_train_samples=50000):
117
+ self.vocab_size = vocab_size
118
+ self.max_len = max_len
119
+ self.mask_token_id = 0
120
+
121
+ if Path(self.MODEL_PATH).exists():
122
+ print(f"📖 Loading target tokenizer from {self.MODEL_PATH} …")
123
+ self.tokenizer = Tokenizer.from_file(self.MODEL_PATH)
124
+ else:
125
+ print("🎓 Training target tokenizer on quote_devanagari …")
126
+ self._train(vocab_size, n_train_samples)
127
+
128
+ _validate(self.tokenizer, "TgtTokenizer")
129
+
130
+ def _train(self, vocab_size, n_samples):
131
+ dataset = load_dataset("paws/sanskrit-verses-gretil", split="train")
132
+ n = min(n_samples, len(dataset))
133
+ texts = [s["quote_devanagari"] for s in dataset.select(range(n))
134
+ if s["quote_devanagari"].strip()]
135
+ self.tokenizer = _build_bpe(texts, vocab_size)
136
+ self.tokenizer.save(self.MODEL_PATH)
137
+ print(f"✅ Target tokenizer trained on {len(texts)} Devanagari texts.")
138
+
139
+ def encode(self, text):
140
+ ids = self.tokenizer.encode(text).ids[:self.max_len]
141
+ pad = self.tokenizer.token_to_id("[PAD]")
142
+ ids += [pad] * max(0, self.max_len - len(ids))
143
+ return ids[:self.max_len]
144
+
145
+ def decode(self, ids):
146
+ clean = [i for i in ids if i > 4]
147
+ return self.tokenizer.decode(clean)
148
+
149
+ # Methods required by BERTScore
150
+ def build_inputs_with_special_tokens(self, token_ids):
151
+ return list(token_ids)
152
+
153
+ def get_vocab(self):
154
+ return {str(i): i for i in range(self.vocab_size)}
155
+
156
+ def convert_ids_to_tokens(self, ids):
157
+ return [str(i) for i in ids]
158
+
159
+ def __len__(self):
160
+ return self.vocab_size
161
+
162
+
163
+ # ── Legacy shared tokenizer (kept for backward compat) ───────────────
164
+
165
+ class SanskritTokenizer:
166
+ """
167
+ LEGACY: single shared tokenizer trained on BOTH scripts.
168
+ Still works but suboptimal — use SanskritSourceTokenizer +
169
+ SanskritTargetTokenizer for the quote_text → quote_devanagari task.
170
+ """
171
+ MODEL_PATH = "sanskrit_tokenizer_m4pro.json"
172
+
173
+ def __init__(self, vocab_size=16000, max_len=80):
174
+ self.vocab_size = vocab_size
175
+ self.max_len = max_len
176
+ self.mask_token_id = 0
177
+
178
+ if Path(self.MODEL_PATH).exists():
179
+ print("📖 Loading shared tokenizer …")
180
+ self.tokenizer = Tokenizer.from_file(self.MODEL_PATH)
181
+ else:
182
+ print("🎓 Training shared tokenizer on both scripts …")
183
+ self._train(vocab_size)
184
+
185
+ _validate(self.tokenizer, "SharedTokenizer")
186
+
187
+ def _train(self, vocab_size):
188
+ dataset = load_dataset("paws/sanskrit-verses-gretil", split="train")
189
+ n = min(50000, len(dataset))
190
+ texts = []
191
+ for s in dataset.select(range(n)):
192
+ if s["quote_text"].strip():
193
+ texts.append(s["quote_text"])
194
+ if s["quote_devanagari"].strip():
195
+ texts.append(s["quote_devanagari"])
196
+ self.tokenizer = _build_bpe(texts, vocab_size)
197
+ self.tokenizer.save(self.MODEL_PATH)
198
+ print(f"✅ Shared tokenizer trained ({len(texts)} texts).")
199
+
200
+ def encode(self, text):
201
+ ids = self.tokenizer.encode(text).ids[:self.max_len]
202
+ pad = self.tokenizer.token_to_id("[PAD]")
203
+ ids += [pad] * max(0, self.max_len - len(ids))
204
+ return ids[:self.max_len]
205
+
206
+ def decode(self, ids):
207
+ if ids and isinstance(ids[0], list):
208
+ raise TypeError("decode() got 2D list — pass a 1D list.")
209
+ clean = [i for i in ids if i > 4]
210
+ return self.tokenizer.decode(clean)
211
+
212
+ def build_inputs_with_special_tokens(self, token_ids):
213
+ return list(token_ids)
214
+
215
+ def get_vocab(self):
216
+ return {str(i): i for i in range(self.vocab_size)}
217
+
218
+ def convert_ids_to_tokens(self, ids):
219
+ return [str(i) for i in ids]
220
+
221
+ def __len__(self):
222
+ return self.vocab_size
model/tokenizers.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ tokenizer.py — FINAL
3
+ =====================
4
+ Uses the original sanskrit_tokenizer_m4pro.json — the exact one the model
5
+ was trained with. Hard-coded absolute path as primary, with fallbacks.
6
+
7
+ This tokenizer has NO </w> end-of-word markers and NO decoder set.
8
+ decode() returns space-separated BPE pieces — this is the format the
9
+ model was trained and evaluated on (BERTScore 0.71). Do NOT add a decoder
10
+ or retrain: that would break alignment with the checkpoint.
11
+ """
12
+
13
+ from tokenizers import Tokenizer
14
+ from tokenizers.models import BPE
15
+ from tokenizers.trainers import BpeTrainer
16
+ from tokenizers.pre_tokenizers import Whitespace
17
+ from datasets import load_dataset
18
+ from pathlib import Path
19
+ import os
20
+
21
+ # Hard-coded absolute path — update if you move the project
22
+ TOKENIZER_PATH = "/Users/bhsingh/Documents/Final_Paraphrase/sanskrit_tokenizer_m4pro.json"
23
+
24
+
25
+ def build_tokenizer(texts, vocab_size=16000):
26
+ tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
27
+ tokenizer.pre_tokenizer = Whitespace()
28
+ trainer = BpeTrainer(
29
+ vocab_size=vocab_size,
30
+ special_tokens=["[MASK]", "[PAD]", "[UNK]", "[CLS]", "[SEP]"],
31
+ min_frequency=2,
32
+ )
33
+ tokenizer.train_from_iterator(texts, trainer)
34
+ return tokenizer
35
+
36
+
37
+ class SanskritTokenizer:
38
+ def __init__(self, vocab_size=16000, max_len=80):
39
+ self.vocab_size = vocab_size
40
+ self.max_len = max_len
41
+ self.mask_token_id = 0
42
+
43
+ script_dir = Path(__file__).resolve().parent
44
+ candidates = [
45
+ os.environ.get("SANSKRIT_TOKENIZER_PATH", ""),
46
+ TOKENIZER_PATH,
47
+ str(script_dir.parent / "sanskrit_tokenizer_m4pro.json"),
48
+ str(script_dir / "sanskrit_tokenizer_m4pro.json"),
49
+ str(Path.cwd() / "sanskrit_tokenizer_m4pro.json"),
50
+ ]
51
+
52
+ self.model_path = None
53
+ for c in candidates:
54
+ if c and Path(c).exists():
55
+ self.model_path = c
56
+ break
57
+
58
+ if self.model_path:
59
+ print(f"📖 Loading tokenizer from: {self.model_path}")
60
+ self.tokenizer = Tokenizer.from_file(self.model_path)
61
+ self._validate_mask_token()
62
+ else:
63
+ print(f"⚠️ Tokenizer not found at any candidate path.")
64
+ print(f" Expected: {TOKENIZER_PATH}")
65
+ print(" Retraining — WARNING: output will not match existing checkpoint!")
66
+ self.model_path = TOKENIZER_PATH
67
+ self._train_tokenizer()
68
+
69
+ def _validate_mask_token(self):
70
+ mask_id = self.tokenizer.token_to_id("[MASK]")
71
+ assert mask_id == 0, f"[MASK] must be ID 0, got {mask_id}"
72
+ print("✅ [MASK] token confirmed at ID=0")
73
+
74
+ def _train_tokenizer(self):
75
+ dataset = load_dataset("paws/sanskrit-verses-gretil", split="train")
76
+ texts = []
77
+ for sample in dataset.select(range(50000)):
78
+ texts.extend([sample["quote_text"], sample["quote_devanagari"]])
79
+ tokenizer = build_tokenizer(texts, self.vocab_size)
80
+ tokenizer.save(self.model_path)
81
+ self.tokenizer = tokenizer
82
+ self._validate_mask_token()
83
+ print(f"✅ Tokenizer saved to: {self.model_path}")
84
+
85
+ def encode(self, text):
86
+ encoded = self.tokenizer.encode(text)
87
+ token_ids = encoded.ids[:self.max_len]
88
+ pad_id = self.tokenizer.token_to_id("[PAD]")
89
+ if len(token_ids) < self.max_len:
90
+ token_ids += [pad_id] * (self.max_len - len(token_ids))
91
+ return token_ids[:self.max_len]
92
+
93
+ def decode(self, ids):
94
+ if isinstance(ids, list) and len(ids) > 0 and isinstance(ids[0], list):
95
+ raise TypeError("decode() expects 1D list of IDs, not 2D.")
96
+ # Filter special tokens: 0=MASK 1=PAD 2=UNK 3=CLS 4=SEP
97
+ clean = [i for i in ids if isinstance(i, int) and i > 4]
98
+ if not clean:
99
+ return ""
100
+ return self.tokenizer.decode(clean, skip_special_tokens=True).strip()
101
+
102
+ def build_inputs_with_special_tokens(self, token_ids):
103
+ return list(token_ids)
104
+
105
+ def get_vocab(self):
106
+ return {str(i): i for i in range(self.vocab_size)}
107
+
108
+ def convert_ids_to_tokens(self, ids):
109
+ return [str(i) for i in ids]
110
+
111
+ def __len__(self):
112
+ return self.vocab_size
model_settings.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "model_type": "d3pm_encoder_decoder",
3
+ "include_negative_examples": false,
4
+ "num_steps": 4
5
+ }
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=2.2
2
+ numpy>=1.24
3
+ tqdm>=4.66
4
+ datasets>=2.19
5
+ tokenizers>=0.15
6
+ scikit-learn>=1.3
sanskrit_src_tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
sanskrit_tgt_tokenizer.json ADDED
The diff for this file is too large to render. See raw diff