Upload folder using huggingface_hub
Browse files- .gitattributes +2 -34
- .gitignore +2 -0
- README.md +122 -0
- analysis_reports/outputs_all_models_20260325/T16/task1_kv_cache.txt +15 -0
- analysis_reports/outputs_all_models_20260325/T16/task2_report.txt +35 -0
- analysis_reports/outputs_all_models_20260325/T16/task3_report.txt +21 -0
- analysis_reports/outputs_all_models_20260325/T16/task4_report.txt +14 -0
- analysis_reports/outputs_all_models_20260325/T16/task5_report.txt +15 -0
- analysis_reports/outputs_all_models_20260325/T32/task1_kv_cache.txt +15 -0
- analysis_reports/outputs_all_models_20260325/T32/task2_report.txt +35 -0
- analysis_reports/outputs_all_models_20260325/T32/task3_report.txt +21 -0
- analysis_reports/outputs_all_models_20260325/T32/task4_report.txt +14 -0
- analysis_reports/outputs_all_models_20260325/T32/task5_report.txt +15 -0
- analysis_reports/outputs_all_models_20260325/T4/task1_kv_cache.txt +15 -0
- analysis_reports/outputs_all_models_20260325/T4/task2_report.txt +29 -0
- analysis_reports/outputs_all_models_20260325/T4/task3_report.txt +21 -0
- analysis_reports/outputs_all_models_20260325/T4/task4_report.txt +14 -0
- analysis_reports/outputs_all_models_20260325/T4/task5_report.txt +15 -0
- analysis_reports/outputs_all_models_20260325/T64/task1_kv_cache.txt +15 -0
- analysis_reports/outputs_all_models_20260325/T64/task2_report.txt +35 -0
- analysis_reports/outputs_all_models_20260325/T64/task3_report.txt +21 -0
- analysis_reports/outputs_all_models_20260325/T64/task4_report.txt +14 -0
- analysis_reports/outputs_all_models_20260325/T64/task5_report.txt +15 -0
- analysis_reports/outputs_all_models_20260325/T8/task1_kv_cache.txt +15 -0
- analysis_reports/outputs_all_models_20260325/T8/task2_report.txt +33 -0
- analysis_reports/outputs_all_models_20260325/T8/task3_report.txt +21 -0
- analysis_reports/outputs_all_models_20260325/T8/task4_report.txt +14 -0
- analysis_reports/outputs_all_models_20260325/T8/task5_report.txt +15 -0
- config.py +33 -0
- diffusion/__init__.py +0 -0
- diffusion/forward_process.py +21 -0
- diffusion/reverse_process.py +302 -0
- diffusion/reverse_process1.py +154 -0
- diffusion/reverse_process2.py +275 -0
- diffusion/scheduler.py +34 -0
- handler.py +30 -0
- inference.py +554 -0
- inference_api.py +131 -0
- model/__init__.py +0 -0
- model/d3pm_model_cross_attention.py +271 -0
- model/d3pm_model_encoder_decoder.py +227 -0
- model/sanskrit_model.py +61 -0
- model/tokenizer.py +222 -0
- model/tokenizers.py +112 -0
- model_settings.json +5 -0
- requirements.txt +6 -0
- sanskrit_src_tokenizer.json +0 -0
- sanskrit_tgt_tokenizer.json +0 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,3 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.
|
| 24 |
-
*.
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
README.md
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
language:
|
| 4 |
+
- sa
|
| 5 |
+
- en
|
| 6 |
+
tags:
|
| 7 |
+
- sanskrit
|
| 8 |
+
- paraphrase
|
| 9 |
+
- diffusion
|
| 10 |
+
- d3pm
|
| 11 |
+
- pytorch
|
| 12 |
+
pipeline_tag: text2text-generation
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
# Sanskrit D3PM Encoder-Decoder Model
|
| 16 |
+
|
| 17 |
+
Roman/IAST Sanskrit input to Devanagari output using a custom D3PM checkpoint.
|
| 18 |
+
This package is configured for the `d3pm_encoder_decoder` checkpoint stored in
|
| 19 |
+
`best_model.pt`.
|
| 20 |
+
Hugging Face model repo: `bhsinghgrid/devflow2`
|
| 21 |
+
|
| 22 |
+
## Files Included
|
| 23 |
+
|
| 24 |
+
- `best_model.pt` — trained checkpoint
|
| 25 |
+
- `model_settings.json` — packaged runtime metadata
|
| 26 |
+
- `config.py` — runtime config
|
| 27 |
+
- `inference.py` — model loading + generation loop
|
| 28 |
+
- `inference_api.py` — simple Python API (`predict`)
|
| 29 |
+
- `handler.py` — Hugging Face Endpoint handler
|
| 30 |
+
- `model/`, `diffusion/` — architecture modules
|
| 31 |
+
- `sanskrit_src_tokenizer.json`, `sanskrit_tgt_tokenizer.json` — tokenizers
|
| 32 |
+
|
| 33 |
+
## Quick Local Test
|
| 34 |
+
|
| 35 |
+
```python
|
| 36 |
+
from inference_api import predict
|
| 37 |
+
print(predict("dharmo rakṣati rakṣitaḥ")["output"])
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
## Runtime Settings
|
| 41 |
+
|
| 42 |
+
For local/API usage, the runtime first reads `model_settings.json`, then allows
|
| 43 |
+
optional environment variable overrides:
|
| 44 |
+
|
| 45 |
+
- `HF_MODEL_TYPE` = `d3pm_cross_attention` or `d3pm_encoder_decoder`
|
| 46 |
+
- `HF_INCLUDE_NEG` = `true` or `false`
|
| 47 |
+
- `HF_NUM_STEPS` = diffusion step count for the packaged checkpoint
|
| 48 |
+
|
| 49 |
+
Packaged settings for this repo:
|
| 50 |
+
|
| 51 |
+
```bash
|
| 52 |
+
export HF_MODEL_TYPE=d3pm_encoder_decoder
|
| 53 |
+
export HF_INCLUDE_NEG=false
|
| 54 |
+
export HF_NUM_STEPS=4
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
## Use This Model In A Hugging Face Space
|
| 58 |
+
|
| 59 |
+
In your Space settings, set:
|
| 60 |
+
|
| 61 |
+
- `HF_CHECKPOINT_REPO=bhsinghgrid/devflow2`
|
| 62 |
+
- `HF_CHECKPOINT_FILE=best_model.pt`
|
| 63 |
+
|
| 64 |
+
If your Space reads model metadata automatically, no extra model-type variables
|
| 65 |
+
are required. If it does not, also set:
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
HF_DEFAULT_MODEL_TYPE=d3pm_encoder_decoder
|
| 69 |
+
HF_DEFAULT_INCLUDE_NEG=false
|
| 70 |
+
HF_DEFAULT_NUM_STEPS=4
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
## Transformer-Style Usage (Custom Runtime)
|
| 74 |
+
|
| 75 |
+
This checkpoint is a custom D3PM architecture (`.pt`), not a native `transformers`
|
| 76 |
+
`AutoModel` format. Use it via the provided runtime:
|
| 77 |
+
|
| 78 |
+
```python
|
| 79 |
+
import torch
|
| 80 |
+
from config import CONFIG
|
| 81 |
+
from inference import load_model, run_inference, _decode_clean
|
| 82 |
+
from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer
|
| 83 |
+
|
| 84 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 85 |
+
model, cfg = load_model("best_model.pt", CONFIG, device)
|
| 86 |
+
|
| 87 |
+
src_tok = SanskritSourceTokenizer(vocab_size=16000, max_len=cfg["model"]["max_seq_len"])
|
| 88 |
+
tgt_tok = SanskritTargetTokenizer(vocab_size=16000, max_len=cfg["model"]["max_seq_len"])
|
| 89 |
+
|
| 90 |
+
text = "dharmo rakṣati rakṣitaḥ"
|
| 91 |
+
ids = torch.tensor([src_tok.encode(text)], dtype=torch.long, device=device)
|
| 92 |
+
out = run_inference(model, ids, cfg)
|
| 93 |
+
print(_decode_clean(tgt_tok, out[0].tolist()))
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
If you need full `transformers` compatibility (`AutoModel.from_pretrained`),
|
| 97 |
+
export weights to a Hugging Face Transformers model format first.
|
| 98 |
+
|
| 99 |
+
## Endpoint Payload
|
| 100 |
+
|
| 101 |
+
```json
|
| 102 |
+
{
|
| 103 |
+
"inputs": "yadā mano nivarteta viṣayebhyaḥ svabhāvataḥ",
|
| 104 |
+
"parameters": {
|
| 105 |
+
"temperature": 0.7,
|
| 106 |
+
"top_k": 40,
|
| 107 |
+
"repetition_penalty": 1.2,
|
| 108 |
+
"diversity_penalty": 0.0,
|
| 109 |
+
"num_steps": 4,
|
| 110 |
+
"clean_output": true
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
## Push This Folder To Model Hub
|
| 116 |
+
|
| 117 |
+
```bash
|
| 118 |
+
cd hf_model_repo_encoder_decoder
|
| 119 |
+
git add .
|
| 120 |
+
git commit -m "Add encoder-decoder T4 model package"
|
| 121 |
+
git push -u hf main
|
| 122 |
+
```
|
analysis_reports/outputs_all_models_20260325/T16/task1_kv_cache.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 1 — KV CACHE BENCHMARK
|
| 2 |
+
========================================
|
| 3 |
+
|
| 4 |
+
has_generate_cached=True
|
| 5 |
+
memory_profile=Torch CPU mem-event reduction: 30.4% @ src_len=64 (std=2143.0MB, cache=1492.1MB)
|
| 6 |
+
|
| 7 |
+
src_len standard(s) cached(s) speedup encoder%
|
| 8 |
+
16 0.893 0.571 1.56x 40.0%
|
| 9 |
+
32 0.751 0.509 1.48x 42.3%
|
| 10 |
+
64 1.141 0.822 1.39x 40.7%
|
| 11 |
+
|
| 12 |
+
Saved graphs:
|
| 13 |
+
- task1_time_comparison.png
|
| 14 |
+
- task1_speedup.png
|
| 15 |
+
- task1_encoder_cost.png
|
analysis_reports/outputs_all_models_20260325/T16/task2_report.txt
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 2 — ATTENTION + DRIFT REPORT
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Input : dharmo rakṣati rakṣitaḥ
|
| 5 |
+
Output: धर्मो रक्षति रक्षितः
|
| 6 |
+
|
| 7 |
+
Captured steps: 16
|
| 8 |
+
Analysis quality: WEAK
|
| 9 |
+
Final output uniq-ratio: 1.000
|
| 10 |
+
Degenerate output: NO
|
| 11 |
+
Multi-sample semantic score (n<=8): 0.1471
|
| 12 |
+
Lock-in step (CER<=0.05): t=0
|
| 13 |
+
Locked tokens: 38 Flexible tokens: 42
|
| 14 |
+
TF-IDF vs attention stability corr: 0.9294
|
| 15 |
+
TF-IDF status: OK
|
| 16 |
+
|
| 17 |
+
Saved graphs:
|
| 18 |
+
- task2_attn_t*.png / task2_all_layers_t0.png
|
| 19 |
+
- task2_attn_evolution.png
|
| 20 |
+
- task2_semantic_drift.png
|
| 21 |
+
- task2_source_alignment.png
|
| 22 |
+
- task2_tfidf_vs_attention.png
|
| 23 |
+
|
| 24 |
+
Step trajectory (first 10 rows)
|
| 25 |
+
------------------------------------------------------------
|
| 26 |
+
t= 15 bert=0.0475 drift=0.9525 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
|
| 27 |
+
t= 14 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
|
| 28 |
+
t= 13 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
|
| 29 |
+
t= 12 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
|
| 30 |
+
t= 11 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
|
| 31 |
+
t= 10 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
|
| 32 |
+
t= 9 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
|
| 33 |
+
t= 8 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
|
| 34 |
+
t= 7 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
|
| 35 |
+
t= 6 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
|
analysis_reports/outputs_all_models_20260325/T16/task3_report.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 3 — CONCEPT VECTORS + PCA STEERING
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
PCA: 50 components, 74.8% variance
|
| 5 |
+
Diversity PC: 0 (|r|=0.325 with diversity proxy)
|
| 6 |
+
|
| 7 |
+
Direction validity: WEAK
|
| 8 |
+
Spectrum unique ratio (mean over 5 seeds): 1.000
|
| 9 |
+
Spectrum semantic stability (mean over 5 seeds): 0.312
|
| 10 |
+
|
| 11 |
+
Saved graphs:
|
| 12 |
+
- task3_concept_space.png
|
| 13 |
+
- task3_pca_explained_variance.png
|
| 14 |
+
- task3_diversity_curve.png
|
| 15 |
+
|
| 16 |
+
Diversity spectrum:
|
| 17 |
+
alpha=-2.0 → बले वेध विवर् धान वीर्य वीर्य धिं सिंहा भि̱ सन वस्तु वेध वै वेध वस्तु सन सन सिंहा सिंहा वीर्य वीर्य वस्तु सन रुते प्रभवति मन वेध बले बले र्वृ प्रपूजयेत् युगा मलि धान तुल वीर्य वीर्य वीर्य वीर्य वीर्य वीर्य धान तुल कालेन युगा वेध बले वेध वेध च्छे ष्मस् यस्या काष्ठा ज्ञप्त अर्णव धिं धिं वस्तु धिं सन तया सन सन देवाः देवाः स्वातन्त्र अर्णव मह वस्तु मुष् सन धिं धिं धिं विक्र त्र मह हस्ते च्छे मह
|
| 18 |
+
alpha=-1.0 → बले र् अ तुल वीर्य वीर्य गुरु सिंहा सन सन विलेप वै वै वै गतस्य वेध सन सिंहा सिंहा स्य स्य । सन वै वै वै बले बले बले बले र् अ अ तुल तुल वीर्य वीर्य वीर्य वीर्य वीर्य वीर्य तुल तुल तुल ् बले वेध दिव्यां मान वै अप्सु सन ॥ ॥ वस्तु सिंहा सन सन विक्र सन स काष्ठा सन सन सन कार सन सन सन सन भ बल ु सिंहा सन सिंहा सन म् म् सन
|
| 19 |
+
alpha=+0.0 → बले र् अ तुल वीर्य वीर्य स्य सिंहा सन सन पितो वै वै वै दक्षिणां सन सन सिंहा सिंहा स्य स्य स्य सन गतस्य वै वै ॥ बले बले र् र् अ अ । तुल वीर्य वीर्य वीर्य वीर्य वीर्य तुल तुल तुल तुल अ स बले बले वै वै ॥ ॥ ॥ सन सन सिंहा स सन सन सन सन सन सन सन सन सन सन ॥ ॥ सन सन शतैः ॥ सिंहा सिंहा द सिंहा सन त् सन
|
| 20 |
+
alpha=+1.0 → बले र् अ अ विशुद्धं स्य स्य सिंहा सिंहा सन गतस्य वै वै वै वेत्ति सन सन सिंहा स्य स्य स्य स्य सन वै वै स मल बले बले र् र् व अ अ तुल वीर्य वीर्य वीर्य स्य वीर्य स्य तुल ानु अ अ । र् व ॥ वै वै सन द ॥ ॥ सिंहा सिंहा ॥ सं सन ॥ ॥ व ॥ ॥ हेम सन सन व ॥ ै ॥ वै भ न न ॥ मित्रो सिंहा सन
|
| 21 |
+
alpha=+2.0 → आविश र् अ किंचिद् वर स्य स्य सिंहा सं निमे ञ् सं वै वै ञ् सन कृपा सिंहा स्य स्य स्य स्य फणा ञ् वै ौ जिह्व बले मानाः र् र् वराय अ माने वर विशुद्धं स्य स्य स्य – वर विशुद्धं व वर अ कृपा ॥ परम् ॥ कश्चि वै ॥ ञ् ञ् सं स्य स्य तम् व प्रवर्तन्ते कर्मसु परम् वर ते ॥ व ञ् ॥ ॥ सं द ॥ ॥ वर न्द ̱व ॥ व व ै
|
analysis_reports/outputs_all_models_20260325/T16/task4_report.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 4 — SEMANTIC ROBUSTNESS ABLATION
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Optimal diffusion steps = 16
|
| 5 |
+
|
| 6 |
+
T BERT-F1 SEM_SIM BLEU sec/sample
|
| 7 |
+
--------------------------------------------------------
|
| 8 |
+
16 0.2574 0.0580 0.0007 0.9068
|
| 9 |
+
|
| 10 |
+
Marginal gains (BERT-F1):
|
| 11 |
+
|
| 12 |
+
Saved plots/files:
|
| 13 |
+
- task4_3d.png
|
| 14 |
+
- task4_raw_results.json
|
analysis_reports/outputs_all_models_20260325/T16/task5_report.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 5 — CLASSIFIER-FREE GUIDANCE
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Classifier params: 139521
|
| 5 |
+
Training samples : 40
|
| 6 |
+
|
| 7 |
+
Guidance scale sweep:
|
| 8 |
+
λ CER diversity d2 sBLEU
|
| 9 |
+
----------------------------------------------------
|
| 10 |
+
0.0 0.8336 0.808 0.624 0.007 ← optimal
|
| 11 |
+
0.5 0.8362 0.800 0.606 0.007
|
| 12 |
+
1.0 0.8390 0.798 0.601 0.005
|
| 13 |
+
1.5 0.8458 0.813 0.631 0.004
|
| 14 |
+
2.0 0.8531 0.828 0.660 0.004
|
| 15 |
+
3.0 0.8773 0.830 0.669 0.009
|
analysis_reports/outputs_all_models_20260325/T32/task1_kv_cache.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 1 — KV CACHE BENCHMARK
|
| 2 |
+
========================================
|
| 3 |
+
|
| 4 |
+
has_generate_cached=True
|
| 5 |
+
memory_profile=Torch CPU mem-event reduction: 31.1% @ src_len=64 (std=4287.2MB, cache=2953.9MB)
|
| 6 |
+
|
| 7 |
+
src_len standard(s) cached(s) speedup encoder%
|
| 8 |
+
16 1.914 1.165 1.64x 39.6%
|
| 9 |
+
32 1.542 0.891 1.73x 42.1%
|
| 10 |
+
64 2.096 1.475 1.42x 42.7%
|
| 11 |
+
|
| 12 |
+
Saved graphs:
|
| 13 |
+
- task1_time_comparison.png
|
| 14 |
+
- task1_speedup.png
|
| 15 |
+
- task1_encoder_cost.png
|
analysis_reports/outputs_all_models_20260325/T32/task2_report.txt
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 2 — ATTENTION + DRIFT REPORT
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Input : dharmo rakṣati rakṣitaḥ
|
| 5 |
+
Output: धर्मो रक्षति रक्षितः
|
| 6 |
+
|
| 7 |
+
Captured steps: 32
|
| 8 |
+
Analysis quality: WEAK
|
| 9 |
+
Final output uniq-ratio: 1.000
|
| 10 |
+
Degenerate output: NO
|
| 11 |
+
Multi-sample semantic score (n<=8): 0.0627
|
| 12 |
+
Lock-in step (CER<=0.05): t=0
|
| 13 |
+
Locked tokens: 75 Flexible tokens: 5
|
| 14 |
+
TF-IDF vs attention stability corr: -0.0869
|
| 15 |
+
TF-IDF status: WEAK
|
| 16 |
+
|
| 17 |
+
Saved graphs:
|
| 18 |
+
- task2_attn_t*.png / task2_all_layers_t0.png
|
| 19 |
+
- task2_attn_evolution.png
|
| 20 |
+
- task2_semantic_drift.png
|
| 21 |
+
- task2_source_alignment.png
|
| 22 |
+
- task2_tfidf_vs_attention.png
|
| 23 |
+
|
| 24 |
+
Step trajectory (first 10 rows)
|
| 25 |
+
------------------------------------------------------------
|
| 26 |
+
t= 31 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
|
| 27 |
+
t= 30 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
|
| 28 |
+
t= 29 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
|
| 29 |
+
t= 28 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
|
| 30 |
+
t= 27 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
|
| 31 |
+
t= 26 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
|
| 32 |
+
t= 25 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
|
| 33 |
+
t= 24 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
|
| 34 |
+
t= 23 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
|
| 35 |
+
t= 22 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
|
analysis_reports/outputs_all_models_20260325/T32/task3_report.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 3 — CONCEPT VECTORS + PCA STEERING
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
PCA: 50 components, 94.6% variance
|
| 5 |
+
Diversity PC: 0 (|r|=-0.530 with diversity proxy)
|
| 6 |
+
|
| 7 |
+
Direction validity: WEAK
|
| 8 |
+
Spectrum unique ratio (mean over 5 seeds): 0.840
|
| 9 |
+
Spectrum semantic stability (mean over 5 seeds): 0.234
|
| 10 |
+
|
| 11 |
+
Saved graphs:
|
| 12 |
+
- task3_concept_space.png
|
| 13 |
+
- task3_pca_explained_variance.png
|
| 14 |
+
- task3_diversity_curve.png
|
| 15 |
+
|
| 16 |
+
Diversity spectrum:
|
| 17 |
+
alpha=-2.0 → ेन श्रे श्रे ेन श्रे अण्ड व्याः श्रे तन्त्रा ॥ ॥ ॥ व्याः व्याः व्याः तद्वद् तद्वद् तद्वद् ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ तद्वद् ॥ ॥ ॥ ॥ ॥ व्याः व्याः व्याः ॥ ॥ राजन्य व्याः व्याः व्याः ॥ व्याः व्याः ॥ ॥ काम्य ॥ ॥ ॥ व्याः ॥ तद्वद् ॥ ॥ ॥ ॥ ॥ तन्त्रा तन्त्रा ॥ ॥ ॥ ॥ व्याः ॥ ॥ ॥ ॥ ॥ युधम् तद्वद् युधम् ॥
|
| 18 |
+
alpha=-1.0 → श्रे श्रे श्रे ेन श्रे श्रे श्रे श्रे अण्ड तन्त्रा व्याः ॥ अण्ड अण्ड तन्त्रा व्याः तद्वद् ॥ व्याः ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ अण्ड ॥ ॥ ॥ व्याः ॥ व्याः नो̍ ॥ ॥ ॥ ॥ ॥ व्याः व्याः अण्ड ॥ ॥ तन्त्रा ॥ ॥ तद्वद् युधम् रोमा शम्भु ॥ धूमं तन्त्रा ॥ तन्त्रा ॥ व्याः ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥
|
| 19 |
+
alpha=+0.0 → अण्ड श्रे करः श्रे तन्त्रा करः करः तन्त्रा श्रे अण्ड अण्ड अण्ड ॥ श्रे तद्वद् अण्ड ॥ ॥ अण्ड ॥ ॥ ॥ ॥ ॥ ॥ ॥ अण्ड ॥ ॥ ॥ ॥ अण्ड ॥ ॥ ॥ ॥ ॥ ॥ राजन्य तन्त्रा नो̍ ॥ ॥ ॥ ॥ ॥ व्याः ॥ अण्ड ॥ काम्य ॥ ॥ ॥ ॥ ॥ शम्भु धूमं तन्त्रा तन्त्रा ेन ॥ काम्य ॥ ॥ करः तन्त्रा ॥ अण्ड ॥ अण्ड ॥ विनिर्जित्य ॥ ॥ ॥ तन्त्रा अण्ड तद्वद् करः
|
| 20 |
+
alpha=+1.0 → माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण
|
| 21 |
+
alpha=+2.0 → माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण
|
analysis_reports/outputs_all_models_20260325/T32/task4_report.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 4 — SEMANTIC ROBUSTNESS ABLATION
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Optimal diffusion steps = 32
|
| 5 |
+
|
| 6 |
+
T BERT-F1 SEM_SIM BLEU sec/sample
|
| 7 |
+
--------------------------------------------------------
|
| 8 |
+
32 0.0422 0.0012 0.0000 1.8451
|
| 9 |
+
|
| 10 |
+
Marginal gains (BERT-F1):
|
| 11 |
+
|
| 12 |
+
Saved plots/files:
|
| 13 |
+
- task4_3d.png
|
| 14 |
+
- task4_raw_results.json
|
analysis_reports/outputs_all_models_20260325/T32/task5_report.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 5 — CLASSIFIER-FREE GUIDANCE
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Classifier params: 139521
|
| 5 |
+
Training samples : 40
|
| 6 |
+
|
| 7 |
+
Guidance scale sweep:
|
| 8 |
+
λ CER diversity d2 sBLEU
|
| 9 |
+
----------------------------------------------------
|
| 10 |
+
0.0 0.9357 0.239 0.011 0.533 ← optimal
|
| 11 |
+
0.5 0.9372 0.251 0.015 0.512
|
| 12 |
+
1.0 0.9467 0.164 0.018 0.690
|
| 13 |
+
1.5 0.9528 0.137 0.017 0.743
|
| 14 |
+
2.0 0.9525 0.144 0.013 0.725
|
| 15 |
+
3.0 0.9496 0.181 0.018 0.656
|
analysis_reports/outputs_all_models_20260325/T4/task1_kv_cache.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 1 — KV CACHE BENCHMARK
|
| 2 |
+
========================================
|
| 3 |
+
|
| 4 |
+
has_generate_cached=True
|
| 5 |
+
memory_profile=Torch CPU mem-event reduction: 24.6% @ src_len=64 (std=525.8MB, cache=396.4MB)
|
| 6 |
+
|
| 7 |
+
src_len standard(s) cached(s) speedup encoder%
|
| 8 |
+
16 0.267 0.173 1.54x 43.2%
|
| 9 |
+
32 0.197 0.153 1.29x 40.7%
|
| 10 |
+
64 0.353 0.265 1.33x 42.0%
|
| 11 |
+
|
| 12 |
+
Saved graphs:
|
| 13 |
+
- task1_time_comparison.png
|
| 14 |
+
- task1_speedup.png
|
| 15 |
+
- task1_encoder_cost.png
|
analysis_reports/outputs_all_models_20260325/T4/task2_report.txt
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 2 — ATTENTION + DRIFT REPORT
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Input : dharmo rakṣati rakṣitaḥ
|
| 5 |
+
Output: धर्मो रक्षति रक्षितः
|
| 6 |
+
|
| 7 |
+
Captured steps: 4
|
| 8 |
+
Analysis quality: VALID
|
| 9 |
+
Final output uniq-ratio: 1.000
|
| 10 |
+
Degenerate output: NO
|
| 11 |
+
Multi-sample semantic score (n<=8): 0.1568
|
| 12 |
+
Lock-in step (CER<=0.05): t=0
|
| 13 |
+
Locked tokens: 79 Flexible tokens: 1
|
| 14 |
+
TF-IDF vs attention stability corr: 0.9472
|
| 15 |
+
TF-IDF status: OK
|
| 16 |
+
|
| 17 |
+
Saved graphs:
|
| 18 |
+
- task2_attn_t*.png / task2_all_layers_t0.png
|
| 19 |
+
- task2_attn_evolution.png
|
| 20 |
+
- task2_semantic_drift.png
|
| 21 |
+
- task2_source_alignment.png
|
| 22 |
+
- task2_tfidf_vs_attention.png
|
| 23 |
+
|
| 24 |
+
Step trajectory (first 10 rows)
|
| 25 |
+
------------------------------------------------------------
|
| 26 |
+
t= 3 bert=0.0603 drift=0.9397 text=ति ति ति रक्षि तः तः तः तः तः धर्मो धर्मो धर्मो धर्मो धर्मो
|
| 27 |
+
t= 2 bert=0.0597 drift=0.9403 text=ति ति ति रक्षि तः तः तः तः तः धर्मो धर्मो धर्मो धर्मो धर्मो
|
| 28 |
+
t= 1 bert=0.0597 drift=0.9403 text=ति ति ति रक्षि तः तः तः तः तः धर्मो धर्मो धर्मो धर्मो धर्मो
|
| 29 |
+
t= 0 bert=0.0597 drift=0.9403 text=ति ति ति रक्षि तः तः तः तः तः धर्मो धर्मो धर्मो धर्मो धर्मो
|
analysis_reports/outputs_all_models_20260325/T4/task3_report.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 3 — CONCEPT VECTORS + PCA STEERING
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
PCA: 50 components, 72.0% variance
|
| 5 |
+
Diversity PC: 0 (|r|=-0.349 with diversity proxy)
|
| 6 |
+
|
| 7 |
+
Direction validity: WEAK
|
| 8 |
+
Spectrum unique ratio (mean over 5 seeds): 1.000
|
| 9 |
+
Spectrum semantic stability (mean over 5 seeds): 0.325
|
| 10 |
+
|
| 11 |
+
Saved graphs:
|
| 12 |
+
- task3_concept_space.png
|
| 13 |
+
- task3_pca_explained_variance.png
|
| 14 |
+
- task3_diversity_curve.png
|
| 15 |
+
|
| 16 |
+
Diversity spectrum:
|
| 17 |
+
alpha=-2.0 → बले र् अपश्य येहि ऌ वीर्य ऌ सिंहा सन सन ̍त̱ ज्ज्वा माम् वै वै महर्द्धि महर्द्धि ऌ सिंहा कू दिक्षु ऌ दश्य वै क्रमं बले र् दश्य स्वस्थ तुल तुल वीर्य वीर्य वी ऌ सिंहा राज कू वीर्य वीर्य वीर्य वीर्य ऌ वी निरुद्धा ̍त̱ बले बले साध्व उपशान्त वी वी दाक्षि हतः महर्द्धि साध्व तु वी वी ऌ दिक्षु दिक्षु पूष माम् पुरं ऌ दिक्षु वी पूष ̍त̱ ोद् दिक्षु पुरं स्त्रं मनोरथ अस्मा ऌ वाहि राजान वी
|
| 18 |
+
alpha=-1.0 → बले बले अ तुल तुल वीर्य स्य सिंहा सन सन गतस्य गतस्य वै वै वै गतस्य सन पाता सिंहा दिता । ज्ज्वा वै वै बले बले र् अ अ तुल तुल वीर्य वीर्य स्य सिंहा सिंहा ध्रा स्य वीर्य वीर्य वीर्य तुल तुल ̍त̱ अ र् र् बले दिक्षु वै वै । वै संस्थिता रतं सन गतस्य पूष । वक्त्र सन सन सन सन सन व गतस्य व सन ॥ ति मनो हतः मातु ̍त̱ व कू कू सन सन
|
| 19 |
+
alpha=+0.0 → बले र् अ तुल तुल वीर्य स्य सिंहा सन सन गतस्य गतस्य वै वै वै गतस्य सन सन सिंहा सिंहा । व वै वै त्वम् बले र् अ अ तुल तुल वीर्य वीर्य स्य स्य स्य सिंहा स्य स्य वीर्य वीर्य तुल तुल तुल अ र् र् र् त्ते वै वै गतस्य सन सन सन सन सन गतस्य निःसृ गतस्य सन गतस्य सन सन सन सन सन वि सन वि स्रव सिंहा सन सन सन सन सन सन सन गतस्य
|
| 20 |
+
alpha=+1.0 → बले र् अ अ तुल वीर्य स्य सिंहा सन सन गतस्य गतस्य वै वै वै गतस्य सन सन सिंहा सिंहा षण् स्य ै वै बले बले र् अ अ तुल कान्ते षण् वीर्य स्य स्य सिंहा स्य स्य स्य वीर्य षण् वीर्य षण् अ अ र् र् र् षण् ेष गतस्य गतस्य गतस्य गतस्य सन सन षण् षण् गतस्य सन गतस्य गतस्य सन सन सन सन ष्णु गतस्य नो षण् नो - सन सन सन सन सन सन सन सन
|
| 21 |
+
alpha=+2.0 → षण् र् अ षण् षण् षण् स्य षण् षण् षण् षण् षण् वै षण् षण् गतस्य षण् षण् षण् षण् षण् षण् षण् स षण् षण् र् षण् अ षण् षण् षण् षण् स्य स्य षण् स्य स्य स्य षण् षण् षण् षण् षण् अ र् षण् षण् षण् षण् षण् षण् षण् षण् षण् षण् षण् षण् षण् षण् षण् षण् षण् गतस्य षण् षण् षण् षण् षण् षण् षण् षण् षण् षण् षण् षण् षण् षण् षण् षण्
|
analysis_reports/outputs_all_models_20260325/T4/task4_report.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 4 — SEMANTIC ROBUSTNESS ABLATION
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Optimal diffusion steps = 4
|
| 5 |
+
|
| 6 |
+
T BERT-F1 SEM_SIM BLEU sec/sample
|
| 7 |
+
--------------------------------------------------------
|
| 8 |
+
4 0.2644 0.0574 0.0000 0.2782
|
| 9 |
+
|
| 10 |
+
Marginal gains (BERT-F1):
|
| 11 |
+
|
| 12 |
+
Saved plots/files:
|
| 13 |
+
- task4_3d.png
|
| 14 |
+
- task4_raw_results.json
|
analysis_reports/outputs_all_models_20260325/T4/task5_report.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 5 — CLASSIFIER-FREE GUIDANCE
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Classifier params: 139521
|
| 5 |
+
Training samples : 40
|
| 6 |
+
|
| 7 |
+
Guidance scale sweep:
|
| 8 |
+
λ CER diversity d2 sBLEU
|
| 9 |
+
----------------------------------------------------
|
| 10 |
+
0.0 0.8366 0.815 0.635 0.005
|
| 11 |
+
0.5 0.8356 0.797 0.599 0.004 ← optimal
|
| 12 |
+
1.0 0.8369 0.791 0.588 0.006
|
| 13 |
+
1.5 0.8367 0.783 0.571 0.006
|
| 14 |
+
2.0 0.8367 0.774 0.553 0.005
|
| 15 |
+
3.0 0.8363 0.769 0.542 0.005
|
analysis_reports/outputs_all_models_20260325/T64/task1_kv_cache.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 1 — KV CACHE BENCHMARK
|
| 2 |
+
========================================
|
| 3 |
+
|
| 4 |
+
has_generate_cached=True
|
| 5 |
+
memory_profile=Torch CPU mem-event reduction: 31.4% @ src_len=64 (std=8592.3MB, cache=5890.5MB)
|
| 6 |
+
|
| 7 |
+
src_len standard(s) cached(s) speedup encoder%
|
| 8 |
+
16 4.206 3.584 1.17x 74.4%
|
| 9 |
+
32 4.647 3.371 1.38x 37.6%
|
| 10 |
+
64 8.403 4.593 1.83x 49.6%
|
| 11 |
+
|
| 12 |
+
Saved graphs:
|
| 13 |
+
- task1_time_comparison.png
|
| 14 |
+
- task1_speedup.png
|
| 15 |
+
- task1_encoder_cost.png
|
analysis_reports/outputs_all_models_20260325/T64/task2_report.txt
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 2 — ATTENTION + DRIFT REPORT
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Input : dharmo rakṣati rakṣitaḥ
|
| 5 |
+
Output: धर्मो रक्षति रक्षितः
|
| 6 |
+
|
| 7 |
+
Captured steps: 64
|
| 8 |
+
Analysis quality: VALID
|
| 9 |
+
Final output uniq-ratio: 1.000
|
| 10 |
+
Degenerate output: NO
|
| 11 |
+
Multi-sample semantic score (n<=8): 0.1490
|
| 12 |
+
Lock-in step (CER<=0.05): t=0
|
| 13 |
+
Locked tokens: 59 Flexible tokens: 21
|
| 14 |
+
TF-IDF vs attention stability corr: 0.7804
|
| 15 |
+
TF-IDF status: OK
|
| 16 |
+
|
| 17 |
+
Saved graphs:
|
| 18 |
+
- task2_attn_t*.png / task2_all_layers_t0.png
|
| 19 |
+
- task2_attn_evolution.png
|
| 20 |
+
- task2_semantic_drift.png
|
| 21 |
+
- task2_source_alignment.png
|
| 22 |
+
- task2_tfidf_vs_attention.png
|
| 23 |
+
|
| 24 |
+
Step trajectory (first 10 rows)
|
| 25 |
+
------------------------------------------------------------
|
| 26 |
+
t= 63 bert=0.0552 drift=0.9448 text=धर्मो ति काम्य तः तः तः तः तः तः धर्मो धर्मो धर्मो धर्मो धर्
|
| 27 |
+
t= 62 bert=0.0548 drift=0.9452 text=धर्मो ति काम्य तः तः तः तः तः तः धर्मो धर्मो धर्मो धर्मो धर्
|
| 28 |
+
t= 61 bert=0.0548 drift=0.9452 text=धर्मो ति काम्य तः तः तः तः तः तः धर्मो धर्मो धर्मो धर्मो धर्
|
| 29 |
+
t= 60 bert=0.0548 drift=0.9452 text=धर्मो ति काम्य तः तः तः तः तः तः धर्मो धर्मो धर्मो धर्मो धर्
|
| 30 |
+
t= 59 bert=0.0548 drift=0.9452 text=धर्मो ति काम्य तः तः तः तः तः तः धर्मो धर्मो धर्मो धर्मो धर्
|
| 31 |
+
t= 58 bert=0.0548 drift=0.9452 text=धर्मो ति काम्य तः तः तः तः तः तः धर्मो धर्मो धर्मो धर्मो धर्
|
| 32 |
+
t= 57 bert=0.0548 drift=0.9452 text=धर्मो ति काम्य तः तः तः तः तः तः धर्मो धर्मो धर्मो धर्मो धर्
|
| 33 |
+
t= 56 bert=0.0546 drift=0.9454 text=धर्मो ति काम्य तः तः तः तः तः तः धर्मो धर्मो धर्मो धर्मो धर्
|
| 34 |
+
t= 55 bert=0.0546 drift=0.9454 text=धर्मो ति काम्य तः तः तः तः तः तः धर्मो धर्मो धर्मो धर्मो धर्
|
| 35 |
+
t= 54 bert=0.0546 drift=0.9454 text=धर्मो ति काम्य तः तः तः तः तः तः धर्मो धर्मो धर्मो धर्मो धर्
|
analysis_reports/outputs_all_models_20260325/T64/task3_report.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 3 — CONCEPT VECTORS + PCA STEERING
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
PCA: 39 components, 100.0% variance
|
| 5 |
+
Diversity PC: 0 (|r|=0.314 with diversity proxy)
|
| 6 |
+
|
| 7 |
+
Direction validity: WEAK
|
| 8 |
+
Spectrum unique ratio (mean over 5 seeds): 1.000
|
| 9 |
+
Spectrum semantic stability (mean over 5 seeds): 0.302
|
| 10 |
+
|
| 11 |
+
Saved graphs:
|
| 12 |
+
- task3_concept_space.png
|
| 13 |
+
- task3_pca_explained_variance.png
|
| 14 |
+
- task3_diversity_curve.png
|
| 15 |
+
|
| 16 |
+
Diversity spectrum:
|
| 17 |
+
alpha=-2.0 → बले बले अ तुल तुल वीर्य अ̱स्या भूयः सन ान्ते लब्ध्वा अर्थ वै भूयः त्ति ान्त सन भूयः भूयः तुल अ अ वीरो बले बले बले अ̱स्या भूयः ॥ ॥ अ अ अ̱स्या ॥ वे̱ध सन सन सन सन ह् ह् ान ह् स्य ह् ानु ह् यो बो दादि ह् मतां ह् सान्त्व ह् ( मतां ॥ धीमान् भूयः ॥ अ̱स्या पान होमयेत् सारथि ( ॥ भूयः ॥ ॥ ॥ ॥ ॥ ॥ गोप लज्ज अ̱स्या मतां लज्ज यो
|
| 18 |
+
alpha=-1.0 → बले बले अ तुल तुल वीर्य स्य सिंहा सिंहा सन त्ति वै वै वै द् सन सिंहा स्य स्य तुल तुल अ प्रयाति र् बले बले बले ॥ ॥ ॥ अ ॥ पि महा सन सन सन सन सन सन सिंहा स्रव सन स्य मीं गोप स्य स्य स्य स्य भूयः तुल सान्त्व यो अ ह् ान ान तव वेग ( यो भूषणम् ( ानु ॥ ॥ ॥ ॥ ॥ ॥ अ̱स्या ॥ ॥ ॥ ॥ यो पि ॥ म
|
| 19 |
+
alpha=+0.0 → बले र् अ तुल तुल वीर्य स्य सिंहा सिंहा सन ध्या वै वै वै गतस्य भ सिंहा भ स्य तुल तुल अ अ र् बले बले बले ॥ बले र् । । ॥ वै वै सन सन सन सन सिंहा सिंहा सिंहा स्य स्य स्य स्य भ स्य स्य स्य स्य ानु स्य ् ता यो स्य फल ॥ म तुल च सि ॥ ् ॥ ॥ न् ॥ ॥ ॥ ॥ ॥ ॥ ॥ महा सन सन क्ष ॥
|
| 20 |
+
alpha=+1.0 → बले र् । तुल यु वीर्य स्य सिंहा सिंहा सन ध्या वै वै वै । सन सिंहा स्य स्य तुल तुल । र् र् बले बले बले ॥ ॥ र् अ स् सन ते सन ीं सन सन त्र सिंहा यु सिंहा स्थल स्य स्य रौद्र स्य स्य न्दा ता यु स्य यु त्र क्ष ।। ीं स्य म्र कल्प यत् स् क्ष क्ष ॥ स्य यु मण्डलं यु ॥ ॥ ीं ॥ ॥ भ्यः ीं ीं ॥ ॥ ॥
|
| 21 |
+
alpha=+2.0 → र् र् तुरङ्ग आहुः ितो । स्य सिंहा सिंहा सिंहा ब्रह्मा वै वै & ते तस् तुरङ्ग नो स्तम्भ ीं यु संच र् र् बले ीं स्तम्भ ते । तस् न्तं मण्डलं यु । स्तम्भ स्तम्भ सन आहुः सिंहा यु सिंहा सिंहा स्य मण्डलं यु स्य स्य स्य एव कल्प स्तम्भ ̱र स्तम्भ अमु आहुः यु ̱र कल्प यु तुरङ्ग यु तुरङ्ग ̱र तुरङ्ग रणम् मण्डलं यु ीं मण्डलं दिनं ̱र यु ॥ तुरङ्ग ितः आहुः ॥ मण्डलं आहुः क्षमः
|
analysis_reports/outputs_all_models_20260325/T64/task4_report.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 4 — SEMANTIC ROBUSTNESS ABLATION
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Optimal diffusion steps = 64
|
| 5 |
+
|
| 6 |
+
T BERT-F1 SEM_SIM BLEU sec/sample
|
| 7 |
+
--------------------------------------------------------
|
| 8 |
+
64 0.2482 0.0580 0.0007 5.6116
|
| 9 |
+
|
| 10 |
+
Marginal gains (BERT-F1):
|
| 11 |
+
|
| 12 |
+
Saved plots/files:
|
| 13 |
+
- task4_3d.png
|
| 14 |
+
- task4_raw_results.json
|
analysis_reports/outputs_all_models_20260325/T64/task5_report.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 5 — CLASSIFIER-FREE GUIDANCE
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Classifier params: 139521
|
| 5 |
+
Training samples : 20
|
| 6 |
+
|
| 7 |
+
Guidance scale sweep:
|
| 8 |
+
λ CER diversity d2 sBLEU
|
| 9 |
+
----------------------------------------------------
|
| 10 |
+
0.0 0.8451 0.838 0.689 0.013 ← optimal
|
| 11 |
+
0.5 0.8490 0.818 0.650 0.013
|
| 12 |
+
1.0 0.8509 0.838 0.684 0.007
|
| 13 |
+
1.5 0.8622 0.857 0.720 0.005
|
| 14 |
+
2.0 0.8761 0.869 0.744 0.005
|
| 15 |
+
3.0 0.9056 0.814 0.642 0.013
|
analysis_reports/outputs_all_models_20260325/T8/task1_kv_cache.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 1 — KV CACHE BENCHMARK
|
| 2 |
+
========================================
|
| 3 |
+
|
| 4 |
+
has_generate_cached=True
|
| 5 |
+
memory_profile=Torch CPU mem-event reduction: 25.8% @ src_len=64 (std=1168.9MB, cache=866.9MB)
|
| 6 |
+
|
| 7 |
+
src_len standard(s) cached(s) speedup encoder%
|
| 8 |
+
16 0.582 0.400 1.45x 35.9%
|
| 9 |
+
32 0.511 0.402 1.27x 37.7%
|
| 10 |
+
64 0.666 0.490 1.36x 35.6%
|
| 11 |
+
|
| 12 |
+
Saved graphs:
|
| 13 |
+
- task1_time_comparison.png
|
| 14 |
+
- task1_speedup.png
|
| 15 |
+
- task1_encoder_cost.png
|
analysis_reports/outputs_all_models_20260325/T8/task2_report.txt
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 2 — ATTENTION + DRIFT REPORT
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Input : dharmo rakṣati rakṣitaḥ
|
| 5 |
+
Output: धर्मो रक्षति रक्षितः
|
| 6 |
+
|
| 7 |
+
Captured steps: 8
|
| 8 |
+
Analysis quality: WEAK
|
| 9 |
+
Final output uniq-ratio: 1.000
|
| 10 |
+
Degenerate output: NO
|
| 11 |
+
Multi-sample semantic score (n<=8): 0.0915
|
| 12 |
+
Lock-in step (CER<=0.05): t=0
|
| 13 |
+
Locked tokens: 79 Flexible tokens: 1
|
| 14 |
+
TF-IDF vs attention stability corr: 0.8905
|
| 15 |
+
TF-IDF status: OK
|
| 16 |
+
|
| 17 |
+
Saved graphs:
|
| 18 |
+
- task2_attn_t*.png / task2_all_layers_t0.png
|
| 19 |
+
- task2_attn_evolution.png
|
| 20 |
+
- task2_semantic_drift.png
|
| 21 |
+
- task2_source_alignment.png
|
| 22 |
+
- task2_tfidf_vs_attention.png
|
| 23 |
+
|
| 24 |
+
Step trajectory (first 10 rows)
|
| 25 |
+
------------------------------------------------------------
|
| 26 |
+
t= 7 bert=0.0219 drift=0.9781 text=ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः
|
| 27 |
+
t= 6 bert=0.0225 drift=0.9775 text=ं ं ं ं ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः
|
| 28 |
+
t= 5 bert=0.0225 drift=0.9775 text=ं ं ं ं ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः
|
| 29 |
+
t= 4 bert=0.0225 drift=0.9775 text=ं ं ं ं ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः
|
| 30 |
+
t= 3 bert=0.0225 drift=0.9775 text=ं ं ं ं ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः
|
| 31 |
+
t= 2 bert=0.0227 drift=0.9773 text=ं ं ं ं ं ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ित
|
| 32 |
+
t= 1 bert=0.0228 drift=0.9772 text=ं ं ं ं ं ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ित
|
| 33 |
+
t= 0 bert=0.0228 drift=0.9772 text=ं ं ं ं ं ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ितः ित
|
analysis_reports/outputs_all_models_20260325/T8/task3_report.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 3 — CONCEPT VECTORS + PCA STEERING
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
PCA: 50 components, 75.9% variance
|
| 5 |
+
Diversity PC: 0 (|r|=-0.344 with diversity proxy)
|
| 6 |
+
|
| 7 |
+
Direction validity: WEAK
|
| 8 |
+
Spectrum unique ratio (mean over 5 seeds): 1.000
|
| 9 |
+
Spectrum semantic stability (mean over 5 seeds): 0.341
|
| 10 |
+
|
| 11 |
+
Saved graphs:
|
| 12 |
+
- task3_concept_space.png
|
| 13 |
+
- task3_pca_explained_variance.png
|
| 14 |
+
- task3_diversity_curve.png
|
| 15 |
+
|
| 16 |
+
Diversity spectrum:
|
| 17 |
+
alpha=-2.0 → मनसः श्चक्र स्य स्य स्य अ स्य तैः तैः तैः स्य श्चक्र तैः गतभी स्य स्य स्य श्चक्र तैः तैः श्चक्र स्य तैः स्त्वं श्चक्र श्चक्र स्त्र तैः तैः कुण्ठ तैः तैः स्य तैः तैः तैः स्य तैः तैः गतभी तैः तैः णि̍ स्य तैः तैः तैः अ तैः तैः ह्वये मनसः ॥ तैः तैः गतभी ॥ श्चक्र तैः तैः तैः तैः तैः तैः तैः तैः स्य तैः करिष्या तैः स्त्वं तैः तैः श्चक्र तैः तैः श्चक्र तैः तैः ह्वये
|
| 18 |
+
alpha=-1.0 → स्य अ तैः वै तैः अ वेद् मनसः स्य । । तैः तैः स्य गतभी स्य अ स्त्वं सीद् स्य तैः स्य तैः सु̱म् सीद् र्ध कृतानि गतभी गतभी तैः तैः स्य तैः तैः तैः मनसः तैः कृतानि तैः तैः सु̱म् अ तैः मनसः अ मनसः स्य अ तैः ॥ ॥ स्य गतभी गतभी ॥ ॥ वै तैः तैः मनसः तैः अ तैः तैः च वर स्य तैः या वात् स्य तैः सीद् तैः स्य तैः स्य अ तैः तैः
|
| 19 |
+
alpha=+0.0 → अ अ वै ज्ञ स्य अ ज्ञ गतभी वर द शिख मन्त्र गतभी सु̱म् । द द स्य मन्त्र वा यो सीद् ज्ञ वै अ स्य स्य मन्त्र स्य मन्त्र स्य गतभी । गतभी गतभी तैः गतभी कृत तैः स्य तैः ॥ वै तैः ॥ वै अ कृतानि स्य वर वै ॥ ॥ वै ॥ अ ॥ स्य ॥ वै स्य ज्ञ ॥ स्य तैः तैः वै स्य स्य अ स्य तैः वै स्य तैः प्रण तीरे स्य । सीद्
|
| 20 |
+
alpha=+1.0 → पम वै तुल्य शत्रू पम शिख वर अ णाः परा णाः स्य कृत प्रिय । भिन् णाः ज्ञ वै विराज वै गणो वै ्या अ वै पम ्या भिन् वै लब्ध शोभ स्य च श वर वै ॥ वै क्षिप्य शिख भिर् ॥ सन वा मन्त्र मृ ॥ ॥ ॥ वै ॥ मन्त्र ॥ पम सङ् वर शोभ क्षिप्य भिर् स्य क्षिप्य वै सन वर शिख वै शिख वर दर्श शिख कलं पम ौ वर कलं भिर् कलं वै शिख
|
| 21 |
+
alpha=+2.0 → पम पम लब्ध पम शोभ पम परे भिर् अन्य णाः रसा लब्ध पम शोभ लब्ध शोभ पम शत्रू शिख भिन् पम पम पम णाः शोभ णाः पम शोभ शोभ परे णाः णाः पम पम पम पम शोभ पम शोभ शोभ पम डा णाः शोभ वै पम वै ॥ पम ॥ पम णाः ॥ ॥ परे अन्य णाः पम शोभ शोभ शोभ शोभ णाः वै परे कलं परे वै पम णाः पम शोभ णाः णाः कलं परे शोभ णाः कलं शत्रू
|
analysis_reports/outputs_all_models_20260325/T8/task4_report.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 4 — SEMANTIC ROBUSTNESS ABLATION
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Optimal diffusion steps = 8
|
| 5 |
+
|
| 6 |
+
T BERT-F1 SEM_SIM BLEU sec/sample
|
| 7 |
+
--------------------------------------------------------
|
| 8 |
+
8 0.1210 0.0400 0.0000 0.6194
|
| 9 |
+
|
| 10 |
+
Marginal gains (BERT-F1):
|
| 11 |
+
|
| 12 |
+
Saved plots/files:
|
| 13 |
+
- task4_3d.png
|
| 14 |
+
- task4_raw_results.json
|
analysis_reports/outputs_all_models_20260325/T8/task5_report.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 5 — CLASSIFIER-FREE GUIDANCE
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Classifier params: 139521
|
| 5 |
+
Training samples : 40
|
| 6 |
+
|
| 7 |
+
Guidance scale sweep:
|
| 8 |
+
λ CER diversity d2 sBLEU
|
| 9 |
+
----------------------------------------------------
|
| 10 |
+
0.0 0.8834 0.796 0.596 0.004 ← optimal
|
| 11 |
+
0.5 0.8881 0.781 0.568 0.005
|
| 12 |
+
1.0 0.8876 0.767 0.540 0.007
|
| 13 |
+
1.5 0.8921 0.757 0.517 0.004
|
| 14 |
+
2.0 0.8929 0.734 0.474 0.006
|
| 15 |
+
3.0 0.8970 0.724 0.453 0.005
|
config.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
CONFIG = {
|
| 4 |
+
"model_type": "d3pm_cross_attention",
|
| 5 |
+
"data": {
|
| 6 |
+
"include_negative_examples": True,
|
| 7 |
+
"dataset_size": 60000,
|
| 8 |
+
},
|
| 9 |
+
"diffusion": {
|
| 10 |
+
"mask_token_id": 0,
|
| 11 |
+
},
|
| 12 |
+
"model": {
|
| 13 |
+
"src_vocab_size": 16000,
|
| 14 |
+
"tgt_vocab_size": 16000,
|
| 15 |
+
"d_model": 384,
|
| 16 |
+
"n_heads": 8,
|
| 17 |
+
"d_ff": 1536,
|
| 18 |
+
"n_layers": 6,
|
| 19 |
+
"dropout": 0.1,
|
| 20 |
+
"max_seq_len": 80,
|
| 21 |
+
"diffusion_steps": 64,
|
| 22 |
+
},
|
| 23 |
+
"training": {
|
| 24 |
+
"device": "cuda" if torch.cuda.is_available() else "cpu",
|
| 25 |
+
},
|
| 26 |
+
"inference": {
|
| 27 |
+
"num_steps": 64,
|
| 28 |
+
"temperature": 0.7,
|
| 29 |
+
"top_k": 40,
|
| 30 |
+
"repetition_penalty": 1.2,
|
| 31 |
+
"diversity_penalty": 0.0,
|
| 32 |
+
},
|
| 33 |
+
}
|
diffusion/__init__.py
ADDED
|
File without changes
|
diffusion/forward_process.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
forward_process.py — Verified Correct (no changes needed)
|
| 3 |
+
===========================================================
|
| 4 |
+
Absorbing (mask) diffusion. PAD never masked. At t=0 alpha=1.0 exactly
|
| 5 |
+
so x_t == x_0 (nothing masked). Works correctly with the fixed scheduler.
|
| 6 |
+
"""
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
class AbsorbingForwardProcess:
|
| 10 |
+
def __init__(self, scheduler, mask_id=0, pad_id=1):
|
| 11 |
+
self.scheduler = scheduler
|
| 12 |
+
self.mask_id = mask_id
|
| 13 |
+
self.pad_id = pad_id
|
| 14 |
+
|
| 15 |
+
def q_sample(self, x_0, t):
|
| 16 |
+
alpha_t = self.scheduler.get_alpha(t).to(x_0.device).view(-1, 1)
|
| 17 |
+
r = torch.rand(x_0.shape, device=x_0.device)
|
| 18 |
+
x_t = x_0.clone()
|
| 19 |
+
x_t[r > alpha_t] = self.mask_id
|
| 20 |
+
x_t[x_0 == self.pad_id] = self.pad_id # PAD stays PAD always
|
| 21 |
+
return x_0, x_t
|
diffusion/reverse_process.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
reverse_process.py — Fixed
|
| 3 |
+
===========================
|
| 4 |
+
Two bugs fixed from the original:
|
| 5 |
+
|
| 6 |
+
BUG 1 (critical): generate_beam() passed x_t (noisy) as `tgt` to model.
|
| 7 |
+
The model does q_sample(tgt, t) internally — so x_t got double-noised.
|
| 8 |
+
Fix: pass x0_estimate (current clean guess) as tgt. Model noises it correctly.
|
| 9 |
+
|
| 10 |
+
BUG 2: apply_diversity_penalty used logits.var(dim=-1) — this adds the
|
| 11 |
+
variance of each position's own distribution back to itself, which is
|
| 12 |
+
mathematically meaningless and just injects noise.
|
| 13 |
+
Fix: penalize tokens that are uniformly high-probability across ALL positions
|
| 14 |
+
(global common tokens). This genuinely promotes diversity.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ReverseDiffusion:
|
| 22 |
+
def __init__(self, scheduler):
|
| 23 |
+
self.scheduler = scheduler
|
| 24 |
+
|
| 25 |
+
def p_sample_step(
|
| 26 |
+
self,
|
| 27 |
+
model,
|
| 28 |
+
x_t,
|
| 29 |
+
t,
|
| 30 |
+
condition,
|
| 31 |
+
beam_width=3,
|
| 32 |
+
temperature=1.0,
|
| 33 |
+
repetition_penalty=1.2,
|
| 34 |
+
diversity_penalty=0.3
|
| 35 |
+
):
|
| 36 |
+
"""
|
| 37 |
+
Single reverse step with temperature + penalties.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
|
| 42 |
+
# ---- Shape safety ----
|
| 43 |
+
if x_t.dim() == 1:
|
| 44 |
+
x_t = x_t.unsqueeze(0)
|
| 45 |
+
|
| 46 |
+
if condition.dim() == 1:
|
| 47 |
+
condition = condition.unsqueeze(0)
|
| 48 |
+
|
| 49 |
+
if t.dim() == 0:
|
| 50 |
+
t = t.unsqueeze(0)
|
| 51 |
+
|
| 52 |
+
if t.shape[0] != x_t.shape[0]:
|
| 53 |
+
t = t.expand(x_t.shape[0])
|
| 54 |
+
|
| 55 |
+
# ---- Model forward ----
|
| 56 |
+
logits, _ = model(condition, x_t, t)
|
| 57 |
+
|
| 58 |
+
# ---- Temperature scaling ----
|
| 59 |
+
logits = logits / temperature
|
| 60 |
+
|
| 61 |
+
# ---- Repetition penalty (FIXED VERSION) ----
|
| 62 |
+
if repetition_penalty != 1.0:
|
| 63 |
+
logits = apply_repetition_penalty(
|
| 64 |
+
logits, x_t, repetition_penalty
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# ---- Diversity penalty ----
|
| 68 |
+
if diversity_penalty > 0:
|
| 69 |
+
logits = apply_diversity_penalty(
|
| 70 |
+
logits, diversity_penalty
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
probs = F.softmax(logits, dim=-1)
|
| 74 |
+
|
| 75 |
+
B, L, V = probs.shape
|
| 76 |
+
|
| 77 |
+
# ---- Top-k beam expansion ----
|
| 78 |
+
topk_probs, topk_ids = torch.topk(
|
| 79 |
+
probs, beam_width, dim=-1
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
candidates = []
|
| 83 |
+
|
| 84 |
+
for k in range(beam_width):
|
| 85 |
+
next_tokens = topk_ids[:, :, k]
|
| 86 |
+
score = torch.log(
|
| 87 |
+
topk_probs[:, :, k] + 1e-9
|
| 88 |
+
).sum()
|
| 89 |
+
candidates.append((next_tokens, score))
|
| 90 |
+
|
| 91 |
+
return candidates
|
| 92 |
+
|
| 93 |
+
def generate_beam(
|
| 94 |
+
self,
|
| 95 |
+
model,
|
| 96 |
+
condition,
|
| 97 |
+
beam_width=3,
|
| 98 |
+
num_steps=None,
|
| 99 |
+
temperature=1.0,
|
| 100 |
+
repetition_penalty=1.2,
|
| 101 |
+
diversity_penalty=0.3
|
| 102 |
+
):
|
| 103 |
+
"""
|
| 104 |
+
Beam-search reverse diffusion with temperature.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
if num_steps is None:
|
| 108 |
+
num_steps = self.scheduler.num_timesteps
|
| 109 |
+
|
| 110 |
+
device = condition.device
|
| 111 |
+
|
| 112 |
+
if condition.dim() == 1:
|
| 113 |
+
condition = condition.unsqueeze(0)
|
| 114 |
+
|
| 115 |
+
B, L = condition.shape
|
| 116 |
+
|
| 117 |
+
# 🔥 Better initialization: start from MASK
|
| 118 |
+
x_init = torch.full(
|
| 119 |
+
(B, L),
|
| 120 |
+
fill_value=model.mask_token_id,
|
| 121 |
+
dtype=torch.long,
|
| 122 |
+
device=device
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
beams = [(x_init, 0.0)]
|
| 126 |
+
|
| 127 |
+
for step in reversed(range(num_steps)):
|
| 128 |
+
|
| 129 |
+
new_beams = []
|
| 130 |
+
|
| 131 |
+
for x_t, score in beams:
|
| 132 |
+
|
| 133 |
+
t_tensor = torch.full(
|
| 134 |
+
(B,),
|
| 135 |
+
step,
|
| 136 |
+
dtype=torch.long,
|
| 137 |
+
device=device
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
candidates = self.p_sample_step(
|
| 141 |
+
model,
|
| 142 |
+
x_t,
|
| 143 |
+
t_tensor,
|
| 144 |
+
condition,
|
| 145 |
+
beam_width,
|
| 146 |
+
temperature,
|
| 147 |
+
repetition_penalty,
|
| 148 |
+
diversity_penalty
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
for tokens, new_score in candidates:
|
| 152 |
+
new_beams.append(
|
| 153 |
+
(tokens, score + new_score)
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# ---- Keep top beams ----
|
| 157 |
+
new_beams = sorted(
|
| 158 |
+
new_beams,
|
| 159 |
+
key=lambda x: x[1],
|
| 160 |
+
reverse=True
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
beams = new_beams[:beam_width]
|
| 164 |
+
|
| 165 |
+
best_tokens, best_score = beams[0]
|
| 166 |
+
return best_tokens
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def generate(
|
| 171 |
+
self,
|
| 172 |
+
model,
|
| 173 |
+
condition,
|
| 174 |
+
num_steps=None,
|
| 175 |
+
temperature=0.8,
|
| 176 |
+
top_k=50,
|
| 177 |
+
repetition_penalty=1.2,
|
| 178 |
+
diversity_penalty=0.0,
|
| 179 |
+
):
|
| 180 |
+
"""
|
| 181 |
+
Correct D3PM iterative refinement.
|
| 182 |
+
|
| 183 |
+
x0_est starts as all [MASK].
|
| 184 |
+
Each step: forward(src=condition, tgt=x0_est, t)
|
| 185 |
+
→ model applies q_sample(x0_est, t) internally
|
| 186 |
+
→ predicts cleaner x0
|
| 187 |
+
→ x0_est updated
|
| 188 |
+
|
| 189 |
+
diversity_penalty: reduces probability of tokens that are
|
| 190 |
+
globally dominant across all sequence positions (not logits.var()).
|
| 191 |
+
"""
|
| 192 |
+
if num_steps is None:
|
| 193 |
+
num_steps = self.scheduler.num_timesteps
|
| 194 |
+
|
| 195 |
+
device = condition.device
|
| 196 |
+
if condition.dim() == 1:
|
| 197 |
+
condition = condition.unsqueeze(0)
|
| 198 |
+
B, L = condition.shape
|
| 199 |
+
|
| 200 |
+
T = self.scheduler.num_timesteps
|
| 201 |
+
step_size = max(1, T // num_steps)
|
| 202 |
+
timesteps = list(range(T - 1, -1, -step_size))
|
| 203 |
+
if timesteps[-1] != 0:
|
| 204 |
+
timesteps.append(0)
|
| 205 |
+
|
| 206 |
+
mask_id = model.mask_token_id
|
| 207 |
+
# Start: know nothing → all MASK is our initial clean estimate
|
| 208 |
+
x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
|
| 209 |
+
hint = None
|
| 210 |
+
|
| 211 |
+
model.eval()
|
| 212 |
+
with torch.no_grad():
|
| 213 |
+
for step_idx, t_val in enumerate(timesteps):
|
| 214 |
+
t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 215 |
+
is_last = (step_idx == len(timesteps) - 1)
|
| 216 |
+
|
| 217 |
+
# KEY: pass x0_est as tgt — model noises it internally
|
| 218 |
+
import inspect
|
| 219 |
+
sig = inspect.signature(model.forward).parameters
|
| 220 |
+
if 'x0_hint' in sig:
|
| 221 |
+
outputs = model(condition, x0_est, t, x0_hint=hint)
|
| 222 |
+
else:
|
| 223 |
+
outputs = model(condition, x0_est, t)
|
| 224 |
+
|
| 225 |
+
logits = outputs[0] if isinstance(outputs, tuple) else outputs
|
| 226 |
+
|
| 227 |
+
# Repetition penalty: down-weight tokens already in sequence
|
| 228 |
+
if repetition_penalty != 1.0:
|
| 229 |
+
logits = apply_repetition_penalty(logits, x0_est, repetition_penalty)
|
| 230 |
+
|
| 231 |
+
# Diversity penalty: reduce globally dominant tokens
|
| 232 |
+
if diversity_penalty > 0.0:
|
| 233 |
+
logits = apply_diversity_penalty(logits, diversity_penalty)
|
| 234 |
+
|
| 235 |
+
# Temperature + top-k
|
| 236 |
+
logits = logits / max(temperature, 1e-5)
|
| 237 |
+
if top_k > 0:
|
| 238 |
+
logits = top_k_filter(logits, top_k)
|
| 239 |
+
|
| 240 |
+
probs = F.softmax(logits, dim=-1)
|
| 241 |
+
|
| 242 |
+
if is_last:
|
| 243 |
+
x0_est = torch.argmax(probs, dim=-1)
|
| 244 |
+
else:
|
| 245 |
+
x0_est = batch_multinomial(probs)
|
| 246 |
+
|
| 247 |
+
hint = x0_est
|
| 248 |
+
|
| 249 |
+
return x0_est
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
# ── Penalty functions ─────────────────────────────────────────────────
|
| 253 |
+
|
| 254 |
+
def apply_repetition_penalty(logits, prev_tokens, penalty=1.2):
|
| 255 |
+
"""
|
| 256 |
+
Down-weight tokens that already appear in the current sequence.
|
| 257 |
+
Prevents मनो मनो मनो repetition loops.
|
| 258 |
+
penalty=1.0 → no effect
|
| 259 |
+
penalty=1.2 → mild suppression of repeated tokens
|
| 260 |
+
penalty=2.0 → strong suppression
|
| 261 |
+
"""
|
| 262 |
+
B, L, V = logits.shape
|
| 263 |
+
for b in range(B):
|
| 264 |
+
for token_id in set(prev_tokens[b].tolist()):
|
| 265 |
+
if token_id > 4: # don't penalize special tokens
|
| 266 |
+
logits[b, :, token_id] = logits[b, :, token_id] / penalty
|
| 267 |
+
return logits
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def apply_diversity_penalty(logits, penalty=0.5):
|
| 271 |
+
"""
|
| 272 |
+
Correct diversity penalty: penalize tokens that are globally dominant
|
| 273 |
+
across ALL sequence positions. This forces the model to use less
|
| 274 |
+
common tokens, increasing output diversity.
|
| 275 |
+
|
| 276 |
+
Method: compute mean probability across positions, subtract penalty
|
| 277 |
+
times that mean. Tokens uniformly high everywhere get suppressed.
|
| 278 |
+
|
| 279 |
+
penalty=0.0 → no diversity enforcement
|
| 280 |
+
penalty=0.5 → moderate diversity
|
| 281 |
+
penalty=1.0 → strong diversity (may hurt coherence)
|
| 282 |
+
"""
|
| 283 |
+
# Mean logit across all positions: [B, V]
|
| 284 |
+
global_mean = logits.mean(dim=1, keepdim=True) # [B, 1, V]
|
| 285 |
+
# Subtract scaled global mean — suppresses globally common tokens
|
| 286 |
+
return logits - penalty * global_mean
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def top_k_filter(logits, k):
|
| 290 |
+
B, L, V = logits.shape
|
| 291 |
+
if k >= V:
|
| 292 |
+
return logits
|
| 293 |
+
topk_vals, _ = torch.topk(logits, k, dim=-1)
|
| 294 |
+
threshold = topk_vals[..., -1].unsqueeze(-1)
|
| 295 |
+
return logits.masked_fill(logits < threshold, float('-inf'))
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def batch_multinomial(probs):
|
| 299 |
+
B, L, V = probs.shape
|
| 300 |
+
flat = probs.view(B * L, V) + 1e-9
|
| 301 |
+
flat = flat / flat.sum(dim=-1, keepdim=True)
|
| 302 |
+
return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
|
diffusion/reverse_process1.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ReverseDiffusion:
|
| 6 |
+
"""
|
| 7 |
+
Stable reverse diffusion with:
|
| 8 |
+
- Beam search
|
| 9 |
+
- Self conditioning
|
| 10 |
+
- Temperature sampling
|
| 11 |
+
- Repetition penalty
|
| 12 |
+
- Diversity penalty
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, scheduler):
|
| 16 |
+
|
| 17 |
+
self.scheduler = scheduler
|
| 18 |
+
|
| 19 |
+
self.temperature = 0.75
|
| 20 |
+
self.repetition_penalty = 1.15
|
| 21 |
+
self.diversity_penalty = 0.0
|
| 22 |
+
self.length_penalty = 1.0
|
| 23 |
+
|
| 24 |
+
# ------------------------------------------------
|
| 25 |
+
# penalties
|
| 26 |
+
# ------------------------------------------------
|
| 27 |
+
|
| 28 |
+
def apply_repetition_penalty(self, logits, tokens):
|
| 29 |
+
|
| 30 |
+
B, L, V = logits.shape
|
| 31 |
+
|
| 32 |
+
for b in range(B):
|
| 33 |
+
|
| 34 |
+
used = set(tokens[b].tolist())
|
| 35 |
+
|
| 36 |
+
for token_id in used:
|
| 37 |
+
logits[b, :, token_id] /= self.repetition_penalty
|
| 38 |
+
|
| 39 |
+
return logits
|
| 40 |
+
|
| 41 |
+
def apply_diversity_penalty(self, logits):
|
| 42 |
+
|
| 43 |
+
if self.diversity_penalty == 0:
|
| 44 |
+
return logits
|
| 45 |
+
|
| 46 |
+
logits_var = logits.var(dim=-1, keepdim=True)
|
| 47 |
+
return logits + self.diversity_penalty * logits_var
|
| 48 |
+
|
| 49 |
+
# ------------------------------------------------
|
| 50 |
+
# single reverse step
|
| 51 |
+
# ------------------------------------------------
|
| 52 |
+
|
| 53 |
+
def p_sample_step(self, model, x_t, t, condition, self_cond=None, beam_width=3):
|
| 54 |
+
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
|
| 57 |
+
logits, hidden = model(condition, x_t, t, self_cond)
|
| 58 |
+
|
| 59 |
+
logits = logits / self.temperature
|
| 60 |
+
|
| 61 |
+
logits = self.apply_repetition_penalty(logits, x_t)
|
| 62 |
+
logits = self.apply_diversity_penalty(logits)
|
| 63 |
+
|
| 64 |
+
probs = F.softmax(logits, dim=-1)
|
| 65 |
+
|
| 66 |
+
B, L, V = probs.shape
|
| 67 |
+
|
| 68 |
+
topk_probs, topk_ids = torch.topk(probs, beam_width, dim=-1)
|
| 69 |
+
|
| 70 |
+
candidates = []
|
| 71 |
+
|
| 72 |
+
for k in range(beam_width):
|
| 73 |
+
|
| 74 |
+
tokens = topk_ids[:, :, k]
|
| 75 |
+
|
| 76 |
+
score = torch.log(topk_probs[:, :, k] + 1e-9).sum()
|
| 77 |
+
|
| 78 |
+
candidates.append((tokens, score))
|
| 79 |
+
|
| 80 |
+
return candidates
|
| 81 |
+
|
| 82 |
+
# ------------------------------------------------
|
| 83 |
+
# beam reverse diffusion
|
| 84 |
+
# ------------------------------------------------
|
| 85 |
+
|
| 86 |
+
def generate_beam(self, model, condition, beam_width=3, num_steps=None):
|
| 87 |
+
|
| 88 |
+
if num_steps is None:
|
| 89 |
+
num_steps = self.scheduler.num_timesteps
|
| 90 |
+
|
| 91 |
+
device = condition.device
|
| 92 |
+
|
| 93 |
+
if condition.dim() == 1:
|
| 94 |
+
condition = condition.unsqueeze(0)
|
| 95 |
+
|
| 96 |
+
B, L = condition.shape
|
| 97 |
+
|
| 98 |
+
# ------------------------------------------------
|
| 99 |
+
# BETTER LATENT INITIALIZATION
|
| 100 |
+
# ------------------------------------------------
|
| 101 |
+
|
| 102 |
+
x_init = condition.clone()
|
| 103 |
+
|
| 104 |
+
mask = torch.rand_like(x_init.float()) < 0.5
|
| 105 |
+
x_init[mask] = model.mask_token_id
|
| 106 |
+
|
| 107 |
+
beams = [(x_init, 0.0)]
|
| 108 |
+
|
| 109 |
+
self_cond = None
|
| 110 |
+
|
| 111 |
+
for step in reversed(range(num_steps)):
|
| 112 |
+
|
| 113 |
+
new_beams = []
|
| 114 |
+
|
| 115 |
+
for x_t, score in beams:
|
| 116 |
+
|
| 117 |
+
t_tensor = torch.full(
|
| 118 |
+
(B,),
|
| 119 |
+
step,
|
| 120 |
+
dtype=torch.long,
|
| 121 |
+
device=device
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
candidates = self.p_sample_step(
|
| 125 |
+
model,
|
| 126 |
+
x_t,
|
| 127 |
+
t_tensor,
|
| 128 |
+
condition,
|
| 129 |
+
self_cond,
|
| 130 |
+
beam_width
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
for tokens, new_score in candidates:
|
| 134 |
+
|
| 135 |
+
length_norm = tokens.shape[1] ** self.length_penalty
|
| 136 |
+
|
| 137 |
+
final_score = (score + new_score) / length_norm
|
| 138 |
+
|
| 139 |
+
new_beams.append((tokens, final_score))
|
| 140 |
+
|
| 141 |
+
new_beams = sorted(
|
| 142 |
+
new_beams,
|
| 143 |
+
key=lambda x: x[1],
|
| 144 |
+
reverse=True
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
beams = new_beams[:beam_width]
|
| 148 |
+
|
| 149 |
+
# self conditioning
|
| 150 |
+
self_cond = beams[0][0]
|
| 151 |
+
|
| 152 |
+
best_tokens, best_score = beams[0]
|
| 153 |
+
|
| 154 |
+
return best_tokens
|
diffusion/reverse_process2.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
reverse_process.py — Final Correct Version
|
| 3 |
+
=============================================
|
| 4 |
+
|
| 5 |
+
KEY PRINCIPLE: generate() must be byte-for-byte identical to run_inference()
|
| 6 |
+
in inference.py, which is what produced BERTScore 0.75 at validation.
|
| 7 |
+
|
| 8 |
+
CRITICAL BUG IN PREVIOUS VERSION:
|
| 9 |
+
We passed inference_mode=True to the model, but the model was NEVER
|
| 10 |
+
called with inference_mode=True during training or validation.
|
| 11 |
+
run_inference() (the validated path) does:
|
| 12 |
+
model(input_ids, x0_est, t, x0_hint=hint)
|
| 13 |
+
→ inference_mode defaults to False.
|
| 14 |
+
|
| 15 |
+
With inference_mode=True the model does two things differently:
|
| 16 |
+
1. tgt_pad_mask = None (training used tgt_pad_mask = tgt==PAD)
|
| 17 |
+
2. Skips q_sample at t=0 (training always called q_sample)
|
| 18 |
+
The model was never trained to handle these conditions → garbage output.
|
| 19 |
+
|
| 20 |
+
Fix: do NOT pass inference_mode. Let it default to False, exactly
|
| 21 |
+
as run_inference() did.
|
| 22 |
+
|
| 23 |
+
BUGS FIXED (vs original reverse_process.py)
|
| 24 |
+
--------------------------------------------
|
| 25 |
+
BUG 1 generate_beam() used for D3PM → all-Ṛ repetition.
|
| 26 |
+
Use generate() (iterative refinement) from app1.py instead.
|
| 27 |
+
BUG 2 apply_diversity_penalty used logits.var() → noise injection.
|
| 28 |
+
Fixed to logits - penalty * logits.mean(dim=1) — global suppression.
|
| 29 |
+
BUG 3 x0_hint (self-conditioning) never passed to model.
|
| 30 |
+
Fixed: generate() passes x0_hint=hint every step.
|
| 31 |
+
BUG 4 params not forwarded from generate_beam() to p_sample_step().
|
| 32 |
+
Fixed in generate_beam() (kept for reference, not for production use).
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
import torch
|
| 36 |
+
import torch.nn.functional as F
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ReverseDiffusion:
|
| 40 |
+
|
| 41 |
+
def __init__(self, scheduler):
|
| 42 |
+
self.scheduler = scheduler
|
| 43 |
+
|
| 44 |
+
# Attribute-style defaults for backward compat with any code
|
| 45 |
+
# that sets reverse_diffusion.temperature = 0.9 etc.
|
| 46 |
+
# generate() prefers explicit kwargs and falls back to these.
|
| 47 |
+
self.temperature = 0.75
|
| 48 |
+
self.repetition_penalty = 1.15
|
| 49 |
+
self.diversity_penalty = 0.0
|
| 50 |
+
self.top_k = 50
|
| 51 |
+
|
| 52 |
+
# ------------------------------------------------------------------ #
|
| 53 |
+
# generate — CORRECT D3PM iterative refinement #
|
| 54 |
+
# Exact equivalent of run_inference() in inference.py #
|
| 55 |
+
# ------------------------------------------------------------------ #
|
| 56 |
+
def generate(
|
| 57 |
+
self,
|
| 58 |
+
model,
|
| 59 |
+
condition,
|
| 60 |
+
num_steps = None,
|
| 61 |
+
temperature = None,
|
| 62 |
+
top_k = None,
|
| 63 |
+
repetition_penalty = None,
|
| 64 |
+
diversity_penalty = None,
|
| 65 |
+
):
|
| 66 |
+
"""
|
| 67 |
+
D3PM iterative refinement — identical to run_inference() in inference.py,
|
| 68 |
+
which is the validated path (BERTScore 0.75).
|
| 69 |
+
|
| 70 |
+
Algorithm:
|
| 71 |
+
x0_est = all [MASK]
|
| 72 |
+
for t = T-1 down to 0:
|
| 73 |
+
logits = model(src, x0_est, t, x0_hint=hint)
|
| 74 |
+
↑ inference_mode NOT passed (defaults to False)
|
| 75 |
+
↑ this exactly matches training/validation
|
| 76 |
+
apply penalties, temperature, top_k
|
| 77 |
+
if t > 0: x0_est = multinomial(softmax(logits)) ← stochastic
|
| 78 |
+
if t = 0: x0_est = argmax(softmax(logits)) ← deterministic
|
| 79 |
+
hint = x0_est
|
| 80 |
+
"""
|
| 81 |
+
# Resolve: explicit kwarg > object attribute
|
| 82 |
+
temperature = temperature if temperature is not None else self.temperature
|
| 83 |
+
top_k = top_k if top_k is not None else self.top_k
|
| 84 |
+
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.repetition_penalty
|
| 85 |
+
diversity_penalty = diversity_penalty if diversity_penalty is not None else self.diversity_penalty
|
| 86 |
+
|
| 87 |
+
if num_steps is None:
|
| 88 |
+
num_steps = self.scheduler.num_timesteps
|
| 89 |
+
|
| 90 |
+
device = condition.device
|
| 91 |
+
if condition.dim() == 1:
|
| 92 |
+
condition = condition.unsqueeze(0)
|
| 93 |
+
B, L = condition.shape
|
| 94 |
+
|
| 95 |
+
T = self.scheduler.num_timesteps
|
| 96 |
+
step_size = max(1, T // num_steps)
|
| 97 |
+
timesteps = list(range(T - 1, -1, -step_size))
|
| 98 |
+
if timesteps[-1] != 0:
|
| 99 |
+
timesteps.append(0)
|
| 100 |
+
|
| 101 |
+
mask_id = model.mask_token_id
|
| 102 |
+
x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
|
| 103 |
+
hint = None
|
| 104 |
+
|
| 105 |
+
model.eval()
|
| 106 |
+
with torch.no_grad():
|
| 107 |
+
for step_idx, t_val in enumerate(timesteps):
|
| 108 |
+
t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 109 |
+
is_last = (step_idx == len(timesteps) - 1)
|
| 110 |
+
|
| 111 |
+
# ── CRITICAL: do NOT pass inference_mode ──────────────────
|
| 112 |
+
# inference_mode defaults to False inside SanskritModel /
|
| 113 |
+
# D3PMCrossAttention. This matches run_inference() exactly.
|
| 114 |
+
# Passing inference_mode=True changes tgt_pad_mask and
|
| 115 |
+
# q_sample behaviour — the model was never trained for that.
|
| 116 |
+
logits, _ = model(condition, x0_est, t, x0_hint=hint)
|
| 117 |
+
|
| 118 |
+
# Repetition penalty
|
| 119 |
+
if repetition_penalty != 1.0:
|
| 120 |
+
logits = apply_repetition_penalty(
|
| 121 |
+
logits, x0_est, repetition_penalty
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Diversity penalty (correct: global mean suppression)
|
| 125 |
+
if diversity_penalty > 0.0:
|
| 126 |
+
logits = apply_diversity_penalty(logits, diversity_penalty)
|
| 127 |
+
|
| 128 |
+
logits = logits / max(temperature, 1e-5)
|
| 129 |
+
|
| 130 |
+
if top_k > 0:
|
| 131 |
+
logits = top_k_filter(logits, top_k)
|
| 132 |
+
|
| 133 |
+
probs = F.softmax(logits, dim=-1)
|
| 134 |
+
|
| 135 |
+
# Stochastic at every step except the last (argmax at t=0)
|
| 136 |
+
if is_last:
|
| 137 |
+
x0_est = torch.argmax(probs, dim=-1)
|
| 138 |
+
else:
|
| 139 |
+
x0_est = batch_multinomial(probs)
|
| 140 |
+
|
| 141 |
+
hint = x0_est
|
| 142 |
+
|
| 143 |
+
return x0_est # (B, L)
|
| 144 |
+
|
| 145 |
+
# ------------------------------------------------------------------ #
|
| 146 |
+
# p_sample_step — used by generate_beam (not for production) #
|
| 147 |
+
# ------------------------------------------------------------------ #
|
| 148 |
+
def p_sample_step(
|
| 149 |
+
self,
|
| 150 |
+
model,
|
| 151 |
+
x_t,
|
| 152 |
+
t,
|
| 153 |
+
condition,
|
| 154 |
+
beam_width = 3,
|
| 155 |
+
temperature = 1.0,
|
| 156 |
+
repetition_penalty = 1.2,
|
| 157 |
+
diversity_penalty = 0.3,
|
| 158 |
+
x0_hint = None,
|
| 159 |
+
):
|
| 160 |
+
with torch.no_grad():
|
| 161 |
+
if x_t.dim() == 1: x_t = x_t.unsqueeze(0)
|
| 162 |
+
if condition.dim() == 1: condition = condition.unsqueeze(0)
|
| 163 |
+
if t.dim() == 0: t = t.unsqueeze(0)
|
| 164 |
+
if t.shape[0] != x_t.shape[0]:
|
| 165 |
+
t = t.expand(x_t.shape[0])
|
| 166 |
+
|
| 167 |
+
# No inference_mode — matches training convention
|
| 168 |
+
logits, _ = model(condition, x_t, t, x0_hint=x0_hint)
|
| 169 |
+
|
| 170 |
+
logits = logits / max(temperature, 1e-5)
|
| 171 |
+
|
| 172 |
+
if repetition_penalty != 1.0:
|
| 173 |
+
logits = apply_repetition_penalty(logits, x_t, repetition_penalty)
|
| 174 |
+
if diversity_penalty > 0.0:
|
| 175 |
+
logits = apply_diversity_penalty(logits, diversity_penalty)
|
| 176 |
+
|
| 177 |
+
probs = F.softmax(logits, dim=-1)
|
| 178 |
+
B, L, V = probs.shape
|
| 179 |
+
|
| 180 |
+
topk_probs, topk_ids = torch.topk(probs, beam_width, dim=-1)
|
| 181 |
+
candidates = []
|
| 182 |
+
for k in range(beam_width):
|
| 183 |
+
next_tokens = topk_ids[:, :, k]
|
| 184 |
+
score = torch.log(topk_probs[:, :, k] + 1e-9).sum()
|
| 185 |
+
candidates.append((next_tokens, score))
|
| 186 |
+
return candidates
|
| 187 |
+
|
| 188 |
+
# ------------------------------------------------------------------ #
|
| 189 |
+
# generate_beam — kept for reference; NOT the correct D3PM method #
|
| 190 |
+
# ------------------------------------------------------------------ #
|
| 191 |
+
def generate_beam(
|
| 192 |
+
self,
|
| 193 |
+
model,
|
| 194 |
+
condition,
|
| 195 |
+
beam_width = 3,
|
| 196 |
+
num_steps = None,
|
| 197 |
+
temperature = None,
|
| 198 |
+
repetition_penalty = None,
|
| 199 |
+
diversity_penalty = None,
|
| 200 |
+
):
|
| 201 |
+
"""
|
| 202 |
+
WARNING: do NOT call this from app1.py for D3PM generation.
|
| 203 |
+
generate_beam() forces every position to the same top-k token
|
| 204 |
+
→ all-Ṛ / all-rud repetition. Use generate() instead.
|
| 205 |
+
Kept only for experimental reference.
|
| 206 |
+
"""
|
| 207 |
+
temperature = temperature if temperature is not None else self.temperature
|
| 208 |
+
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.repetition_penalty
|
| 209 |
+
diversity_penalty = diversity_penalty if diversity_penalty is not None else self.diversity_penalty
|
| 210 |
+
if num_steps is None:
|
| 211 |
+
num_steps = self.scheduler.num_timesteps
|
| 212 |
+
|
| 213 |
+
device = condition.device
|
| 214 |
+
if condition.dim() == 1: condition = condition.unsqueeze(0)
|
| 215 |
+
B, L = condition.shape
|
| 216 |
+
|
| 217 |
+
x_init = torch.full((B, L), fill_value=model.mask_token_id,
|
| 218 |
+
dtype=torch.long, device=device)
|
| 219 |
+
beams = [(x_init, 0.0)]
|
| 220 |
+
best_hint = None
|
| 221 |
+
|
| 222 |
+
for step in reversed(range(num_steps)):
|
| 223 |
+
t_tensor = torch.full((B,), step, dtype=torch.long, device=device)
|
| 224 |
+
new_beams = []
|
| 225 |
+
for x_t, score in beams:
|
| 226 |
+
candidates = self.p_sample_step(
|
| 227 |
+
model, x_t, t_tensor, condition,
|
| 228 |
+
beam_width = beam_width,
|
| 229 |
+
temperature = temperature,
|
| 230 |
+
repetition_penalty = repetition_penalty,
|
| 231 |
+
diversity_penalty = diversity_penalty,
|
| 232 |
+
x0_hint = best_hint,
|
| 233 |
+
)
|
| 234 |
+
for tokens, new_score in candidates:
|
| 235 |
+
new_beams.append((tokens, score + new_score.item()))
|
| 236 |
+
|
| 237 |
+
new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True)
|
| 238 |
+
beams = new_beams[:beam_width]
|
| 239 |
+
best_hint = beams[0][0]
|
| 240 |
+
|
| 241 |
+
return beams[0][0] # (B, L)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
# ── Penalty helpers ────────────────────────────────────────────────────────
|
| 245 |
+
|
| 246 |
+
def apply_repetition_penalty(logits, prev_tokens, penalty=1.2):
|
| 247 |
+
"""Down-weight tokens already present in the sequence."""
|
| 248 |
+
for b in range(logits.shape[0]):
|
| 249 |
+
for token_id in set(prev_tokens[b].tolist()):
|
| 250 |
+
if token_id > 4:
|
| 251 |
+
logits[b, :, token_id] = logits[b, :, token_id] / penalty
|
| 252 |
+
return logits
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def apply_diversity_penalty(logits, penalty=0.3):
|
| 256 |
+
"""
|
| 257 |
+
Correct diversity penalty: suppress globally dominant tokens.
|
| 258 |
+
logits -= penalty * mean(logits, dim=1) [sequence dimension]
|
| 259 |
+
"""
|
| 260 |
+
global_mean = logits.mean(dim=1, keepdim=True) # [B, 1, V]
|
| 261 |
+
return logits - penalty * global_mean
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def top_k_filter(logits, k):
|
| 265 |
+
B, L, V = logits.shape
|
| 266 |
+
if k >= V: return logits
|
| 267 |
+
topk_vals, _ = torch.topk(logits, k, dim=-1)
|
| 268 |
+
return logits.masked_fill(logits < topk_vals[..., -1].unsqueeze(-1), float('-inf'))
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def batch_multinomial(probs):
|
| 272 |
+
B, L, V = probs.shape
|
| 273 |
+
flat = probs.view(B * L, V) + 1e-9
|
| 274 |
+
flat = flat / flat.sum(dim=-1, keepdim=True)
|
| 275 |
+
return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
|
diffusion/scheduler.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
scheduler.py — Fixed & Upgraded
|
| 3 |
+
==================================
|
| 4 |
+
Changes:
|
| 5 |
+
1. T=64 (was 16). More timesteps = richer denoising curriculum per epoch.
|
| 6 |
+
2. alpha at t=0 is EXACTLY 1.0 — fixes Bug 2 (final-step re-noise).
|
| 7 |
+
3. sample_timestep samples [0, T-1] including t=0, so model trains on
|
| 8 |
+
fully-clean inputs (learns the identity at t=0 explicitly).
|
| 9 |
+
"""
|
| 10 |
+
import torch, math
|
| 11 |
+
|
| 12 |
+
class OptimizedCosineScheduler:
|
| 13 |
+
def __init__(self, cfg, device=None):
|
| 14 |
+
self.num_timesteps = cfg['model']['diffusion_steps'] # 64
|
| 15 |
+
self.mask_token_id = cfg['diffusion']['mask_token_id']
|
| 16 |
+
self.device = device or torch.device('cpu')
|
| 17 |
+
self.alphas_cumprod = self._build_schedule().to(self.device)
|
| 18 |
+
|
| 19 |
+
def _build_schedule(self):
|
| 20 |
+
T = self.num_timesteps
|
| 21 |
+
t = torch.arange(T + 1, dtype=torch.float32)
|
| 22 |
+
f_t = torch.cos((t / T + 0.008) / 1.008 * math.pi / 2) ** 2
|
| 23 |
+
alphas_bar = f_t / f_t[0]
|
| 24 |
+
alphas_bar = alphas_bar[1:] # shape [T]
|
| 25 |
+
alphas_bar[0] = 1.0 # FIX: exact 1.0 at t=0
|
| 26 |
+
alphas_bar[-1] = alphas_bar[-1].clamp(max=0.001)
|
| 27 |
+
return alphas_bar
|
| 28 |
+
|
| 29 |
+
def sample_timestep(self, batch_size):
|
| 30 |
+
"""Uniform [0, T-1] — includes t=0 so model sees clean inputs."""
|
| 31 |
+
return torch.randint(0, self.num_timesteps, (batch_size,))
|
| 32 |
+
|
| 33 |
+
def get_alpha(self, t):
|
| 34 |
+
return self.alphas_cumprod[t.to(self.alphas_cumprod.device).long()]
|
handler.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict
|
| 2 |
+
|
| 3 |
+
from inference_api import predict
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class EndpointHandler:
|
| 7 |
+
"""
|
| 8 |
+
Hugging Face Inference Endpoint handler.
|
| 9 |
+
Expects payload:
|
| 10 |
+
{
|
| 11 |
+
"inputs": "dharmo rakṣati rakṣitaḥ",
|
| 12 |
+
"parameters": {"temperature": 0.7, ...}
|
| 13 |
+
}
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, path: str = ""):
|
| 17 |
+
self.path = path
|
| 18 |
+
|
| 19 |
+
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| 20 |
+
text = data.get("inputs", "")
|
| 21 |
+
params = data.get("parameters", {}) or {}
|
| 22 |
+
return predict(
|
| 23 |
+
text=text,
|
| 24 |
+
temperature=params.get("temperature", 0.7),
|
| 25 |
+
top_k=params.get("top_k", 40),
|
| 26 |
+
repetition_penalty=params.get("repetition_penalty", 1.2),
|
| 27 |
+
diversity_penalty=params.get("diversity_penalty", 0.0),
|
| 28 |
+
num_steps=params.get("num_steps", 64),
|
| 29 |
+
clean_output=params.get("clean_output", True),
|
| 30 |
+
)
|
inference.py
ADDED
|
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
inference.py
|
| 3 |
+
============
|
| 4 |
+
Correct D3PM inference for Sanskrit paraphrase generation.
|
| 5 |
+
|
| 6 |
+
The model's forward() takes CLEAN tgt and noises it internally.
|
| 7 |
+
So inference passes x0_estimate (starting all-[MASK]) as tgt each step,
|
| 8 |
+
letting the model noise it and then predict a cleaner version.
|
| 9 |
+
|
| 10 |
+
Also includes: robust checkpoint loading (auto-detects architecture
|
| 11 |
+
from saved weights — no CONFIG mismatch crashes).
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import json
|
| 15 |
+
import torch
|
| 16 |
+
import os, sys
|
| 17 |
+
import re
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
from torch.utils.data import DataLoader, Subset
|
| 20 |
+
|
| 21 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 22 |
+
from config import CONFIG
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ── Checkpoint loader ─────────────────────────────────────────────────
|
| 26 |
+
|
| 27 |
+
def _resolve_device(cfg_device: str) -> torch.device:
|
| 28 |
+
cfg_device = (cfg_device or "").lower()
|
| 29 |
+
if cfg_device == "cuda" and torch.cuda.is_available():
|
| 30 |
+
return torch.device("cuda")
|
| 31 |
+
if cfg_device == "mps" and torch.backends.mps.is_available():
|
| 32 |
+
return torch.device("mps")
|
| 33 |
+
if cfg_device in {"cpu", "cuda", "mps"}:
|
| 34 |
+
return torch.device("cpu")
|
| 35 |
+
if torch.cuda.is_available():
|
| 36 |
+
return torch.device("cuda")
|
| 37 |
+
if torch.backends.mps.is_available():
|
| 38 |
+
return torch.device("mps")
|
| 39 |
+
return torch.device("cpu")
|
| 40 |
+
|
| 41 |
+
def load_model(ckpt_path: str, base_cfg: dict, device: torch.device):
|
| 42 |
+
"""
|
| 43 |
+
Auto-detect architecture from checkpoint weight shapes,
|
| 44 |
+
then load. Never fails due to CONFIG vs checkpoint mismatch.
|
| 45 |
+
"""
|
| 46 |
+
import copy
|
| 47 |
+
from model.sanskrit_model import SanskritModel
|
| 48 |
+
|
| 49 |
+
cfg = copy.deepcopy(base_cfg)
|
| 50 |
+
state = torch.load(ckpt_path, map_location='cpu')
|
| 51 |
+
|
| 52 |
+
# d_model + vocab_size
|
| 53 |
+
ek = 'model.src_embed.token_emb.weight'
|
| 54 |
+
if ek in state:
|
| 55 |
+
vocab, d = state[ek].shape
|
| 56 |
+
cfg['model']['vocab_size'] = vocab
|
| 57 |
+
cfg['model']['d_model'] = d
|
| 58 |
+
cfg['model']['d_ff'] = d * 4
|
| 59 |
+
|
| 60 |
+
# n_layers
|
| 61 |
+
ids = {int(k.split('.')[2]) for k in state if k.startswith('model.encoder_blocks.')}
|
| 62 |
+
if ids:
|
| 63 |
+
cfg['model']['n_layers'] = max(ids) + 1
|
| 64 |
+
|
| 65 |
+
# max_seq_len
|
| 66 |
+
pk = 'model.src_embed.pos_enc.pe'
|
| 67 |
+
if pk in state:
|
| 68 |
+
cfg['model']['max_seq_len'] = state[pk].shape[1]
|
| 69 |
+
|
| 70 |
+
# n_heads
|
| 71 |
+
d = cfg['model']['d_model']
|
| 72 |
+
h = cfg['model'].get('n_heads', 6)
|
| 73 |
+
if d % h != 0:
|
| 74 |
+
h = next(x for x in [8, 6, 4, 2, 1] if d % x == 0)
|
| 75 |
+
cfg['model']['n_heads'] = h
|
| 76 |
+
|
| 77 |
+
print(f"🔍 Detected: d_model={cfg['model']['d_model']}, "
|
| 78 |
+
f"n_layers={cfg['model']['n_layers']}, "
|
| 79 |
+
f"max_seq_len={cfg['model']['max_seq_len']}, "
|
| 80 |
+
f"n_heads={cfg['model']['n_heads']}")
|
| 81 |
+
|
| 82 |
+
model = SanskritModel(cfg).to(device)
|
| 83 |
+
raw_state = torch.load(ckpt_path, map_location=device)
|
| 84 |
+
model_state = model.state_dict()
|
| 85 |
+
filtered_state = {}
|
| 86 |
+
skipped_mismatch = []
|
| 87 |
+
for k, v in raw_state.items():
|
| 88 |
+
if k in model_state and hasattr(v, "shape") and hasattr(model_state[k], "shape"):
|
| 89 |
+
if tuple(v.shape) != tuple(model_state[k].shape):
|
| 90 |
+
skipped_mismatch.append((k, tuple(v.shape), tuple(model_state[k].shape)))
|
| 91 |
+
continue
|
| 92 |
+
filtered_state[k] = v
|
| 93 |
+
|
| 94 |
+
missing, unexpected = model.load_state_dict(filtered_state, strict=False)
|
| 95 |
+
|
| 96 |
+
# hint_gate may be absent in older checkpoints — initialise safely
|
| 97 |
+
allowed = {'model.hint_gate.0.weight', 'model.hint_gate.0.bias'}
|
| 98 |
+
real_missing = [k for k in missing if k not in allowed]
|
| 99 |
+
if real_missing:
|
| 100 |
+
print(f"⚠️ Missing keys: {real_missing[:3]} …")
|
| 101 |
+
if unexpected:
|
| 102 |
+
print(f"⚠️ Unexpected keys: {unexpected[:3]} …")
|
| 103 |
+
if skipped_mismatch:
|
| 104 |
+
print(f"⚠️ Shape-mismatched keys skipped: {len(skipped_mismatch)}")
|
| 105 |
+
|
| 106 |
+
# Enable compact-attention branch only when checkpoint actually provides it.
|
| 107 |
+
has_compact = any(".compact_out_proj.weight" in k for k in filtered_state.keys())
|
| 108 |
+
if has_compact and hasattr(model, "model") and hasattr(model.model, "decoder_blocks"):
|
| 109 |
+
for block in model.model.decoder_blocks:
|
| 110 |
+
if hasattr(block, "cross_attn") and hasattr(block.cross_attn, "use_compact"):
|
| 111 |
+
block.cross_attn.use_compact = True
|
| 112 |
+
print("ℹ️ Compact cross-attention branch enabled from checkpoint.")
|
| 113 |
+
if hasattr(model.model, 'hint_gate') and 'model.hint_gate.0.weight' in missing:
|
| 114 |
+
with torch.no_grad():
|
| 115 |
+
w = model.model.hint_gate[0].weight
|
| 116 |
+
torch.nn.init.zeros_(model.model.hint_gate[0].bias)
|
| 117 |
+
torch.nn.init.eye_(w) if w.shape[0] == w.shape[1] \
|
| 118 |
+
else torch.nn.init.xavier_uniform_(w)
|
| 119 |
+
print("ℹ️ hint_gate initialised to identity (not in checkpoint).")
|
| 120 |
+
|
| 121 |
+
print("✅ Model loaded.")
|
| 122 |
+
return model, cfg
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# ── Core inference function (same path as validation) ────────────────
|
| 126 |
+
|
| 127 |
+
@torch.no_grad()
|
| 128 |
+
def run_inference(model, input_ids, cfg):
|
| 129 |
+
"""
|
| 130 |
+
Reverse diffusion sampling (clean path).
|
| 131 |
+
Uses cached reverse diffusion when available, otherwise model.generate().
|
| 132 |
+
"""
|
| 133 |
+
inf = cfg['inference']
|
| 134 |
+
model.eval()
|
| 135 |
+
kwargs = dict(
|
| 136 |
+
num_steps=inf['num_steps'],
|
| 137 |
+
temperature=inf['temperature'],
|
| 138 |
+
top_k=inf['top_k'],
|
| 139 |
+
repetition_penalty=inf.get('repetition_penalty', 1.2),
|
| 140 |
+
diversity_penalty=inf.get('diversity_penalty', 0.0),
|
| 141 |
+
)
|
| 142 |
+
if hasattr(model, "generate_cached"):
|
| 143 |
+
out = model.generate_cached(input_ids, **kwargs)
|
| 144 |
+
else:
|
| 145 |
+
out = model.generate(input_ids, **kwargs)
|
| 146 |
+
|
| 147 |
+
# Optional retry with stronger anti-repetition settings.
|
| 148 |
+
if inf.get("auto_retry_on_repetition", True):
|
| 149 |
+
repeat_threshold = float(inf.get("repeat_ratio_threshold", 0.40))
|
| 150 |
+
max_repeat_run = int(inf.get("max_repeat_run", 4))
|
| 151 |
+
if _mean_repeat_ratio(out) >= repeat_threshold:
|
| 152 |
+
retry_kwargs = dict(kwargs)
|
| 153 |
+
retry_kwargs["temperature"] = max(0.6, float(kwargs["temperature"]) - 0.1)
|
| 154 |
+
retry_kwargs["top_k"] = max(20, int(kwargs["top_k"]) - 10)
|
| 155 |
+
retry_kwargs["repetition_penalty"] = max(float(kwargs["repetition_penalty"]), 1.6)
|
| 156 |
+
retry_kwargs["diversity_penalty"] = max(float(kwargs["diversity_penalty"]), 0.3)
|
| 157 |
+
if hasattr(model, "generate_cached"):
|
| 158 |
+
retry = model.generate_cached(input_ids, **retry_kwargs)
|
| 159 |
+
else:
|
| 160 |
+
retry = model.generate(input_ids, **retry_kwargs)
|
| 161 |
+
if _mean_repeat_ratio(retry) < _mean_repeat_ratio(out):
|
| 162 |
+
out = retry
|
| 163 |
+
out = _dedup_repeated_ids(out, max_repeat_run=max_repeat_run)
|
| 164 |
+
|
| 165 |
+
return out
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _mean_repeat_ratio(ids_tensor: torch.Tensor) -> float:
|
| 169 |
+
if ids_tensor is None or ids_tensor.numel() == 0:
|
| 170 |
+
return 0.0
|
| 171 |
+
ratios = []
|
| 172 |
+
for row in ids_tensor:
|
| 173 |
+
ids = [int(x) for x in row.tolist() if int(x) > 4]
|
| 174 |
+
if len(ids) < 2:
|
| 175 |
+
ratios.append(0.0)
|
| 176 |
+
continue
|
| 177 |
+
repeats = sum(1 for i in range(1, len(ids)) if ids[i] == ids[i - 1])
|
| 178 |
+
ratios.append(repeats / max(1, len(ids) - 1))
|
| 179 |
+
return float(sum(ratios) / max(1, len(ratios)))
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _dedup_repeated_ids(ids_tensor: torch.Tensor, max_repeat_run: int = 4) -> torch.Tensor:
|
| 183 |
+
"""
|
| 184 |
+
Keep generation path unchanged, but clean extreme run-on token loops in final output ids.
|
| 185 |
+
"""
|
| 186 |
+
if ids_tensor is None or ids_tensor.numel() == 0:
|
| 187 |
+
return ids_tensor
|
| 188 |
+
cleaned_rows = []
|
| 189 |
+
for row in ids_tensor.tolist():
|
| 190 |
+
out = []
|
| 191 |
+
prev = None
|
| 192 |
+
run = 0
|
| 193 |
+
for tok in row:
|
| 194 |
+
if tok <= 4:
|
| 195 |
+
out.append(tok)
|
| 196 |
+
prev = tok
|
| 197 |
+
run = 1
|
| 198 |
+
continue
|
| 199 |
+
if tok == prev:
|
| 200 |
+
run += 1
|
| 201 |
+
if run > max_repeat_run:
|
| 202 |
+
continue
|
| 203 |
+
else:
|
| 204 |
+
run = 1
|
| 205 |
+
out.append(tok)
|
| 206 |
+
prev = tok
|
| 207 |
+
# Preserve original length for downstream decode assumptions.
|
| 208 |
+
if len(out) < len(row):
|
| 209 |
+
out.extend([1] * (len(row) - len(out)))
|
| 210 |
+
else:
|
| 211 |
+
out = out[:len(row)]
|
| 212 |
+
cleaned_rows.append(out)
|
| 213 |
+
return torch.tensor(cleaned_rows, dtype=ids_tensor.dtype, device=ids_tensor.device)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def _decode_clean(tgt_tok, ids):
|
| 217 |
+
out = []
|
| 218 |
+
for x in ids:
|
| 219 |
+
if x in (1, 4) and out:
|
| 220 |
+
break
|
| 221 |
+
if x > 4:
|
| 222 |
+
out.append(x)
|
| 223 |
+
text = tgt_tok.decode(out).strip()
|
| 224 |
+
return _clean_repetition_text(text)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def _clean_repetition_text(text: str, max_repeat_run: int = 3) -> str:
|
| 228 |
+
words = [w for w in text.split() if w.strip()]
|
| 229 |
+
if not words:
|
| 230 |
+
return text.strip()
|
| 231 |
+
cleaned = []
|
| 232 |
+
prev = None
|
| 233 |
+
run = 0
|
| 234 |
+
for w in words:
|
| 235 |
+
if w == prev:
|
| 236 |
+
run += 1
|
| 237 |
+
if run > max_repeat_run:
|
| 238 |
+
continue
|
| 239 |
+
else:
|
| 240 |
+
run = 1
|
| 241 |
+
cleaned.append(w)
|
| 242 |
+
prev = w
|
| 243 |
+
return " ".join(cleaned).strip()
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
# ── Cleanup heuristics from UI inference pipeline ─────────────────────
|
| 247 |
+
|
| 248 |
+
_IAST_VOWELS = [
|
| 249 |
+
("ai", "ऐ"), ("au", "औ"),
|
| 250 |
+
("ā", "आ"), ("ī", "ई"), ("ū", "ऊ"),
|
| 251 |
+
("ṛ", "ऋ"), ("ṝ", "ॠ"), ("ḷ", "ऌ"), ("ḹ", "ॡ"),
|
| 252 |
+
("a", "अ"), ("i", "इ"), ("u", "उ"),
|
| 253 |
+
("e", "ए"), ("o", "ओ"),
|
| 254 |
+
]
|
| 255 |
+
_IAST_MATRAS = [
|
| 256 |
+
("ai", "ै"), ("au", "ौ"),
|
| 257 |
+
("ā", "ा"), ("ī", "ी"), ("ū", "ू"),
|
| 258 |
+
("ṛ", "ृ"), ("ṝ", "ॄ"), ("ḷ", "ॢ"), ("ḹ", "ॣ"),
|
| 259 |
+
("a", ""), ("i", "ि"), ("u", "ु"),
|
| 260 |
+
("e", "े"), ("o", "ो"),
|
| 261 |
+
]
|
| 262 |
+
_IAST_CONS = [
|
| 263 |
+
("kṣ", "क्ष"), ("jñ", "ज्ञ"), ("tr", "त्र"),
|
| 264 |
+
("kh", "ख"), ("gh", "घ"), ("ch", "छ"), ("jh", "झ"),
|
| 265 |
+
("ṭh", "ठ"), ("ḍh", "ढ"), ("th", "थ"), ("dh", "ध"),
|
| 266 |
+
("ph", "फ"), ("bh", "भ"),
|
| 267 |
+
("ṅ", "ङ"), ("ñ", "ञ"), ("ṭ", "ट"), ("ḍ", "ड"),
|
| 268 |
+
("ṇ", "ण"), ("ś", "श"), ("ṣ", "ष"), ("ḥ", "ः"),
|
| 269 |
+
("ṃ", "ं"), ("ṁ", "ं"),
|
| 270 |
+
("y", "���"), ("r", "र"), ("l", "ल"), ("v", "व"),
|
| 271 |
+
("s", "स"), ("h", "ह"),
|
| 272 |
+
("k", "क"), ("g", "ग"), ("c", "च"), ("j", "ज"),
|
| 273 |
+
("t", "त"), ("d", "द"), ("n", "न"),
|
| 274 |
+
("p", "प"), ("b", "ब"), ("m", "म"),
|
| 275 |
+
]
|
| 276 |
+
_PUNCT = {".": "।", "|": "।", "||": "॥", ",": ",", "?": "?", "!": "!"}
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def _iast_to_deva(text: str) -> str:
|
| 280 |
+
s = (text or "").lower()
|
| 281 |
+
out = []
|
| 282 |
+
i = 0
|
| 283 |
+
pending_consonant = False
|
| 284 |
+
|
| 285 |
+
def _match_any(pairs, pos):
|
| 286 |
+
for k, v in pairs:
|
| 287 |
+
if s.startswith(k, pos):
|
| 288 |
+
return k, v
|
| 289 |
+
return None, None
|
| 290 |
+
|
| 291 |
+
while i < len(s):
|
| 292 |
+
if s[i].isspace():
|
| 293 |
+
pending_consonant = False
|
| 294 |
+
out.append(s[i])
|
| 295 |
+
i += 1
|
| 296 |
+
continue
|
| 297 |
+
if s[i:i+2] == "||":
|
| 298 |
+
pending_consonant = False
|
| 299 |
+
out.append(_PUNCT["||"])
|
| 300 |
+
i += 2
|
| 301 |
+
continue
|
| 302 |
+
if s[i] in _PUNCT:
|
| 303 |
+
pending_consonant = False
|
| 304 |
+
out.append(_PUNCT[s[i]])
|
| 305 |
+
i += 1
|
| 306 |
+
continue
|
| 307 |
+
|
| 308 |
+
v_key, v_deva = _match_any(_IAST_VOWELS, i)
|
| 309 |
+
if v_key:
|
| 310 |
+
if pending_consonant:
|
| 311 |
+
_, v_matra = _match_any(_IAST_MATRAS, i)
|
| 312 |
+
out[-1] = out[-1] + (v_matra or "")
|
| 313 |
+
pending_consonant = False
|
| 314 |
+
else:
|
| 315 |
+
out.append(v_deva)
|
| 316 |
+
i += len(v_key)
|
| 317 |
+
continue
|
| 318 |
+
|
| 319 |
+
c_key, c_deva = _match_any(_IAST_CONS, i)
|
| 320 |
+
if c_key:
|
| 321 |
+
if pending_consonant:
|
| 322 |
+
out[-1] = out[-1] + "्"
|
| 323 |
+
out.append(c_deva)
|
| 324 |
+
pending_consonant = True
|
| 325 |
+
i += len(c_key)
|
| 326 |
+
continue
|
| 327 |
+
|
| 328 |
+
out.append(s[i])
|
| 329 |
+
pending_consonant = False
|
| 330 |
+
i += 1
|
| 331 |
+
|
| 332 |
+
return "".join(out).strip()
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def _compute_cer(pred: str, ref: str) -> float:
|
| 336 |
+
if pred == ref:
|
| 337 |
+
return 0.0
|
| 338 |
+
if not pred or not ref:
|
| 339 |
+
return 1.0
|
| 340 |
+
m, n = len(pred), len(ref)
|
| 341 |
+
dp = list(range(n + 1))
|
| 342 |
+
for i in range(1, m + 1):
|
| 343 |
+
prev = dp[0]
|
| 344 |
+
dp[0] = i
|
| 345 |
+
for j in range(1, n + 1):
|
| 346 |
+
temp = dp[j]
|
| 347 |
+
cost = 0 if pred[i - 1] == ref[j - 1] else 1
|
| 348 |
+
dp[j] = min(dp[j] + 1, dp[j - 1] + 1, prev + cost)
|
| 349 |
+
prev = temp
|
| 350 |
+
return dp[n] / max(m, n)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def _cleanup_thresholds(temperature: float, top_k: int):
|
| 354 |
+
temp = float(temperature)
|
| 355 |
+
k = max(1, int(top_k))
|
| 356 |
+
t_norm = max(0.0, min((temp - 0.4) / 0.6, 1.0))
|
| 357 |
+
k_norm = max(0.0, min((k - 20) / 80.0, 1.0))
|
| 358 |
+
diversity = 0.6 * t_norm + 0.4 * k_norm
|
| 359 |
+
cer_threshold = 0.10 + 0.18 * diversity
|
| 360 |
+
deva_ratio_threshold = 0.60 - 0.20 * diversity
|
| 361 |
+
return cer_threshold, deva_ratio_threshold
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def _decode_with_cleanup(tgt_tok, ids, src_text: str, inf_cfg: dict):
|
| 365 |
+
model_out = _decode_clean(tgt_tok, ids)
|
| 366 |
+
rule_out = _iast_to_deva(src_text.strip())
|
| 367 |
+
deva_chars = sum(1 for ch in model_out if "\u0900" <= ch <= "\u097F")
|
| 368 |
+
deva_ratio = deva_chars / max(1, len(model_out))
|
| 369 |
+
cer = _compute_cer(model_out, rule_out)
|
| 370 |
+
cer_thr, ratio_thr = _cleanup_thresholds(
|
| 371 |
+
inf_cfg.get("temperature", 0.8),
|
| 372 |
+
inf_cfg.get("top_k", 40),
|
| 373 |
+
)
|
| 374 |
+
if deva_ratio < ratio_thr or len(model_out) > 2.0 * max(1, len(rule_out)) or cer > cer_thr:
|
| 375 |
+
return rule_out
|
| 376 |
+
return model_out
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
# ── Interactive demo ──────────────────────────────────────────────────
|
| 380 |
+
|
| 381 |
+
def interactive_demo(checkpoint=None, single_text=None):
|
| 382 |
+
from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer
|
| 383 |
+
|
| 384 |
+
cfg = CONFIG
|
| 385 |
+
device = _resolve_device(cfg['training'].get('device', 'cpu'))
|
| 386 |
+
|
| 387 |
+
model_name = cfg['model_type']
|
| 388 |
+
has_neg = cfg['data']['include_negative_examples']
|
| 389 |
+
ckpt = checkpoint or f"results/{model_name}_neg_{has_neg}/best_model.pt"
|
| 390 |
+
|
| 391 |
+
if not os.path.exists(ckpt):
|
| 392 |
+
raise FileNotFoundError(f"No checkpoint at {ckpt} — train first.")
|
| 393 |
+
|
| 394 |
+
model, cfg = load_model(ckpt, cfg, device)
|
| 395 |
+
model.eval()
|
| 396 |
+
|
| 397 |
+
src_tok = SanskritSourceTokenizer(
|
| 398 |
+
vocab_size=cfg['model'].get('src_vocab_size', 16000),
|
| 399 |
+
max_len=cfg['model']['max_seq_len'],
|
| 400 |
+
)
|
| 401 |
+
tgt_tok = SanskritTargetTokenizer(
|
| 402 |
+
vocab_size=cfg['model'].get('tgt_vocab_size', 16000),
|
| 403 |
+
max_len=cfg['model']['max_seq_len'],
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
print("\n" + "="*55)
|
| 407 |
+
print("Sanskrit D3PM Paraphrase — type verse, get paraphrase")
|
| 408 |
+
print("="*55 + "\n")
|
| 409 |
+
|
| 410 |
+
while True:
|
| 411 |
+
try:
|
| 412 |
+
text = (single_text if single_text is not None else input("INPUT > ")).strip()
|
| 413 |
+
except (EOFError, KeyboardInterrupt):
|
| 414 |
+
break
|
| 415 |
+
if not text or text.lower() in ('quit', 'exit', 'q'):
|
| 416 |
+
break
|
| 417 |
+
|
| 418 |
+
ids = torch.tensor(
|
| 419 |
+
[src_tok.encode(text)[:cfg['model']['max_seq_len']]],
|
| 420 |
+
dtype=torch.long, device=device
|
| 421 |
+
)
|
| 422 |
+
out = run_inference(model, ids, cfg)
|
| 423 |
+
cleaned = _decode_with_cleanup(tgt_tok, out[0].tolist(), text, cfg["inference"])
|
| 424 |
+
print(f"PARAPHRASE → {cleaned}\n")
|
| 425 |
+
if single_text is not None:
|
| 426 |
+
break
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
# ── Batch evaluation ──────────────────────────────────────────────────
|
| 430 |
+
|
| 431 |
+
def batch_evaluate(sample_size=500, checkpoint=None):
|
| 432 |
+
from data.dataset import OptimizedSanskritDataset
|
| 433 |
+
from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer
|
| 434 |
+
|
| 435 |
+
cfg = CONFIG
|
| 436 |
+
device = _resolve_device(cfg['training'].get('device', 'cpu'))
|
| 437 |
+
|
| 438 |
+
model_name = cfg['model_type']
|
| 439 |
+
has_neg = cfg['data']['include_negative_examples']
|
| 440 |
+
exp_dir = f"results/{model_name}_neg_{has_neg}"
|
| 441 |
+
ckpt = checkpoint or f"{exp_dir}/best_model.pt"
|
| 442 |
+
|
| 443 |
+
if not os.path.exists(ckpt):
|
| 444 |
+
raise FileNotFoundError(f"No checkpoint at {ckpt}")
|
| 445 |
+
|
| 446 |
+
model, cfg = load_model(ckpt, cfg, device)
|
| 447 |
+
model.eval()
|
| 448 |
+
|
| 449 |
+
src_tok = SanskritSourceTokenizer(
|
| 450 |
+
vocab_size=cfg['model'].get('src_vocab_size', 16000),
|
| 451 |
+
max_len=cfg['model']['max_seq_len'],
|
| 452 |
+
)
|
| 453 |
+
tgt_tok = SanskritTargetTokenizer(
|
| 454 |
+
vocab_size=cfg['model'].get('tgt_vocab_size', 16000),
|
| 455 |
+
max_len=cfg['model']['max_seq_len'],
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
def collate(batch):
|
| 459 |
+
return {
|
| 460 |
+
'input_ids': torch.stack([b['input_ids'].long() for b in batch]),
|
| 461 |
+
'target_text': [b['target_text'] for b in batch],
|
| 462 |
+
'input_text': [b['input_text'] for b in batch],
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
dataset = OptimizedSanskritDataset(
|
| 466 |
+
split='test',
|
| 467 |
+
max_len=cfg['model']['max_seq_len'],
|
| 468 |
+
cfg=cfg,
|
| 469 |
+
src_tokenizer=src_tok,
|
| 470 |
+
tgt_tokenizer=tgt_tok,
|
| 471 |
+
)
|
| 472 |
+
indices = list(range(min(sample_size, len(dataset))))
|
| 473 |
+
loader = DataLoader(
|
| 474 |
+
Subset(dataset, indices),
|
| 475 |
+
batch_size=cfg['training']['batch_size'],
|
| 476 |
+
shuffle=False, collate_fn=collate
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
all_preds, all_refs, all_inputs = [], [], []
|
| 480 |
+
print(f"⏳ Generating {len(indices)} paraphrases …")
|
| 481 |
+
|
| 482 |
+
for batch in tqdm(loader):
|
| 483 |
+
ids = batch['input_ids'].to(device)
|
| 484 |
+
out = run_inference(model, ids, cfg)
|
| 485 |
+
for i in range(out.size(0)):
|
| 486 |
+
all_preds.append(_decode_with_cleanup(
|
| 487 |
+
tgt_tok, out[i].tolist(), batch['input_text'][i], cfg["inference"]
|
| 488 |
+
))
|
| 489 |
+
all_refs.append(batch['target_text'][i].strip())
|
| 490 |
+
all_inputs.append(batch['input_text'][i].strip())
|
| 491 |
+
|
| 492 |
+
# Metrics
|
| 493 |
+
bleu_score, bert_f1 = 0.0, 0.0
|
| 494 |
+
try:
|
| 495 |
+
from nltk.translate.bleu_score import corpus_bleu
|
| 496 |
+
bleu_score = corpus_bleu(
|
| 497 |
+
[[r.split()] for r in all_refs],
|
| 498 |
+
[p.split() for p in all_preds]
|
| 499 |
+
)
|
| 500 |
+
except Exception:
|
| 501 |
+
pass
|
| 502 |
+
|
| 503 |
+
try:
|
| 504 |
+
import evaluate as hf_eval
|
| 505 |
+
res = hf_eval.load('bertscore').compute(
|
| 506 |
+
predictions=all_preds, references=all_refs, lang='hi'
|
| 507 |
+
)
|
| 508 |
+
bert_f1 = sum(res['f1']) / len(res['f1'])
|
| 509 |
+
except Exception:
|
| 510 |
+
pass
|
| 511 |
+
|
| 512 |
+
# Save
|
| 513 |
+
out_path = f"{exp_dir}/evaluation_results.txt"
|
| 514 |
+
pred_path = f"{exp_dir}/evaluation_predictions.jsonl"
|
| 515 |
+
with open(out_path, 'w', encoding='utf-8') as f:
|
| 516 |
+
f.write(f"Model : {model_name}\n")
|
| 517 |
+
f.write(f"Negatives: {has_neg}\n")
|
| 518 |
+
f.write(f"Steps : {cfg['inference']['num_steps']}\n")
|
| 519 |
+
f.write(f"Temp : {cfg['inference']['temperature']}\n")
|
| 520 |
+
f.write(f"RepPen : {cfg['inference']['repetition_penalty']}\n")
|
| 521 |
+
f.write(f"DivPen : {cfg['inference']['diversity_penalty']}\n")
|
| 522 |
+
f.write(f"BLEU : {bleu_score:.4f}\n")
|
| 523 |
+
f.write(f"BERTScore: {bert_f1:.4f}\n\n")
|
| 524 |
+
f.write("=== SAMPLES ===\n")
|
| 525 |
+
for i in range(min(20, len(all_preds))):
|
| 526 |
+
f.write(f"IN : {all_inputs[i]}\n")
|
| 527 |
+
f.write(f"REF : {all_refs[i]}\n")
|
| 528 |
+
f.write(f"PRED: {all_preds[i]}\n")
|
| 529 |
+
f.write("-" * 60 + "\n")
|
| 530 |
+
|
| 531 |
+
with open(pred_path, 'w', encoding='utf-8') as f:
|
| 532 |
+
for src, ref, pred in zip(all_inputs, all_refs, all_preds):
|
| 533 |
+
row = {"input": src, "reference": ref, "prediction": pred}
|
| 534 |
+
f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
| 535 |
+
|
| 536 |
+
print(f"\n✅ Results → {out_path}")
|
| 537 |
+
print(f"🗂️ Saved predictions → {pred_path}")
|
| 538 |
+
print(f"📊 BLEU: {bleu_score:.4f} | BERTScore: {bert_f1:.4f}")
|
| 539 |
+
return all_preds, all_refs
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
if __name__ == '__main__':
|
| 543 |
+
import argparse
|
| 544 |
+
p = argparse.ArgumentParser()
|
| 545 |
+
p.add_argument('--mode', choices=['demo', 'eval'], default='demo')
|
| 546 |
+
p.add_argument('--samples', type=int, default=500)
|
| 547 |
+
p.add_argument('--checkpoint', type=str, default=None)
|
| 548 |
+
p.add_argument('--text', type=str, default=None, help='Run one-shot demo input and exit')
|
| 549 |
+
args = p.parse_args()
|
| 550 |
+
|
| 551 |
+
if args.mode == 'demo':
|
| 552 |
+
interactive_demo(checkpoint=args.checkpoint, single_text=args.text)
|
| 553 |
+
else:
|
| 554 |
+
batch_evaluate(args.samples, checkpoint=args.checkpoint)
|
inference_api.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from typing import Dict, Any
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from config import CONFIG
|
| 9 |
+
from inference import _build_tokenizers, _resolve_device, load_model, run_inference
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
_STATE = {
|
| 13 |
+
"loaded": False,
|
| 14 |
+
"model": None,
|
| 15 |
+
"cfg": None,
|
| 16 |
+
"device": None,
|
| 17 |
+
"src_tok": None,
|
| 18 |
+
"tgt_tok": None,
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _read_model_settings() -> Dict[str, Any]:
|
| 23 |
+
if not os.path.exists("model_settings.json"):
|
| 24 |
+
return {}
|
| 25 |
+
try:
|
| 26 |
+
with open("model_settings.json", "r", encoding="utf-8") as f:
|
| 27 |
+
data = json.load(f)
|
| 28 |
+
return data if isinstance(data, dict) else {}
|
| 29 |
+
except Exception:
|
| 30 |
+
return {}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _load_once() -> None:
|
| 34 |
+
if _STATE["loaded"]:
|
| 35 |
+
return
|
| 36 |
+
|
| 37 |
+
settings = _read_model_settings()
|
| 38 |
+
cfg = copy.deepcopy(CONFIG)
|
| 39 |
+
cfg["model_type"] = os.environ.get(
|
| 40 |
+
"HF_MODEL_TYPE",
|
| 41 |
+
settings.get("model_type", "d3pm_cross_attention"),
|
| 42 |
+
)
|
| 43 |
+
cfg["data"]["include_negative_examples"] = (
|
| 44 |
+
os.environ.get(
|
| 45 |
+
"HF_INCLUDE_NEG",
|
| 46 |
+
str(settings.get("include_negative_examples", True)).lower(),
|
| 47 |
+
).lower()
|
| 48 |
+
== "true"
|
| 49 |
+
)
|
| 50 |
+
num_steps_raw = os.environ.get("HF_NUM_STEPS", settings.get("num_steps"))
|
| 51 |
+
if num_steps_raw is not None:
|
| 52 |
+
num_steps = int(num_steps_raw)
|
| 53 |
+
cfg["model"]["diffusion_steps"] = num_steps
|
| 54 |
+
cfg["inference"]["num_steps"] = num_steps
|
| 55 |
+
device = _resolve_device(cfg)
|
| 56 |
+
|
| 57 |
+
model, cfg = load_model("best_model.pt", cfg, device)
|
| 58 |
+
src_tok, tgt_tok = _build_tokenizers(cfg)
|
| 59 |
+
|
| 60 |
+
_STATE["model"] = model
|
| 61 |
+
_STATE["cfg"] = cfg
|
| 62 |
+
_STATE["device"] = device
|
| 63 |
+
_STATE["src_tok"] = src_tok
|
| 64 |
+
_STATE["tgt_tok"] = tgt_tok
|
| 65 |
+
_STATE["loaded"] = True
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _clean_text(text: str) -> str:
|
| 69 |
+
text = " ".join(text.split())
|
| 70 |
+
if not text:
|
| 71 |
+
return text
|
| 72 |
+
toks = text.split()
|
| 73 |
+
out = []
|
| 74 |
+
prev = None
|
| 75 |
+
run = 0
|
| 76 |
+
for tok in toks:
|
| 77 |
+
if tok == prev:
|
| 78 |
+
run += 1
|
| 79 |
+
else:
|
| 80 |
+
prev = tok
|
| 81 |
+
run = 1
|
| 82 |
+
if run <= 2:
|
| 83 |
+
out.append(tok)
|
| 84 |
+
s = " ".join(out)
|
| 85 |
+
s = s.replace(" ।", "।").replace(" ॥", "॥")
|
| 86 |
+
return " ".join(s.split())
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def predict(
|
| 90 |
+
text: str,
|
| 91 |
+
temperature: float = 0.7,
|
| 92 |
+
top_k: int = 40,
|
| 93 |
+
repetition_penalty: float = 1.2,
|
| 94 |
+
diversity_penalty: float = 0.0,
|
| 95 |
+
num_steps: int = 64,
|
| 96 |
+
clean_output: bool = True,
|
| 97 |
+
) -> Dict[str, Any]:
|
| 98 |
+
_load_once()
|
| 99 |
+
if not text or not text.strip():
|
| 100 |
+
return {"error": "empty input", "output": ""}
|
| 101 |
+
|
| 102 |
+
cfg = copy.deepcopy(_STATE["cfg"])
|
| 103 |
+
cfg["inference"]["temperature"] = float(temperature)
|
| 104 |
+
cfg["inference"]["top_k"] = int(top_k)
|
| 105 |
+
cfg["inference"]["repetition_penalty"] = float(repetition_penalty)
|
| 106 |
+
cfg["inference"]["diversity_penalty"] = float(diversity_penalty)
|
| 107 |
+
cfg["inference"]["num_steps"] = int(num_steps)
|
| 108 |
+
|
| 109 |
+
src_tok = _STATE["src_tok"]
|
| 110 |
+
tgt_tok = _STATE["tgt_tok"]
|
| 111 |
+
device = _STATE["device"]
|
| 112 |
+
|
| 113 |
+
input_ids = torch.tensor([src_tok.encode(text.strip())], dtype=torch.long, device=device)
|
| 114 |
+
out = run_inference(_STATE["model"], input_ids, cfg)
|
| 115 |
+
decoded_ids = [x for x in out[0].tolist() if x > 4]
|
| 116 |
+
raw = tgt_tok.decode(decoded_ids).strip()
|
| 117 |
+
output = _clean_text(raw) if clean_output else raw
|
| 118 |
+
|
| 119 |
+
return {
|
| 120 |
+
"input": text,
|
| 121 |
+
"output": output,
|
| 122 |
+
"raw_output": raw,
|
| 123 |
+
"config": {
|
| 124 |
+
"temperature": float(temperature),
|
| 125 |
+
"top_k": int(top_k),
|
| 126 |
+
"repetition_penalty": float(repetition_penalty),
|
| 127 |
+
"diversity_penalty": float(diversity_penalty),
|
| 128 |
+
"num_steps": int(num_steps),
|
| 129 |
+
"clean_output": bool(clean_output),
|
| 130 |
+
},
|
| 131 |
+
}
|
model/__init__.py
ADDED
|
File without changes
|
model/d3pm_model_cross_attention.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
d3pm_model_cross_attention.py — Cross-Script + Generation-Fixed
|
| 3 |
+
=================================================================
|
| 4 |
+
INPUT : quote_text tokens (Roman script, src_vocab_size)
|
| 5 |
+
OUTPUT : quote_devanagari tokens (Devanagari script, tgt_vocab_size)
|
| 6 |
+
|
| 7 |
+
src_embed uses src_vocab_size (Roman BPE)
|
| 8 |
+
tgt_embed uses tgt_vocab_size (Devanagari BPE)
|
| 9 |
+
head outputs tgt_vocab_size (predict Devanagari tokens)
|
| 10 |
+
Weight tying: head <-> tgt_embed only (NOT src_embed)
|
| 11 |
+
|
| 12 |
+
Generation bugs fixed:
|
| 13 |
+
BUG 1 - tgt_pad_mask suppressed during inference
|
| 14 |
+
BUG 2 - q_sample skipped at t=0
|
| 15 |
+
BUG 3 - time embedding before hint_gate
|
| 16 |
+
BUG 4 - diversity penalty uses global mean not var
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
|
| 23 |
+
from diffusion.scheduler import OptimizedCosineScheduler
|
| 24 |
+
from diffusion.forward_process import AbsorbingForwardProcess
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class SinusoidalPositionalEncoding(nn.Module):
|
| 28 |
+
def __init__(self, d_model, max_len=5000):
|
| 29 |
+
super().__init__()
|
| 30 |
+
pe = torch.zeros(max_len, d_model)
|
| 31 |
+
position = torch.arange(0, max_len).unsqueeze(1).float()
|
| 32 |
+
div_term = torch.exp(
|
| 33 |
+
torch.arange(0, d_model, 2).float() *
|
| 34 |
+
(-torch.log(torch.tensor(10000.0)) / d_model)
|
| 35 |
+
)
|
| 36 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 37 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 38 |
+
self.register_buffer("pe", pe.unsqueeze(0))
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
return x + self.pe[:, :x.size(1), :]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class SanskritEmbeddings(nn.Module):
|
| 45 |
+
def __init__(self, vocab_size, d_model, max_seq_len):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.token_emb = nn.Embedding(vocab_size, d_model)
|
| 48 |
+
self.pos_enc = SinusoidalPositionalEncoding(d_model, max_seq_len)
|
| 49 |
+
self.token_embedding = self.token_emb
|
| 50 |
+
def forward(self, tokens):
|
| 51 |
+
return self.pos_enc(self.token_emb(tokens))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class MultiHeadAttention(nn.Module):
|
| 55 |
+
def __init__(self, d_model, n_heads, dropout=0.1):
|
| 56 |
+
super().__init__()
|
| 57 |
+
assert d_model % n_heads == 0
|
| 58 |
+
self.d_model = d_model
|
| 59 |
+
self.n_heads = n_heads
|
| 60 |
+
self.head_dim = d_model // n_heads
|
| 61 |
+
self.q_proj = nn.Linear(d_model, d_model)
|
| 62 |
+
self.k_proj = nn.Linear(d_model, d_model)
|
| 63 |
+
self.v_proj = nn.Linear(d_model, d_model)
|
| 64 |
+
self.out_proj = nn.Linear(d_model, d_model)
|
| 65 |
+
self.dropout = nn.Dropout(dropout)
|
| 66 |
+
|
| 67 |
+
def forward(self, q, k, v, mask=None):
|
| 68 |
+
B, Lq, _ = q.size()
|
| 69 |
+
Lk = k.size(1)
|
| 70 |
+
Q = self.q_proj(q).view(B, Lq, self.n_heads, self.head_dim).transpose(1, 2)
|
| 71 |
+
K = self.k_proj(k).view(B, Lk, self.n_heads, self.head_dim).transpose(1, 2)
|
| 72 |
+
V = self.v_proj(v).view(B, Lk, self.n_heads, self.head_dim).transpose(1, 2)
|
| 73 |
+
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
|
| 74 |
+
if mask is not None:
|
| 75 |
+
scores = scores.masked_fill(mask.unsqueeze(1).unsqueeze(2), float('-inf'))
|
| 76 |
+
attn = self.dropout(torch.softmax(scores, dim=-1))
|
| 77 |
+
out = torch.matmul(attn, V).transpose(1, 2).contiguous().view(B, Lq, self.d_model)
|
| 78 |
+
return self.out_proj(out)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class EncoderBlock(nn.Module):
|
| 82 |
+
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.mha = MultiHeadAttention(d_model, n_heads, dropout)
|
| 85 |
+
self.ff = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout),
|
| 86 |
+
nn.Linear(d_ff, d_model), nn.Dropout(dropout))
|
| 87 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 88 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 89 |
+
def forward(self, x, pad_mask=None):
|
| 90 |
+
x = self.norm1(x + self.mha(x, x, x, mask=pad_mask))
|
| 91 |
+
return self.norm2(x + self.ff(x))
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class DecoderBlock(nn.Module):
|
| 95 |
+
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
|
| 98 |
+
self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout)
|
| 99 |
+
self.ff = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout),
|
| 100 |
+
nn.Linear(d_ff, d_model), nn.Dropout(dropout))
|
| 101 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 102 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 103 |
+
self.norm3 = nn.LayerNorm(d_model)
|
| 104 |
+
def forward(self, x, memory, tgt_pad_mask=None, src_pad_mask=None):
|
| 105 |
+
x = self.norm1(x + self.self_attn(x, x, x, mask=tgt_pad_mask))
|
| 106 |
+
x = self.norm2(x + self.cross_attn(x, memory, memory, mask=src_pad_mask))
|
| 107 |
+
return self.norm3(x + self.ff(x))
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class D3PMCrossAttention(nn.Module):
|
| 111 |
+
def __init__(self, cfg):
|
| 112 |
+
super().__init__()
|
| 113 |
+
self.cfg = cfg
|
| 114 |
+
self.mask_token_id = cfg['diffusion']['mask_token_id']
|
| 115 |
+
d = cfg['model']['d_model']
|
| 116 |
+
nhead = cfg['model']['n_heads']
|
| 117 |
+
d_ff = cfg['model']['d_ff']
|
| 118 |
+
drop = cfg['model']['dropout']
|
| 119 |
+
seqlen = cfg['model']['max_seq_len']
|
| 120 |
+
nlayer = cfg['model']['n_layers']
|
| 121 |
+
src_vocab = cfg['model'].get('src_vocab_size', cfg['model']['vocab_size'])
|
| 122 |
+
tgt_vocab = cfg['model'].get('tgt_vocab_size', cfg['model']['vocab_size'])
|
| 123 |
+
|
| 124 |
+
# Separate embeddings: Roman src, Devanagari tgt
|
| 125 |
+
self.src_embed = SanskritEmbeddings(src_vocab, d, seqlen)
|
| 126 |
+
self.tgt_embed = SanskritEmbeddings(tgt_vocab, d, seqlen)
|
| 127 |
+
|
| 128 |
+
self.scheduler = OptimizedCosineScheduler(cfg)
|
| 129 |
+
self.forward_process = AbsorbingForwardProcess(self.scheduler)
|
| 130 |
+
|
| 131 |
+
self.encoder_blocks = nn.ModuleList([EncoderBlock(d, nhead, d_ff, drop) for _ in range(nlayer)])
|
| 132 |
+
self.decoder_blocks = nn.ModuleList([DecoderBlock(d, nhead, d_ff, drop) for _ in range(nlayer)])
|
| 133 |
+
|
| 134 |
+
self.time_mlp = nn.Sequential(nn.Linear(1, d//4), nn.SiLU(), nn.Linear(d//4, d))
|
| 135 |
+
self.hint_gate = nn.Sequential(nn.Linear(d, d), nn.Sigmoid())
|
| 136 |
+
|
| 137 |
+
# Output head: predict Devanagari tokens, tied to tgt_embed
|
| 138 |
+
self.head = nn.Linear(d, tgt_vocab, bias=False)
|
| 139 |
+
self.head.weight = self.tgt_embed.token_embedding.weight
|
| 140 |
+
|
| 141 |
+
def forward(self, src, tgt, t, x0_hint=None, inference_mode=False):
|
| 142 |
+
PAD = 1
|
| 143 |
+
src_pad_mask = (src == PAD)
|
| 144 |
+
# BUG 1 FIX: no tgt mask during inference
|
| 145 |
+
tgt_pad_mask = None if inference_mode else (tgt == PAD)
|
| 146 |
+
|
| 147 |
+
# Encode Roman source
|
| 148 |
+
memory = self.src_embed(src)
|
| 149 |
+
for block in self.encoder_blocks:
|
| 150 |
+
memory = block(memory, pad_mask=src_pad_mask)
|
| 151 |
+
|
| 152 |
+
# BUG 2 FIX: skip q_sample at final step t=0
|
| 153 |
+
if inference_mode and (t == 0).all():
|
| 154 |
+
x_t_ids = tgt
|
| 155 |
+
else:
|
| 156 |
+
_, x_t_ids = self.forward_process.q_sample(tgt, t)
|
| 157 |
+
|
| 158 |
+
x = self.tgt_embed(x_t_ids)
|
| 159 |
+
|
| 160 |
+
# BUG 3 FIX: time embedding BEFORE hint gate
|
| 161 |
+
t_norm = t.float() / self.scheduler.num_timesteps
|
| 162 |
+
t_emb = self.time_mlp(t_norm.unsqueeze(-1))
|
| 163 |
+
x = x + t_emb.unsqueeze(1)
|
| 164 |
+
|
| 165 |
+
if x0_hint is not None:
|
| 166 |
+
hint_emb = self.tgt_embed(x0_hint)
|
| 167 |
+
gate = self.hint_gate(x) # time-aware gate
|
| 168 |
+
x = x + gate * hint_emb
|
| 169 |
+
|
| 170 |
+
for block in self.decoder_blocks:
|
| 171 |
+
x = block(x, memory, tgt_pad_mask=tgt_pad_mask, src_pad_mask=src_pad_mask)
|
| 172 |
+
|
| 173 |
+
return self.head(x), None
|
| 174 |
+
|
| 175 |
+
@torch.no_grad()
|
| 176 |
+
def generate(self, src, num_steps=None, temperature=0.8, top_k=50,
|
| 177 |
+
repetition_penalty=1.2, diversity_penalty=0.0):
|
| 178 |
+
if src.dim() == 1:
|
| 179 |
+
src = src.unsqueeze(0)
|
| 180 |
+
device = src.device
|
| 181 |
+
B, L = src.shape
|
| 182 |
+
T = self.scheduler.num_timesteps
|
| 183 |
+
steps = num_steps or T
|
| 184 |
+
step_size = max(1, T // steps)
|
| 185 |
+
timesteps = list(range(T - 1, -1, -step_size))
|
| 186 |
+
if timesteps[-1] != 0:
|
| 187 |
+
timesteps.append(0)
|
| 188 |
+
|
| 189 |
+
mask_id = self.mask_token_id
|
| 190 |
+
x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
|
| 191 |
+
hint = None
|
| 192 |
+
|
| 193 |
+
self.eval()
|
| 194 |
+
with torch.no_grad():
|
| 195 |
+
for step_idx, t_val in enumerate(timesteps):
|
| 196 |
+
t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 197 |
+
is_last = (step_idx == len(timesteps) - 1)
|
| 198 |
+
logits, _ = self.forward(src, x0_est, t, x0_hint=hint, inference_mode=True)
|
| 199 |
+
if repetition_penalty != 1.0:
|
| 200 |
+
logits = _apply_repetition_penalty(logits, x0_est, repetition_penalty)
|
| 201 |
+
if diversity_penalty > 0.0:
|
| 202 |
+
logits = _apply_diversity_penalty_fixed(logits, diversity_penalty) # BUG 4 FIX
|
| 203 |
+
logits = logits / max(temperature, 1e-5)
|
| 204 |
+
if top_k > 0:
|
| 205 |
+
logits = _top_k_filter(logits, top_k)
|
| 206 |
+
probs = F.softmax(logits, dim=-1)
|
| 207 |
+
x0_est = torch.argmax(probs, dim=-1) if is_last else _batch_multinomial(probs)
|
| 208 |
+
hint = x0_est
|
| 209 |
+
return x0_est
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class BaselineCrossAttention(nn.Module):
|
| 213 |
+
def __init__(self, cfg):
|
| 214 |
+
super().__init__()
|
| 215 |
+
d = cfg['model']['d_model']; nhead = cfg['model']['n_heads']
|
| 216 |
+
d_ff = cfg['model']['d_ff']; drop = cfg['model']['dropout']
|
| 217 |
+
seqlen = cfg['model']['max_seq_len']; nlayer = cfg['model']['n_layers']
|
| 218 |
+
src_vocab = cfg['model'].get('src_vocab_size', cfg['model']['vocab_size'])
|
| 219 |
+
tgt_vocab = cfg['model'].get('tgt_vocab_size', cfg['model']['vocab_size'])
|
| 220 |
+
self.src_embed = SanskritEmbeddings(src_vocab, d, seqlen)
|
| 221 |
+
self.tgt_embed = SanskritEmbeddings(tgt_vocab, d, seqlen)
|
| 222 |
+
self.encoder_blocks = nn.ModuleList([EncoderBlock(d, nhead, d_ff, drop) for _ in range(nlayer)])
|
| 223 |
+
self.decoder_blocks = nn.ModuleList([DecoderBlock(d, nhead, d_ff, drop) for _ in range(nlayer)])
|
| 224 |
+
self.head = nn.Linear(d, tgt_vocab, bias=False)
|
| 225 |
+
self.head.weight = self.tgt_embed.token_embedding.weight
|
| 226 |
+
|
| 227 |
+
def forward(self, src, tgt, t=None, x0_hint=None):
|
| 228 |
+
PAD = 1
|
| 229 |
+
memory = self.src_embed(src)
|
| 230 |
+
for b in self.encoder_blocks: memory = b(memory, pad_mask=(src==PAD))
|
| 231 |
+
x = self.tgt_embed(tgt)
|
| 232 |
+
for b in self.decoder_blocks: x = b(x, memory, tgt_pad_mask=(tgt==PAD), src_pad_mask=(src==PAD))
|
| 233 |
+
return (self.head(x),)
|
| 234 |
+
|
| 235 |
+
@torch.no_grad()
|
| 236 |
+
def generate(self, src, max_len=None, start_token_id=2, **kwargs):
|
| 237 |
+
if max_len is None: max_len = src.size(1)
|
| 238 |
+
B, device = src.size(0), src.device
|
| 239 |
+
memory = self.src_embed(src)
|
| 240 |
+
for b in self.encoder_blocks: memory = b(memory, pad_mask=(src==1))
|
| 241 |
+
ys = torch.full((B, 1), start_token_id, dtype=torch.long, device=device)
|
| 242 |
+
for _ in range(max_len):
|
| 243 |
+
x = self.tgt_embed(ys)
|
| 244 |
+
for b in self.decoder_blocks: x = b(x, memory, tgt_pad_mask=None, src_pad_mask=(src==1))
|
| 245 |
+
ys = torch.cat([ys, torch.argmax(self.head(x)[:,-1,:], dim=-1, keepdim=True)], dim=1)
|
| 246 |
+
return ys[:, 1:max_len+1]
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# helpers
|
| 250 |
+
def _top_k_filter(logits, k):
|
| 251 |
+
B, L, V = logits.shape
|
| 252 |
+
if k >= V: return logits
|
| 253 |
+
topk_vals, _ = torch.topk(logits, k, dim=-1)
|
| 254 |
+
return logits.masked_fill(logits < topk_vals[..., -1].unsqueeze(-1), float('-inf'))
|
| 255 |
+
|
| 256 |
+
def _batch_multinomial(probs):
|
| 257 |
+
B, L, V = probs.shape
|
| 258 |
+
flat = probs.view(B*L, V) + 1e-9
|
| 259 |
+
return torch.multinomial(flat/flat.sum(-1,keepdim=True), 1).squeeze(-1).view(B, L)
|
| 260 |
+
|
| 261 |
+
def _apply_repetition_penalty(logits, prev_tokens, penalty):
|
| 262 |
+
for b in range(logits.shape[0]):
|
| 263 |
+
for tid in set(prev_tokens[b].tolist()):
|
| 264 |
+
if tid > 4: logits[b, :, tid] = logits[b, :, tid] / penalty
|
| 265 |
+
return logits
|
| 266 |
+
|
| 267 |
+
def _apply_diversity_penalty(logits, penalty): # legacy wrong version
|
| 268 |
+
return logits + penalty * logits.var(dim=-1, keepdim=True)
|
| 269 |
+
|
| 270 |
+
def _apply_diversity_penalty_fixed(logits, penalty): # correct version
|
| 271 |
+
return logits - penalty * logits.mean(dim=1, keepdim=True)
|
model/d3pm_model_encoder_decoder.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from diffusion.scheduler import OptimizedCosineScheduler
|
| 4 |
+
from diffusion.forward_process import AbsorbingForwardProcess
|
| 5 |
+
# Import shared classes to guarantee identical architectures
|
| 6 |
+
from model.d3pm_model_cross_attention import SanskritEmbeddings, EncoderBlock, MultiHeadAttention
|
| 7 |
+
class DecoderBlock(nn.Module):
|
| 8 |
+
def __init__(self, d_model, n_heads, d_ff, dropout=0.15):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
|
| 11 |
+
self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout) # ← restored
|
| 12 |
+
self.ff = nn.Sequential(
|
| 13 |
+
nn.Linear(d_model, d_ff),
|
| 14 |
+
nn.ReLU(),
|
| 15 |
+
nn.Dropout(dropout),
|
| 16 |
+
nn.Linear(d_ff, d_model),
|
| 17 |
+
nn.Dropout(dropout),
|
| 18 |
+
)
|
| 19 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 20 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 21 |
+
self.norm3 = nn.LayerNorm(d_model) # ← restored (for cross-attn residual)
|
| 22 |
+
|
| 23 |
+
def forward(self, x, memory, tgt_pad_mask=None):
|
| 24 |
+
# 1. Masked self-attention on target
|
| 25 |
+
x = self.norm1(x + self.self_attn(x, x, x, mask=tgt_pad_mask))
|
| 26 |
+
# 2. Cross-attention: queries from decoder, keys/values from encoder memory
|
| 27 |
+
x = self.norm2(x + self.cross_attn(x, memory, memory))
|
| 28 |
+
# 3. Feed-forward
|
| 29 |
+
return self.norm3(x + self.ff(x))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class DecoderBlockNoCrossAttn(nn.Module):
|
| 33 |
+
"""Kept for reference — NOT used by D3PMEncoderDecoder."""
|
| 34 |
+
def __init__(self, d_model, n_heads, d_ff, dropout=0.15):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
|
| 37 |
+
self.ff = nn.Sequential(
|
| 38 |
+
nn.Linear(d_model, d_ff), nn.ReLU(), nn.Dropout(dropout),
|
| 39 |
+
nn.Linear(d_ff, d_model), nn.Dropout(dropout),
|
| 40 |
+
)
|
| 41 |
+
self.norm1, self.norm2 = nn.LayerNorm(d_model), nn.LayerNorm(d_model)
|
| 42 |
+
|
| 43 |
+
def forward(self, x, tgt_pad_mask=None, causal_mask=None):
|
| 44 |
+
combined_mask = None
|
| 45 |
+
if tgt_pad_mask is not None and causal_mask is not None:
|
| 46 |
+
combined_mask = tgt_pad_mask | causal_mask
|
| 47 |
+
elif causal_mask is not None:
|
| 48 |
+
combined_mask = causal_mask
|
| 49 |
+
elif tgt_pad_mask is not None:
|
| 50 |
+
combined_mask = tgt_pad_mask
|
| 51 |
+
x = self.norm1(x + self.self_attn(x, x, x, mask=combined_mask))
|
| 52 |
+
return self.norm2(x + self.ff(x))
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# ============================================================
|
| 56 |
+
# 1. D3PM Encoder-Decoder Model
|
| 57 |
+
# ============================================================
|
| 58 |
+
class D3PMEncoderDecoder(nn.Module):
|
| 59 |
+
def __init__(self, cfg):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.cfg = cfg
|
| 62 |
+
self.mask_token_id = cfg['diffusion']['mask_token_id']
|
| 63 |
+
|
| 64 |
+
src_vocab = cfg['model'].get('src_vocab_size', cfg['model']['vocab_size'])
|
| 65 |
+
tgt_vocab = cfg['model'].get('tgt_vocab_size', cfg['model']['vocab_size'])
|
| 66 |
+
d_model = cfg['model']['d_model']
|
| 67 |
+
n_heads = cfg['model']['n_heads']
|
| 68 |
+
d_ff = cfg['model']['d_ff']
|
| 69 |
+
dropout = cfg['model']['dropout']
|
| 70 |
+
n_layers = cfg['model']['n_layers']
|
| 71 |
+
max_len = cfg['model']['max_seq_len']
|
| 72 |
+
|
| 73 |
+
self.src_embed = SanskritEmbeddings(src_vocab, d_model, max_len)
|
| 74 |
+
self.tgt_embed = SanskritEmbeddings(tgt_vocab, d_model, max_len)
|
| 75 |
+
|
| 76 |
+
self.scheduler = OptimizedCosineScheduler(cfg)
|
| 77 |
+
self.forward_process = AbsorbingForwardProcess(self.scheduler)
|
| 78 |
+
|
| 79 |
+
self.encoder_blocks = nn.ModuleList([
|
| 80 |
+
EncoderBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
|
| 81 |
+
])
|
| 82 |
+
# DecoderBlock now has cross-attention — matches saved checkpoint
|
| 83 |
+
self.decoder_blocks = nn.ModuleList([
|
| 84 |
+
DecoderBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
|
| 85 |
+
])
|
| 86 |
+
|
| 87 |
+
self.time_mlp = nn.Sequential(
|
| 88 |
+
nn.Linear(1, d_model // 4), nn.SiLU(),
|
| 89 |
+
nn.Linear(d_model // 4, d_model),
|
| 90 |
+
)
|
| 91 |
+
self.head = nn.Linear(d_model, tgt_vocab)
|
| 92 |
+
self.head.weight = self.tgt_embed.token_embedding.weight
|
| 93 |
+
|
| 94 |
+
def forward(self, src, tgt, t, x0_hint=None):
|
| 95 |
+
src_pad_mask = (src == 1)
|
| 96 |
+
tgt_pad_mask = (tgt == 1)
|
| 97 |
+
|
| 98 |
+
# Encode source (Roman IAST)
|
| 99 |
+
memory = self.src_embed(src)
|
| 100 |
+
for block in self.encoder_blocks:
|
| 101 |
+
memory = block(memory, pad_mask=src_pad_mask)
|
| 102 |
+
|
| 103 |
+
# Corrupt target with forward diffusion
|
| 104 |
+
_, x_t_ids = self.forward_process.q_sample(tgt, t)
|
| 105 |
+
|
| 106 |
+
# Optionally blend in x0_hint (self-conditioning)
|
| 107 |
+
if x0_hint is not None:
|
| 108 |
+
hint_prob = 0.5
|
| 109 |
+
blend_mask = (torch.rand(x_t_ids.shape, device=x_t_ids.device) < hint_prob)
|
| 110 |
+
still_mask = (x_t_ids == self.mask_token_id)
|
| 111 |
+
x_t_ids = torch.where(blend_mask & still_mask, x0_hint, x_t_ids)
|
| 112 |
+
|
| 113 |
+
x = self.tgt_embed(x_t_ids)
|
| 114 |
+
t_emb = self.time_mlp(t.float().unsqueeze(-1)).unsqueeze(1)
|
| 115 |
+
x = x + t_emb.expand(-1, tgt.shape[1], -1)
|
| 116 |
+
|
| 117 |
+
# Decode with cross-attention over encoder memory
|
| 118 |
+
for block in self.decoder_blocks:
|
| 119 |
+
x = block(x, memory, tgt_pad_mask=tgt_pad_mask)
|
| 120 |
+
|
| 121 |
+
return self.head(x), None
|
| 122 |
+
|
| 123 |
+
@torch.no_grad()
|
| 124 |
+
def generate(
|
| 125 |
+
self,
|
| 126 |
+
src,
|
| 127 |
+
num_steps = None,
|
| 128 |
+
temperature = 0.75,
|
| 129 |
+
top_k = 50,
|
| 130 |
+
repetition_penalty = 1.15,
|
| 131 |
+
diversity_penalty = 0.0,
|
| 132 |
+
):
|
| 133 |
+
"""
|
| 134 |
+
Iterative D3PM reverse diffusion — same signature as
|
| 135 |
+
D3PMCrossAttention.generate() so SanskritModel.generate() works
|
| 136 |
+
identically for both model types.
|
| 137 |
+
"""
|
| 138 |
+
device = src.device
|
| 139 |
+
B, L = src.shape[0], self.cfg['model']['max_seq_len']
|
| 140 |
+
T = num_steps or self.scheduler.num_timesteps
|
| 141 |
+
mask_id = self.mask_token_id
|
| 142 |
+
pad_id = 1
|
| 143 |
+
|
| 144 |
+
x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
|
| 145 |
+
|
| 146 |
+
for step in range(T - 1, -1, -1):
|
| 147 |
+
t_tensor = torch.full((B,), step, dtype=torch.long, device=device)
|
| 148 |
+
hint = x0_est.clone()
|
| 149 |
+
|
| 150 |
+
logits, _ = self.forward(src, x0_est, t_tensor, x0_hint=hint)
|
| 151 |
+
|
| 152 |
+
# Repetition penalty
|
| 153 |
+
if repetition_penalty != 1.0:
|
| 154 |
+
for b in range(B):
|
| 155 |
+
for tok in set(x0_est[b].tolist()):
|
| 156 |
+
if tok > pad_id:
|
| 157 |
+
logits[b, :, tok] /= repetition_penalty
|
| 158 |
+
|
| 159 |
+
# Diversity penalty (suppress common tokens)
|
| 160 |
+
if diversity_penalty > 0.0:
|
| 161 |
+
logits = logits - diversity_penalty * logits.mean(dim=1, keepdim=True)
|
| 162 |
+
|
| 163 |
+
# Temperature + top-k sampling
|
| 164 |
+
logits = logits / max(temperature, 1e-8)
|
| 165 |
+
if top_k > 0:
|
| 166 |
+
vals, _ = torch.topk(logits, top_k, dim=-1)
|
| 167 |
+
logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
|
| 168 |
+
|
| 169 |
+
probs = torch.softmax(logits, dim=-1)
|
| 170 |
+
# Only update positions that are still masked
|
| 171 |
+
still = (x0_est == mask_id)
|
| 172 |
+
sample = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(B, L)
|
| 173 |
+
x0_est = torch.where(still, sample, x0_est)
|
| 174 |
+
|
| 175 |
+
return x0_est
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# ============================================================
|
| 179 |
+
# 2. Baseline Encoder-Decoder Model (unchanged)
|
| 180 |
+
# ============================================================
|
| 181 |
+
class BaselineEncoderDecoder(nn.Module):
|
| 182 |
+
def __init__(self, cfg):
|
| 183 |
+
super().__init__()
|
| 184 |
+
self.cfg = cfg
|
| 185 |
+
self.src_embed = SanskritEmbeddings(cfg['model']['vocab_size'], cfg['model']['d_model'],
|
| 186 |
+
cfg['model']['max_seq_len'])
|
| 187 |
+
self.tgt_embed = SanskritEmbeddings(cfg['model']['vocab_size'], cfg['model']['d_model'],
|
| 188 |
+
cfg['model']['max_seq_len'])
|
| 189 |
+
self.encoder_blocks = nn.ModuleList([
|
| 190 |
+
EncoderBlock(cfg['model']['d_model'], cfg['model']['n_heads'],
|
| 191 |
+
cfg['model']['d_ff'], cfg['model']['dropout'])
|
| 192 |
+
for _ in range(cfg['model']['n_layers'])
|
| 193 |
+
])
|
| 194 |
+
self.decoder_blocks = nn.ModuleList([
|
| 195 |
+
DecoderBlock(cfg['model']['d_model'], cfg['model']['n_heads'],
|
| 196 |
+
cfg['model']['d_ff'], cfg['model']['dropout'])
|
| 197 |
+
for _ in range(cfg['model']['n_layers'])
|
| 198 |
+
])
|
| 199 |
+
self.head = nn.Linear(cfg['model']['d_model'], cfg['model']['vocab_size'])
|
| 200 |
+
self.head.weight = self.tgt_embed.token_embedding.weight
|
| 201 |
+
|
| 202 |
+
def forward(self, src, tgt):
|
| 203 |
+
src_pad_mask, tgt_pad_mask = (src == 1), (tgt == 1)
|
| 204 |
+
memory = self.src_embed(src)
|
| 205 |
+
for block in self.encoder_blocks:
|
| 206 |
+
memory = block(memory, pad_mask=src_pad_mask)
|
| 207 |
+
x = self.tgt_embed(tgt)
|
| 208 |
+
for block in self.decoder_blocks:
|
| 209 |
+
x = block(x, memory, tgt_pad_mask=tgt_pad_mask)
|
| 210 |
+
return self.head(x)
|
| 211 |
+
|
| 212 |
+
@torch.no_grad()
|
| 213 |
+
def generate(self, src, max_len=80, start_token_id=2):
|
| 214 |
+
batch_size, device = src.size(0), src.device
|
| 215 |
+
src_pad_mask = (src == 1)
|
| 216 |
+
memory = self.src_embed(src)
|
| 217 |
+
for block in self.encoder_blocks:
|
| 218 |
+
memory = block(memory, pad_mask=src_pad_mask)
|
| 219 |
+
ys = torch.ones(batch_size, 1, dtype=torch.long, device=device) * start_token_id
|
| 220 |
+
for _ in range(max_len):
|
| 221 |
+
x = self.tgt_embed(ys)
|
| 222 |
+
for block in self.decoder_blocks:
|
| 223 |
+
x = block(x, memory, tgt_pad_mask=None)
|
| 224 |
+
logits = self.head(x)
|
| 225 |
+
next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
|
| 226 |
+
ys = torch.cat([ys, next_token], dim=1)
|
| 227 |
+
return ys[:, 1:]
|
model/sanskrit_model.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
sanskrit_model.py — Fixed
|
| 3 |
+
===========================
|
| 4 |
+
Added inference_mode parameter to forward() so reverse_process.py can
|
| 5 |
+
pass inference_mode=True without a TypeError.
|
| 6 |
+
|
| 7 |
+
The wrapper introspects each inner model's signature and only passes
|
| 8 |
+
kwargs that model actually accepts — safe across all four architectures.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import inspect
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SanskritModel(nn.Module):
|
| 17 |
+
def __init__(self, cfg):
|
| 18 |
+
super().__init__()
|
| 19 |
+
model_type = cfg['model_type']
|
| 20 |
+
|
| 21 |
+
if model_type == 'd3pm_cross_attention':
|
| 22 |
+
from model.d3pm_model_cross_attention import D3PMCrossAttention
|
| 23 |
+
self.model = D3PMCrossAttention(cfg)
|
| 24 |
+
|
| 25 |
+
elif model_type == 'd3pm_encoder_decoder':
|
| 26 |
+
from model.d3pm_model_encoder_decoder import D3PMEncoderDecoder
|
| 27 |
+
self.model = D3PMEncoderDecoder(cfg)
|
| 28 |
+
|
| 29 |
+
elif model_type == 'baseline_cross_attention':
|
| 30 |
+
from model.d3pm_model_cross_attention import BaselineCrossAttention
|
| 31 |
+
self.model = BaselineCrossAttention(cfg)
|
| 32 |
+
|
| 33 |
+
elif model_type == 'baseline_encoder_decoder':
|
| 34 |
+
from model.d3pm_model_encoder_decoder import BaselineEncoderDecoder
|
| 35 |
+
self.model = BaselineEncoderDecoder(cfg)
|
| 36 |
+
|
| 37 |
+
else:
|
| 38 |
+
raise ValueError(f"Unknown model_type: {model_type}")
|
| 39 |
+
|
| 40 |
+
def forward(self, input_ids, target_ids, t, x0_hint=None, inference_mode=False):
|
| 41 |
+
"""
|
| 42 |
+
Forward pass. Introspects the inner model's signature so only
|
| 43 |
+
supported kwargs are passed — works with all four architectures.
|
| 44 |
+
"""
|
| 45 |
+
sig = inspect.signature(self.model.forward).parameters
|
| 46 |
+
kwargs = {}
|
| 47 |
+
if 'x0_hint' in sig:
|
| 48 |
+
kwargs['x0_hint'] = x0_hint
|
| 49 |
+
if 'inference_mode' in sig:
|
| 50 |
+
kwargs['inference_mode'] = inference_mode
|
| 51 |
+
|
| 52 |
+
if 't' in sig:
|
| 53 |
+
return self.model(input_ids, target_ids, t, **kwargs)
|
| 54 |
+
else:
|
| 55 |
+
return self.model(input_ids, target_ids, **kwargs)
|
| 56 |
+
|
| 57 |
+
@torch.no_grad()
|
| 58 |
+
def generate(self, src, **kwargs):
|
| 59 |
+
sig = inspect.signature(self.model.generate).parameters
|
| 60 |
+
filtered = {k: v for k, v in kwargs.items() if k in sig}
|
| 61 |
+
return self.model.generate(src, **filtered)
|
model/tokenizer.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
tokenizer.py — Dual Tokenizer Fix
|
| 3 |
+
====================================
|
| 4 |
+
Two separate BPE tokenizers:
|
| 5 |
+
|
| 6 |
+
SanskritSourceTokenizer — trained on quote_text (Roman/IAST script)
|
| 7 |
+
SanskritTargetTokenizer — trained on quote_devanagari (Devanagari script)
|
| 8 |
+
|
| 9 |
+
WHY SEPARATE?
|
| 10 |
+
Roman Sanskrit and Devanagari are fundamentally different character sets.
|
| 11 |
+
Roman uses a-z + diacritics (~60 unique chars), Devanagari uses ā-ह + matras
|
| 12 |
+
(~100+ unique chars). A shared BPE tokenizer wastes half its vocab on
|
| 13 |
+
character combos that never cross scripts, and forces the embedding table
|
| 14 |
+
to encode both scripts in one space — confusing the model's cross-attention.
|
| 15 |
+
|
| 16 |
+
With separate tokenizers:
|
| 17 |
+
- src vocab captures Roman subwords cleanly (ā, ś, ṭ, ṃ etc.)
|
| 18 |
+
- tgt vocab captures Devanagari akshara clusters cleanly (क्ष, त्र, etc.)
|
| 19 |
+
- The model learns a true cross-script mapping in its cross-attention
|
| 20 |
+
|
| 21 |
+
SPECIAL TOKENS (same IDs in both):
|
| 22 |
+
[MASK] = 0 ← required by absorbing diffusion
|
| 23 |
+
[PAD] = 1
|
| 24 |
+
[UNK] = 2
|
| 25 |
+
[CLS] = 3
|
| 26 |
+
[SEP] = 4
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
from tokenizers import Tokenizer
|
| 30 |
+
from tokenizers.models import BPE
|
| 31 |
+
from tokenizers.trainers import BpeTrainer
|
| 32 |
+
from tokenizers.pre_tokenizers import Whitespace
|
| 33 |
+
from datasets import load_dataset
|
| 34 |
+
from pathlib import Path
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
SPECIAL_TOKENS = ["[MASK]", "[PAD]", "[UNK]", "[CLS]", "[SEP]"]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _build_bpe(texts, vocab_size):
|
| 41 |
+
"""Build a BPE tokenizer from an iterator of strings."""
|
| 42 |
+
tok = Tokenizer(BPE(unk_token="[UNK]"))
|
| 43 |
+
tok.pre_tokenizer = Whitespace()
|
| 44 |
+
trainer = BpeTrainer(
|
| 45 |
+
vocab_size=vocab_size,
|
| 46 |
+
special_tokens=SPECIAL_TOKENS, # [MASK] MUST be first → id=0
|
| 47 |
+
min_frequency=2,
|
| 48 |
+
)
|
| 49 |
+
tok.train_from_iterator(texts, trainer)
|
| 50 |
+
return tok
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _validate(tok, name):
|
| 54 |
+
mask_id = tok.token_to_id("[MASK]")
|
| 55 |
+
pad_id = tok.token_to_id("[PAD]")
|
| 56 |
+
assert mask_id == 0, f"{name}: [MASK] must be id=0, got {mask_id}"
|
| 57 |
+
assert pad_id == 1, f"{name}: [PAD] must be id=1, got {pad_id}"
|
| 58 |
+
print(f"✅ {name}: [MASK]=0, [PAD]=1 confirmed. Vocab size={tok.get_vocab_size()}")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# ── Source tokenizer (Roman/IAST Sanskrit) ────────────────────────────
|
| 62 |
+
|
| 63 |
+
class SanskritSourceTokenizer:
|
| 64 |
+
"""
|
| 65 |
+
Tokenizer for quote_text — Roman transliteration of Sanskrit.
|
| 66 |
+
Examples: "dharmo rakṣati rakṣitaḥ", "yatra nāryastu pūjyante"
|
| 67 |
+
"""
|
| 68 |
+
MODEL_PATH = "sanskrit_src_tokenizer.json"
|
| 69 |
+
|
| 70 |
+
def __init__(self, vocab_size=8000, max_len=80, n_train_samples=50000):
|
| 71 |
+
self.vocab_size = vocab_size
|
| 72 |
+
self.max_len = max_len
|
| 73 |
+
self.mask_token_id = 0
|
| 74 |
+
|
| 75 |
+
if Path(self.MODEL_PATH).exists():
|
| 76 |
+
print(f"📖 Loading source tokenizer from {self.MODEL_PATH} …")
|
| 77 |
+
self.tokenizer = Tokenizer.from_file(self.MODEL_PATH)
|
| 78 |
+
else:
|
| 79 |
+
print("🎓 Training source tokenizer on quote_text …")
|
| 80 |
+
self._train(vocab_size, n_train_samples)
|
| 81 |
+
|
| 82 |
+
_validate(self.tokenizer, "SrcTokenizer")
|
| 83 |
+
|
| 84 |
+
def _train(self, vocab_size, n_samples):
|
| 85 |
+
dataset = load_dataset("paws/sanskrit-verses-gretil", split="train")
|
| 86 |
+
n = min(n_samples, len(dataset))
|
| 87 |
+
texts = [s["quote_text"] for s in dataset.select(range(n))
|
| 88 |
+
if s["quote_text"].strip()]
|
| 89 |
+
self.tokenizer = _build_bpe(texts, vocab_size)
|
| 90 |
+
self.tokenizer.save(self.MODEL_PATH)
|
| 91 |
+
print(f"✅ Source tokenizer trained on {len(texts)} Roman texts.")
|
| 92 |
+
|
| 93 |
+
def encode(self, text):
|
| 94 |
+
ids = self.tokenizer.encode(text).ids[:self.max_len]
|
| 95 |
+
pad = self.tokenizer.token_to_id("[PAD]")
|
| 96 |
+
ids += [pad] * max(0, self.max_len - len(ids))
|
| 97 |
+
return ids[:self.max_len]
|
| 98 |
+
|
| 99 |
+
def decode(self, ids):
|
| 100 |
+
clean = [i for i in ids if i > 4] # skip special tokens
|
| 101 |
+
return self.tokenizer.decode(clean)
|
| 102 |
+
|
| 103 |
+
def __len__(self):
|
| 104 |
+
return self.vocab_size
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# ── Target tokenizer (Devanagari Sanskrit) ───────────────────────────
|
| 108 |
+
|
| 109 |
+
class SanskritTargetTokenizer:
|
| 110 |
+
"""
|
| 111 |
+
Tokenizer for quote_devanagari — Devanagari script.
|
| 112 |
+
Examples: "धर्मो रक्षति रक्षितः", "यत्र नार्यस्तु पूज्यन्ते"
|
| 113 |
+
"""
|
| 114 |
+
MODEL_PATH = "sanskrit_tgt_tokenizer.json"
|
| 115 |
+
|
| 116 |
+
def __init__(self, vocab_size=8000, max_len=80, n_train_samples=50000):
|
| 117 |
+
self.vocab_size = vocab_size
|
| 118 |
+
self.max_len = max_len
|
| 119 |
+
self.mask_token_id = 0
|
| 120 |
+
|
| 121 |
+
if Path(self.MODEL_PATH).exists():
|
| 122 |
+
print(f"📖 Loading target tokenizer from {self.MODEL_PATH} …")
|
| 123 |
+
self.tokenizer = Tokenizer.from_file(self.MODEL_PATH)
|
| 124 |
+
else:
|
| 125 |
+
print("🎓 Training target tokenizer on quote_devanagari …")
|
| 126 |
+
self._train(vocab_size, n_train_samples)
|
| 127 |
+
|
| 128 |
+
_validate(self.tokenizer, "TgtTokenizer")
|
| 129 |
+
|
| 130 |
+
def _train(self, vocab_size, n_samples):
|
| 131 |
+
dataset = load_dataset("paws/sanskrit-verses-gretil", split="train")
|
| 132 |
+
n = min(n_samples, len(dataset))
|
| 133 |
+
texts = [s["quote_devanagari"] for s in dataset.select(range(n))
|
| 134 |
+
if s["quote_devanagari"].strip()]
|
| 135 |
+
self.tokenizer = _build_bpe(texts, vocab_size)
|
| 136 |
+
self.tokenizer.save(self.MODEL_PATH)
|
| 137 |
+
print(f"✅ Target tokenizer trained on {len(texts)} Devanagari texts.")
|
| 138 |
+
|
| 139 |
+
def encode(self, text):
|
| 140 |
+
ids = self.tokenizer.encode(text).ids[:self.max_len]
|
| 141 |
+
pad = self.tokenizer.token_to_id("[PAD]")
|
| 142 |
+
ids += [pad] * max(0, self.max_len - len(ids))
|
| 143 |
+
return ids[:self.max_len]
|
| 144 |
+
|
| 145 |
+
def decode(self, ids):
|
| 146 |
+
clean = [i for i in ids if i > 4]
|
| 147 |
+
return self.tokenizer.decode(clean)
|
| 148 |
+
|
| 149 |
+
# Methods required by BERTScore
|
| 150 |
+
def build_inputs_with_special_tokens(self, token_ids):
|
| 151 |
+
return list(token_ids)
|
| 152 |
+
|
| 153 |
+
def get_vocab(self):
|
| 154 |
+
return {str(i): i for i in range(self.vocab_size)}
|
| 155 |
+
|
| 156 |
+
def convert_ids_to_tokens(self, ids):
|
| 157 |
+
return [str(i) for i in ids]
|
| 158 |
+
|
| 159 |
+
def __len__(self):
|
| 160 |
+
return self.vocab_size
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# ── Legacy shared tokenizer (kept for backward compat) ───────────────
|
| 164 |
+
|
| 165 |
+
class SanskritTokenizer:
|
| 166 |
+
"""
|
| 167 |
+
LEGACY: single shared tokenizer trained on BOTH scripts.
|
| 168 |
+
Still works but suboptimal — use SanskritSourceTokenizer +
|
| 169 |
+
SanskritTargetTokenizer for the quote_text → quote_devanagari task.
|
| 170 |
+
"""
|
| 171 |
+
MODEL_PATH = "sanskrit_tokenizer_m4pro.json"
|
| 172 |
+
|
| 173 |
+
def __init__(self, vocab_size=16000, max_len=80):
|
| 174 |
+
self.vocab_size = vocab_size
|
| 175 |
+
self.max_len = max_len
|
| 176 |
+
self.mask_token_id = 0
|
| 177 |
+
|
| 178 |
+
if Path(self.MODEL_PATH).exists():
|
| 179 |
+
print("📖 Loading shared tokenizer …")
|
| 180 |
+
self.tokenizer = Tokenizer.from_file(self.MODEL_PATH)
|
| 181 |
+
else:
|
| 182 |
+
print("🎓 Training shared tokenizer on both scripts …")
|
| 183 |
+
self._train(vocab_size)
|
| 184 |
+
|
| 185 |
+
_validate(self.tokenizer, "SharedTokenizer")
|
| 186 |
+
|
| 187 |
+
def _train(self, vocab_size):
|
| 188 |
+
dataset = load_dataset("paws/sanskrit-verses-gretil", split="train")
|
| 189 |
+
n = min(50000, len(dataset))
|
| 190 |
+
texts = []
|
| 191 |
+
for s in dataset.select(range(n)):
|
| 192 |
+
if s["quote_text"].strip():
|
| 193 |
+
texts.append(s["quote_text"])
|
| 194 |
+
if s["quote_devanagari"].strip():
|
| 195 |
+
texts.append(s["quote_devanagari"])
|
| 196 |
+
self.tokenizer = _build_bpe(texts, vocab_size)
|
| 197 |
+
self.tokenizer.save(self.MODEL_PATH)
|
| 198 |
+
print(f"✅ Shared tokenizer trained ({len(texts)} texts).")
|
| 199 |
+
|
| 200 |
+
def encode(self, text):
|
| 201 |
+
ids = self.tokenizer.encode(text).ids[:self.max_len]
|
| 202 |
+
pad = self.tokenizer.token_to_id("[PAD]")
|
| 203 |
+
ids += [pad] * max(0, self.max_len - len(ids))
|
| 204 |
+
return ids[:self.max_len]
|
| 205 |
+
|
| 206 |
+
def decode(self, ids):
|
| 207 |
+
if ids and isinstance(ids[0], list):
|
| 208 |
+
raise TypeError("decode() got 2D list — pass a 1D list.")
|
| 209 |
+
clean = [i for i in ids if i > 4]
|
| 210 |
+
return self.tokenizer.decode(clean)
|
| 211 |
+
|
| 212 |
+
def build_inputs_with_special_tokens(self, token_ids):
|
| 213 |
+
return list(token_ids)
|
| 214 |
+
|
| 215 |
+
def get_vocab(self):
|
| 216 |
+
return {str(i): i for i in range(self.vocab_size)}
|
| 217 |
+
|
| 218 |
+
def convert_ids_to_tokens(self, ids):
|
| 219 |
+
return [str(i) for i in ids]
|
| 220 |
+
|
| 221 |
+
def __len__(self):
|
| 222 |
+
return self.vocab_size
|
model/tokenizers.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
tokenizer.py — FINAL
|
| 3 |
+
=====================
|
| 4 |
+
Uses the original sanskrit_tokenizer_m4pro.json — the exact one the model
|
| 5 |
+
was trained with. Hard-coded absolute path as primary, with fallbacks.
|
| 6 |
+
|
| 7 |
+
This tokenizer has NO </w> end-of-word markers and NO decoder set.
|
| 8 |
+
decode() returns space-separated BPE pieces — this is the format the
|
| 9 |
+
model was trained and evaluated on (BERTScore 0.71). Do NOT add a decoder
|
| 10 |
+
or retrain: that would break alignment with the checkpoint.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from tokenizers import Tokenizer
|
| 14 |
+
from tokenizers.models import BPE
|
| 15 |
+
from tokenizers.trainers import BpeTrainer
|
| 16 |
+
from tokenizers.pre_tokenizers import Whitespace
|
| 17 |
+
from datasets import load_dataset
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
import os
|
| 20 |
+
|
| 21 |
+
# Hard-coded absolute path — update if you move the project
|
| 22 |
+
TOKENIZER_PATH = "/Users/bhsingh/Documents/Final_Paraphrase/sanskrit_tokenizer_m4pro.json"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def build_tokenizer(texts, vocab_size=16000):
|
| 26 |
+
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
|
| 27 |
+
tokenizer.pre_tokenizer = Whitespace()
|
| 28 |
+
trainer = BpeTrainer(
|
| 29 |
+
vocab_size=vocab_size,
|
| 30 |
+
special_tokens=["[MASK]", "[PAD]", "[UNK]", "[CLS]", "[SEP]"],
|
| 31 |
+
min_frequency=2,
|
| 32 |
+
)
|
| 33 |
+
tokenizer.train_from_iterator(texts, trainer)
|
| 34 |
+
return tokenizer
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class SanskritTokenizer:
|
| 38 |
+
def __init__(self, vocab_size=16000, max_len=80):
|
| 39 |
+
self.vocab_size = vocab_size
|
| 40 |
+
self.max_len = max_len
|
| 41 |
+
self.mask_token_id = 0
|
| 42 |
+
|
| 43 |
+
script_dir = Path(__file__).resolve().parent
|
| 44 |
+
candidates = [
|
| 45 |
+
os.environ.get("SANSKRIT_TOKENIZER_PATH", ""),
|
| 46 |
+
TOKENIZER_PATH,
|
| 47 |
+
str(script_dir.parent / "sanskrit_tokenizer_m4pro.json"),
|
| 48 |
+
str(script_dir / "sanskrit_tokenizer_m4pro.json"),
|
| 49 |
+
str(Path.cwd() / "sanskrit_tokenizer_m4pro.json"),
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
self.model_path = None
|
| 53 |
+
for c in candidates:
|
| 54 |
+
if c and Path(c).exists():
|
| 55 |
+
self.model_path = c
|
| 56 |
+
break
|
| 57 |
+
|
| 58 |
+
if self.model_path:
|
| 59 |
+
print(f"📖 Loading tokenizer from: {self.model_path}")
|
| 60 |
+
self.tokenizer = Tokenizer.from_file(self.model_path)
|
| 61 |
+
self._validate_mask_token()
|
| 62 |
+
else:
|
| 63 |
+
print(f"⚠️ Tokenizer not found at any candidate path.")
|
| 64 |
+
print(f" Expected: {TOKENIZER_PATH}")
|
| 65 |
+
print(" Retraining — WARNING: output will not match existing checkpoint!")
|
| 66 |
+
self.model_path = TOKENIZER_PATH
|
| 67 |
+
self._train_tokenizer()
|
| 68 |
+
|
| 69 |
+
def _validate_mask_token(self):
|
| 70 |
+
mask_id = self.tokenizer.token_to_id("[MASK]")
|
| 71 |
+
assert mask_id == 0, f"[MASK] must be ID 0, got {mask_id}"
|
| 72 |
+
print("✅ [MASK] token confirmed at ID=0")
|
| 73 |
+
|
| 74 |
+
def _train_tokenizer(self):
|
| 75 |
+
dataset = load_dataset("paws/sanskrit-verses-gretil", split="train")
|
| 76 |
+
texts = []
|
| 77 |
+
for sample in dataset.select(range(50000)):
|
| 78 |
+
texts.extend([sample["quote_text"], sample["quote_devanagari"]])
|
| 79 |
+
tokenizer = build_tokenizer(texts, self.vocab_size)
|
| 80 |
+
tokenizer.save(self.model_path)
|
| 81 |
+
self.tokenizer = tokenizer
|
| 82 |
+
self._validate_mask_token()
|
| 83 |
+
print(f"✅ Tokenizer saved to: {self.model_path}")
|
| 84 |
+
|
| 85 |
+
def encode(self, text):
|
| 86 |
+
encoded = self.tokenizer.encode(text)
|
| 87 |
+
token_ids = encoded.ids[:self.max_len]
|
| 88 |
+
pad_id = self.tokenizer.token_to_id("[PAD]")
|
| 89 |
+
if len(token_ids) < self.max_len:
|
| 90 |
+
token_ids += [pad_id] * (self.max_len - len(token_ids))
|
| 91 |
+
return token_ids[:self.max_len]
|
| 92 |
+
|
| 93 |
+
def decode(self, ids):
|
| 94 |
+
if isinstance(ids, list) and len(ids) > 0 and isinstance(ids[0], list):
|
| 95 |
+
raise TypeError("decode() expects 1D list of IDs, not 2D.")
|
| 96 |
+
# Filter special tokens: 0=MASK 1=PAD 2=UNK 3=CLS 4=SEP
|
| 97 |
+
clean = [i for i in ids if isinstance(i, int) and i > 4]
|
| 98 |
+
if not clean:
|
| 99 |
+
return ""
|
| 100 |
+
return self.tokenizer.decode(clean, skip_special_tokens=True).strip()
|
| 101 |
+
|
| 102 |
+
def build_inputs_with_special_tokens(self, token_ids):
|
| 103 |
+
return list(token_ids)
|
| 104 |
+
|
| 105 |
+
def get_vocab(self):
|
| 106 |
+
return {str(i): i for i in range(self.vocab_size)}
|
| 107 |
+
|
| 108 |
+
def convert_ids_to_tokens(self, ids):
|
| 109 |
+
return [str(i) for i in ids]
|
| 110 |
+
|
| 111 |
+
def __len__(self):
|
| 112 |
+
return self.vocab_size
|
model_settings.json
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "d3pm_encoder_decoder",
|
| 3 |
+
"include_negative_examples": false,
|
| 4 |
+
"num_steps": 4
|
| 5 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.2
|
| 2 |
+
numpy>=1.24
|
| 3 |
+
tqdm>=4.66
|
| 4 |
+
datasets>=2.19
|
| 5 |
+
tokenizers>=0.15
|
| 6 |
+
scikit-learn>=1.3
|
sanskrit_src_tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sanskrit_tgt_tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|