Upload folder using huggingface_hub
Browse files- .gitattributes +2 -34
- .gitignore +2 -0
- README.md +122 -0
- analysis_reports/outputs_all_models_20260325/T16/task1_kv_cache.txt +15 -0
- analysis_reports/outputs_all_models_20260325/T16/task2_report.txt +35 -0
- analysis_reports/outputs_all_models_20260325/T16/task3_report.txt +21 -0
- analysis_reports/outputs_all_models_20260325/T16/task4_report.txt +14 -0
- analysis_reports/outputs_all_models_20260325/T16/task5_report.txt +15 -0
- analysis_reports/outputs_all_models_20260325/T32/task1_kv_cache.txt +15 -0
- analysis_reports/outputs_all_models_20260325/T32/task2_report.txt +35 -0
- analysis_reports/outputs_all_models_20260325/T32/task3_report.txt +21 -0
- analysis_reports/outputs_all_models_20260325/T32/task4_report.txt +14 -0
- analysis_reports/outputs_all_models_20260325/T32/task5_report.txt +15 -0
- analysis_reports/outputs_all_models_20260325/T4/task1_kv_cache.txt +15 -0
- analysis_reports/outputs_all_models_20260325/T4/task2_report.txt +29 -0
- analysis_reports/outputs_all_models_20260325/T4/task3_report.txt +21 -0
- analysis_reports/outputs_all_models_20260325/T4/task4_report.txt +14 -0
- analysis_reports/outputs_all_models_20260325/T4/task5_report.txt +15 -0
- analysis_reports/outputs_all_models_20260325/T64/task1_kv_cache.txt +15 -0
- analysis_reports/outputs_all_models_20260325/T64/task2_report.txt +35 -0
- analysis_reports/outputs_all_models_20260325/T64/task3_report.txt +21 -0
- analysis_reports/outputs_all_models_20260325/T64/task4_report.txt +14 -0
- analysis_reports/outputs_all_models_20260325/T64/task5_report.txt +15 -0
- analysis_reports/outputs_all_models_20260325/T8/task1_kv_cache.txt +15 -0
- analysis_reports/outputs_all_models_20260325/T8/task2_report.txt +33 -0
- analysis_reports/outputs_all_models_20260325/T8/task3_report.txt +21 -0
- analysis_reports/outputs_all_models_20260325/T8/task4_report.txt +14 -0
- analysis_reports/outputs_all_models_20260325/T8/task5_report.txt +15 -0
- config.py +33 -0
- diffusion/__init__.py +0 -0
- diffusion/forward_process.py +21 -0
- diffusion/reverse_process.py +302 -0
- diffusion/reverse_process1.py +154 -0
- diffusion/reverse_process2.py +275 -0
- diffusion/scheduler.py +34 -0
- handler.py +30 -0
- inference.py +554 -0
- inference_api.py +131 -0
- model/__init__.py +0 -0
- model/d3pm_model_cross_attention.py +271 -0
- model/d3pm_model_encoder_decoder.py +227 -0
- model/sanskrit_model.py +61 -0
- model/tokenizer.py +222 -0
- model/tokenizers.py +112 -0
- model_settings.json +5 -0
- requirements.txt +6 -0
- sanskrit_src_tokenizer.json +0 -0
- sanskrit_tgt_tokenizer.json +0 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,3 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.
|
| 24 |
-
*.
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
README.md
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
language:
|
| 4 |
+
- sa
|
| 5 |
+
- en
|
| 6 |
+
tags:
|
| 7 |
+
- sanskrit
|
| 8 |
+
- paraphrase
|
| 9 |
+
- diffusion
|
| 10 |
+
- d3pm
|
| 11 |
+
- pytorch
|
| 12 |
+
pipeline_tag: text2text-generation
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
# Sanskrit D3PM Encoder-Decoder Model
|
| 16 |
+
|
| 17 |
+
Roman/IAST Sanskrit input to Devanagari output using a custom D3PM checkpoint.
|
| 18 |
+
This package is configured for the `d3pm_encoder_decoder` checkpoint stored in
|
| 19 |
+
`best_model.pt`.
|
| 20 |
+
Hugging Face model repo: `bhsinghgrid/devflow2`
|
| 21 |
+
|
| 22 |
+
## Files Included
|
| 23 |
+
|
| 24 |
+
- `best_model.pt` โ trained checkpoint
|
| 25 |
+
- `model_settings.json` โ packaged runtime metadata
|
| 26 |
+
- `config.py` โ runtime config
|
| 27 |
+
- `inference.py` โ model loading + generation loop
|
| 28 |
+
- `inference_api.py` โ simple Python API (`predict`)
|
| 29 |
+
- `handler.py` โ Hugging Face Endpoint handler
|
| 30 |
+
- `model/`, `diffusion/` โ architecture modules
|
| 31 |
+
- `sanskrit_src_tokenizer.json`, `sanskrit_tgt_tokenizer.json` โ tokenizers
|
| 32 |
+
|
| 33 |
+
## Quick Local Test
|
| 34 |
+
|
| 35 |
+
```python
|
| 36 |
+
from inference_api import predict
|
| 37 |
+
print(predict("dharmo rakแนฃati rakแนฃitaแธฅ")["output"])
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
## Runtime Settings
|
| 41 |
+
|
| 42 |
+
For local/API usage, the runtime first reads `model_settings.json`, then allows
|
| 43 |
+
optional environment variable overrides:
|
| 44 |
+
|
| 45 |
+
- `HF_MODEL_TYPE` = `d3pm_cross_attention` or `d3pm_encoder_decoder`
|
| 46 |
+
- `HF_INCLUDE_NEG` = `true` or `false`
|
| 47 |
+
- `HF_NUM_STEPS` = diffusion step count for the packaged checkpoint
|
| 48 |
+
|
| 49 |
+
Packaged settings for this repo:
|
| 50 |
+
|
| 51 |
+
```bash
|
| 52 |
+
export HF_MODEL_TYPE=d3pm_encoder_decoder
|
| 53 |
+
export HF_INCLUDE_NEG=false
|
| 54 |
+
export HF_NUM_STEPS=4
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
## Use This Model In A Hugging Face Space
|
| 58 |
+
|
| 59 |
+
In your Space settings, set:
|
| 60 |
+
|
| 61 |
+
- `HF_CHECKPOINT_REPO=bhsinghgrid/devflow2`
|
| 62 |
+
- `HF_CHECKPOINT_FILE=best_model.pt`
|
| 63 |
+
|
| 64 |
+
If your Space reads model metadata automatically, no extra model-type variables
|
| 65 |
+
are required. If it does not, also set:
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
HF_DEFAULT_MODEL_TYPE=d3pm_encoder_decoder
|
| 69 |
+
HF_DEFAULT_INCLUDE_NEG=false
|
| 70 |
+
HF_DEFAULT_NUM_STEPS=4
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
## Transformer-Style Usage (Custom Runtime)
|
| 74 |
+
|
| 75 |
+
This checkpoint is a custom D3PM architecture (`.pt`), not a native `transformers`
|
| 76 |
+
`AutoModel` format. Use it via the provided runtime:
|
| 77 |
+
|
| 78 |
+
```python
|
| 79 |
+
import torch
|
| 80 |
+
from config import CONFIG
|
| 81 |
+
from inference import load_model, run_inference, _decode_clean
|
| 82 |
+
from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer
|
| 83 |
+
|
| 84 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 85 |
+
model, cfg = load_model("best_model.pt", CONFIG, device)
|
| 86 |
+
|
| 87 |
+
src_tok = SanskritSourceTokenizer(vocab_size=16000, max_len=cfg["model"]["max_seq_len"])
|
| 88 |
+
tgt_tok = SanskritTargetTokenizer(vocab_size=16000, max_len=cfg["model"]["max_seq_len"])
|
| 89 |
+
|
| 90 |
+
text = "dharmo rakแนฃati rakแนฃitaแธฅ"
|
| 91 |
+
ids = torch.tensor([src_tok.encode(text)], dtype=torch.long, device=device)
|
| 92 |
+
out = run_inference(model, ids, cfg)
|
| 93 |
+
print(_decode_clean(tgt_tok, out[0].tolist()))
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
If you need full `transformers` compatibility (`AutoModel.from_pretrained`),
|
| 97 |
+
export weights to a Hugging Face Transformers model format first.
|
| 98 |
+
|
| 99 |
+
## Endpoint Payload
|
| 100 |
+
|
| 101 |
+
```json
|
| 102 |
+
{
|
| 103 |
+
"inputs": "yadฤ mano nivarteta viแนฃayebhyaแธฅ svabhฤvataแธฅ",
|
| 104 |
+
"parameters": {
|
| 105 |
+
"temperature": 0.7,
|
| 106 |
+
"top_k": 40,
|
| 107 |
+
"repetition_penalty": 1.2,
|
| 108 |
+
"diversity_penalty": 0.0,
|
| 109 |
+
"num_steps": 4,
|
| 110 |
+
"clean_output": true
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
## Push This Folder To Model Hub
|
| 116 |
+
|
| 117 |
+
```bash
|
| 118 |
+
cd hf_model_repo_encoder_decoder
|
| 119 |
+
git add .
|
| 120 |
+
git commit -m "Add encoder-decoder T4 model package"
|
| 121 |
+
git push -u hf main
|
| 122 |
+
```
|
analysis_reports/outputs_all_models_20260325/T16/task1_kv_cache.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 1 โ KV CACHE BENCHMARK
|
| 2 |
+
========================================
|
| 3 |
+
|
| 4 |
+
has_generate_cached=True
|
| 5 |
+
memory_profile=Torch CPU mem-event reduction: 30.4% @ src_len=64 (std=2143.0MB, cache=1492.1MB)
|
| 6 |
+
|
| 7 |
+
src_len standard(s) cached(s) speedup encoder%
|
| 8 |
+
16 0.893 0.571 1.56x 40.0%
|
| 9 |
+
32 0.751 0.509 1.48x 42.3%
|
| 10 |
+
64 1.141 0.822 1.39x 40.7%
|
| 11 |
+
|
| 12 |
+
Saved graphs:
|
| 13 |
+
- task1_time_comparison.png
|
| 14 |
+
- task1_speedup.png
|
| 15 |
+
- task1_encoder_cost.png
|
analysis_reports/outputs_all_models_20260325/T16/task2_report.txt
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 2 โ ATTENTION + DRIFT REPORT
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Input : dharmo rakแนฃati rakแนฃitaแธฅ
|
| 5 |
+
Output: เคงเคฐเฅเคฎเฅ เคฐเคเฅเคทเคคเคฟ เคฐเคเฅเคทเคฟเคคเค
|
| 6 |
+
|
| 7 |
+
Captured steps: 16
|
| 8 |
+
Analysis quality: WEAK
|
| 9 |
+
Final output uniq-ratio: 1.000
|
| 10 |
+
Degenerate output: NO
|
| 11 |
+
Multi-sample semantic score (n<=8): 0.1471
|
| 12 |
+
Lock-in step (CER<=0.05): t=0
|
| 13 |
+
Locked tokens: 38 Flexible tokens: 42
|
| 14 |
+
TF-IDF vs attention stability corr: 0.9294
|
| 15 |
+
TF-IDF status: OK
|
| 16 |
+
|
| 17 |
+
Saved graphs:
|
| 18 |
+
- task2_attn_t*.png / task2_all_layers_t0.png
|
| 19 |
+
- task2_attn_evolution.png
|
| 20 |
+
- task2_semantic_drift.png
|
| 21 |
+
- task2_source_alignment.png
|
| 22 |
+
- task2_tfidf_vs_attention.png
|
| 23 |
+
|
| 24 |
+
Step trajectory (first 10 rows)
|
| 25 |
+
------------------------------------------------------------
|
| 26 |
+
t= 15 bert=0.0475 drift=0.9525 text=เคงเคฐเฅเคฎเฅ เคคเคฟ เคฐเคเฅเคท เคฐเคเฅเคทเคฟ เคคเค เคคเค เคคเค เคคเค เคฟเคคเค เคคเค เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅ
|
| 27 |
+
t= 14 bert=0.0478 drift=0.9522 text=เคงเคฐเฅเคฎเฅ เคคเคฟ เคฐเคเฅเคท เคฐเคเฅเคทเคฟ เคคเค เคคเค เคคเค เคคเค เคฟเคคเค เคคเค เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅ
|
| 28 |
+
t= 13 bert=0.0478 drift=0.9522 text=เคงเคฐเฅเคฎเฅ เคคเคฟ เคฐเคเฅเคท เคฐเคเฅเคทเคฟ เคคเค เคคเค เคคเค เคคเค เคฟเคคเค เคคเค เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅ
|
| 29 |
+
t= 12 bert=0.0478 drift=0.9522 text=เคงเคฐเฅเคฎเฅ เคคเคฟ เคฐเคเฅเคท เคฐเคเฅเคทเคฟ เคคเค เคคเค เคคเค เคคเค เคฟเคคเค เคคเค เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅ
|
| 30 |
+
t= 11 bert=0.0478 drift=0.9522 text=เคงเคฐเฅเคฎเฅ เคคเคฟ เคฐเคเฅเคท เคฐเคเฅเคทเคฟ เคคเค เคคเค เคคเค เคคเค เคฟเคคเค เคคเค เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅ
|
| 31 |
+
t= 10 bert=0.0478 drift=0.9522 text=เคงเคฐเฅเคฎเฅ เคคเคฟ เคฐเคเฅเคท เคฐเคเฅเคทเคฟ เคคเค เคคเค เคคเค เคคเค เคฟเคคเค เคคเค เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅ
|
| 32 |
+
t= 9 bert=0.0478 drift=0.9522 text=เคงเคฐเฅเคฎเฅ เคคเคฟ เคฐเคเฅเคท เคฐเคเฅเคทเคฟ เคคเค เคคเค เคคเค เคคเค เคฟเคคเค เคคเค เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅ
|
| 33 |
+
t= 8 bert=0.0478 drift=0.9522 text=เคงเคฐเฅเคฎเฅ เคคเคฟ เคฐเคเฅเคท เคฐเคเฅเคทเคฟ เคคเค เคคเค เคคเค เคคเค เคฟเคคเค เคคเค เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅ
|
| 34 |
+
t= 7 bert=0.0478 drift=0.9522 text=เคงเคฐเฅเคฎเฅ เคคเคฟ เคฐเคเฅเคท เคฐเคเฅเคทเคฟ เคคเค เคคเค เคคเค เคคเค เคฟเคคเค เคคเค เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅ
|
| 35 |
+
t= 6 bert=0.0478 drift=0.9522 text=เคงเคฐเฅเคฎเฅ เคคเคฟ เคฐเคเฅเคท เคฐเคเฅเคทเคฟ เคคเค เคคเค เคคเค เคคเค เคฟเคคเค เคคเค เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅ
|
analysis_reports/outputs_all_models_20260325/T16/task3_report.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 3 โ CONCEPT VECTORS + PCA STEERING
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
PCA: 50 components, 74.8% variance
|
| 5 |
+
Diversity PC: 0 (|r|=0.325 with diversity proxy)
|
| 6 |
+
|
| 7 |
+
Direction validity: WEAK
|
| 8 |
+
Spectrum unique ratio (mean over 5 seeds): 1.000
|
| 9 |
+
Spectrum semantic stability (mean over 5 seeds): 0.312
|
| 10 |
+
|
| 11 |
+
Saved graphs:
|
| 12 |
+
- task3_concept_space.png
|
| 13 |
+
- task3_pca_explained_variance.png
|
| 14 |
+
- task3_diversity_curve.png
|
| 15 |
+
|
| 16 |
+
Diversity spectrum:
|
| 17 |
+
alpha=-2.0 โ เคฌเคฒเฅ เคตเฅเคง เคตเคฟเคตเคฐเฅ เคงเคพเคจ เคตเฅเคฐเฅเคฏ เคตเฅเคฐเฅเคฏ เคงเคฟเค เคธเคฟเคเคนเคพ เคญเคฟฬฑ เคธเคจ เคตเคธเฅเคคเฅ เคตเฅเคง เคตเฅ เคตเฅเคง เคตเคธเฅเคคเฅ เคธเคจ เคธเคจ เคธเคฟเคเคนเคพ เคธเคฟเคเคนเคพ เคตเฅเคฐเฅเคฏ เคตเฅเคฐเฅเคฏ เคตเคธเฅเคคเฅ เคธเคจ เคฐเฅเคคเฅ เคชเฅเคฐเคญเคตเคคเคฟ เคฎเคจ เคตเฅเคง เคฌเคฒเฅ เคฌเคฒเฅ เคฐเฅเคตเฅ เคชเฅเคฐเคชเฅเคเคฏเฅเคคเฅ เคฏเฅเคเคพ เคฎเคฒเคฟ เคงเคพเคจ เคคเฅเคฒ เคตเฅเคฐเฅเคฏ เคตเฅเคฐเฅเคฏ เคตเฅเคฐเฅเคฏ เคตเฅเคฐเฅเคฏ เคตเฅเคฐเฅเคฏ เคตเฅเคฐเฅเคฏ เคงเคพเคจ เคคเฅเคฒ เคเคพเคฒเฅเคจ เคฏเฅเคเคพ เคตเฅเคง เคฌเคฒเฅ เคตเฅเคง เคตเฅเคง เคเฅเคเฅ เคทเฅเคฎเคธเฅ เคฏเคธเฅเคฏเคพ เคเคพเคทเฅเค เคพ เคเฅเคเคชเฅเคค เค
เคฐเฅเคฃเคต เคงเคฟเค เคงเคฟเค เคตเคธเฅเคคเฅ เคงเคฟเค เคธเคจ เคคเคฏเคพ เคธเคจ เคธเคจ เคฆเฅเคตเคพเค เคฆเฅเคตเคพเค เคธเฅเคตเคพเคคเคจเฅเคคเฅเคฐ เค
เคฐเฅเคฃเคต เคฎเคน เคตเคธเฅเคคเฅ เคฎเฅเคทเฅ เคธเคจ เคงเคฟเค เคงเคฟเค เคงเคฟเค เคตเคฟเคเฅเคฐ เคคเฅเคฐ เคฎเคน เคนเคธเฅเคคเฅ เคเฅเคเฅ เคฎเคน
|
| 18 |
+
alpha=-1.0 โ เคฌเคฒเฅ เคฐเฅ เค
เคคเฅเคฒ เคตเฅเคฐเฅเคฏ เคตเฅเคฐเฅเคฏ เคเฅเคฐเฅ เคธเคฟเคเคนเคพ เคธเคจ เคธเคจ เคตเคฟเคฒเฅเคช เคตเฅ เคตเฅ เคตเฅ เคเคคเคธเฅเคฏ เคตเฅเคง เคธเคจ เคธเคฟเคเคนเคพ เคธเคฟเคเคนเคพ เคธเฅเคฏ เคธเฅเคฏ เฅค เคธเคจ เคตเฅ เคตเฅ เคตเฅ เคฌเคฒเฅ เคฌเคฒเฅ เคฌเคฒเฅ เคฌเคฒเฅ เคฐเฅ เค
เค
เคคเฅเคฒ เคคเฅเคฒ เคตเฅเคฐเฅเคฏ เคตเฅเคฐเฅเคฏ เคตเฅเคฐเฅเคฏ เคตเฅเคฐเฅเคฏ เคตเฅเคฐเฅเคฏ เคตเฅเคฐเฅเคฏ เคคเฅเคฒ เคคเฅเคฒ เคคเฅเคฒ เฅ เคฌเคฒเฅ เคตเฅเคง เคฆเคฟเคตเฅเคฏเคพเค เคฎเคพเคจ เคตเฅ เค
เคชเฅเคธเฅ เคธเคจ เฅฅ เฅฅ เคตเคธเฅเคคเฅ เคธเคฟเคเคนเคพ เคธเคจ เคธเคจ เคตเคฟเคเฅเคฐ เคธเคจ เคธ เคเคพเคทเฅเค เคพ เคธเคจ เคธเคจ เคธเคจ เคเคพเคฐ เคธเคจ เคธเคจ เคธเคจ เคธเคจ เคญ เคฌเคฒ เฅ เคธเคฟเคเคนเคพ เคธเคจ เคธเคฟเคเคนเคพ เคธเคจ เคฎเฅ เคฎเฅ เคธเคจ
|
| 19 |
+
alpha=+0.0 โ เคฌเคฒเฅ เคฐเฅ เค
เคคเฅเคฒ เคตเฅเคฐเฅเคฏ เคตเฅเคฐเฅเคฏ เคธเฅเคฏ เคธเคฟเคเคนเคพ เคธเคจ เคธเคจ เคชเคฟเคคเฅ เคตเฅ เคตเฅ เคตเฅ เคฆเคเฅเคทเคฟเคฃเคพเค เคธเคจ เคธเคจ เคธเคฟเคเคนเคพ เคธเคฟเคเคนเคพ เคธเฅเคฏ เคธเฅเคฏ เคธเฅเคฏ เคธเคจ เคเคคเคธเฅเคฏ เคตเฅ เคตเฅ เฅฅ เคฌเคฒเฅ เคฌเคฒเฅ เคฐเฅ เคฐเฅ เค
เค
เฅค เคคเฅเคฒ เคตเฅเคฐเฅเคฏ เคตเฅเคฐเฅเคฏ เคตเฅเคฐเฅเคฏ เคตเฅเคฐเฅเคฏ เคตเฅเคฐเฅเคฏ เคคเฅเคฒ เคคเฅเคฒ เคคเฅเคฒ เคคเฅเคฒ เค
เคธ เคฌเคฒเฅ เคฌเคฒเฅ เคตเฅ เคตเฅ เฅฅ เฅฅ เฅฅ เคธเคจ เคธเคจ เคธเคฟเคเคนเคพ เคธ เคธเคจ เคธเคจ เคธเคจ เคธเคจ เคธเคจ เคธเคจ เคธเคจ เคธเคจ เคธเคจ เคธเคจ เฅฅ เฅฅ เคธเคจ เคธเคจ เคถเคคเฅเค เฅฅ เคธเคฟเคเคนเคพ เคธเคฟเคเคนเคพ เคฆ เคธเคฟเคเคนเคพ เคธเคจ เคคเฅ เคธเคจ
|
| 20 |
+
alpha=+1.0 โ เคฌเคฒเฅ เคฐเฅ เค
เค
เคตเคฟเคถเฅเคฆเฅเคงเค เคธเฅเคฏ เคธเฅเคฏ เคธเคฟเคเคนเคพ เคธเคฟเคเคนเคพ เคธเคจ เคเคคเคธเฅเคฏ เคตเฅ เคตเฅ เคตเฅ เคตเฅเคคเฅเคคเคฟ เคธเคจ เคธเคจ เคธเคฟเคเคนเคพ เคธเฅเคฏ เคธเฅเคฏ เคธเฅเคฏ เคธเฅเคฏ เคธเคจ เคตเฅ เคตเฅ เคธ เคฎเคฒ เคฌเคฒเฅ เคฌเคฒเฅ เคฐเฅ เคฐเฅ เคต เค
เค
เคคเฅเคฒ เคตเฅเคฐเฅเคฏ เคตเฅเคฐเฅเคฏ เคตเฅเคฐเฅเคฏ เคธเฅเคฏ เคตเฅเคฐเฅเคฏ เคธเฅเคฏ เคคเฅเคฒ เคพเคจเฅ เค
เค
เฅค เคฐเฅ เคต เฅฅ เคตเฅ เคตเฅ เคธเคจ เคฆ เฅฅ เฅฅ เคธเคฟเคเคนเคพ เคธเคฟเคเคนเคพ เฅฅ เคธเค เคธเคจ เฅฅ เฅฅ เคต เฅฅ เฅฅ เคนเฅเคฎ เคธเคจ เคธเคจ เคต เฅฅ เฅ เฅฅ เคตเฅ เคญ เคจ เคจ เฅฅ เคฎเคฟเคคเฅเคฐเฅ เคธเคฟเคเคนเคพ เคธเคจ
|
| 21 |
+
alpha=+2.0 โ เคเคตเคฟเคถ เคฐเฅ เค
เคเคฟเคเคเคฟเคฆเฅ เคตเคฐ เคธเฅเคฏ เคธเฅเคฏ เคธเคฟเคเคนเคพ เคธเค เคจเคฟเคฎเฅ เคเฅ เคธเค เคตเฅ เคตเฅ เคเฅ เคธเคจ เคเฅเคชเคพ เคธเคฟเคเคนเคพ เคธเฅเคฏ เคธเฅเคฏ เคธเฅเคฏ เคธเฅเคฏ เคซเคฃเคพ เคเฅ เคตเฅ เฅ เคเคฟเคนเฅเคต เคฌเคฒเฅ เคฎเคพเคจเคพเค เคฐเฅ เคฐเฅ เคตเคฐเคพเคฏ เค
เคฎเคพเคจเฅ เคตเคฐ เคตเคฟเคถเฅเคฆเฅเคงเค เคธเฅเคฏ เคธเฅเคฏ เคธเฅเคฏ โ เคตเคฐ เคตเคฟเคถเฅเคฆเฅเคงเค เคต เคตเคฐ เค
เคเฅเคชเคพ เฅฅ เคชเคฐเคฎเฅ เฅฅ เคเคถเฅเคเคฟ เคตเฅ เฅฅ เคเฅ เคเฅ เคธเค เคธเฅเคฏ เคธเฅเคฏ เคคเคฎเฅ เคต เคชเฅเคฐเคตเคฐเฅเคคเคจเฅเคคเฅ เคเคฐเฅเคฎเคธเฅ เคชเคฐเคฎเฅ เคตเคฐ เคคเฅ เฅฅ เคต เคเฅ เฅฅ เฅฅ เคธเค เคฆ เฅฅ เฅฅ เคตเคฐ เคจเฅเคฆ ฬฑเคต เฅฅ เคต เคต เฅ
|
analysis_reports/outputs_all_models_20260325/T16/task4_report.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 4 โ SEMANTIC ROBUSTNESS ABLATION
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Optimal diffusion steps = 16
|
| 5 |
+
|
| 6 |
+
T BERT-F1 SEM_SIM BLEU sec/sample
|
| 7 |
+
--------------------------------------------------------
|
| 8 |
+
16 0.2574 0.0580 0.0007 0.9068
|
| 9 |
+
|
| 10 |
+
Marginal gains (BERT-F1):
|
| 11 |
+
|
| 12 |
+
Saved plots/files:
|
| 13 |
+
- task4_3d.png
|
| 14 |
+
- task4_raw_results.json
|
analysis_reports/outputs_all_models_20260325/T16/task5_report.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 5 โ CLASSIFIER-FREE GUIDANCE
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Classifier params: 139521
|
| 5 |
+
Training samples : 40
|
| 6 |
+
|
| 7 |
+
Guidance scale sweep:
|
| 8 |
+
ฮป CER diversity d2 sBLEU
|
| 9 |
+
----------------------------------------------------
|
| 10 |
+
0.0 0.8336 0.808 0.624 0.007 โ optimal
|
| 11 |
+
0.5 0.8362 0.800 0.606 0.007
|
| 12 |
+
1.0 0.8390 0.798 0.601 0.005
|
| 13 |
+
1.5 0.8458 0.813 0.631 0.004
|
| 14 |
+
2.0 0.8531 0.828 0.660 0.004
|
| 15 |
+
3.0 0.8773 0.830 0.669 0.009
|
analysis_reports/outputs_all_models_20260325/T32/task1_kv_cache.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 1 โ KV CACHE BENCHMARK
|
| 2 |
+
========================================
|
| 3 |
+
|
| 4 |
+
has_generate_cached=True
|
| 5 |
+
memory_profile=Torch CPU mem-event reduction: 31.1% @ src_len=64 (std=4287.2MB, cache=2953.9MB)
|
| 6 |
+
|
| 7 |
+
src_len standard(s) cached(s) speedup encoder%
|
| 8 |
+
16 1.914 1.165 1.64x 39.6%
|
| 9 |
+
32 1.542 0.891 1.73x 42.1%
|
| 10 |
+
64 2.096 1.475 1.42x 42.7%
|
| 11 |
+
|
| 12 |
+
Saved graphs:
|
| 13 |
+
- task1_time_comparison.png
|
| 14 |
+
- task1_speedup.png
|
| 15 |
+
- task1_encoder_cost.png
|
analysis_reports/outputs_all_models_20260325/T32/task2_report.txt
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 2 โ ATTENTION + DRIFT REPORT
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Input : dharmo rakแนฃati rakแนฃitaแธฅ
|
| 5 |
+
Output: เคงเคฐเฅเคฎเฅ เคฐเคเฅเคทเคคเคฟ เคฐเคเฅเคทเคฟเคคเค
|
| 6 |
+
|
| 7 |
+
Captured steps: 32
|
| 8 |
+
Analysis quality: WEAK
|
| 9 |
+
Final output uniq-ratio: 1.000
|
| 10 |
+
Degenerate output: NO
|
| 11 |
+
Multi-sample semantic score (n<=8): 0.0627
|
| 12 |
+
Lock-in step (CER<=0.05): t=0
|
| 13 |
+
Locked tokens: 75 Flexible tokens: 5
|
| 14 |
+
TF-IDF vs attention stability corr: -0.0869
|
| 15 |
+
TF-IDF status: WEAK
|
| 16 |
+
|
| 17 |
+
Saved graphs:
|
| 18 |
+
- task2_attn_t*.png / task2_all_layers_t0.png
|
| 19 |
+
- task2_attn_evolution.png
|
| 20 |
+
- task2_semantic_drift.png
|
| 21 |
+
- task2_source_alignment.png
|
| 22 |
+
- task2_tfidf_vs_attention.png
|
| 23 |
+
|
| 24 |
+
Step trajectory (first 10 rows)
|
| 25 |
+
------------------------------------------------------------
|
| 26 |
+
t= 31 bert=0.0167 drift=0.9833 text=เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ
|
| 27 |
+
t= 30 bert=0.0167 drift=0.9833 text=เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ
|
| 28 |
+
t= 29 bert=0.0167 drift=0.9833 text=เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ
|
| 29 |
+
t= 28 bert=0.0167 drift=0.9833 text=เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ
|
| 30 |
+
t= 27 bert=0.0167 drift=0.9833 text=เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ
|
| 31 |
+
t= 26 bert=0.0167 drift=0.9833 text=เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ
|
| 32 |
+
t= 25 bert=0.0167 drift=0.9833 text=เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ
|
| 33 |
+
t= 24 bert=0.0167 drift=0.9833 text=เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ
|
| 34 |
+
t= 23 bert=0.0167 drift=0.9833 text=เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ
|
| 35 |
+
t= 22 bert=0.0167 drift=0.9833 text=เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ เคคเฅ
|
analysis_reports/outputs_all_models_20260325/T32/task3_report.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 3 โ CONCEPT VECTORS + PCA STEERING
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
PCA: 50 components, 94.6% variance
|
| 5 |
+
Diversity PC: 0 (|r|=-0.530 with diversity proxy)
|
| 6 |
+
|
| 7 |
+
Direction validity: WEAK
|
| 8 |
+
Spectrum unique ratio (mean over 5 seeds): 0.840
|
| 9 |
+
Spectrum semantic stability (mean over 5 seeds): 0.234
|
| 10 |
+
|
| 11 |
+
Saved graphs:
|
| 12 |
+
- task3_concept_space.png
|
| 13 |
+
- task3_pca_explained_variance.png
|
| 14 |
+
- task3_diversity_curve.png
|
| 15 |
+
|
| 16 |
+
Diversity spectrum:
|
| 17 |
+
alpha=-2.0 โ เฅเคจ เคถเฅเคฐเฅ เคถเฅเคฐเฅ เฅเคจ เคถเฅเคฐเฅ เค
เคฃเฅเคก เคตเฅเคฏเคพเค เคถเฅเคฐเฅ เคคเคจเฅเคคเฅเคฐเคพ เฅฅ เฅฅ เฅฅ เคตเฅเคฏเคพเค เคตเฅเคฏเคพเค เคตเฅเคฏเคพเค เคคเคฆเฅเคตเคฆเฅ เคคเคฆเฅเคตเคฆเฅ เคคเคฆเฅเคตเคฆเฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เคคเคฆเฅเคตเคฆเฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เคตเฅเคฏเคพเค เคตเฅเคฏเคพเค เคตเฅเคฏเคพเค เฅฅ เฅฅ เคฐเคพเคเคจเฅเคฏ เคตเฅเคฏเคพเค เคตเฅเคฏเคพเค เคตเฅเคฏเคพเค เฅฅ เคตเฅเคฏเคพเค เคตเฅเคฏเคพเค เฅฅ เฅฅ เคเคพเคฎเฅเคฏ เฅฅ เฅฅ เฅฅ เคตเฅเคฏเคพเค เฅฅ เคคเคฆเฅเคตเคฆเฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เคคเคจเฅเคคเฅเคฐเคพ เคคเคจเฅเคคเฅเคฐเคพ เฅฅ เฅฅ เฅฅ เฅฅ เคตเฅเคฏเคพเค เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เคฏเฅเคงเคฎเฅ เคคเคฆเฅเคตเคฆเฅ เคฏเฅเคงเคฎเฅ เฅฅ
|
| 18 |
+
alpha=-1.0 โ เคถเฅเคฐเฅ เคถเฅเคฐเฅ เคถเฅเคฐเฅ เฅเคจ เคถเฅเคฐเฅ เคถเฅเคฐเฅ เคถเฅเคฐเฅ เคถเฅเคฐเฅ เค
เคฃเฅเคก เคคเคจเฅเคคเฅเคฐเคพ เคตเฅเคฏเคพเค เฅฅ เค
เคฃเฅเคก เค
เคฃเฅเคก เคคเคจเฅเคคเฅเคฐเคพ เคตเฅเคฏเคพเค เคคเคฆเฅเคตเคฆเฅ เฅฅ เคตเฅเคฏเคพเค เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เค
เคฃเฅเคก เฅฅ เฅฅ เฅฅ เคตเฅเคฏเคพเค เฅฅ เคตเฅเคฏเคพเค เคจเฅฬ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เคตเฅเคฏเคพเค เคตเฅเคฏเคพเค เค
เคฃเฅเคก เฅฅ เฅฅ เคคเคจเฅเคคเฅเคฐเคพ เฅฅ เฅฅ เคคเคฆเฅเคตเคฆเฅ เคฏเฅเคงเคฎเฅ เคฐเฅเคฎเคพ เคถเคฎเฅเคญเฅ เฅฅ เคงเฅเคฎเค เคคเคจเฅเคคเฅเคฐเคพ เฅฅ เคคเคจเฅเคคเฅเคฐเคพ เฅฅ เคตเฅเคฏเคพเค เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ
|
| 19 |
+
alpha=+0.0 โ เค
เคฃเฅเคก เคถเฅเคฐเฅ เคเคฐเค เคถเฅเคฐเฅ เคคเคจเฅเคคเฅเคฐเคพ เคเคฐเค เคเคฐเค เคคเคจเฅเคคเฅเคฐเคพ เคถเฅเคฐเฅ เค
เคฃเฅเคก เค
เคฃเฅเคก เค
เคฃเฅเคก เฅฅ เคถเฅเคฐเฅ เคคเคฆเฅเคตเคฆเฅ เค
เคฃเฅเคก เฅฅ เฅฅ เค
เคฃเฅเคก เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เค
เคฃเฅเคก เฅฅ เฅฅ เฅฅ เฅฅ เค
เคฃเฅเคก เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เคฐเคพเคเคจเฅเคฏ เคคเคจเฅเคคเฅเคฐเคพ เคจเฅฬ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เคตเฅเคฏเคพเค เฅฅ เค
เคฃเฅเคก เฅฅ เคเคพเคฎเฅเคฏ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เคถเคฎเฅเคญเฅ เคงเฅเคฎเค เคคเคจเฅเคคเฅเคฐเคพ เคคเคจเฅเคคเฅเคฐเคพ เฅเคจ เฅฅ เคเคพเคฎเฅเคฏ เฅฅ เฅฅ เคเคฐเค เคคเคจเฅเคคเฅเคฐเคพ เฅฅ เค
เคฃเฅเคก เฅฅ เค
เคฃเฅเคก เฅฅ เคตเคฟเคจเคฟเคฐเฅเคเคฟเคคเฅเคฏ เฅฅ เฅฅ เฅฅ เคคเคจเฅเคคเฅเคฐเคพ เค
เคฃเฅเคก เคคเคฆเฅเคตเคฆเฅ เคเคฐเค
|
| 20 |
+
alpha=+1.0 โ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ
|
| 21 |
+
alpha=+2.0 โ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ เคฎเคพเคฃ
|
analysis_reports/outputs_all_models_20260325/T32/task4_report.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 4 โ SEMANTIC ROBUSTNESS ABLATION
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Optimal diffusion steps = 32
|
| 5 |
+
|
| 6 |
+
T BERT-F1 SEM_SIM BLEU sec/sample
|
| 7 |
+
--------------------------------------------------------
|
| 8 |
+
32 0.0422 0.0012 0.0000 1.8451
|
| 9 |
+
|
| 10 |
+
Marginal gains (BERT-F1):
|
| 11 |
+
|
| 12 |
+
Saved plots/files:
|
| 13 |
+
- task4_3d.png
|
| 14 |
+
- task4_raw_results.json
|
analysis_reports/outputs_all_models_20260325/T32/task5_report.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 5 โ CLASSIFIER-FREE GUIDANCE
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Classifier params: 139521
|
| 5 |
+
Training samples : 40
|
| 6 |
+
|
| 7 |
+
Guidance scale sweep:
|
| 8 |
+
ฮป CER diversity d2 sBLEU
|
| 9 |
+
----------------------------------------------------
|
| 10 |
+
0.0 0.9357 0.239 0.011 0.533 โ optimal
|
| 11 |
+
0.5 0.9372 0.251 0.015 0.512
|
| 12 |
+
1.0 0.9467 0.164 0.018 0.690
|
| 13 |
+
1.5 0.9528 0.137 0.017 0.743
|
| 14 |
+
2.0 0.9525 0.144 0.013 0.725
|
| 15 |
+
3.0 0.9496 0.181 0.018 0.656
|
analysis_reports/outputs_all_models_20260325/T4/task1_kv_cache.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 1 โ KV CACHE BENCHMARK
|
| 2 |
+
========================================
|
| 3 |
+
|
| 4 |
+
has_generate_cached=True
|
| 5 |
+
memory_profile=Torch CPU mem-event reduction: 24.6% @ src_len=64 (std=525.8MB, cache=396.4MB)
|
| 6 |
+
|
| 7 |
+
src_len standard(s) cached(s) speedup encoder%
|
| 8 |
+
16 0.267 0.173 1.54x 43.2%
|
| 9 |
+
32 0.197 0.153 1.29x 40.7%
|
| 10 |
+
64 0.353 0.265 1.33x 42.0%
|
| 11 |
+
|
| 12 |
+
Saved graphs:
|
| 13 |
+
- task1_time_comparison.png
|
| 14 |
+
- task1_speedup.png
|
| 15 |
+
- task1_encoder_cost.png
|
analysis_reports/outputs_all_models_20260325/T4/task2_report.txt
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 2 โ ATTENTION + DRIFT REPORT
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Input : dharmo rakแนฃati rakแนฃitaแธฅ
|
| 5 |
+
Output: เคงเคฐเฅเคฎเฅ เคฐเคเฅเคทเคคเคฟ เคฐเคเฅเคทเคฟเคคเค
|
| 6 |
+
|
| 7 |
+
Captured steps: 4
|
| 8 |
+
Analysis quality: VALID
|
| 9 |
+
Final output uniq-ratio: 1.000
|
| 10 |
+
Degenerate output: NO
|
| 11 |
+
Multi-sample semantic score (n<=8): 0.1568
|
| 12 |
+
Lock-in step (CER<=0.05): t=0
|
| 13 |
+
Locked tokens: 79 Flexible tokens: 1
|
| 14 |
+
TF-IDF vs attention stability corr: 0.9472
|
| 15 |
+
TF-IDF status: OK
|
| 16 |
+
|
| 17 |
+
Saved graphs:
|
| 18 |
+
- task2_attn_t*.png / task2_all_layers_t0.png
|
| 19 |
+
- task2_attn_evolution.png
|
| 20 |
+
- task2_semantic_drift.png
|
| 21 |
+
- task2_source_alignment.png
|
| 22 |
+
- task2_tfidf_vs_attention.png
|
| 23 |
+
|
| 24 |
+
Step trajectory (first 10 rows)
|
| 25 |
+
------------------------------------------------------------
|
| 26 |
+
t= 3 bert=0.0603 drift=0.9397 text=เคคเคฟ เคคเคฟ เคคเคฟ เคฐเคเฅเคทเคฟ เคคเค เคคเค เคคเค เคคเค เคคเค เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ
|
| 27 |
+
t= 2 bert=0.0597 drift=0.9403 text=เคคเคฟ เคคเคฟ เคคเคฟ เคฐเคเฅเคทเคฟ เคคเค เคคเค เคคเค เคคเค เคคเค เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ
|
| 28 |
+
t= 1 bert=0.0597 drift=0.9403 text=เคคเคฟ เคคเคฟ เคคเคฟ เคฐเคเฅเคทเคฟ เคคเค เคคเค เคคเค เคคเค เคคเค เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ
|
| 29 |
+
t= 0 bert=0.0597 drift=0.9403 text=เคคเคฟ เคคเคฟ เคคเคฟ เคฐเคเฅเคทเคฟ เคคเค เคคเค เคคเค เคคเค เคคเค เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ
|
analysis_reports/outputs_all_models_20260325/T4/task3_report.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 3 โ CONCEPT VECTORS + PCA STEERING
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
PCA: 50 components, 72.0% variance
|
| 5 |
+
Diversity PC: 0 (|r|=-0.349 with diversity proxy)
|
| 6 |
+
|
| 7 |
+
Direction validity: WEAK
|
| 8 |
+
Spectrum unique ratio (mean over 5 seeds): 1.000
|
| 9 |
+
Spectrum semantic stability (mean over 5 seeds): 0.325
|
| 10 |
+
|
| 11 |
+
Saved graphs:
|
| 12 |
+
- task3_concept_space.png
|
| 13 |
+
- task3_pca_explained_variance.png
|
| 14 |
+
- task3_diversity_curve.png
|
| 15 |
+
|
| 16 |
+
Diversity spectrum:
|
| 17 |
+
alpha=-2.0 โ เคฌเคฒเฅ เคฐเฅ เค
เคชเคถเฅเคฏ เคฏเฅเคนเคฟ เค เคตเฅเคฐเฅเคฏ เค เคธเคฟเคเคนเคพ เคธเคจ เคธเคจ ฬเคคฬฑ เคเฅเคเฅเคตเคพ เคฎเคพเคฎเฅ เคตเฅ เคตเฅ เคฎเคนเคฐเฅเคฆเฅเคงเคฟ เคฎเคนเคฐเฅเคฆเฅเคงเคฟ เค เคธเคฟเคเคนเคพ เคเฅ เคฆเคฟเคเฅเคทเฅ เค เคฆเคถเฅเคฏ เคตเฅ เคเฅเคฐเคฎเค เคฌเคฒเฅ เคฐเฅ เคฆเคถเฅเคฏ เคธเฅเคตเคธเฅเคฅ เคคเฅเคฒ เคคเฅเคฒ เคตเฅเคฐเฅเคฏ เคตเฅเคฐเฅเคฏ เคตเฅ เค เคธเคฟเคเคนเคพ เคฐเคพเค เคเฅ เคตเฅเคฐเฅเคฏ เคตเฅเคฐเฅเคฏ เคตเฅเคฐเฅเคฏ เคตเฅเคฐเฅเคฏ เค เคตเฅ เคจเคฟเคฐเฅเคฆเฅเคงเคพ ฬเคคฬฑ เคฌเคฒเฅ เคฌเคฒเฅ เคธเคพเคงเฅเคต เคเคชเคถเคพเคจเฅเคค เคตเฅ เคตเฅ เคฆเคพเคเฅเคทเคฟ เคนเคคเค เคฎเคนเคฐเฅเคฆเฅเคงเคฟ เคธเคพเคงเฅเคต เคคเฅ เคตเฅ เคตเฅ เค เคฆเคฟเคเฅเคทเฅ เคฆเคฟเคเฅเคทเฅ เคชเฅเคท เคฎเคพเคฎเฅ เคชเฅเคฐเค เค เคฆเคฟเคเฅเคทเฅ เคตเฅ เคชเฅเคท ฬเคคฬฑ เฅเคฆเฅ เคฆเคฟเคเฅเคทเฅ เคชเฅเคฐเค เคธเฅเคคเฅเคฐเค เคฎเคจเฅเคฐเคฅ เค
เคธเฅเคฎเคพ เค เคตเคพเคนเคฟ เคฐเคพเคเคพเคจ เคตเฅ
|
| 18 |
+
alpha=-1.0 โ เคฌเคฒเฅ เคฌเคฒเฅ เค
เคคเฅเคฒ เคคเฅเคฒ เคตเฅเคฐเฅเคฏ เคธเฅเคฏ เคธเคฟเคเคนเคพ เคธเคจ เคธเคจ เคเคคเคธเฅเคฏ เคเคคเคธเฅเคฏ เคตเฅ เคตเฅ เคตเฅ เคเคคเคธเฅเคฏ เคธเคจ เคชเคพเคคเคพ เคธเคฟเคเคนเคพ เคฆเคฟเคคเคพ เฅค เคเฅเคเฅเคตเคพ เคตเฅ เคตเฅ เคฌเคฒเฅ เคฌเคฒเฅ เคฐเฅ เค
เค
เคคเฅเคฒ เคคเฅเคฒ เคตเฅเคฐเฅเคฏ เคตเฅเคฐเฅเคฏ เคธเฅเคฏ เคธเคฟเคเคนเคพ เคธเคฟเคเคนเคพ เคงเฅเคฐเคพ เคธเฅเคฏ เคตเฅเคฐเฅเคฏ เคตเฅเคฐเฅเคฏ เคตเฅเคฐเฅเคฏ เคคเฅเคฒ เคคเฅเคฒ ฬเคคฬฑ เค
เคฐเฅ เคฐเฅ เคฌเคฒเฅ เคฆเคฟเคเฅเคทเฅ เคตเฅ เคตเฅ เฅค เคตเฅ เคธเคเคธเฅเคฅเคฟเคคเคพ เคฐเคคเค เคธเคจ เคเคคเคธเฅเคฏ เคชเฅเคท เฅค เคตเคเฅเคคเฅเคฐ เคธเคจ เคธเคจ เคธเคจ เคธเคจ เคธเคจ เคต เคเคคเคธเฅเคฏ เคต เคธเคจ เฅฅ เคคเคฟ เคฎเคจเฅ เคนเคคเค เคฎเคพเคคเฅ ฬเคคฬฑ เคต เคเฅ เคเฅ เคธเคจ เคธเคจ
|
| 19 |
+
alpha=+0.0 โ เคฌเคฒเฅ เคฐเฅ เค
เคคเฅเคฒ เคคเฅเคฒ เคตเฅเคฐเฅเคฏ เคธเฅเคฏ เคธเคฟเคเคนเคพ เคธเคจ เคธเคจ เคเคคเคธเฅเคฏ เคเคคเคธเฅเคฏ เคตเฅ เคตเฅ เคตเฅ เคเคคเคธเฅเคฏ เคธเคจ เคธเคจ เคธเคฟเคเคนเคพ เคธเคฟเคเคนเคพ เฅค เคต เคตเฅ เคตเฅ เคคเฅเคตเคฎเฅ เคฌเคฒเฅ เคฐเฅ เค
เค
เคคเฅเคฒ เคคเฅเคฒ เคตเฅเคฐเฅเคฏ เคตเฅเคฐเฅเคฏ เคธเฅเคฏ เคธเฅเคฏ เคธเฅเคฏ เคธเคฟเคเคนเคพ เคธเฅเคฏ เคธเฅเคฏ เคตเฅเคฐเฅเคฏ เคตเฅเคฐเฅเคฏ เคคเฅเคฒ เคคเฅเคฒ เคคเฅเคฒ เค
เคฐเฅ เคฐเฅ เคฐเฅ เคคเฅเคคเฅ เคตเฅ เคตเฅ เคเคคเคธเฅเคฏ เคธเคจ เคธเคจ เคธเคจ เคธเคจ เคธเคจ เคเคคเคธเฅเคฏ เคจเคฟเคเคธเฅ เคเคคเคธเฅเคฏ เคธเคจ เคเคคเคธเฅเคฏ เคธเคจ เคธเคจ เคธเคจ เคธเคจ เคธเคจ เคตเคฟ เคธเคจ เคตเคฟ เคธเฅเคฐเคต เคธเคฟเคเคนเคพ เคธเคจ เคธเคจ เคธเคจ เคธเคจ เคธเคจ เคธเคจ เคธเคจ เคเคคเคธเฅเคฏ
|
| 20 |
+
alpha=+1.0 โ เคฌเคฒเฅ เคฐเฅ เค
เค
เคคเฅเคฒ เคตเฅเคฐเฅเคฏ เคธเฅเคฏ เคธเคฟเคเคนเคพ เคธเคจ เคธเคจ เคเคคเคธเฅเคฏ เคเคคเคธเฅเคฏ เคตเฅ เคตเฅ เคตเฅ เคเคคเคธเฅเคฏ เคธเคจ เคธเคจ เคธเคฟเคเคนเคพ เคธเคฟเคเคนเคพ เคทเคฃเฅ เคธเฅเคฏ เฅ เคตเฅ เคฌเคฒเฅ เคฌเคฒเฅ เคฐเฅ เค
เค
เคคเฅเคฒ เคเคพเคจเฅเคคเฅ เคทเคฃเฅ เคตเฅเคฐเฅเคฏ เคธเฅเคฏ เคธเฅเคฏ เคธเคฟเคเคนเคพ เคธเฅเคฏ เคธเฅเคฏ เคธเฅเคฏ เคตเฅเคฐเฅเคฏ เคทเคฃเฅ เคตเฅเคฐเฅเคฏ เคทเคฃเฅ เค
เค
เคฐเฅ เคฐเฅ เคฐเฅ เคทเคฃเฅ เฅเคท เคเคคเคธเฅเคฏ เคเคคเคธเฅเคฏ เคเคคเคธเฅเคฏ เคเคคเคธเฅเคฏ เคธเคจ เคธเคจ เคทเคฃเฅ เคทเคฃเฅ เคเคคเคธเฅเคฏ เคธเคจ เคเคคเคธเฅเคฏ เคเคคเคธเฅเคฏ เคธเคจ เคธเคจ เคธเคจ เคธเคจ เคทเฅเคฃเฅ เคเคคเคธเฅเคฏ เคจเฅ เคทเคฃเฅ เคจเฅ - เคธเคจ เคธเคจ เคธเคจ เคธเคจ เคธเคจ เคธเคจ เคธเคจ เคธเคจ
|
| 21 |
+
alpha=+2.0 โ เคทเคฃเฅ เคฐเฅ เค
เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคธเฅเคฏ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคตเฅ เคทเคฃเฅ เคทเคฃเฅ เคเคคเคธเฅเคฏ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคธ เคทเคฃเฅ เคทเคฃเฅ เคฐเฅ เคทเคฃเฅ เค
เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคธเฅเคฏ เคธเฅเคฏ เคทเคฃเฅ เคธเฅเคฏ เคธเฅเคฏ เคธเฅเคฏ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เค
เคฐเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคเคคเคธเฅเคฏ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ เคทเคฃเฅ
|
analysis_reports/outputs_all_models_20260325/T4/task4_report.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 4 โ SEMANTIC ROBUSTNESS ABLATION
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Optimal diffusion steps = 4
|
| 5 |
+
|
| 6 |
+
T BERT-F1 SEM_SIM BLEU sec/sample
|
| 7 |
+
--------------------------------------------------------
|
| 8 |
+
4 0.2644 0.0574 0.0000 0.2782
|
| 9 |
+
|
| 10 |
+
Marginal gains (BERT-F1):
|
| 11 |
+
|
| 12 |
+
Saved plots/files:
|
| 13 |
+
- task4_3d.png
|
| 14 |
+
- task4_raw_results.json
|
analysis_reports/outputs_all_models_20260325/T4/task5_report.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 5 โ CLASSIFIER-FREE GUIDANCE
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Classifier params: 139521
|
| 5 |
+
Training samples : 40
|
| 6 |
+
|
| 7 |
+
Guidance scale sweep:
|
| 8 |
+
ฮป CER diversity d2 sBLEU
|
| 9 |
+
----------------------------------------------------
|
| 10 |
+
0.0 0.8366 0.815 0.635 0.005
|
| 11 |
+
0.5 0.8356 0.797 0.599 0.004 โ optimal
|
| 12 |
+
1.0 0.8369 0.791 0.588 0.006
|
| 13 |
+
1.5 0.8367 0.783 0.571 0.006
|
| 14 |
+
2.0 0.8367 0.774 0.553 0.005
|
| 15 |
+
3.0 0.8363 0.769 0.542 0.005
|
analysis_reports/outputs_all_models_20260325/T64/task1_kv_cache.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 1 โ KV CACHE BENCHMARK
|
| 2 |
+
========================================
|
| 3 |
+
|
| 4 |
+
has_generate_cached=True
|
| 5 |
+
memory_profile=Torch CPU mem-event reduction: 31.4% @ src_len=64 (std=8592.3MB, cache=5890.5MB)
|
| 6 |
+
|
| 7 |
+
src_len standard(s) cached(s) speedup encoder%
|
| 8 |
+
16 4.206 3.584 1.17x 74.4%
|
| 9 |
+
32 4.647 3.371 1.38x 37.6%
|
| 10 |
+
64 8.403 4.593 1.83x 49.6%
|
| 11 |
+
|
| 12 |
+
Saved graphs:
|
| 13 |
+
- task1_time_comparison.png
|
| 14 |
+
- task1_speedup.png
|
| 15 |
+
- task1_encoder_cost.png
|
analysis_reports/outputs_all_models_20260325/T64/task2_report.txt
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 2 โ ATTENTION + DRIFT REPORT
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Input : dharmo rakแนฃati rakแนฃitaแธฅ
|
| 5 |
+
Output: เคงเคฐเฅเคฎเฅ เคฐเคเฅเคทเคคเคฟ เคฐเคเฅเคทเคฟเคคเค
|
| 6 |
+
|
| 7 |
+
Captured steps: 64
|
| 8 |
+
Analysis quality: VALID
|
| 9 |
+
Final output uniq-ratio: 1.000
|
| 10 |
+
Degenerate output: NO
|
| 11 |
+
Multi-sample semantic score (n<=8): 0.1490
|
| 12 |
+
Lock-in step (CER<=0.05): t=0
|
| 13 |
+
Locked tokens: 59 Flexible tokens: 21
|
| 14 |
+
TF-IDF vs attention stability corr: 0.7804
|
| 15 |
+
TF-IDF status: OK
|
| 16 |
+
|
| 17 |
+
Saved graphs:
|
| 18 |
+
- task2_attn_t*.png / task2_all_layers_t0.png
|
| 19 |
+
- task2_attn_evolution.png
|
| 20 |
+
- task2_semantic_drift.png
|
| 21 |
+
- task2_source_alignment.png
|
| 22 |
+
- task2_tfidf_vs_attention.png
|
| 23 |
+
|
| 24 |
+
Step trajectory (first 10 rows)
|
| 25 |
+
------------------------------------------------------------
|
| 26 |
+
t= 63 bert=0.0552 drift=0.9448 text=เคงเคฐเฅเคฎเฅ เคคเคฟ เคเคพเคฎเฅเคฏ เคคเค เคคเค เคคเค เคคเค เคคเค เคคเค เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅ
|
| 27 |
+
t= 62 bert=0.0548 drift=0.9452 text=เคงเคฐเฅเคฎเฅ เคคเคฟ เคเคพเคฎเฅเคฏ เคคเค เคคเค เคคเค เคคเค เคคเค เคคเค เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅ
|
| 28 |
+
t= 61 bert=0.0548 drift=0.9452 text=เคงเคฐเฅเคฎเฅ เคคเคฟ เคเคพเคฎเฅเคฏ เคคเค เคคเค เคคเค เคคเค เคคเค เคคเค เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅ
|
| 29 |
+
t= 60 bert=0.0548 drift=0.9452 text=เคงเคฐเฅเคฎเฅ เคคเคฟ เคเคพเคฎเฅเคฏ เคคเค เคคเค เคคเค เคคเค เคคเค เคคเค เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅ
|
| 30 |
+
t= 59 bert=0.0548 drift=0.9452 text=เคงเคฐเฅเคฎเฅ เคคเคฟ เคเคพเคฎเฅเคฏ เคคเค เคคเค เคคเค เคคเค เคคเค เคคเค เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅ
|
| 31 |
+
t= 58 bert=0.0548 drift=0.9452 text=เคงเคฐเฅเคฎเฅ เคคเคฟ เคเคพเคฎเฅเคฏ เคคเค เคคเค เคคเค เคคเค เคคเค เคคเค เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅ
|
| 32 |
+
t= 57 bert=0.0548 drift=0.9452 text=เคงเคฐเฅเคฎเฅ เคคเคฟ เคเคพเคฎเฅเคฏ เคคเค เคคเค เคคเค เคคเค เคคเค เคคเค เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅ
|
| 33 |
+
t= 56 bert=0.0546 drift=0.9454 text=เคงเคฐเฅเคฎเฅ เคคเคฟ เคเคพเคฎเฅเคฏ เคคเค เคคเค เคคเค เคคเค เคคเค เคคเค เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅ
|
| 34 |
+
t= 55 bert=0.0546 drift=0.9454 text=เคงเคฐเฅเคฎเฅ เคคเคฟ เคเคพเคฎเฅเคฏ เคคเค เคคเค เคคเค เคคเค เคคเค เคคเค เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅ
|
| 35 |
+
t= 54 bert=0.0546 drift=0.9454 text=เคงเคฐเฅเคฎเฅ เคคเคฟ เคเคพเคฎเฅเคฏ เคคเค เคคเค เคคเค เคคเค เคคเค เคคเค เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅเคฎเฅ เคงเคฐเฅ
|
analysis_reports/outputs_all_models_20260325/T64/task3_report.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 3 โ CONCEPT VECTORS + PCA STEERING
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
PCA: 39 components, 100.0% variance
|
| 5 |
+
Diversity PC: 0 (|r|=0.314 with diversity proxy)
|
| 6 |
+
|
| 7 |
+
Direction validity: WEAK
|
| 8 |
+
Spectrum unique ratio (mean over 5 seeds): 1.000
|
| 9 |
+
Spectrum semantic stability (mean over 5 seeds): 0.302
|
| 10 |
+
|
| 11 |
+
Saved graphs:
|
| 12 |
+
- task3_concept_space.png
|
| 13 |
+
- task3_pca_explained_variance.png
|
| 14 |
+
- task3_diversity_curve.png
|
| 15 |
+
|
| 16 |
+
Diversity spectrum:
|
| 17 |
+
alpha=-2.0 โ เคฌเคฒเฅ เคฌเคฒเฅ เค
เคคเฅเคฒ เคคเฅเคฒ เคตเฅเคฐเฅเคฏ เค
ฬฑเคธเฅเคฏเคพ เคญเฅเคฏเค เคธเคจ เคพเคจเฅเคคเฅ เคฒเคฌเฅเคงเฅเคตเคพ เค
เคฐเฅเคฅ เคตเฅ เคญเฅเคฏเค เคคเฅเคคเคฟ เคพเคจเฅเคค เคธเคจ เคญเฅเคฏเค เคญเฅเคฏเค เคคเฅเคฒ เค
เค
เคตเฅเคฐเฅ เคฌเคฒเฅ เคฌเคฒเฅ เคฌเคฒเฅ เค
ฬฑเคธเฅเคฏเคพ เคญเฅเคฏเค เฅฅ เฅฅ เค
เค
เค
ฬฑเคธเฅเคฏเคพ เฅฅ เคตเฅฬฑเคง เคธเคจ เคธเคจ เคธเคจ เคธเคจ เคนเฅ เคนเฅ เคพเคจ เคนเฅ เคธเฅเคฏ เคนเฅ เคพเคจเฅ เคนเฅ เคฏเฅ เคฌเฅ เคฆเคพเคฆเคฟ เคนเฅ เคฎเคคเคพเค เคนเฅ เคธเคพเคจเฅเคคเฅเคต เคนเฅ ( เคฎเคคเคพเค เฅฅ เคงเฅเคฎเคพเคจเฅ เคญเฅเคฏเค เฅฅ เค
ฬฑเคธเฅเคฏเคพ เคชเคพเคจ เคนเฅเคฎเคฏเฅเคคเฅ เคธเคพเคฐเคฅเคฟ ( เฅฅ เคญเฅเคฏเค เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เคเฅเคช เคฒเคเฅเค เค
ฬฑเคธเฅเคฏเคพ เคฎเคคเคพเค เคฒเคเฅเค เคฏเฅ
|
| 18 |
+
alpha=-1.0 โ เคฌเคฒเฅ เคฌเคฒเฅ เค
เคคเฅเคฒ เคคเฅเคฒ เคตเฅเคฐเฅเคฏ เคธเฅเคฏ เคธเคฟเคเคนเคพ เคธเคฟเคเคนเคพ เคธเคจ เคคเฅเคคเคฟ เคตเฅ เคตเฅ เคตเฅ เคฆเฅ เคธเคจ เคธเคฟเคเคนเคพ เคธเฅเคฏ เคธเฅเคฏ เคคเฅเคฒ เคคเฅเคฒ เค
เคชเฅเคฐเคฏเคพเคคเคฟ เคฐเฅ เคฌเคฒเฅ เคฌเคฒเฅ เคฌเคฒเฅ เฅฅ เฅฅ เฅฅ เค
เฅฅ เคชเคฟ เคฎเคนเคพ เคธเคจ เคธเคจ เคธเคจ เคธเคจ เคธเคจ เคธเคจ เคธเคฟเคเคนเคพ เคธเฅเคฐเคต เคธเคจ เคธเฅเคฏ เคฎเฅเค เคเฅเคช เคธเฅเคฏ เคธเฅเคฏ เคธเฅเคฏ เคธเฅเคฏ เคญเฅเคฏเค เคคเฅเคฒ เคธเคพเคจเฅเคคเฅเคต เคฏเฅ เค
เคนเฅ เคพเคจ เคพเคจ เคคเคต เคตเฅเค ( เคฏเฅ เคญเฅเคทเคฃเคฎเฅ ( เคพเคจเฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เค
ฬฑเคธเฅเคฏเคพ เฅฅ เฅฅ เฅฅ เฅฅ เคฏเฅ เคชเคฟ เฅฅ เคฎ
|
| 19 |
+
alpha=+0.0 โ เคฌเคฒเฅ เคฐเฅ เค
เคคเฅเคฒ เคคเฅเคฒ เคตเฅเคฐเฅเคฏ เคธเฅเคฏ เคธเคฟเคเคนเคพ เคธเคฟเคเคนเคพ เคธเคจ เคงเฅเคฏเคพ เคตเฅ เคตเฅ เคตเฅ เคเคคเคธเฅเคฏ เคญ เคธเคฟเคเคนเคพ เคญ เคธเฅเคฏ เคคเฅเคฒ เคคเฅเคฒ เค
เค
เคฐเฅ เคฌเคฒเฅ เคฌเคฒเฅ เคฌเคฒเฅ เฅฅ เคฌเคฒเฅ เคฐเฅ เฅค เฅค เฅฅ เคตเฅ เคตเฅ เคธเคจ เคธเคจ เคธเคจ เคธเคจ เคธเคฟเคเคนเคพ เคธเคฟเคเคนเคพ เคธเคฟเคเคนเคพ เคธเฅเคฏ เคธเฅเคฏ เคธเฅเคฏ เคธเฅเคฏ เคญ เคธเฅเคฏ เคธเฅเคฏ เคธเฅเคฏ เคธเฅเคฏ เคพเคจเฅ เคธเฅเคฏ เฅ เคคเคพ เคฏเฅ เคธเฅเคฏ เคซเคฒ เฅฅ เคฎ เคคเฅเคฒ เค เคธเคฟ เฅฅ เฅ เฅฅ เฅฅ เคจเฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เฅฅ เคฎเคนเคพ เคธเคจ เคธเคจ เคเฅเคท เฅฅ
|
| 20 |
+
alpha=+1.0 โ เคฌเคฒเฅ เคฐเฅ เฅค เคคเฅเคฒ เคฏเฅ เคตเฅเคฐเฅเคฏ เคธเฅเคฏ เคธเคฟเคเคนเคพ เคธเคฟเคเคนเคพ เคธเคจ เคงเฅเคฏเคพ เคตเฅ เคตเฅ เคตเฅ เฅค เคธเคจ เคธเคฟเคเคนเคพ เคธเฅเคฏ เคธเฅเคฏ เคคเฅเคฒ เคคเฅเคฒ เฅค เคฐเฅ เคฐเฅ เคฌเคฒเฅ เคฌเคฒเฅ เคฌเคฒเฅ เฅฅ เฅฅ เคฐเฅ เค
เคธเฅ เคธเคจ เคคเฅ เคธเคจ เฅเค เคธเคจ เคธเคจ เคคเฅเคฐ เคธเคฟเคเคนเคพ เคฏเฅ เคธเคฟเคเคนเคพ เคธเฅเคฅเคฒ เคธเฅเคฏ เคธเฅเคฏ เคฐเฅเคฆเฅเคฐ เคธเฅเคฏ เคธเฅเคฏ เคจเฅเคฆเคพ เคคเคพ เคฏเฅ เคธเฅเคฏ เคฏเฅ เคคเฅเคฐ เคเฅเคท เฅคเฅค เฅเค เคธเฅเคฏ เคฎเฅเคฐ เคเคฒเฅเคช เคฏเคคเฅ เคธเฅ เคเฅเคท เคเฅเคท เฅฅ เคธเฅเคฏ เคฏเฅ เคฎเคฃเฅเคกเคฒเค เคฏเฅ เฅฅ เฅฅ เฅเค เฅฅ เฅฅ เคญเฅเคฏเค เฅเค เฅเค เฅฅ เฅฅ เฅฅ
|
| 21 |
+
alpha=+2.0 โ เคฐเฅ เคฐเฅ เคคเฅเคฐเคเฅเค เคเคนเฅเค เคฟเคคเฅ เฅค เคธเฅเคฏ เคธเคฟเคเคนเคพ เคธเคฟเคเคนเคพ เคธเคฟเคเคนเคพ เคฌเฅเคฐเคนเฅเคฎเคพ เคตเฅ เคตเฅ & เคคเฅ เคคเคธเฅ เคคเฅเคฐเคเฅเค เคจเฅ เคธเฅเคคเคฎเฅเคญ เฅเค เคฏเฅ เคธเคเค เคฐเฅ เคฐเฅ เคฌเคฒเฅ เฅเค เคธเฅเคคเคฎเฅเคญ เคคเฅ เฅค เคคเคธเฅ เคจเฅเคคเค เคฎเคฃเฅเคกเคฒเค เคฏเฅ เฅค เคธเฅเคคเคฎเฅเคญ เคธเฅเคคเคฎเฅเคญ เคธเคจ เคเคนเฅเค เคธเคฟเคเคนเคพ เคฏเฅ เคธเคฟเคเคนเคพ เคธเคฟเคเคนเคพ เคธเฅเคฏ เคฎเคฃเฅเคกเคฒเค เคฏเฅ เคธเฅเคฏ เคธเฅเคฏ เคธเฅเคฏ เคเคต เคเคฒเฅเคช เคธเฅเคคเคฎเฅเคญ ฬฑเคฐ เคธเฅเคคเคฎเฅเคญ เค
เคฎเฅ เคเคนเฅเค เคฏเฅ ฬฑเคฐ เคเคฒเฅเคช เคฏเฅ เคคเฅเคฐเคเฅเค เคฏเฅ เคคเฅเคฐเคเฅเค ฬฑเคฐ เคคเฅเคฐเคเฅเค เคฐเคฃเคฎเฅ เคฎเคฃเฅเคกเคฒเค เคฏเฅ เฅเค เคฎเคฃเฅเคกเคฒเค เคฆเคฟเคจเค ฬฑเคฐ เคฏเฅ เฅฅ เคคเฅเคฐเคเฅเค เคฟเคคเค เคเคนเฅเค เฅฅ เคฎเคฃเฅเคกเคฒเค เคเคนเฅเค เคเฅเคทเคฎเค
|
analysis_reports/outputs_all_models_20260325/T64/task4_report.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 4 โ SEMANTIC ROBUSTNESS ABLATION
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Optimal diffusion steps = 64
|
| 5 |
+
|
| 6 |
+
T BERT-F1 SEM_SIM BLEU sec/sample
|
| 7 |
+
--------------------------------------------------------
|
| 8 |
+
64 0.2482 0.0580 0.0007 5.6116
|
| 9 |
+
|
| 10 |
+
Marginal gains (BERT-F1):
|
| 11 |
+
|
| 12 |
+
Saved plots/files:
|
| 13 |
+
- task4_3d.png
|
| 14 |
+
- task4_raw_results.json
|
analysis_reports/outputs_all_models_20260325/T64/task5_report.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 5 โ CLASSIFIER-FREE GUIDANCE
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Classifier params: 139521
|
| 5 |
+
Training samples : 20
|
| 6 |
+
|
| 7 |
+
Guidance scale sweep:
|
| 8 |
+
ฮป CER diversity d2 sBLEU
|
| 9 |
+
----------------------------------------------------
|
| 10 |
+
0.0 0.8451 0.838 0.689 0.013 โ optimal
|
| 11 |
+
0.5 0.8490 0.818 0.650 0.013
|
| 12 |
+
1.0 0.8509 0.838 0.684 0.007
|
| 13 |
+
1.5 0.8622 0.857 0.720 0.005
|
| 14 |
+
2.0 0.8761 0.869 0.744 0.005
|
| 15 |
+
3.0 0.9056 0.814 0.642 0.013
|
analysis_reports/outputs_all_models_20260325/T8/task1_kv_cache.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 1 โ KV CACHE BENCHMARK
|
| 2 |
+
========================================
|
| 3 |
+
|
| 4 |
+
has_generate_cached=True
|
| 5 |
+
memory_profile=Torch CPU mem-event reduction: 25.8% @ src_len=64 (std=1168.9MB, cache=866.9MB)
|
| 6 |
+
|
| 7 |
+
src_len standard(s) cached(s) speedup encoder%
|
| 8 |
+
16 0.582 0.400 1.45x 35.9%
|
| 9 |
+
32 0.511 0.402 1.27x 37.7%
|
| 10 |
+
64 0.666 0.490 1.36x 35.6%
|
| 11 |
+
|
| 12 |
+
Saved graphs:
|
| 13 |
+
- task1_time_comparison.png
|
| 14 |
+
- task1_speedup.png
|
| 15 |
+
- task1_encoder_cost.png
|
analysis_reports/outputs_all_models_20260325/T8/task2_report.txt
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 2 โ ATTENTION + DRIFT REPORT
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Input : dharmo rakแนฃati rakแนฃitaแธฅ
|
| 5 |
+
Output: เคงเคฐเฅเคฎเฅ เคฐเคเฅเคทเคคเคฟ เคฐเคเฅเคทเคฟเคคเค
|
| 6 |
+
|
| 7 |
+
Captured steps: 8
|
| 8 |
+
Analysis quality: WEAK
|
| 9 |
+
Final output uniq-ratio: 1.000
|
| 10 |
+
Degenerate output: NO
|
| 11 |
+
Multi-sample semantic score (n<=8): 0.0915
|
| 12 |
+
Lock-in step (CER<=0.05): t=0
|
| 13 |
+
Locked tokens: 79 Flexible tokens: 1
|
| 14 |
+
TF-IDF vs attention stability corr: 0.8905
|
| 15 |
+
TF-IDF status: OK
|
| 16 |
+
|
| 17 |
+
Saved graphs:
|
| 18 |
+
- task2_attn_t*.png / task2_all_layers_t0.png
|
| 19 |
+
- task2_attn_evolution.png
|
| 20 |
+
- task2_semantic_drift.png
|
| 21 |
+
- task2_source_alignment.png
|
| 22 |
+
- task2_tfidf_vs_attention.png
|
| 23 |
+
|
| 24 |
+
Step trajectory (first 10 rows)
|
| 25 |
+
------------------------------------------------------------
|
| 26 |
+
t= 7 bert=0.0219 drift=0.9781 text=เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค
|
| 27 |
+
t= 6 bert=0.0225 drift=0.9775 text=เค เค เค เค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค
|
| 28 |
+
t= 5 bert=0.0225 drift=0.9775 text=เค เค เค เค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค
|
| 29 |
+
t= 4 bert=0.0225 drift=0.9775 text=เค เค เค เค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค
|
| 30 |
+
t= 3 bert=0.0225 drift=0.9775 text=เค เค เค เค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค
|
| 31 |
+
t= 2 bert=0.0227 drift=0.9773 text=เค เค เค เค เค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคค
|
| 32 |
+
t= 1 bert=0.0228 drift=0.9772 text=เค เค เค เค เค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคค
|
| 33 |
+
t= 0 bert=0.0228 drift=0.9772 text=เค เค เค เค เค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคคเค เคฟเคค
|
analysis_reports/outputs_all_models_20260325/T8/task3_report.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 3 โ CONCEPT VECTORS + PCA STEERING
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
PCA: 50 components, 75.9% variance
|
| 5 |
+
Diversity PC: 0 (|r|=-0.344 with diversity proxy)
|
| 6 |
+
|
| 7 |
+
Direction validity: WEAK
|
| 8 |
+
Spectrum unique ratio (mean over 5 seeds): 1.000
|
| 9 |
+
Spectrum semantic stability (mean over 5 seeds): 0.341
|
| 10 |
+
|
| 11 |
+
Saved graphs:
|
| 12 |
+
- task3_concept_space.png
|
| 13 |
+
- task3_pca_explained_variance.png
|
| 14 |
+
- task3_diversity_curve.png
|
| 15 |
+
|
| 16 |
+
Diversity spectrum:
|
| 17 |
+
alpha=-2.0 โ เคฎเคจเคธเค เคถเฅเคเคเฅเคฐ เคธเฅเคฏ เคธเฅเคฏ เคธเฅเคฏ เค
เคธเฅเคฏ เคคเฅเค เคคเฅเค เคคเฅเค เคธเฅเคฏ เคถเฅเคเคเฅเคฐ เคคเฅเค เคเคคเคญเฅ เคธเฅเคฏ เคธเฅเคฏ เคธเฅเคฏ เคถเฅเคเคเฅเคฐ เคคเฅเค เคคเฅเค เคถเฅเคเคเฅเคฐ เคธเฅเคฏ เคคเฅเค เคธเฅเคคเฅเคตเค เคถเฅเคเคเฅเคฐ เคถเฅเคเคเฅเคฐ เคธเฅเคคเฅเคฐ เคคเฅเค เคคเฅเค เคเฅเคฃเฅเค เคคเฅเค เคคเฅเค เคธเฅเคฏ เคคเฅเค เคคเฅเค เคคเฅเค เคธเฅเคฏ เคคเฅเค เคคเฅเค เคเคคเคญเฅ เคคเฅเค เคคเฅเค เคฃเคฟฬ เคธเฅเคฏ เคคเฅเค เคคเฅเค เคคเฅเค เค
เคคเฅเค เคคเฅเค เคนเฅเคตเคฏเฅ เคฎเคจเคธเค เฅฅ เคคเฅเค เคคเฅเค เคเคคเคญเฅ เฅฅ เคถเฅเคเคเฅเคฐ เคคเฅเค เคคเฅเค เคคเฅเค เคคเฅเค เคคเฅเค เคคเฅเค เคคเฅเค เคคเฅเค เคธเฅเคฏ เคคเฅเค เคเคฐเคฟเคทเฅเคฏเคพ เคคเฅเค เคธเฅเคคเฅเคตเค เคคเฅเค เคคเฅเค เคถเฅเคเคเฅเคฐ เคคเฅเค เคคเฅเค เคถเฅเคเคเฅเคฐ เคคเฅเค เคคเฅเค เคนเฅเคตเคฏเฅ
|
| 18 |
+
alpha=-1.0 โ เคธเฅเคฏ เค
เคคเฅเค เคตเฅ เคคเฅเค เค
เคตเฅเคฆเฅ เคฎเคจเคธเค เคธเฅเคฏ เฅค เฅค เคคเฅเค เคคเฅเค เคธเฅเคฏ เคเคคเคญเฅ เคธเฅเคฏ เค
เคธเฅเคคเฅเคตเค เคธเฅเคฆเฅ เคธเฅเคฏ เคคเฅเค เคธเฅเคฏ เคคเฅเค เคธเฅฬฑเคฎเฅ เคธเฅเคฆเฅ เคฐเฅเคง เคเฅเคคเคพเคจเคฟ เคเคคเคญเฅ เคเคคเคญเฅ เคคเฅเค เคคเฅเค เคธเฅเคฏ เคคเฅเค เคคเฅเค เคคเฅเค เคฎเคจเคธเค เคคเฅเค เคเฅเคคเคพเคจเคฟ เคคเฅเค เคคเฅเค เคธเฅฬฑเคฎเฅ เค
เคคเฅเค เคฎเคจเคธเค เค
เคฎเคจเคธเค เคธเฅเคฏ เค
เคคเฅเค เฅฅ เฅฅ เคธเฅเคฏ เคเคคเคญเฅ เคเคคเคญเฅ เฅฅ เฅฅ เคตเฅ เคคเฅเค เคคเฅเค เคฎเคจเคธเค เคคเฅเค เค
เคคเฅเค เคคเฅเค เค เคตเคฐ เคธเฅเคฏ เคคเฅเค เคฏเคพ เคตเคพเคคเฅ เคธเฅเคฏ เคคเฅเค เคธเฅเคฆเฅ เคคเฅเค เคธเฅเคฏ เคคเฅเค เคธเฅเคฏ เค
เคคเฅเค เคคเฅเค
|
| 19 |
+
alpha=+0.0 โ เค
เค
เคตเฅ เคเฅเค เคธเฅเคฏ เค
เคเฅเค เคเคคเคญเฅ เคตเคฐ เคฆ เคถเคฟเค เคฎเคจเฅเคคเฅเคฐ เคเคคเคญเฅ เคธเฅฬฑเคฎเฅ เฅค เคฆ เคฆ เคธเฅเคฏ เคฎเคจเฅเคคเฅเคฐ เคตเคพ เคฏเฅ เคธเฅเคฆเฅ เคเฅเค เคตเฅ เค
เคธเฅเคฏ เคธเฅเคฏ เคฎเคจเฅเคคเฅเคฐ เคธเฅเคฏ เคฎเคจเฅเคคเฅเคฐ เคธเฅเคฏ เคเคคเคญเฅ เฅค เคเคคเคญเฅ เคเคคเคญเฅ เคคเฅเค เคเคคเคญเฅ เคเฅเคค เคคเฅเค เคธเฅเคฏ เคคเฅเค เฅฅ เคตเฅ เคคเฅเค เฅฅ เคตเฅ เค
เคเฅเคคเคพเคจเคฟ เคธเฅเคฏ เคตเคฐ เคตเฅ เฅฅ เฅฅ เคตเฅ เฅฅ เค
เฅฅ เคธเฅเคฏ เฅฅ เคตเฅ เคธเฅเคฏ เคเฅเค เฅฅ เคธเฅเคฏ เคคเฅเค เคคเฅเค เคตเฅ เคธเฅเคฏ เคธเฅเคฏ เค
เคธเฅเคฏ เคคเฅเค เคตเฅ เคธเฅเคฏ เคคเฅเค เคชเฅเคฐเคฃ เคคเฅเคฐเฅ เคธเฅเคฏ เฅค เคธเฅเคฆเฅ
|
| 20 |
+
alpha=+1.0 โ เคชเคฎ เคตเฅ เคคเฅเคฒเฅเคฏ เคถเคคเฅเคฐเฅ เคชเคฎ เคถเคฟเค เคตเคฐ เค
เคฃเคพเค เคชเคฐเคพ เคฃเคพเค เคธเฅเคฏ เคเฅเคค เคชเฅเคฐเคฟเคฏ เฅค เคญเคฟเคจเฅ เคฃเคพเค เคเฅเค เคตเฅ เคตเคฟเคฐเคพเค เคตเฅ เคเคฃเฅ เคตเฅ เฅเคฏเคพ เค
เคตเฅ เคชเคฎ เฅเคฏเคพ เคญเคฟเคจเฅ เคตเฅ เคฒเคฌเฅเคง เคถเฅเคญ เคธเฅเคฏ เค เคถ เคตเคฐ เคตเฅ เฅฅ เคตเฅ เคเฅเคทเคฟเคชเฅเคฏ เคถเคฟเค เคญเคฟเคฐเฅ เฅฅ เคธเคจ เคตเคพ เคฎเคจเฅเคคเฅเคฐ เคฎเฅ เฅฅ เฅฅ เฅฅ เคตเฅ เฅฅ เคฎเคจเฅเคคเฅเคฐ เฅฅ เคชเคฎ เคธเคเฅ เคตเคฐ เคถเฅเคญ เคเฅเคทเคฟเคชเฅเคฏ เคญเคฟเคฐเฅ เคธเฅเคฏ เคเฅเคทเคฟเคชเฅเคฏ เคตเฅ เคธเคจ เคตเคฐ เคถเคฟเค เคตเฅ เคถเคฟเค เคตเคฐ เคฆเคฐเฅเคถ เคถเคฟเค เคเคฒเค เคชเคฎ เฅ เคตเคฐ เคเคฒเค เคญเคฟเคฐเฅ เคเคฒเค เคตเฅ เคถเคฟเค
|
| 21 |
+
alpha=+2.0 โ เคชเคฎ เคชเคฎ เคฒเคฌเฅเคง เคชเคฎ เคถเฅเคญ เคชเคฎ เคชเคฐเฅ เคญเคฟเคฐเฅ เค
เคจเฅเคฏ เคฃเคพเค เคฐเคธเคพ เคฒเคฌเฅเคง เคชเคฎ เคถเฅเคญ เคฒเคฌเฅเคง เคถเฅเคญ เคชเคฎ เคถเคคเฅเคฐเฅ เคถเคฟเค เคญเคฟเคจเฅ เคชเคฎ เคชเคฎ เคชเคฎ เคฃเคพเค เคถเฅเคญ เคฃเคพเค เคชเคฎ เคถเฅเคญ เคถเฅเคญ เคชเคฐเฅ เคฃเคพเค เคฃเคพเค เคชเคฎ เคชเคฎ เคชเคฎ เคชเคฎ เคถเฅเคญ เคชเคฎ เคถเฅเคญ เคถเฅเคญ เคชเคฎ เคกเคพ เคฃเคพเค เคถเฅเคญ เคตเฅ เคชเคฎ เคตเฅ เฅฅ เคชเคฎ เฅฅ เคชเคฎ เคฃเคพเค เฅฅ เฅฅ เคชเคฐเฅ เค
เคจเฅเคฏ เคฃเคพเค เคชเคฎ เคถเฅเคญ เคถเฅเคญ เคถเฅเคญ เคถเฅเคญ เคฃเคพเค เคตเฅ เคชเคฐเฅ เคเคฒเค เคชเคฐเฅ เคตเฅ เคชเคฎ เคฃเคพเค เคชเคฎ เคถเฅเคญ เคฃเคพเค เคฃเคพเค เคเคฒเค เคชเคฐเฅ เคถเฅเคญ เคฃเคพเค เคเคฒเค เคถเคคเฅเคฐเฅ
|
analysis_reports/outputs_all_models_20260325/T8/task4_report.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 4 โ SEMANTIC ROBUSTNESS ABLATION
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Optimal diffusion steps = 8
|
| 5 |
+
|
| 6 |
+
T BERT-F1 SEM_SIM BLEU sec/sample
|
| 7 |
+
--------------------------------------------------------
|
| 8 |
+
8 0.1210 0.0400 0.0000 0.6194
|
| 9 |
+
|
| 10 |
+
Marginal gains (BERT-F1):
|
| 11 |
+
|
| 12 |
+
Saved plots/files:
|
| 13 |
+
- task4_3d.png
|
| 14 |
+
- task4_raw_results.json
|
analysis_reports/outputs_all_models_20260325/T8/task5_report.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 5 โ CLASSIFIER-FREE GUIDANCE
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Classifier params: 139521
|
| 5 |
+
Training samples : 40
|
| 6 |
+
|
| 7 |
+
Guidance scale sweep:
|
| 8 |
+
ฮป CER diversity d2 sBLEU
|
| 9 |
+
----------------------------------------------------
|
| 10 |
+
0.0 0.8834 0.796 0.596 0.004 โ optimal
|
| 11 |
+
0.5 0.8881 0.781 0.568 0.005
|
| 12 |
+
1.0 0.8876 0.767 0.540 0.007
|
| 13 |
+
1.5 0.8921 0.757 0.517 0.004
|
| 14 |
+
2.0 0.8929 0.734 0.474 0.006
|
| 15 |
+
3.0 0.8970 0.724 0.453 0.005
|
config.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
CONFIG = {
|
| 4 |
+
"model_type": "d3pm_cross_attention",
|
| 5 |
+
"data": {
|
| 6 |
+
"include_negative_examples": True,
|
| 7 |
+
"dataset_size": 60000,
|
| 8 |
+
},
|
| 9 |
+
"diffusion": {
|
| 10 |
+
"mask_token_id": 0,
|
| 11 |
+
},
|
| 12 |
+
"model": {
|
| 13 |
+
"src_vocab_size": 16000,
|
| 14 |
+
"tgt_vocab_size": 16000,
|
| 15 |
+
"d_model": 384,
|
| 16 |
+
"n_heads": 8,
|
| 17 |
+
"d_ff": 1536,
|
| 18 |
+
"n_layers": 6,
|
| 19 |
+
"dropout": 0.1,
|
| 20 |
+
"max_seq_len": 80,
|
| 21 |
+
"diffusion_steps": 64,
|
| 22 |
+
},
|
| 23 |
+
"training": {
|
| 24 |
+
"device": "cuda" if torch.cuda.is_available() else "cpu",
|
| 25 |
+
},
|
| 26 |
+
"inference": {
|
| 27 |
+
"num_steps": 64,
|
| 28 |
+
"temperature": 0.7,
|
| 29 |
+
"top_k": 40,
|
| 30 |
+
"repetition_penalty": 1.2,
|
| 31 |
+
"diversity_penalty": 0.0,
|
| 32 |
+
},
|
| 33 |
+
}
|
diffusion/__init__.py
ADDED
|
File without changes
|
diffusion/forward_process.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
forward_process.py โ Verified Correct (no changes needed)
|
| 3 |
+
===========================================================
|
| 4 |
+
Absorbing (mask) diffusion. PAD never masked. At t=0 alpha=1.0 exactly
|
| 5 |
+
so x_t == x_0 (nothing masked). Works correctly with the fixed scheduler.
|
| 6 |
+
"""
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
class AbsorbingForwardProcess:
|
| 10 |
+
def __init__(self, scheduler, mask_id=0, pad_id=1):
|
| 11 |
+
self.scheduler = scheduler
|
| 12 |
+
self.mask_id = mask_id
|
| 13 |
+
self.pad_id = pad_id
|
| 14 |
+
|
| 15 |
+
def q_sample(self, x_0, t):
|
| 16 |
+
alpha_t = self.scheduler.get_alpha(t).to(x_0.device).view(-1, 1)
|
| 17 |
+
r = torch.rand(x_0.shape, device=x_0.device)
|
| 18 |
+
x_t = x_0.clone()
|
| 19 |
+
x_t[r > alpha_t] = self.mask_id
|
| 20 |
+
x_t[x_0 == self.pad_id] = self.pad_id # PAD stays PAD always
|
| 21 |
+
return x_0, x_t
|
diffusion/reverse_process.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
reverse_process.py โ Fixed
|
| 3 |
+
===========================
|
| 4 |
+
Two bugs fixed from the original:
|
| 5 |
+
|
| 6 |
+
BUG 1 (critical): generate_beam() passed x_t (noisy) as `tgt` to model.
|
| 7 |
+
The model does q_sample(tgt, t) internally โ so x_t got double-noised.
|
| 8 |
+
Fix: pass x0_estimate (current clean guess) as tgt. Model noises it correctly.
|
| 9 |
+
|
| 10 |
+
BUG 2: apply_diversity_penalty used logits.var(dim=-1) โ this adds the
|
| 11 |
+
variance of each position's own distribution back to itself, which is
|
| 12 |
+
mathematically meaningless and just injects noise.
|
| 13 |
+
Fix: penalize tokens that are uniformly high-probability across ALL positions
|
| 14 |
+
(global common tokens). This genuinely promotes diversity.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ReverseDiffusion:
|
| 22 |
+
def __init__(self, scheduler):
|
| 23 |
+
self.scheduler = scheduler
|
| 24 |
+
|
| 25 |
+
def p_sample_step(
|
| 26 |
+
self,
|
| 27 |
+
model,
|
| 28 |
+
x_t,
|
| 29 |
+
t,
|
| 30 |
+
condition,
|
| 31 |
+
beam_width=3,
|
| 32 |
+
temperature=1.0,
|
| 33 |
+
repetition_penalty=1.2,
|
| 34 |
+
diversity_penalty=0.3
|
| 35 |
+
):
|
| 36 |
+
"""
|
| 37 |
+
Single reverse step with temperature + penalties.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
|
| 42 |
+
# ---- Shape safety ----
|
| 43 |
+
if x_t.dim() == 1:
|
| 44 |
+
x_t = x_t.unsqueeze(0)
|
| 45 |
+
|
| 46 |
+
if condition.dim() == 1:
|
| 47 |
+
condition = condition.unsqueeze(0)
|
| 48 |
+
|
| 49 |
+
if t.dim() == 0:
|
| 50 |
+
t = t.unsqueeze(0)
|
| 51 |
+
|
| 52 |
+
if t.shape[0] != x_t.shape[0]:
|
| 53 |
+
t = t.expand(x_t.shape[0])
|
| 54 |
+
|
| 55 |
+
# ---- Model forward ----
|
| 56 |
+
logits, _ = model(condition, x_t, t)
|
| 57 |
+
|
| 58 |
+
# ---- Temperature scaling ----
|
| 59 |
+
logits = logits / temperature
|
| 60 |
+
|
| 61 |
+
# ---- Repetition penalty (FIXED VERSION) ----
|
| 62 |
+
if repetition_penalty != 1.0:
|
| 63 |
+
logits = apply_repetition_penalty(
|
| 64 |
+
logits, x_t, repetition_penalty
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# ---- Diversity penalty ----
|
| 68 |
+
if diversity_penalty > 0:
|
| 69 |
+
logits = apply_diversity_penalty(
|
| 70 |
+
logits, diversity_penalty
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
probs = F.softmax(logits, dim=-1)
|
| 74 |
+
|
| 75 |
+
B, L, V = probs.shape
|
| 76 |
+
|
| 77 |
+
# ---- Top-k beam expansion ----
|
| 78 |
+
topk_probs, topk_ids = torch.topk(
|
| 79 |
+
probs, beam_width, dim=-1
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
candidates = []
|
| 83 |
+
|
| 84 |
+
for k in range(beam_width):
|
| 85 |
+
next_tokens = topk_ids[:, :, k]
|
| 86 |
+
score = torch.log(
|
| 87 |
+
topk_probs[:, :, k] + 1e-9
|
| 88 |
+
).sum()
|
| 89 |
+
candidates.append((next_tokens, score))
|
| 90 |
+
|
| 91 |
+
return candidates
|
| 92 |
+
|
| 93 |
+
def generate_beam(
|
| 94 |
+
self,
|
| 95 |
+
model,
|
| 96 |
+
condition,
|
| 97 |
+
beam_width=3,
|
| 98 |
+
num_steps=None,
|
| 99 |
+
temperature=1.0,
|
| 100 |
+
repetition_penalty=1.2,
|
| 101 |
+
diversity_penalty=0.3
|
| 102 |
+
):
|
| 103 |
+
"""
|
| 104 |
+
Beam-search reverse diffusion with temperature.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
if num_steps is None:
|
| 108 |
+
num_steps = self.scheduler.num_timesteps
|
| 109 |
+
|
| 110 |
+
device = condition.device
|
| 111 |
+
|
| 112 |
+
if condition.dim() == 1:
|
| 113 |
+
condition = condition.unsqueeze(0)
|
| 114 |
+
|
| 115 |
+
B, L = condition.shape
|
| 116 |
+
|
| 117 |
+
# ๐ฅ Better initialization: start from MASK
|
| 118 |
+
x_init = torch.full(
|
| 119 |
+
(B, L),
|
| 120 |
+
fill_value=model.mask_token_id,
|
| 121 |
+
dtype=torch.long,
|
| 122 |
+
device=device
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
beams = [(x_init, 0.0)]
|
| 126 |
+
|
| 127 |
+
for step in reversed(range(num_steps)):
|
| 128 |
+
|
| 129 |
+
new_beams = []
|
| 130 |
+
|
| 131 |
+
for x_t, score in beams:
|
| 132 |
+
|
| 133 |
+
t_tensor = torch.full(
|
| 134 |
+
(B,),
|
| 135 |
+
step,
|
| 136 |
+
dtype=torch.long,
|
| 137 |
+
device=device
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
candidates = self.p_sample_step(
|
| 141 |
+
model,
|
| 142 |
+
x_t,
|
| 143 |
+
t_tensor,
|
| 144 |
+
condition,
|
| 145 |
+
beam_width,
|
| 146 |
+
temperature,
|
| 147 |
+
repetition_penalty,
|
| 148 |
+
diversity_penalty
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
for tokens, new_score in candidates:
|
| 152 |
+
new_beams.append(
|
| 153 |
+
(tokens, score + new_score)
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# ---- Keep top beams ----
|
| 157 |
+
new_beams = sorted(
|
| 158 |
+
new_beams,
|
| 159 |
+
key=lambda x: x[1],
|
| 160 |
+
reverse=True
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
beams = new_beams[:beam_width]
|
| 164 |
+
|
| 165 |
+
best_tokens, best_score = beams[0]
|
| 166 |
+
return best_tokens
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def generate(
|
| 171 |
+
self,
|
| 172 |
+
model,
|
| 173 |
+
condition,
|
| 174 |
+
num_steps=None,
|
| 175 |
+
temperature=0.8,
|
| 176 |
+
top_k=50,
|
| 177 |
+
repetition_penalty=1.2,
|
| 178 |
+
diversity_penalty=0.0,
|
| 179 |
+
):
|
| 180 |
+
"""
|
| 181 |
+
Correct D3PM iterative refinement.
|
| 182 |
+
|
| 183 |
+
x0_est starts as all [MASK].
|
| 184 |
+
Each step: forward(src=condition, tgt=x0_est, t)
|
| 185 |
+
โ model applies q_sample(x0_est, t) internally
|
| 186 |
+
โ predicts cleaner x0
|
| 187 |
+
โ x0_est updated
|
| 188 |
+
|
| 189 |
+
diversity_penalty: reduces probability of tokens that are
|
| 190 |
+
globally dominant across all sequence positions (not logits.var()).
|
| 191 |
+
"""
|
| 192 |
+
if num_steps is None:
|
| 193 |
+
num_steps = self.scheduler.num_timesteps
|
| 194 |
+
|
| 195 |
+
device = condition.device
|
| 196 |
+
if condition.dim() == 1:
|
| 197 |
+
condition = condition.unsqueeze(0)
|
| 198 |
+
B, L = condition.shape
|
| 199 |
+
|
| 200 |
+
T = self.scheduler.num_timesteps
|
| 201 |
+
step_size = max(1, T // num_steps)
|
| 202 |
+
timesteps = list(range(T - 1, -1, -step_size))
|
| 203 |
+
if timesteps[-1] != 0:
|
| 204 |
+
timesteps.append(0)
|
| 205 |
+
|
| 206 |
+
mask_id = model.mask_token_id
|
| 207 |
+
# Start: know nothing โ all MASK is our initial clean estimate
|
| 208 |
+
x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
|
| 209 |
+
hint = None
|
| 210 |
+
|
| 211 |
+
model.eval()
|
| 212 |
+
with torch.no_grad():
|
| 213 |
+
for step_idx, t_val in enumerate(timesteps):
|
| 214 |
+
t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 215 |
+
is_last = (step_idx == len(timesteps) - 1)
|
| 216 |
+
|
| 217 |
+
# KEY: pass x0_est as tgt โ model noises it internally
|
| 218 |
+
import inspect
|
| 219 |
+
sig = inspect.signature(model.forward).parameters
|
| 220 |
+
if 'x0_hint' in sig:
|
| 221 |
+
outputs = model(condition, x0_est, t, x0_hint=hint)
|
| 222 |
+
else:
|
| 223 |
+
outputs = model(condition, x0_est, t)
|
| 224 |
+
|
| 225 |
+
logits = outputs[0] if isinstance(outputs, tuple) else outputs
|
| 226 |
+
|
| 227 |
+
# Repetition penalty: down-weight tokens already in sequence
|
| 228 |
+
if repetition_penalty != 1.0:
|
| 229 |
+
logits = apply_repetition_penalty(logits, x0_est, repetition_penalty)
|
| 230 |
+
|
| 231 |
+
# Diversity penalty: reduce globally dominant tokens
|
| 232 |
+
if diversity_penalty > 0.0:
|
| 233 |
+
logits = apply_diversity_penalty(logits, diversity_penalty)
|
| 234 |
+
|
| 235 |
+
# Temperature + top-k
|
| 236 |
+
logits = logits / max(temperature, 1e-5)
|
| 237 |
+
if top_k > 0:
|
| 238 |
+
logits = top_k_filter(logits, top_k)
|
| 239 |
+
|
| 240 |
+
probs = F.softmax(logits, dim=-1)
|
| 241 |
+
|
| 242 |
+
if is_last:
|
| 243 |
+
x0_est = torch.argmax(probs, dim=-1)
|
| 244 |
+
else:
|
| 245 |
+
x0_est = batch_multinomial(probs)
|
| 246 |
+
|
| 247 |
+
hint = x0_est
|
| 248 |
+
|
| 249 |
+
return x0_est
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
# โโ Penalty functions โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 253 |
+
|
| 254 |
+
def apply_repetition_penalty(logits, prev_tokens, penalty=1.2):
|
| 255 |
+
"""
|
| 256 |
+
Down-weight tokens that already appear in the current sequence.
|
| 257 |
+
Prevents เคฎเคจเฅ เคฎเคจเฅ เคฎเคจเฅ repetition loops.
|
| 258 |
+
penalty=1.0 โ no effect
|
| 259 |
+
penalty=1.2 โ mild suppression of repeated tokens
|
| 260 |
+
penalty=2.0 โ strong suppression
|
| 261 |
+
"""
|
| 262 |
+
B, L, V = logits.shape
|
| 263 |
+
for b in range(B):
|
| 264 |
+
for token_id in set(prev_tokens[b].tolist()):
|
| 265 |
+
if token_id > 4: # don't penalize special tokens
|
| 266 |
+
logits[b, :, token_id] = logits[b, :, token_id] / penalty
|
| 267 |
+
return logits
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def apply_diversity_penalty(logits, penalty=0.5):
|
| 271 |
+
"""
|
| 272 |
+
Correct diversity penalty: penalize tokens that are globally dominant
|
| 273 |
+
across ALL sequence positions. This forces the model to use less
|
| 274 |
+
common tokens, increasing output diversity.
|
| 275 |
+
|
| 276 |
+
Method: compute mean probability across positions, subtract penalty
|
| 277 |
+
times that mean. Tokens uniformly high everywhere get suppressed.
|
| 278 |
+
|
| 279 |
+
penalty=0.0 โ no diversity enforcement
|
| 280 |
+
penalty=0.5 โ moderate diversity
|
| 281 |
+
penalty=1.0 โ strong diversity (may hurt coherence)
|
| 282 |
+
"""
|
| 283 |
+
# Mean logit across all positions: [B, V]
|
| 284 |
+
global_mean = logits.mean(dim=1, keepdim=True) # [B, 1, V]
|
| 285 |
+
# Subtract scaled global mean โ suppresses globally common tokens
|
| 286 |
+
return logits - penalty * global_mean
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def top_k_filter(logits, k):
|
| 290 |
+
B, L, V = logits.shape
|
| 291 |
+
if k >= V:
|
| 292 |
+
return logits
|
| 293 |
+
topk_vals, _ = torch.topk(logits, k, dim=-1)
|
| 294 |
+
threshold = topk_vals[..., -1].unsqueeze(-1)
|
| 295 |
+
return logits.masked_fill(logits < threshold, float('-inf'))
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def batch_multinomial(probs):
|
| 299 |
+
B, L, V = probs.shape
|
| 300 |
+
flat = probs.view(B * L, V) + 1e-9
|
| 301 |
+
flat = flat / flat.sum(dim=-1, keepdim=True)
|
| 302 |
+
return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
|
diffusion/reverse_process1.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ReverseDiffusion:
|
| 6 |
+
"""
|
| 7 |
+
Stable reverse diffusion with:
|
| 8 |
+
- Beam search
|
| 9 |
+
- Self conditioning
|
| 10 |
+
- Temperature sampling
|
| 11 |
+
- Repetition penalty
|
| 12 |
+
- Diversity penalty
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, scheduler):
|
| 16 |
+
|
| 17 |
+
self.scheduler = scheduler
|
| 18 |
+
|
| 19 |
+
self.temperature = 0.75
|
| 20 |
+
self.repetition_penalty = 1.15
|
| 21 |
+
self.diversity_penalty = 0.0
|
| 22 |
+
self.length_penalty = 1.0
|
| 23 |
+
|
| 24 |
+
# ------------------------------------------------
|
| 25 |
+
# penalties
|
| 26 |
+
# ------------------------------------------------
|
| 27 |
+
|
| 28 |
+
def apply_repetition_penalty(self, logits, tokens):
|
| 29 |
+
|
| 30 |
+
B, L, V = logits.shape
|
| 31 |
+
|
| 32 |
+
for b in range(B):
|
| 33 |
+
|
| 34 |
+
used = set(tokens[b].tolist())
|
| 35 |
+
|
| 36 |
+
for token_id in used:
|
| 37 |
+
logits[b, :, token_id] /= self.repetition_penalty
|
| 38 |
+
|
| 39 |
+
return logits
|
| 40 |
+
|
| 41 |
+
def apply_diversity_penalty(self, logits):
|
| 42 |
+
|
| 43 |
+
if self.diversity_penalty == 0:
|
| 44 |
+
return logits
|
| 45 |
+
|
| 46 |
+
logits_var = logits.var(dim=-1, keepdim=True)
|
| 47 |
+
return logits + self.diversity_penalty * logits_var
|
| 48 |
+
|
| 49 |
+
# ------------------------------------------------
|
| 50 |
+
# single reverse step
|
| 51 |
+
# ------------------------------------------------
|
| 52 |
+
|
| 53 |
+
def p_sample_step(self, model, x_t, t, condition, self_cond=None, beam_width=3):
|
| 54 |
+
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
|
| 57 |
+
logits, hidden = model(condition, x_t, t, self_cond)
|
| 58 |
+
|
| 59 |
+
logits = logits / self.temperature
|
| 60 |
+
|
| 61 |
+
logits = self.apply_repetition_penalty(logits, x_t)
|
| 62 |
+
logits = self.apply_diversity_penalty(logits)
|
| 63 |
+
|
| 64 |
+
probs = F.softmax(logits, dim=-1)
|
| 65 |
+
|
| 66 |
+
B, L, V = probs.shape
|
| 67 |
+
|
| 68 |
+
topk_probs, topk_ids = torch.topk(probs, beam_width, dim=-1)
|
| 69 |
+
|
| 70 |
+
candidates = []
|
| 71 |
+
|
| 72 |
+
for k in range(beam_width):
|
| 73 |
+
|
| 74 |
+
tokens = topk_ids[:, :, k]
|
| 75 |
+
|
| 76 |
+
score = torch.log(topk_probs[:, :, k] + 1e-9).sum()
|
| 77 |
+
|
| 78 |
+
candidates.append((tokens, score))
|
| 79 |
+
|
| 80 |
+
return candidates
|
| 81 |
+
|
| 82 |
+
# ------------------------------------------------
|
| 83 |
+
# beam reverse diffusion
|
| 84 |
+
# ------------------------------------------------
|
| 85 |
+
|
| 86 |
+
def generate_beam(self, model, condition, beam_width=3, num_steps=None):
|
| 87 |
+
|
| 88 |
+
if num_steps is None:
|
| 89 |
+
num_steps = self.scheduler.num_timesteps
|
| 90 |
+
|
| 91 |
+
device = condition.device
|
| 92 |
+
|
| 93 |
+
if condition.dim() == 1:
|
| 94 |
+
condition = condition.unsqueeze(0)
|
| 95 |
+
|
| 96 |
+
B, L = condition.shape
|
| 97 |
+
|
| 98 |
+
# ------------------------------------------------
|
| 99 |
+
# BETTER LATENT INITIALIZATION
|
| 100 |
+
# ------------------------------------------------
|
| 101 |
+
|
| 102 |
+
x_init = condition.clone()
|
| 103 |
+
|
| 104 |
+
mask = torch.rand_like(x_init.float()) < 0.5
|
| 105 |
+
x_init[mask] = model.mask_token_id
|
| 106 |
+
|
| 107 |
+
beams = [(x_init, 0.0)]
|
| 108 |
+
|
| 109 |
+
self_cond = None
|
| 110 |
+
|
| 111 |
+
for step in reversed(range(num_steps)):
|
| 112 |
+
|
| 113 |
+
new_beams = []
|
| 114 |
+
|
| 115 |
+
for x_t, score in beams:
|
| 116 |
+
|
| 117 |
+
t_tensor = torch.full(
|
| 118 |
+
(B,),
|
| 119 |
+
step,
|
| 120 |
+
dtype=torch.long,
|
| 121 |
+
device=device
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
candidates = self.p_sample_step(
|
| 125 |
+
model,
|
| 126 |
+
x_t,
|
| 127 |
+
t_tensor,
|
| 128 |
+
condition,
|
| 129 |
+
self_cond,
|
| 130 |
+
beam_width
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
for tokens, new_score in candidates:
|
| 134 |
+
|
| 135 |
+
length_norm = tokens.shape[1] ** self.length_penalty
|
| 136 |
+
|
| 137 |
+
final_score = (score + new_score) / length_norm
|
| 138 |
+
|
| 139 |
+
new_beams.append((tokens, final_score))
|
| 140 |
+
|
| 141 |
+
new_beams = sorted(
|
| 142 |
+
new_beams,
|
| 143 |
+
key=lambda x: x[1],
|
| 144 |
+
reverse=True
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
beams = new_beams[:beam_width]
|
| 148 |
+
|
| 149 |
+
# self conditioning
|
| 150 |
+
self_cond = beams[0][0]
|
| 151 |
+
|
| 152 |
+
best_tokens, best_score = beams[0]
|
| 153 |
+
|
| 154 |
+
return best_tokens
|
diffusion/reverse_process2.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
reverse_process.py โ Final Correct Version
|
| 3 |
+
=============================================
|
| 4 |
+
|
| 5 |
+
KEY PRINCIPLE: generate() must be byte-for-byte identical to run_inference()
|
| 6 |
+
in inference.py, which is what produced BERTScore 0.75 at validation.
|
| 7 |
+
|
| 8 |
+
CRITICAL BUG IN PREVIOUS VERSION:
|
| 9 |
+
We passed inference_mode=True to the model, but the model was NEVER
|
| 10 |
+
called with inference_mode=True during training or validation.
|
| 11 |
+
run_inference() (the validated path) does:
|
| 12 |
+
model(input_ids, x0_est, t, x0_hint=hint)
|
| 13 |
+
โ inference_mode defaults to False.
|
| 14 |
+
|
| 15 |
+
With inference_mode=True the model does two things differently:
|
| 16 |
+
1. tgt_pad_mask = None (training used tgt_pad_mask = tgt==PAD)
|
| 17 |
+
2. Skips q_sample at t=0 (training always called q_sample)
|
| 18 |
+
The model was never trained to handle these conditions โ garbage output.
|
| 19 |
+
|
| 20 |
+
Fix: do NOT pass inference_mode. Let it default to False, exactly
|
| 21 |
+
as run_inference() did.
|
| 22 |
+
|
| 23 |
+
BUGS FIXED (vs original reverse_process.py)
|
| 24 |
+
--------------------------------------------
|
| 25 |
+
BUG 1 generate_beam() used for D3PM โ all-แน repetition.
|
| 26 |
+
Use generate() (iterative refinement) from app1.py instead.
|
| 27 |
+
BUG 2 apply_diversity_penalty used logits.var() โ noise injection.
|
| 28 |
+
Fixed to logits - penalty * logits.mean(dim=1) โ global suppression.
|
| 29 |
+
BUG 3 x0_hint (self-conditioning) never passed to model.
|
| 30 |
+
Fixed: generate() passes x0_hint=hint every step.
|
| 31 |
+
BUG 4 params not forwarded from generate_beam() to p_sample_step().
|
| 32 |
+
Fixed in generate_beam() (kept for reference, not for production use).
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
import torch
|
| 36 |
+
import torch.nn.functional as F
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ReverseDiffusion:
|
| 40 |
+
|
| 41 |
+
def __init__(self, scheduler):
|
| 42 |
+
self.scheduler = scheduler
|
| 43 |
+
|
| 44 |
+
# Attribute-style defaults for backward compat with any code
|
| 45 |
+
# that sets reverse_diffusion.temperature = 0.9 etc.
|
| 46 |
+
# generate() prefers explicit kwargs and falls back to these.
|
| 47 |
+
self.temperature = 0.75
|
| 48 |
+
self.repetition_penalty = 1.15
|
| 49 |
+
self.diversity_penalty = 0.0
|
| 50 |
+
self.top_k = 50
|
| 51 |
+
|
| 52 |
+
# ------------------------------------------------------------------ #
|
| 53 |
+
# generate โ CORRECT D3PM iterative refinement #
|
| 54 |
+
# Exact equivalent of run_inference() in inference.py #
|
| 55 |
+
# ------------------------------------------------------------------ #
|
| 56 |
+
def generate(
|
| 57 |
+
self,
|
| 58 |
+
model,
|
| 59 |
+
condition,
|
| 60 |
+
num_steps = None,
|
| 61 |
+
temperature = None,
|
| 62 |
+
top_k = None,
|
| 63 |
+
repetition_penalty = None,
|
| 64 |
+
diversity_penalty = None,
|
| 65 |
+
):
|
| 66 |
+
"""
|
| 67 |
+
D3PM iterative refinement โ identical to run_inference() in inference.py,
|
| 68 |
+
which is the validated path (BERTScore 0.75).
|
| 69 |
+
|
| 70 |
+
Algorithm:
|
| 71 |
+
x0_est = all [MASK]
|
| 72 |
+
for t = T-1 down to 0:
|
| 73 |
+
logits = model(src, x0_est, t, x0_hint=hint)
|
| 74 |
+
โ inference_mode NOT passed (defaults to False)
|
| 75 |
+
โ this exactly matches training/validation
|
| 76 |
+
apply penalties, temperature, top_k
|
| 77 |
+
if t > 0: x0_est = multinomial(softmax(logits)) โ stochastic
|
| 78 |
+
if t = 0: x0_est = argmax(softmax(logits)) โ deterministic
|
| 79 |
+
hint = x0_est
|
| 80 |
+
"""
|
| 81 |
+
# Resolve: explicit kwarg > object attribute
|
| 82 |
+
temperature = temperature if temperature is not None else self.temperature
|
| 83 |
+
top_k = top_k if top_k is not None else self.top_k
|
| 84 |
+
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.repetition_penalty
|
| 85 |
+
diversity_penalty = diversity_penalty if diversity_penalty is not None else self.diversity_penalty
|
| 86 |
+
|
| 87 |
+
if num_steps is None:
|
| 88 |
+
num_steps = self.scheduler.num_timesteps
|
| 89 |
+
|
| 90 |
+
device = condition.device
|
| 91 |
+
if condition.dim() == 1:
|
| 92 |
+
condition = condition.unsqueeze(0)
|
| 93 |
+
B, L = condition.shape
|
| 94 |
+
|
| 95 |
+
T = self.scheduler.num_timesteps
|
| 96 |
+
step_size = max(1, T // num_steps)
|
| 97 |
+
timesteps = list(range(T - 1, -1, -step_size))
|
| 98 |
+
if timesteps[-1] != 0:
|
| 99 |
+
timesteps.append(0)
|
| 100 |
+
|
| 101 |
+
mask_id = model.mask_token_id
|
| 102 |
+
x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
|
| 103 |
+
hint = None
|
| 104 |
+
|
| 105 |
+
model.eval()
|
| 106 |
+
with torch.no_grad():
|
| 107 |
+
for step_idx, t_val in enumerate(timesteps):
|
| 108 |
+
t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 109 |
+
is_last = (step_idx == len(timesteps) - 1)
|
| 110 |
+
|
| 111 |
+
# โโ CRITICAL: do NOT pass inference_mode โโโโโโโโโโโโโโโโโโ
|
| 112 |
+
# inference_mode defaults to False inside SanskritModel /
|
| 113 |
+
# D3PMCrossAttention. This matches run_inference() exactly.
|
| 114 |
+
# Passing inference_mode=True changes tgt_pad_mask and
|
| 115 |
+
# q_sample behaviour โ the model was never trained for that.
|
| 116 |
+
logits, _ = model(condition, x0_est, t, x0_hint=hint)
|
| 117 |
+
|
| 118 |
+
# Repetition penalty
|
| 119 |
+
if repetition_penalty != 1.0:
|
| 120 |
+
logits = apply_repetition_penalty(
|
| 121 |
+
logits, x0_est, repetition_penalty
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Diversity penalty (correct: global mean suppression)
|
| 125 |
+
if diversity_penalty > 0.0:
|
| 126 |
+
logits = apply_diversity_penalty(logits, diversity_penalty)
|
| 127 |
+
|
| 128 |
+
logits = logits / max(temperature, 1e-5)
|
| 129 |
+
|
| 130 |
+
if top_k > 0:
|
| 131 |
+
logits = top_k_filter(logits, top_k)
|
| 132 |
+
|
| 133 |
+
probs = F.softmax(logits, dim=-1)
|
| 134 |
+
|
| 135 |
+
# Stochastic at every step except the last (argmax at t=0)
|
| 136 |
+
if is_last:
|
| 137 |
+
x0_est = torch.argmax(probs, dim=-1)
|
| 138 |
+
else:
|
| 139 |
+
x0_est = batch_multinomial(probs)
|
| 140 |
+
|
| 141 |
+
hint = x0_est
|
| 142 |
+
|
| 143 |
+
return x0_est # (B, L)
|
| 144 |
+
|
| 145 |
+
# ------------------------------------------------------------------ #
|
| 146 |
+
# p_sample_step โ used by generate_beam (not for production) #
|
| 147 |
+
# ------------------------------------------------------------------ #
|
| 148 |
+
def p_sample_step(
|
| 149 |
+
self,
|
| 150 |
+
model,
|
| 151 |
+
x_t,
|
| 152 |
+
t,
|
| 153 |
+
condition,
|
| 154 |
+
beam_width = 3,
|
| 155 |
+
temperature = 1.0,
|
| 156 |
+
repetition_penalty = 1.2,
|
| 157 |
+
diversity_penalty = 0.3,
|
| 158 |
+
x0_hint = None,
|
| 159 |
+
):
|
| 160 |
+
with torch.no_grad():
|
| 161 |
+
if x_t.dim() == 1: x_t = x_t.unsqueeze(0)
|
| 162 |
+
if condition.dim() == 1: condition = condition.unsqueeze(0)
|
| 163 |
+
if t.dim() == 0: t = t.unsqueeze(0)
|
| 164 |
+
if t.shape[0] != x_t.shape[0]:
|
| 165 |
+
t = t.expand(x_t.shape[0])
|
| 166 |
+
|
| 167 |
+
# No inference_mode โ matches training convention
|
| 168 |
+
logits, _ = model(condition, x_t, t, x0_hint=x0_hint)
|
| 169 |
+
|
| 170 |
+
logits = logits / max(temperature, 1e-5)
|
| 171 |
+
|
| 172 |
+
if repetition_penalty != 1.0:
|
| 173 |
+
logits = apply_repetition_penalty(logits, x_t, repetition_penalty)
|
| 174 |
+
if diversity_penalty > 0.0:
|
| 175 |
+
logits = apply_diversity_penalty(logits, diversity_penalty)
|
| 176 |
+
|
| 177 |
+
probs = F.softmax(logits, dim=-1)
|
| 178 |
+
B, L, V = probs.shape
|
| 179 |
+
|
| 180 |
+
topk_probs, topk_ids = torch.topk(probs, beam_width, dim=-1)
|
| 181 |
+
candidates = []
|
| 182 |
+
for k in range(beam_width):
|
| 183 |
+
next_tokens = topk_ids[:, :, k]
|
| 184 |
+
score = torch.log(topk_probs[:, :, k] + 1e-9).sum()
|
| 185 |
+
candidates.append((next_tokens, score))
|
| 186 |
+
return candidates
|
| 187 |
+
|
| 188 |
+
# ------------------------------------------------------------------ #
|
| 189 |
+
# generate_beam โ kept for reference; NOT the correct D3PM method #
|
| 190 |
+
# ------------------------------------------------------------------ #
|
| 191 |
+
def generate_beam(
|
| 192 |
+
self,
|
| 193 |
+
model,
|
| 194 |
+
condition,
|
| 195 |
+
beam_width = 3,
|
| 196 |
+
num_steps = None,
|
| 197 |
+
temperature = None,
|
| 198 |
+
repetition_penalty = None,
|
| 199 |
+
diversity_penalty = None,
|
| 200 |
+
):
|
| 201 |
+
"""
|
| 202 |
+
WARNING: do NOT call this from app1.py for D3PM generation.
|
| 203 |
+
generate_beam() forces every position to the same top-k token
|
| 204 |
+
โ all-แน / all-rud repetition. Use generate() instead.
|
| 205 |
+
Kept only for experimental reference.
|
| 206 |
+
"""
|
| 207 |
+
temperature = temperature if temperature is not None else self.temperature
|
| 208 |
+
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.repetition_penalty
|
| 209 |
+
diversity_penalty = diversity_penalty if diversity_penalty is not None else self.diversity_penalty
|
| 210 |
+
if num_steps is None:
|
| 211 |
+
num_steps = self.scheduler.num_timesteps
|
| 212 |
+
|
| 213 |
+
device = condition.device
|
| 214 |
+
if condition.dim() == 1: condition = condition.unsqueeze(0)
|
| 215 |
+
B, L = condition.shape
|
| 216 |
+
|
| 217 |
+
x_init = torch.full((B, L), fill_value=model.mask_token_id,
|
| 218 |
+
dtype=torch.long, device=device)
|
| 219 |
+
beams = [(x_init, 0.0)]
|
| 220 |
+
best_hint = None
|
| 221 |
+
|
| 222 |
+
for step in reversed(range(num_steps)):
|
| 223 |
+
t_tensor = torch.full((B,), step, dtype=torch.long, device=device)
|
| 224 |
+
new_beams = []
|
| 225 |
+
for x_t, score in beams:
|
| 226 |
+
candidates = self.p_sample_step(
|
| 227 |
+
model, x_t, t_tensor, condition,
|
| 228 |
+
beam_width = beam_width,
|
| 229 |
+
temperature = temperature,
|
| 230 |
+
repetition_penalty = repetition_penalty,
|
| 231 |
+
diversity_penalty = diversity_penalty,
|
| 232 |
+
x0_hint = best_hint,
|
| 233 |
+
)
|
| 234 |
+
for tokens, new_score in candidates:
|
| 235 |
+
new_beams.append((tokens, score + new_score.item()))
|
| 236 |
+
|
| 237 |
+
new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True)
|
| 238 |
+
beams = new_beams[:beam_width]
|
| 239 |
+
best_hint = beams[0][0]
|
| 240 |
+
|
| 241 |
+
return beams[0][0] # (B, L)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
# โโ Penalty helpers โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 245 |
+
|
| 246 |
+
def apply_repetition_penalty(logits, prev_tokens, penalty=1.2):
|
| 247 |
+
"""Down-weight tokens already present in the sequence."""
|
| 248 |
+
for b in range(logits.shape[0]):
|
| 249 |
+
for token_id in set(prev_tokens[b].tolist()):
|
| 250 |
+
if token_id > 4:
|
| 251 |
+
logits[b, :, token_id] = logits[b, :, token_id] / penalty
|
| 252 |
+
return logits
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def apply_diversity_penalty(logits, penalty=0.3):
|
| 256 |
+
"""
|
| 257 |
+
Correct diversity penalty: suppress globally dominant tokens.
|
| 258 |
+
logits -= penalty * mean(logits, dim=1) [sequence dimension]
|
| 259 |
+
"""
|
| 260 |
+
global_mean = logits.mean(dim=1, keepdim=True) # [B, 1, V]
|
| 261 |
+
return logits - penalty * global_mean
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def top_k_filter(logits, k):
|
| 265 |
+
B, L, V = logits.shape
|
| 266 |
+
if k >= V: return logits
|
| 267 |
+
topk_vals, _ = torch.topk(logits, k, dim=-1)
|
| 268 |
+
return logits.masked_fill(logits < topk_vals[..., -1].unsqueeze(-1), float('-inf'))
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def batch_multinomial(probs):
|
| 272 |
+
B, L, V = probs.shape
|
| 273 |
+
flat = probs.view(B * L, V) + 1e-9
|
| 274 |
+
flat = flat / flat.sum(dim=-1, keepdim=True)
|
| 275 |
+
return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
|
diffusion/scheduler.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
scheduler.py โ Fixed & Upgraded
|
| 3 |
+
==================================
|
| 4 |
+
Changes:
|
| 5 |
+
1. T=64 (was 16). More timesteps = richer denoising curriculum per epoch.
|
| 6 |
+
2. alpha at t=0 is EXACTLY 1.0 โ fixes Bug 2 (final-step re-noise).
|
| 7 |
+
3. sample_timestep samples [0, T-1] including t=0, so model trains on
|
| 8 |
+
fully-clean inputs (learns the identity at t=0 explicitly).
|
| 9 |
+
"""
|
| 10 |
+
import torch, math
|
| 11 |
+
|
| 12 |
+
class OptimizedCosineScheduler:
|
| 13 |
+
def __init__(self, cfg, device=None):
|
| 14 |
+
self.num_timesteps = cfg['model']['diffusion_steps'] # 64
|
| 15 |
+
self.mask_token_id = cfg['diffusion']['mask_token_id']
|
| 16 |
+
self.device = device or torch.device('cpu')
|
| 17 |
+
self.alphas_cumprod = self._build_schedule().to(self.device)
|
| 18 |
+
|
| 19 |
+
def _build_schedule(self):
|
| 20 |
+
T = self.num_timesteps
|
| 21 |
+
t = torch.arange(T + 1, dtype=torch.float32)
|
| 22 |
+
f_t = torch.cos((t / T + 0.008) / 1.008 * math.pi / 2) ** 2
|
| 23 |
+
alphas_bar = f_t / f_t[0]
|
| 24 |
+
alphas_bar = alphas_bar[1:] # shape [T]
|
| 25 |
+
alphas_bar[0] = 1.0 # FIX: exact 1.0 at t=0
|
| 26 |
+
alphas_bar[-1] = alphas_bar[-1].clamp(max=0.001)
|
| 27 |
+
return alphas_bar
|
| 28 |
+
|
| 29 |
+
def sample_timestep(self, batch_size):
|
| 30 |
+
"""Uniform [0, T-1] โ includes t=0 so model sees clean inputs."""
|
| 31 |
+
return torch.randint(0, self.num_timesteps, (batch_size,))
|
| 32 |
+
|
| 33 |
+
def get_alpha(self, t):
|
| 34 |
+
return self.alphas_cumprod[t.to(self.alphas_cumprod.device).long()]
|
handler.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict
|
| 2 |
+
|
| 3 |
+
from inference_api import predict
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class EndpointHandler:
|
| 7 |
+
"""
|
| 8 |
+
Hugging Face Inference Endpoint handler.
|
| 9 |
+
Expects payload:
|
| 10 |
+
{
|
| 11 |
+
"inputs": "dharmo rakแนฃati rakแนฃitaแธฅ",
|
| 12 |
+
"parameters": {"temperature": 0.7, ...}
|
| 13 |
+
}
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, path: str = ""):
|
| 17 |
+
self.path = path
|
| 18 |
+
|
| 19 |
+
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| 20 |
+
text = data.get("inputs", "")
|
| 21 |
+
params = data.get("parameters", {}) or {}
|
| 22 |
+
return predict(
|
| 23 |
+
text=text,
|
| 24 |
+
temperature=params.get("temperature", 0.7),
|
| 25 |
+
top_k=params.get("top_k", 40),
|
| 26 |
+
repetition_penalty=params.get("repetition_penalty", 1.2),
|
| 27 |
+
diversity_penalty=params.get("diversity_penalty", 0.0),
|
| 28 |
+
num_steps=params.get("num_steps", 64),
|
| 29 |
+
clean_output=params.get("clean_output", True),
|
| 30 |
+
)
|
inference.py
ADDED
|
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
inference.py
|
| 3 |
+
============
|
| 4 |
+
Correct D3PM inference for Sanskrit paraphrase generation.
|
| 5 |
+
|
| 6 |
+
The model's forward() takes CLEAN tgt and noises it internally.
|
| 7 |
+
So inference passes x0_estimate (starting all-[MASK]) as tgt each step,
|
| 8 |
+
letting the model noise it and then predict a cleaner version.
|
| 9 |
+
|
| 10 |
+
Also includes: robust checkpoint loading (auto-detects architecture
|
| 11 |
+
from saved weights โ no CONFIG mismatch crashes).
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import json
|
| 15 |
+
import torch
|
| 16 |
+
import os, sys
|
| 17 |
+
import re
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
from torch.utils.data import DataLoader, Subset
|
| 20 |
+
|
| 21 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 22 |
+
from config import CONFIG
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# โโ Checkpoint loader โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 26 |
+
|
| 27 |
+
def _resolve_device(cfg_device: str) -> torch.device:
|
| 28 |
+
cfg_device = (cfg_device or "").lower()
|
| 29 |
+
if cfg_device == "cuda" and torch.cuda.is_available():
|
| 30 |
+
return torch.device("cuda")
|
| 31 |
+
if cfg_device == "mps" and torch.backends.mps.is_available():
|
| 32 |
+
return torch.device("mps")
|
| 33 |
+
if cfg_device in {"cpu", "cuda", "mps"}:
|
| 34 |
+
return torch.device("cpu")
|
| 35 |
+
if torch.cuda.is_available():
|
| 36 |
+
return torch.device("cuda")
|
| 37 |
+
if torch.backends.mps.is_available():
|
| 38 |
+
return torch.device("mps")
|
| 39 |
+
return torch.device("cpu")
|
| 40 |
+
|
| 41 |
+
def load_model(ckpt_path: str, base_cfg: dict, device: torch.device):
|
| 42 |
+
"""
|
| 43 |
+
Auto-detect architecture from checkpoint weight shapes,
|
| 44 |
+
then load. Never fails due to CONFIG vs checkpoint mismatch.
|
| 45 |
+
"""
|
| 46 |
+
import copy
|
| 47 |
+
from model.sanskrit_model import SanskritModel
|
| 48 |
+
|
| 49 |
+
cfg = copy.deepcopy(base_cfg)
|
| 50 |
+
state = torch.load(ckpt_path, map_location='cpu')
|
| 51 |
+
|
| 52 |
+
# d_model + vocab_size
|
| 53 |
+
ek = 'model.src_embed.token_emb.weight'
|
| 54 |
+
if ek in state:
|
| 55 |
+
vocab, d = state[ek].shape
|
| 56 |
+
cfg['model']['vocab_size'] = vocab
|
| 57 |
+
cfg['model']['d_model'] = d
|
| 58 |
+
cfg['model']['d_ff'] = d * 4
|
| 59 |
+
|
| 60 |
+
# n_layers
|
| 61 |
+
ids = {int(k.split('.')[2]) for k in state if k.startswith('model.encoder_blocks.')}
|
| 62 |
+
if ids:
|
| 63 |
+
cfg['model']['n_layers'] = max(ids) + 1
|
| 64 |
+
|
| 65 |
+
# max_seq_len
|
| 66 |
+
pk = 'model.src_embed.pos_enc.pe'
|
| 67 |
+
if pk in state:
|
| 68 |
+
cfg['model']['max_seq_len'] = state[pk].shape[1]
|
| 69 |
+
|
| 70 |
+
# n_heads
|
| 71 |
+
d = cfg['model']['d_model']
|
| 72 |
+
h = cfg['model'].get('n_heads', 6)
|
| 73 |
+
if d % h != 0:
|
| 74 |
+
h = next(x for x in [8, 6, 4, 2, 1] if d % x == 0)
|
| 75 |
+
cfg['model']['n_heads'] = h
|
| 76 |
+
|
| 77 |
+
print(f"๐ Detected: d_model={cfg['model']['d_model']}, "
|
| 78 |
+
f"n_layers={cfg['model']['n_layers']}, "
|
| 79 |
+
f"max_seq_len={cfg['model']['max_seq_len']}, "
|
| 80 |
+
f"n_heads={cfg['model']['n_heads']}")
|
| 81 |
+
|
| 82 |
+
model = SanskritModel(cfg).to(device)
|
| 83 |
+
raw_state = torch.load(ckpt_path, map_location=device)
|
| 84 |
+
model_state = model.state_dict()
|
| 85 |
+
filtered_state = {}
|
| 86 |
+
skipped_mismatch = []
|
| 87 |
+
for k, v in raw_state.items():
|
| 88 |
+
if k in model_state and hasattr(v, "shape") and hasattr(model_state[k], "shape"):
|
| 89 |
+
if tuple(v.shape) != tuple(model_state[k].shape):
|
| 90 |
+
skipped_mismatch.append((k, tuple(v.shape), tuple(model_state[k].shape)))
|
| 91 |
+
continue
|
| 92 |
+
filtered_state[k] = v
|
| 93 |
+
|
| 94 |
+
missing, unexpected = model.load_state_dict(filtered_state, strict=False)
|
| 95 |
+
|
| 96 |
+
# hint_gate may be absent in older checkpoints โ initialise safely
|
| 97 |
+
allowed = {'model.hint_gate.0.weight', 'model.hint_gate.0.bias'}
|
| 98 |
+
real_missing = [k for k in missing if k not in allowed]
|
| 99 |
+
if real_missing:
|
| 100 |
+
print(f"โ ๏ธ Missing keys: {real_missing[:3]} โฆ")
|
| 101 |
+
if unexpected:
|
| 102 |
+
print(f"โ ๏ธ Unexpected keys: {unexpected[:3]} โฆ")
|
| 103 |
+
if skipped_mismatch:
|
| 104 |
+
print(f"โ ๏ธ Shape-mismatched keys skipped: {len(skipped_mismatch)}")
|
| 105 |
+
|
| 106 |
+
# Enable compact-attention branch only when checkpoint actually provides it.
|
| 107 |
+
has_compact = any(".compact_out_proj.weight" in k for k in filtered_state.keys())
|
| 108 |
+
if has_compact and hasattr(model, "model") and hasattr(model.model, "decoder_blocks"):
|
| 109 |
+
for block in model.model.decoder_blocks:
|
| 110 |
+
if hasattr(block, "cross_attn") and hasattr(block.cross_attn, "use_compact"):
|
| 111 |
+
block.cross_attn.use_compact = True
|
| 112 |
+
print("โน๏ธ Compact cross-attention branch enabled from checkpoint.")
|
| 113 |
+
if hasattr(model.model, 'hint_gate') and 'model.hint_gate.0.weight' in missing:
|
| 114 |
+
with torch.no_grad():
|
| 115 |
+
w = model.model.hint_gate[0].weight
|
| 116 |
+
torch.nn.init.zeros_(model.model.hint_gate[0].bias)
|
| 117 |
+
torch.nn.init.eye_(w) if w.shape[0] == w.shape[1] \
|
| 118 |
+
else torch.nn.init.xavier_uniform_(w)
|
| 119 |
+
print("โน๏ธ hint_gate initialised to identity (not in checkpoint).")
|
| 120 |
+
|
| 121 |
+
print("โ
Model loaded.")
|
| 122 |
+
return model, cfg
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# โโ Core inference function (same path as validation) โโโโโโโโโโโโโโโโ
|
| 126 |
+
|
| 127 |
+
@torch.no_grad()
|
| 128 |
+
def run_inference(model, input_ids, cfg):
|
| 129 |
+
"""
|
| 130 |
+
Reverse diffusion sampling (clean path).
|
| 131 |
+
Uses cached reverse diffusion when available, otherwise model.generate().
|
| 132 |
+
"""
|
| 133 |
+
inf = cfg['inference']
|
| 134 |
+
model.eval()
|
| 135 |
+
kwargs = dict(
|
| 136 |
+
num_steps=inf['num_steps'],
|
| 137 |
+
temperature=inf['temperature'],
|
| 138 |
+
top_k=inf['top_k'],
|
| 139 |
+
repetition_penalty=inf.get('repetition_penalty', 1.2),
|
| 140 |
+
diversity_penalty=inf.get('diversity_penalty', 0.0),
|
| 141 |
+
)
|
| 142 |
+
if hasattr(model, "generate_cached"):
|
| 143 |
+
out = model.generate_cached(input_ids, **kwargs)
|
| 144 |
+
else:
|
| 145 |
+
out = model.generate(input_ids, **kwargs)
|
| 146 |
+
|
| 147 |
+
# Optional retry with stronger anti-repetition settings.
|
| 148 |
+
if inf.get("auto_retry_on_repetition", True):
|
| 149 |
+
repeat_threshold = float(inf.get("repeat_ratio_threshold", 0.40))
|
| 150 |
+
max_repeat_run = int(inf.get("max_repeat_run", 4))
|
| 151 |
+
if _mean_repeat_ratio(out) >= repeat_threshold:
|
| 152 |
+
retry_kwargs = dict(kwargs)
|
| 153 |
+
retry_kwargs["temperature"] = max(0.6, float(kwargs["temperature"]) - 0.1)
|
| 154 |
+
retry_kwargs["top_k"] = max(20, int(kwargs["top_k"]) - 10)
|
| 155 |
+
retry_kwargs["repetition_penalty"] = max(float(kwargs["repetition_penalty"]), 1.6)
|
| 156 |
+
retry_kwargs["diversity_penalty"] = max(float(kwargs["diversity_penalty"]), 0.3)
|
| 157 |
+
if hasattr(model, "generate_cached"):
|
| 158 |
+
retry = model.generate_cached(input_ids, **retry_kwargs)
|
| 159 |
+
else:
|
| 160 |
+
retry = model.generate(input_ids, **retry_kwargs)
|
| 161 |
+
if _mean_repeat_ratio(retry) < _mean_repeat_ratio(out):
|
| 162 |
+
out = retry
|
| 163 |
+
out = _dedup_repeated_ids(out, max_repeat_run=max_repeat_run)
|
| 164 |
+
|
| 165 |
+
return out
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _mean_repeat_ratio(ids_tensor: torch.Tensor) -> float:
|
| 169 |
+
if ids_tensor is None or ids_tensor.numel() == 0:
|
| 170 |
+
return 0.0
|
| 171 |
+
ratios = []
|
| 172 |
+
for row in ids_tensor:
|
| 173 |
+
ids = [int(x) for x in row.tolist() if int(x) > 4]
|
| 174 |
+
if len(ids) < 2:
|
| 175 |
+
ratios.append(0.0)
|
| 176 |
+
continue
|
| 177 |
+
repeats = sum(1 for i in range(1, len(ids)) if ids[i] == ids[i - 1])
|
| 178 |
+
ratios.append(repeats / max(1, len(ids) - 1))
|
| 179 |
+
return float(sum(ratios) / max(1, len(ratios)))
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _dedup_repeated_ids(ids_tensor: torch.Tensor, max_repeat_run: int = 4) -> torch.Tensor:
|
| 183 |
+
"""
|
| 184 |
+
Keep generation path unchanged, but clean extreme run-on token loops in final output ids.
|
| 185 |
+
"""
|
| 186 |
+
if ids_tensor is None or ids_tensor.numel() == 0:
|
| 187 |
+
return ids_tensor
|
| 188 |
+
cleaned_rows = []
|
| 189 |
+
for row in ids_tensor.tolist():
|
| 190 |
+
out = []
|
| 191 |
+
prev = None
|
| 192 |
+
run = 0
|
| 193 |
+
for tok in row:
|
| 194 |
+
if tok <= 4:
|
| 195 |
+
out.append(tok)
|
| 196 |
+
prev = tok
|
| 197 |
+
run = 1
|
| 198 |
+
continue
|
| 199 |
+
if tok == prev:
|
| 200 |
+
run += 1
|
| 201 |
+
if run > max_repeat_run:
|
| 202 |
+
continue
|
| 203 |
+
else:
|
| 204 |
+
run = 1
|
| 205 |
+
out.append(tok)
|
| 206 |
+
prev = tok
|
| 207 |
+
# Preserve original length for downstream decode assumptions.
|
| 208 |
+
if len(out) < len(row):
|
| 209 |
+
out.extend([1] * (len(row) - len(out)))
|
| 210 |
+
else:
|
| 211 |
+
out = out[:len(row)]
|
| 212 |
+
cleaned_rows.append(out)
|
| 213 |
+
return torch.tensor(cleaned_rows, dtype=ids_tensor.dtype, device=ids_tensor.device)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def _decode_clean(tgt_tok, ids):
|
| 217 |
+
out = []
|
| 218 |
+
for x in ids:
|
| 219 |
+
if x in (1, 4) and out:
|
| 220 |
+
break
|
| 221 |
+
if x > 4:
|
| 222 |
+
out.append(x)
|
| 223 |
+
text = tgt_tok.decode(out).strip()
|
| 224 |
+
return _clean_repetition_text(text)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def _clean_repetition_text(text: str, max_repeat_run: int = 3) -> str:
|
| 228 |
+
words = [w for w in text.split() if w.strip()]
|
| 229 |
+
if not words:
|
| 230 |
+
return text.strip()
|
| 231 |
+
cleaned = []
|
| 232 |
+
prev = None
|
| 233 |
+
run = 0
|
| 234 |
+
for w in words:
|
| 235 |
+
if w == prev:
|
| 236 |
+
run += 1
|
| 237 |
+
if run > max_repeat_run:
|
| 238 |
+
continue
|
| 239 |
+
else:
|
| 240 |
+
run = 1
|
| 241 |
+
cleaned.append(w)
|
| 242 |
+
prev = w
|
| 243 |
+
return " ".join(cleaned).strip()
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
# โโ Cleanup heuristics from UI inference pipeline โโโโโโโโโโโโโโโโโโโโโ
|
| 247 |
+
|
| 248 |
+
_IAST_VOWELS = [
|
| 249 |
+
("ai", "เค"), ("au", "เค"),
|
| 250 |
+
("ฤ", "เค"), ("ฤซ", "เค"), ("ลซ", "เค"),
|
| 251 |
+
("แน", "เค"), ("แน", "เฅ "), ("แธท", "เค"), ("แธน", "เฅก"),
|
| 252 |
+
("a", "เค
"), ("i", "เค"), ("u", "เค"),
|
| 253 |
+
("e", "เค"), ("o", "เค"),
|
| 254 |
+
]
|
| 255 |
+
_IAST_MATRAS = [
|
| 256 |
+
("ai", "เฅ"), ("au", "เฅ"),
|
| 257 |
+
("ฤ", "เคพ"), ("ฤซ", "เฅ"), ("ลซ", "เฅ"),
|
| 258 |
+
("แน", "เฅ"), ("แน", "เฅ"), ("แธท", "เฅข"), ("แธน", "เฅฃ"),
|
| 259 |
+
("a", ""), ("i", "เคฟ"), ("u", "เฅ"),
|
| 260 |
+
("e", "เฅ"), ("o", "เฅ"),
|
| 261 |
+
]
|
| 262 |
+
_IAST_CONS = [
|
| 263 |
+
("kแนฃ", "เคเฅเคท"), ("jรฑ", "เคเฅเค"), ("tr", "เคคเฅเคฐ"),
|
| 264 |
+
("kh", "เค"), ("gh", "เค"), ("ch", "เค"), ("jh", "เค"),
|
| 265 |
+
("แนญh", "เค "), ("แธh", "เคข"), ("th", "เคฅ"), ("dh", "เคง"),
|
| 266 |
+
("ph", "เคซ"), ("bh", "เคญ"),
|
| 267 |
+
("แน
", "เค"), ("รฑ", "เค"), ("แนญ", "เค"), ("แธ", "เคก"),
|
| 268 |
+
("แน", "เคฃ"), ("ล", "เคถ"), ("แนฃ", "เคท"), ("แธฅ", "เค"),
|
| 269 |
+
("แน", "เค"), ("แน", "เค"),
|
| 270 |
+
("y", "๏ฟฝ๏ฟฝ๏ฟฝ"), ("r", "เคฐ"), ("l", "เคฒ"), ("v", "เคต"),
|
| 271 |
+
("s", "เคธ"), ("h", "เคน"),
|
| 272 |
+
("k", "เค"), ("g", "เค"), ("c", "เค"), ("j", "เค"),
|
| 273 |
+
("t", "เคค"), ("d", "เคฆ"), ("n", "เคจ"),
|
| 274 |
+
("p", "เคช"), ("b", "เคฌ"), ("m", "เคฎ"),
|
| 275 |
+
]
|
| 276 |
+
_PUNCT = {".": "เฅค", "|": "เฅค", "||": "เฅฅ", ",": ",", "?": "?", "!": "!"}
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def _iast_to_deva(text: str) -> str:
|
| 280 |
+
s = (text or "").lower()
|
| 281 |
+
out = []
|
| 282 |
+
i = 0
|
| 283 |
+
pending_consonant = False
|
| 284 |
+
|
| 285 |
+
def _match_any(pairs, pos):
|
| 286 |
+
for k, v in pairs:
|
| 287 |
+
if s.startswith(k, pos):
|
| 288 |
+
return k, v
|
| 289 |
+
return None, None
|
| 290 |
+
|
| 291 |
+
while i < len(s):
|
| 292 |
+
if s[i].isspace():
|
| 293 |
+
pending_consonant = False
|
| 294 |
+
out.append(s[i])
|
| 295 |
+
i += 1
|
| 296 |
+
continue
|
| 297 |
+
if s[i:i+2] == "||":
|
| 298 |
+
pending_consonant = False
|
| 299 |
+
out.append(_PUNCT["||"])
|
| 300 |
+
i += 2
|
| 301 |
+
continue
|
| 302 |
+
if s[i] in _PUNCT:
|
| 303 |
+
pending_consonant = False
|
| 304 |
+
out.append(_PUNCT[s[i]])
|
| 305 |
+
i += 1
|
| 306 |
+
continue
|
| 307 |
+
|
| 308 |
+
v_key, v_deva = _match_any(_IAST_VOWELS, i)
|
| 309 |
+
if v_key:
|
| 310 |
+
if pending_consonant:
|
| 311 |
+
_, v_matra = _match_any(_IAST_MATRAS, i)
|
| 312 |
+
out[-1] = out[-1] + (v_matra or "")
|
| 313 |
+
pending_consonant = False
|
| 314 |
+
else:
|
| 315 |
+
out.append(v_deva)
|
| 316 |
+
i += len(v_key)
|
| 317 |
+
continue
|
| 318 |
+
|
| 319 |
+
c_key, c_deva = _match_any(_IAST_CONS, i)
|
| 320 |
+
if c_key:
|
| 321 |
+
if pending_consonant:
|
| 322 |
+
out[-1] = out[-1] + "เฅ"
|
| 323 |
+
out.append(c_deva)
|
| 324 |
+
pending_consonant = True
|
| 325 |
+
i += len(c_key)
|
| 326 |
+
continue
|
| 327 |
+
|
| 328 |
+
out.append(s[i])
|
| 329 |
+
pending_consonant = False
|
| 330 |
+
i += 1
|
| 331 |
+
|
| 332 |
+
return "".join(out).strip()
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def _compute_cer(pred: str, ref: str) -> float:
|
| 336 |
+
if pred == ref:
|
| 337 |
+
return 0.0
|
| 338 |
+
if not pred or not ref:
|
| 339 |
+
return 1.0
|
| 340 |
+
m, n = len(pred), len(ref)
|
| 341 |
+
dp = list(range(n + 1))
|
| 342 |
+
for i in range(1, m + 1):
|
| 343 |
+
prev = dp[0]
|
| 344 |
+
dp[0] = i
|
| 345 |
+
for j in range(1, n + 1):
|
| 346 |
+
temp = dp[j]
|
| 347 |
+
cost = 0 if pred[i - 1] == ref[j - 1] else 1
|
| 348 |
+
dp[j] = min(dp[j] + 1, dp[j - 1] + 1, prev + cost)
|
| 349 |
+
prev = temp
|
| 350 |
+
return dp[n] / max(m, n)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def _cleanup_thresholds(temperature: float, top_k: int):
|
| 354 |
+
temp = float(temperature)
|
| 355 |
+
k = max(1, int(top_k))
|
| 356 |
+
t_norm = max(0.0, min((temp - 0.4) / 0.6, 1.0))
|
| 357 |
+
k_norm = max(0.0, min((k - 20) / 80.0, 1.0))
|
| 358 |
+
diversity = 0.6 * t_norm + 0.4 * k_norm
|
| 359 |
+
cer_threshold = 0.10 + 0.18 * diversity
|
| 360 |
+
deva_ratio_threshold = 0.60 - 0.20 * diversity
|
| 361 |
+
return cer_threshold, deva_ratio_threshold
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def _decode_with_cleanup(tgt_tok, ids, src_text: str, inf_cfg: dict):
|
| 365 |
+
model_out = _decode_clean(tgt_tok, ids)
|
| 366 |
+
rule_out = _iast_to_deva(src_text.strip())
|
| 367 |
+
deva_chars = sum(1 for ch in model_out if "\u0900" <= ch <= "\u097F")
|
| 368 |
+
deva_ratio = deva_chars / max(1, len(model_out))
|
| 369 |
+
cer = _compute_cer(model_out, rule_out)
|
| 370 |
+
cer_thr, ratio_thr = _cleanup_thresholds(
|
| 371 |
+
inf_cfg.get("temperature", 0.8),
|
| 372 |
+
inf_cfg.get("top_k", 40),
|
| 373 |
+
)
|
| 374 |
+
if deva_ratio < ratio_thr or len(model_out) > 2.0 * max(1, len(rule_out)) or cer > cer_thr:
|
| 375 |
+
return rule_out
|
| 376 |
+
return model_out
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
# โโ Interactive demo โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 380 |
+
|
| 381 |
+
def interactive_demo(checkpoint=None, single_text=None):
|
| 382 |
+
from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer
|
| 383 |
+
|
| 384 |
+
cfg = CONFIG
|
| 385 |
+
device = _resolve_device(cfg['training'].get('device', 'cpu'))
|
| 386 |
+
|
| 387 |
+
model_name = cfg['model_type']
|
| 388 |
+
has_neg = cfg['data']['include_negative_examples']
|
| 389 |
+
ckpt = checkpoint or f"results/{model_name}_neg_{has_neg}/best_model.pt"
|
| 390 |
+
|
| 391 |
+
if not os.path.exists(ckpt):
|
| 392 |
+
raise FileNotFoundError(f"No checkpoint at {ckpt} โ train first.")
|
| 393 |
+
|
| 394 |
+
model, cfg = load_model(ckpt, cfg, device)
|
| 395 |
+
model.eval()
|
| 396 |
+
|
| 397 |
+
src_tok = SanskritSourceTokenizer(
|
| 398 |
+
vocab_size=cfg['model'].get('src_vocab_size', 16000),
|
| 399 |
+
max_len=cfg['model']['max_seq_len'],
|
| 400 |
+
)
|
| 401 |
+
tgt_tok = SanskritTargetTokenizer(
|
| 402 |
+
vocab_size=cfg['model'].get('tgt_vocab_size', 16000),
|
| 403 |
+
max_len=cfg['model']['max_seq_len'],
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
print("\n" + "="*55)
|
| 407 |
+
print("Sanskrit D3PM Paraphrase โ type verse, get paraphrase")
|
| 408 |
+
print("="*55 + "\n")
|
| 409 |
+
|
| 410 |
+
while True:
|
| 411 |
+
try:
|
| 412 |
+
text = (single_text if single_text is not None else input("INPUT > ")).strip()
|
| 413 |
+
except (EOFError, KeyboardInterrupt):
|
| 414 |
+
break
|
| 415 |
+
if not text or text.lower() in ('quit', 'exit', 'q'):
|
| 416 |
+
break
|
| 417 |
+
|
| 418 |
+
ids = torch.tensor(
|
| 419 |
+
[src_tok.encode(text)[:cfg['model']['max_seq_len']]],
|
| 420 |
+
dtype=torch.long, device=device
|
| 421 |
+
)
|
| 422 |
+
out = run_inference(model, ids, cfg)
|
| 423 |
+
cleaned = _decode_with_cleanup(tgt_tok, out[0].tolist(), text, cfg["inference"])
|
| 424 |
+
print(f"PARAPHRASE โ {cleaned}\n")
|
| 425 |
+
if single_text is not None:
|
| 426 |
+
break
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
# โโ Batch evaluation โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 430 |
+
|
| 431 |
+
def batch_evaluate(sample_size=500, checkpoint=None):
|
| 432 |
+
from data.dataset import OptimizedSanskritDataset
|
| 433 |
+
from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer
|
| 434 |
+
|
| 435 |
+
cfg = CONFIG
|
| 436 |
+
device = _resolve_device(cfg['training'].get('device', 'cpu'))
|
| 437 |
+
|
| 438 |
+
model_name = cfg['model_type']
|
| 439 |
+
has_neg = cfg['data']['include_negative_examples']
|
| 440 |
+
exp_dir = f"results/{model_name}_neg_{has_neg}"
|
| 441 |
+
ckpt = checkpoint or f"{exp_dir}/best_model.pt"
|
| 442 |
+
|
| 443 |
+
if not os.path.exists(ckpt):
|
| 444 |
+
raise FileNotFoundError(f"No checkpoint at {ckpt}")
|
| 445 |
+
|
| 446 |
+
model, cfg = load_model(ckpt, cfg, device)
|
| 447 |
+
model.eval()
|
| 448 |
+
|
| 449 |
+
src_tok = SanskritSourceTokenizer(
|
| 450 |
+
vocab_size=cfg['model'].get('src_vocab_size', 16000),
|
| 451 |
+
max_len=cfg['model']['max_seq_len'],
|
| 452 |
+
)
|
| 453 |
+
tgt_tok = SanskritTargetTokenizer(
|
| 454 |
+
vocab_size=cfg['model'].get('tgt_vocab_size', 16000),
|
| 455 |
+
max_len=cfg['model']['max_seq_len'],
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
def collate(batch):
|
| 459 |
+
return {
|
| 460 |
+
'input_ids': torch.stack([b['input_ids'].long() for b in batch]),
|
| 461 |
+
'target_text': [b['target_text'] for b in batch],
|
| 462 |
+
'input_text': [b['input_text'] for b in batch],
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
dataset = OptimizedSanskritDataset(
|
| 466 |
+
split='test',
|
| 467 |
+
max_len=cfg['model']['max_seq_len'],
|
| 468 |
+
cfg=cfg,
|
| 469 |
+
src_tokenizer=src_tok,
|
| 470 |
+
tgt_tokenizer=tgt_tok,
|
| 471 |
+
)
|
| 472 |
+
indices = list(range(min(sample_size, len(dataset))))
|
| 473 |
+
loader = DataLoader(
|
| 474 |
+
Subset(dataset, indices),
|
| 475 |
+
batch_size=cfg['training']['batch_size'],
|
| 476 |
+
shuffle=False, collate_fn=collate
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
all_preds, all_refs, all_inputs = [], [], []
|
| 480 |
+
print(f"โณ Generating {len(indices)} paraphrases โฆ")
|
| 481 |
+
|
| 482 |
+
for batch in tqdm(loader):
|
| 483 |
+
ids = batch['input_ids'].to(device)
|
| 484 |
+
out = run_inference(model, ids, cfg)
|
| 485 |
+
for i in range(out.size(0)):
|
| 486 |
+
all_preds.append(_decode_with_cleanup(
|
| 487 |
+
tgt_tok, out[i].tolist(), batch['input_text'][i], cfg["inference"]
|
| 488 |
+
))
|
| 489 |
+
all_refs.append(batch['target_text'][i].strip())
|
| 490 |
+
all_inputs.append(batch['input_text'][i].strip())
|
| 491 |
+
|
| 492 |
+
# Metrics
|
| 493 |
+
bleu_score, bert_f1 = 0.0, 0.0
|
| 494 |
+
try:
|
| 495 |
+
from nltk.translate.bleu_score import corpus_bleu
|
| 496 |
+
bleu_score = corpus_bleu(
|
| 497 |
+
[[r.split()] for r in all_refs],
|
| 498 |
+
[p.split() for p in all_preds]
|
| 499 |
+
)
|
| 500 |
+
except Exception:
|
| 501 |
+
pass
|
| 502 |
+
|
| 503 |
+
try:
|
| 504 |
+
import evaluate as hf_eval
|
| 505 |
+
res = hf_eval.load('bertscore').compute(
|
| 506 |
+
predictions=all_preds, references=all_refs, lang='hi'
|
| 507 |
+
)
|
| 508 |
+
bert_f1 = sum(res['f1']) / len(res['f1'])
|
| 509 |
+
except Exception:
|
| 510 |
+
pass
|
| 511 |
+
|
| 512 |
+
# Save
|
| 513 |
+
out_path = f"{exp_dir}/evaluation_results.txt"
|
| 514 |
+
pred_path = f"{exp_dir}/evaluation_predictions.jsonl"
|
| 515 |
+
with open(out_path, 'w', encoding='utf-8') as f:
|
| 516 |
+
f.write(f"Model : {model_name}\n")
|
| 517 |
+
f.write(f"Negatives: {has_neg}\n")
|
| 518 |
+
f.write(f"Steps : {cfg['inference']['num_steps']}\n")
|
| 519 |
+
f.write(f"Temp : {cfg['inference']['temperature']}\n")
|
| 520 |
+
f.write(f"RepPen : {cfg['inference']['repetition_penalty']}\n")
|
| 521 |
+
f.write(f"DivPen : {cfg['inference']['diversity_penalty']}\n")
|
| 522 |
+
f.write(f"BLEU : {bleu_score:.4f}\n")
|
| 523 |
+
f.write(f"BERTScore: {bert_f1:.4f}\n\n")
|
| 524 |
+
f.write("=== SAMPLES ===\n")
|
| 525 |
+
for i in range(min(20, len(all_preds))):
|
| 526 |
+
f.write(f"IN : {all_inputs[i]}\n")
|
| 527 |
+
f.write(f"REF : {all_refs[i]}\n")
|
| 528 |
+
f.write(f"PRED: {all_preds[i]}\n")
|
| 529 |
+
f.write("-" * 60 + "\n")
|
| 530 |
+
|
| 531 |
+
with open(pred_path, 'w', encoding='utf-8') as f:
|
| 532 |
+
for src, ref, pred in zip(all_inputs, all_refs, all_preds):
|
| 533 |
+
row = {"input": src, "reference": ref, "prediction": pred}
|
| 534 |
+
f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
| 535 |
+
|
| 536 |
+
print(f"\nโ
Results โ {out_path}")
|
| 537 |
+
print(f"๐๏ธ Saved predictions โ {pred_path}")
|
| 538 |
+
print(f"๐ BLEU: {bleu_score:.4f} | BERTScore: {bert_f1:.4f}")
|
| 539 |
+
return all_preds, all_refs
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
if __name__ == '__main__':
|
| 543 |
+
import argparse
|
| 544 |
+
p = argparse.ArgumentParser()
|
| 545 |
+
p.add_argument('--mode', choices=['demo', 'eval'], default='demo')
|
| 546 |
+
p.add_argument('--samples', type=int, default=500)
|
| 547 |
+
p.add_argument('--checkpoint', type=str, default=None)
|
| 548 |
+
p.add_argument('--text', type=str, default=None, help='Run one-shot demo input and exit')
|
| 549 |
+
args = p.parse_args()
|
| 550 |
+
|
| 551 |
+
if args.mode == 'demo':
|
| 552 |
+
interactive_demo(checkpoint=args.checkpoint, single_text=args.text)
|
| 553 |
+
else:
|
| 554 |
+
batch_evaluate(args.samples, checkpoint=args.checkpoint)
|
inference_api.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from typing import Dict, Any
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from config import CONFIG
|
| 9 |
+
from inference import _build_tokenizers, _resolve_device, load_model, run_inference
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
_STATE = {
|
| 13 |
+
"loaded": False,
|
| 14 |
+
"model": None,
|
| 15 |
+
"cfg": None,
|
| 16 |
+
"device": None,
|
| 17 |
+
"src_tok": None,
|
| 18 |
+
"tgt_tok": None,
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _read_model_settings() -> Dict[str, Any]:
|
| 23 |
+
if not os.path.exists("model_settings.json"):
|
| 24 |
+
return {}
|
| 25 |
+
try:
|
| 26 |
+
with open("model_settings.json", "r", encoding="utf-8") as f:
|
| 27 |
+
data = json.load(f)
|
| 28 |
+
return data if isinstance(data, dict) else {}
|
| 29 |
+
except Exception:
|
| 30 |
+
return {}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _load_once() -> None:
|
| 34 |
+
if _STATE["loaded"]:
|
| 35 |
+
return
|
| 36 |
+
|
| 37 |
+
settings = _read_model_settings()
|
| 38 |
+
cfg = copy.deepcopy(CONFIG)
|
| 39 |
+
cfg["model_type"] = os.environ.get(
|
| 40 |
+
"HF_MODEL_TYPE",
|
| 41 |
+
settings.get("model_type", "d3pm_cross_attention"),
|
| 42 |
+
)
|
| 43 |
+
cfg["data"]["include_negative_examples"] = (
|
| 44 |
+
os.environ.get(
|
| 45 |
+
"HF_INCLUDE_NEG",
|
| 46 |
+
str(settings.get("include_negative_examples", True)).lower(),
|
| 47 |
+
).lower()
|
| 48 |
+
== "true"
|
| 49 |
+
)
|
| 50 |
+
num_steps_raw = os.environ.get("HF_NUM_STEPS", settings.get("num_steps"))
|
| 51 |
+
if num_steps_raw is not None:
|
| 52 |
+
num_steps = int(num_steps_raw)
|
| 53 |
+
cfg["model"]["diffusion_steps"] = num_steps
|
| 54 |
+
cfg["inference"]["num_steps"] = num_steps
|
| 55 |
+
device = _resolve_device(cfg)
|
| 56 |
+
|
| 57 |
+
model, cfg = load_model("best_model.pt", cfg, device)
|
| 58 |
+
src_tok, tgt_tok = _build_tokenizers(cfg)
|
| 59 |
+
|
| 60 |
+
_STATE["model"] = model
|
| 61 |
+
_STATE["cfg"] = cfg
|
| 62 |
+
_STATE["device"] = device
|
| 63 |
+
_STATE["src_tok"] = src_tok
|
| 64 |
+
_STATE["tgt_tok"] = tgt_tok
|
| 65 |
+
_STATE["loaded"] = True
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _clean_text(text: str) -> str:
|
| 69 |
+
text = " ".join(text.split())
|
| 70 |
+
if not text:
|
| 71 |
+
return text
|
| 72 |
+
toks = text.split()
|
| 73 |
+
out = []
|
| 74 |
+
prev = None
|
| 75 |
+
run = 0
|
| 76 |
+
for tok in toks:
|
| 77 |
+
if tok == prev:
|
| 78 |
+
run += 1
|
| 79 |
+
else:
|
| 80 |
+
prev = tok
|
| 81 |
+
run = 1
|
| 82 |
+
if run <= 2:
|
| 83 |
+
out.append(tok)
|
| 84 |
+
s = " ".join(out)
|
| 85 |
+
s = s.replace(" เฅค", "เฅค").replace(" เฅฅ", "เฅฅ")
|
| 86 |
+
return " ".join(s.split())
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def predict(
|
| 90 |
+
text: str,
|
| 91 |
+
temperature: float = 0.7,
|
| 92 |
+
top_k: int = 40,
|
| 93 |
+
repetition_penalty: float = 1.2,
|
| 94 |
+
diversity_penalty: float = 0.0,
|
| 95 |
+
num_steps: int = 64,
|
| 96 |
+
clean_output: bool = True,
|
| 97 |
+
) -> Dict[str, Any]:
|
| 98 |
+
_load_once()
|
| 99 |
+
if not text or not text.strip():
|
| 100 |
+
return {"error": "empty input", "output": ""}
|
| 101 |
+
|
| 102 |
+
cfg = copy.deepcopy(_STATE["cfg"])
|
| 103 |
+
cfg["inference"]["temperature"] = float(temperature)
|
| 104 |
+
cfg["inference"]["top_k"] = int(top_k)
|
| 105 |
+
cfg["inference"]["repetition_penalty"] = float(repetition_penalty)
|
| 106 |
+
cfg["inference"]["diversity_penalty"] = float(diversity_penalty)
|
| 107 |
+
cfg["inference"]["num_steps"] = int(num_steps)
|
| 108 |
+
|
| 109 |
+
src_tok = _STATE["src_tok"]
|
| 110 |
+
tgt_tok = _STATE["tgt_tok"]
|
| 111 |
+
device = _STATE["device"]
|
| 112 |
+
|
| 113 |
+
input_ids = torch.tensor([src_tok.encode(text.strip())], dtype=torch.long, device=device)
|
| 114 |
+
out = run_inference(_STATE["model"], input_ids, cfg)
|
| 115 |
+
decoded_ids = [x for x in out[0].tolist() if x > 4]
|
| 116 |
+
raw = tgt_tok.decode(decoded_ids).strip()
|
| 117 |
+
output = _clean_text(raw) if clean_output else raw
|
| 118 |
+
|
| 119 |
+
return {
|
| 120 |
+
"input": text,
|
| 121 |
+
"output": output,
|
| 122 |
+
"raw_output": raw,
|
| 123 |
+
"config": {
|
| 124 |
+
"temperature": float(temperature),
|
| 125 |
+
"top_k": int(top_k),
|
| 126 |
+
"repetition_penalty": float(repetition_penalty),
|
| 127 |
+
"diversity_penalty": float(diversity_penalty),
|
| 128 |
+
"num_steps": int(num_steps),
|
| 129 |
+
"clean_output": bool(clean_output),
|
| 130 |
+
},
|
| 131 |
+
}
|
model/__init__.py
ADDED
|
File without changes
|
model/d3pm_model_cross_attention.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
d3pm_model_cross_attention.py โ Cross-Script + Generation-Fixed
|
| 3 |
+
=================================================================
|
| 4 |
+
INPUT : quote_text tokens (Roman script, src_vocab_size)
|
| 5 |
+
OUTPUT : quote_devanagari tokens (Devanagari script, tgt_vocab_size)
|
| 6 |
+
|
| 7 |
+
src_embed uses src_vocab_size (Roman BPE)
|
| 8 |
+
tgt_embed uses tgt_vocab_size (Devanagari BPE)
|
| 9 |
+
head outputs tgt_vocab_size (predict Devanagari tokens)
|
| 10 |
+
Weight tying: head <-> tgt_embed only (NOT src_embed)
|
| 11 |
+
|
| 12 |
+
Generation bugs fixed:
|
| 13 |
+
BUG 1 - tgt_pad_mask suppressed during inference
|
| 14 |
+
BUG 2 - q_sample skipped at t=0
|
| 15 |
+
BUG 3 - time embedding before hint_gate
|
| 16 |
+
BUG 4 - diversity penalty uses global mean not var
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
|
| 23 |
+
from diffusion.scheduler import OptimizedCosineScheduler
|
| 24 |
+
from diffusion.forward_process import AbsorbingForwardProcess
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class SinusoidalPositionalEncoding(nn.Module):
|
| 28 |
+
def __init__(self, d_model, max_len=5000):
|
| 29 |
+
super().__init__()
|
| 30 |
+
pe = torch.zeros(max_len, d_model)
|
| 31 |
+
position = torch.arange(0, max_len).unsqueeze(1).float()
|
| 32 |
+
div_term = torch.exp(
|
| 33 |
+
torch.arange(0, d_model, 2).float() *
|
| 34 |
+
(-torch.log(torch.tensor(10000.0)) / d_model)
|
| 35 |
+
)
|
| 36 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 37 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 38 |
+
self.register_buffer("pe", pe.unsqueeze(0))
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
return x + self.pe[:, :x.size(1), :]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class SanskritEmbeddings(nn.Module):
|
| 45 |
+
def __init__(self, vocab_size, d_model, max_seq_len):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.token_emb = nn.Embedding(vocab_size, d_model)
|
| 48 |
+
self.pos_enc = SinusoidalPositionalEncoding(d_model, max_seq_len)
|
| 49 |
+
self.token_embedding = self.token_emb
|
| 50 |
+
def forward(self, tokens):
|
| 51 |
+
return self.pos_enc(self.token_emb(tokens))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class MultiHeadAttention(nn.Module):
|
| 55 |
+
def __init__(self, d_model, n_heads, dropout=0.1):
|
| 56 |
+
super().__init__()
|
| 57 |
+
assert d_model % n_heads == 0
|
| 58 |
+
self.d_model = d_model
|
| 59 |
+
self.n_heads = n_heads
|
| 60 |
+
self.head_dim = d_model // n_heads
|
| 61 |
+
self.q_proj = nn.Linear(d_model, d_model)
|
| 62 |
+
self.k_proj = nn.Linear(d_model, d_model)
|
| 63 |
+
self.v_proj = nn.Linear(d_model, d_model)
|
| 64 |
+
self.out_proj = nn.Linear(d_model, d_model)
|
| 65 |
+
self.dropout = nn.Dropout(dropout)
|
| 66 |
+
|
| 67 |
+
def forward(self, q, k, v, mask=None):
|
| 68 |
+
B, Lq, _ = q.size()
|
| 69 |
+
Lk = k.size(1)
|
| 70 |
+
Q = self.q_proj(q).view(B, Lq, self.n_heads, self.head_dim).transpose(1, 2)
|
| 71 |
+
K = self.k_proj(k).view(B, Lk, self.n_heads, self.head_dim).transpose(1, 2)
|
| 72 |
+
V = self.v_proj(v).view(B, Lk, self.n_heads, self.head_dim).transpose(1, 2)
|
| 73 |
+
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
|
| 74 |
+
if mask is not None:
|
| 75 |
+
scores = scores.masked_fill(mask.unsqueeze(1).unsqueeze(2), float('-inf'))
|
| 76 |
+
attn = self.dropout(torch.softmax(scores, dim=-1))
|
| 77 |
+
out = torch.matmul(attn, V).transpose(1, 2).contiguous().view(B, Lq, self.d_model)
|
| 78 |
+
return self.out_proj(out)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class EncoderBlock(nn.Module):
|
| 82 |
+
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.mha = MultiHeadAttention(d_model, n_heads, dropout)
|
| 85 |
+
self.ff = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout),
|
| 86 |
+
nn.Linear(d_ff, d_model), nn.Dropout(dropout))
|
| 87 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 88 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 89 |
+
def forward(self, x, pad_mask=None):
|
| 90 |
+
x = self.norm1(x + self.mha(x, x, x, mask=pad_mask))
|
| 91 |
+
return self.norm2(x + self.ff(x))
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class DecoderBlock(nn.Module):
|
| 95 |
+
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
|
| 98 |
+
self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout)
|
| 99 |
+
self.ff = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout),
|
| 100 |
+
nn.Linear(d_ff, d_model), nn.Dropout(dropout))
|
| 101 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 102 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 103 |
+
self.norm3 = nn.LayerNorm(d_model)
|
| 104 |
+
def forward(self, x, memory, tgt_pad_mask=None, src_pad_mask=None):
|
| 105 |
+
x = self.norm1(x + self.self_attn(x, x, x, mask=tgt_pad_mask))
|
| 106 |
+
x = self.norm2(x + self.cross_attn(x, memory, memory, mask=src_pad_mask))
|
| 107 |
+
return self.norm3(x + self.ff(x))
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class D3PMCrossAttention(nn.Module):
|
| 111 |
+
def __init__(self, cfg):
|
| 112 |
+
super().__init__()
|
| 113 |
+
self.cfg = cfg
|
| 114 |
+
self.mask_token_id = cfg['diffusion']['mask_token_id']
|
| 115 |
+
d = cfg['model']['d_model']
|
| 116 |
+
nhead = cfg['model']['n_heads']
|
| 117 |
+
d_ff = cfg['model']['d_ff']
|
| 118 |
+
drop = cfg['model']['dropout']
|
| 119 |
+
seqlen = cfg['model']['max_seq_len']
|
| 120 |
+
nlayer = cfg['model']['n_layers']
|
| 121 |
+
src_vocab = cfg['model'].get('src_vocab_size', cfg['model']['vocab_size'])
|
| 122 |
+
tgt_vocab = cfg['model'].get('tgt_vocab_size', cfg['model']['vocab_size'])
|
| 123 |
+
|
| 124 |
+
# Separate embeddings: Roman src, Devanagari tgt
|
| 125 |
+
self.src_embed = SanskritEmbeddings(src_vocab, d, seqlen)
|
| 126 |
+
self.tgt_embed = SanskritEmbeddings(tgt_vocab, d, seqlen)
|
| 127 |
+
|
| 128 |
+
self.scheduler = OptimizedCosineScheduler(cfg)
|
| 129 |
+
self.forward_process = AbsorbingForwardProcess(self.scheduler)
|
| 130 |
+
|
| 131 |
+
self.encoder_blocks = nn.ModuleList([EncoderBlock(d, nhead, d_ff, drop) for _ in range(nlayer)])
|
| 132 |
+
self.decoder_blocks = nn.ModuleList([DecoderBlock(d, nhead, d_ff, drop) for _ in range(nlayer)])
|
| 133 |
+
|
| 134 |
+
self.time_mlp = nn.Sequential(nn.Linear(1, d//4), nn.SiLU(), nn.Linear(d//4, d))
|
| 135 |
+
self.hint_gate = nn.Sequential(nn.Linear(d, d), nn.Sigmoid())
|
| 136 |
+
|
| 137 |
+
# Output head: predict Devanagari tokens, tied to tgt_embed
|
| 138 |
+
self.head = nn.Linear(d, tgt_vocab, bias=False)
|
| 139 |
+
self.head.weight = self.tgt_embed.token_embedding.weight
|
| 140 |
+
|
| 141 |
+
def forward(self, src, tgt, t, x0_hint=None, inference_mode=False):
|
| 142 |
+
PAD = 1
|
| 143 |
+
src_pad_mask = (src == PAD)
|
| 144 |
+
# BUG 1 FIX: no tgt mask during inference
|
| 145 |
+
tgt_pad_mask = None if inference_mode else (tgt == PAD)
|
| 146 |
+
|
| 147 |
+
# Encode Roman source
|
| 148 |
+
memory = self.src_embed(src)
|
| 149 |
+
for block in self.encoder_blocks:
|
| 150 |
+
memory = block(memory, pad_mask=src_pad_mask)
|
| 151 |
+
|
| 152 |
+
# BUG 2 FIX: skip q_sample at final step t=0
|
| 153 |
+
if inference_mode and (t == 0).all():
|
| 154 |
+
x_t_ids = tgt
|
| 155 |
+
else:
|
| 156 |
+
_, x_t_ids = self.forward_process.q_sample(tgt, t)
|
| 157 |
+
|
| 158 |
+
x = self.tgt_embed(x_t_ids)
|
| 159 |
+
|
| 160 |
+
# BUG 3 FIX: time embedding BEFORE hint gate
|
| 161 |
+
t_norm = t.float() / self.scheduler.num_timesteps
|
| 162 |
+
t_emb = self.time_mlp(t_norm.unsqueeze(-1))
|
| 163 |
+
x = x + t_emb.unsqueeze(1)
|
| 164 |
+
|
| 165 |
+
if x0_hint is not None:
|
| 166 |
+
hint_emb = self.tgt_embed(x0_hint)
|
| 167 |
+
gate = self.hint_gate(x) # time-aware gate
|
| 168 |
+
x = x + gate * hint_emb
|
| 169 |
+
|
| 170 |
+
for block in self.decoder_blocks:
|
| 171 |
+
x = block(x, memory, tgt_pad_mask=tgt_pad_mask, src_pad_mask=src_pad_mask)
|
| 172 |
+
|
| 173 |
+
return self.head(x), None
|
| 174 |
+
|
| 175 |
+
@torch.no_grad()
|
| 176 |
+
def generate(self, src, num_steps=None, temperature=0.8, top_k=50,
|
| 177 |
+
repetition_penalty=1.2, diversity_penalty=0.0):
|
| 178 |
+
if src.dim() == 1:
|
| 179 |
+
src = src.unsqueeze(0)
|
| 180 |
+
device = src.device
|
| 181 |
+
B, L = src.shape
|
| 182 |
+
T = self.scheduler.num_timesteps
|
| 183 |
+
steps = num_steps or T
|
| 184 |
+
step_size = max(1, T // steps)
|
| 185 |
+
timesteps = list(range(T - 1, -1, -step_size))
|
| 186 |
+
if timesteps[-1] != 0:
|
| 187 |
+
timesteps.append(0)
|
| 188 |
+
|
| 189 |
+
mask_id = self.mask_token_id
|
| 190 |
+
x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
|
| 191 |
+
hint = None
|
| 192 |
+
|
| 193 |
+
self.eval()
|
| 194 |
+
with torch.no_grad():
|
| 195 |
+
for step_idx, t_val in enumerate(timesteps):
|
| 196 |
+
t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 197 |
+
is_last = (step_idx == len(timesteps) - 1)
|
| 198 |
+
logits, _ = self.forward(src, x0_est, t, x0_hint=hint, inference_mode=True)
|
| 199 |
+
if repetition_penalty != 1.0:
|
| 200 |
+
logits = _apply_repetition_penalty(logits, x0_est, repetition_penalty)
|
| 201 |
+
if diversity_penalty > 0.0:
|
| 202 |
+
logits = _apply_diversity_penalty_fixed(logits, diversity_penalty) # BUG 4 FIX
|
| 203 |
+
logits = logits / max(temperature, 1e-5)
|
| 204 |
+
if top_k > 0:
|
| 205 |
+
logits = _top_k_filter(logits, top_k)
|
| 206 |
+
probs = F.softmax(logits, dim=-1)
|
| 207 |
+
x0_est = torch.argmax(probs, dim=-1) if is_last else _batch_multinomial(probs)
|
| 208 |
+
hint = x0_est
|
| 209 |
+
return x0_est
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class BaselineCrossAttention(nn.Module):
|
| 213 |
+
def __init__(self, cfg):
|
| 214 |
+
super().__init__()
|
| 215 |
+
d = cfg['model']['d_model']; nhead = cfg['model']['n_heads']
|
| 216 |
+
d_ff = cfg['model']['d_ff']; drop = cfg['model']['dropout']
|
| 217 |
+
seqlen = cfg['model']['max_seq_len']; nlayer = cfg['model']['n_layers']
|
| 218 |
+
src_vocab = cfg['model'].get('src_vocab_size', cfg['model']['vocab_size'])
|
| 219 |
+
tgt_vocab = cfg['model'].get('tgt_vocab_size', cfg['model']['vocab_size'])
|
| 220 |
+
self.src_embed = SanskritEmbeddings(src_vocab, d, seqlen)
|
| 221 |
+
self.tgt_embed = SanskritEmbeddings(tgt_vocab, d, seqlen)
|
| 222 |
+
self.encoder_blocks = nn.ModuleList([EncoderBlock(d, nhead, d_ff, drop) for _ in range(nlayer)])
|
| 223 |
+
self.decoder_blocks = nn.ModuleList([DecoderBlock(d, nhead, d_ff, drop) for _ in range(nlayer)])
|
| 224 |
+
self.head = nn.Linear(d, tgt_vocab, bias=False)
|
| 225 |
+
self.head.weight = self.tgt_embed.token_embedding.weight
|
| 226 |
+
|
| 227 |
+
def forward(self, src, tgt, t=None, x0_hint=None):
|
| 228 |
+
PAD = 1
|
| 229 |
+
memory = self.src_embed(src)
|
| 230 |
+
for b in self.encoder_blocks: memory = b(memory, pad_mask=(src==PAD))
|
| 231 |
+
x = self.tgt_embed(tgt)
|
| 232 |
+
for b in self.decoder_blocks: x = b(x, memory, tgt_pad_mask=(tgt==PAD), src_pad_mask=(src==PAD))
|
| 233 |
+
return (self.head(x),)
|
| 234 |
+
|
| 235 |
+
@torch.no_grad()
|
| 236 |
+
def generate(self, src, max_len=None, start_token_id=2, **kwargs):
|
| 237 |
+
if max_len is None: max_len = src.size(1)
|
| 238 |
+
B, device = src.size(0), src.device
|
| 239 |
+
memory = self.src_embed(src)
|
| 240 |
+
for b in self.encoder_blocks: memory = b(memory, pad_mask=(src==1))
|
| 241 |
+
ys = torch.full((B, 1), start_token_id, dtype=torch.long, device=device)
|
| 242 |
+
for _ in range(max_len):
|
| 243 |
+
x = self.tgt_embed(ys)
|
| 244 |
+
for b in self.decoder_blocks: x = b(x, memory, tgt_pad_mask=None, src_pad_mask=(src==1))
|
| 245 |
+
ys = torch.cat([ys, torch.argmax(self.head(x)[:,-1,:], dim=-1, keepdim=True)], dim=1)
|
| 246 |
+
return ys[:, 1:max_len+1]
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# helpers
|
| 250 |
+
def _top_k_filter(logits, k):
|
| 251 |
+
B, L, V = logits.shape
|
| 252 |
+
if k >= V: return logits
|
| 253 |
+
topk_vals, _ = torch.topk(logits, k, dim=-1)
|
| 254 |
+
return logits.masked_fill(logits < topk_vals[..., -1].unsqueeze(-1), float('-inf'))
|
| 255 |
+
|
| 256 |
+
def _batch_multinomial(probs):
|
| 257 |
+
B, L, V = probs.shape
|
| 258 |
+
flat = probs.view(B*L, V) + 1e-9
|
| 259 |
+
return torch.multinomial(flat/flat.sum(-1,keepdim=True), 1).squeeze(-1).view(B, L)
|
| 260 |
+
|
| 261 |
+
def _apply_repetition_penalty(logits, prev_tokens, penalty):
|
| 262 |
+
for b in range(logits.shape[0]):
|
| 263 |
+
for tid in set(prev_tokens[b].tolist()):
|
| 264 |
+
if tid > 4: logits[b, :, tid] = logits[b, :, tid] / penalty
|
| 265 |
+
return logits
|
| 266 |
+
|
| 267 |
+
def _apply_diversity_penalty(logits, penalty): # legacy wrong version
|
| 268 |
+
return logits + penalty * logits.var(dim=-1, keepdim=True)
|
| 269 |
+
|
| 270 |
+
def _apply_diversity_penalty_fixed(logits, penalty): # correct version
|
| 271 |
+
return logits - penalty * logits.mean(dim=1, keepdim=True)
|
model/d3pm_model_encoder_decoder.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from diffusion.scheduler import OptimizedCosineScheduler
|
| 4 |
+
from diffusion.forward_process import AbsorbingForwardProcess
|
| 5 |
+
# Import shared classes to guarantee identical architectures
|
| 6 |
+
from model.d3pm_model_cross_attention import SanskritEmbeddings, EncoderBlock, MultiHeadAttention
|
| 7 |
+
class DecoderBlock(nn.Module):
|
| 8 |
+
def __init__(self, d_model, n_heads, d_ff, dropout=0.15):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
|
| 11 |
+
self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout) # โ restored
|
| 12 |
+
self.ff = nn.Sequential(
|
| 13 |
+
nn.Linear(d_model, d_ff),
|
| 14 |
+
nn.ReLU(),
|
| 15 |
+
nn.Dropout(dropout),
|
| 16 |
+
nn.Linear(d_ff, d_model),
|
| 17 |
+
nn.Dropout(dropout),
|
| 18 |
+
)
|
| 19 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 20 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 21 |
+
self.norm3 = nn.LayerNorm(d_model) # โ restored (for cross-attn residual)
|
| 22 |
+
|
| 23 |
+
def forward(self, x, memory, tgt_pad_mask=None):
|
| 24 |
+
# 1. Masked self-attention on target
|
| 25 |
+
x = self.norm1(x + self.self_attn(x, x, x, mask=tgt_pad_mask))
|
| 26 |
+
# 2. Cross-attention: queries from decoder, keys/values from encoder memory
|
| 27 |
+
x = self.norm2(x + self.cross_attn(x, memory, memory))
|
| 28 |
+
# 3. Feed-forward
|
| 29 |
+
return self.norm3(x + self.ff(x))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class DecoderBlockNoCrossAttn(nn.Module):
|
| 33 |
+
"""Kept for reference โ NOT used by D3PMEncoderDecoder."""
|
| 34 |
+
def __init__(self, d_model, n_heads, d_ff, dropout=0.15):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
|
| 37 |
+
self.ff = nn.Sequential(
|
| 38 |
+
nn.Linear(d_model, d_ff), nn.ReLU(), nn.Dropout(dropout),
|
| 39 |
+
nn.Linear(d_ff, d_model), nn.Dropout(dropout),
|
| 40 |
+
)
|
| 41 |
+
self.norm1, self.norm2 = nn.LayerNorm(d_model), nn.LayerNorm(d_model)
|
| 42 |
+
|
| 43 |
+
def forward(self, x, tgt_pad_mask=None, causal_mask=None):
|
| 44 |
+
combined_mask = None
|
| 45 |
+
if tgt_pad_mask is not None and causal_mask is not None:
|
| 46 |
+
combined_mask = tgt_pad_mask | causal_mask
|
| 47 |
+
elif causal_mask is not None:
|
| 48 |
+
combined_mask = causal_mask
|
| 49 |
+
elif tgt_pad_mask is not None:
|
| 50 |
+
combined_mask = tgt_pad_mask
|
| 51 |
+
x = self.norm1(x + self.self_attn(x, x, x, mask=combined_mask))
|
| 52 |
+
return self.norm2(x + self.ff(x))
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# ============================================================
|
| 56 |
+
# 1. D3PM Encoder-Decoder Model
|
| 57 |
+
# ============================================================
|
| 58 |
+
class D3PMEncoderDecoder(nn.Module):
|
| 59 |
+
def __init__(self, cfg):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.cfg = cfg
|
| 62 |
+
self.mask_token_id = cfg['diffusion']['mask_token_id']
|
| 63 |
+
|
| 64 |
+
src_vocab = cfg['model'].get('src_vocab_size', cfg['model']['vocab_size'])
|
| 65 |
+
tgt_vocab = cfg['model'].get('tgt_vocab_size', cfg['model']['vocab_size'])
|
| 66 |
+
d_model = cfg['model']['d_model']
|
| 67 |
+
n_heads = cfg['model']['n_heads']
|
| 68 |
+
d_ff = cfg['model']['d_ff']
|
| 69 |
+
dropout = cfg['model']['dropout']
|
| 70 |
+
n_layers = cfg['model']['n_layers']
|
| 71 |
+
max_len = cfg['model']['max_seq_len']
|
| 72 |
+
|
| 73 |
+
self.src_embed = SanskritEmbeddings(src_vocab, d_model, max_len)
|
| 74 |
+
self.tgt_embed = SanskritEmbeddings(tgt_vocab, d_model, max_len)
|
| 75 |
+
|
| 76 |
+
self.scheduler = OptimizedCosineScheduler(cfg)
|
| 77 |
+
self.forward_process = AbsorbingForwardProcess(self.scheduler)
|
| 78 |
+
|
| 79 |
+
self.encoder_blocks = nn.ModuleList([
|
| 80 |
+
EncoderBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
|
| 81 |
+
])
|
| 82 |
+
# DecoderBlock now has cross-attention โ matches saved checkpoint
|
| 83 |
+
self.decoder_blocks = nn.ModuleList([
|
| 84 |
+
DecoderBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
|
| 85 |
+
])
|
| 86 |
+
|
| 87 |
+
self.time_mlp = nn.Sequential(
|
| 88 |
+
nn.Linear(1, d_model // 4), nn.SiLU(),
|
| 89 |
+
nn.Linear(d_model // 4, d_model),
|
| 90 |
+
)
|
| 91 |
+
self.head = nn.Linear(d_model, tgt_vocab)
|
| 92 |
+
self.head.weight = self.tgt_embed.token_embedding.weight
|
| 93 |
+
|
| 94 |
+
def forward(self, src, tgt, t, x0_hint=None):
|
| 95 |
+
src_pad_mask = (src == 1)
|
| 96 |
+
tgt_pad_mask = (tgt == 1)
|
| 97 |
+
|
| 98 |
+
# Encode source (Roman IAST)
|
| 99 |
+
memory = self.src_embed(src)
|
| 100 |
+
for block in self.encoder_blocks:
|
| 101 |
+
memory = block(memory, pad_mask=src_pad_mask)
|
| 102 |
+
|
| 103 |
+
# Corrupt target with forward diffusion
|
| 104 |
+
_, x_t_ids = self.forward_process.q_sample(tgt, t)
|
| 105 |
+
|
| 106 |
+
# Optionally blend in x0_hint (self-conditioning)
|
| 107 |
+
if x0_hint is not None:
|
| 108 |
+
hint_prob = 0.5
|
| 109 |
+
blend_mask = (torch.rand(x_t_ids.shape, device=x_t_ids.device) < hint_prob)
|
| 110 |
+
still_mask = (x_t_ids == self.mask_token_id)
|
| 111 |
+
x_t_ids = torch.where(blend_mask & still_mask, x0_hint, x_t_ids)
|
| 112 |
+
|
| 113 |
+
x = self.tgt_embed(x_t_ids)
|
| 114 |
+
t_emb = self.time_mlp(t.float().unsqueeze(-1)).unsqueeze(1)
|
| 115 |
+
x = x + t_emb.expand(-1, tgt.shape[1], -1)
|
| 116 |
+
|
| 117 |
+
# Decode with cross-attention over encoder memory
|
| 118 |
+
for block in self.decoder_blocks:
|
| 119 |
+
x = block(x, memory, tgt_pad_mask=tgt_pad_mask)
|
| 120 |
+
|
| 121 |
+
return self.head(x), None
|
| 122 |
+
|
| 123 |
+
@torch.no_grad()
|
| 124 |
+
def generate(
|
| 125 |
+
self,
|
| 126 |
+
src,
|
| 127 |
+
num_steps = None,
|
| 128 |
+
temperature = 0.75,
|
| 129 |
+
top_k = 50,
|
| 130 |
+
repetition_penalty = 1.15,
|
| 131 |
+
diversity_penalty = 0.0,
|
| 132 |
+
):
|
| 133 |
+
"""
|
| 134 |
+
Iterative D3PM reverse diffusion โ same signature as
|
| 135 |
+
D3PMCrossAttention.generate() so SanskritModel.generate() works
|
| 136 |
+
identically for both model types.
|
| 137 |
+
"""
|
| 138 |
+
device = src.device
|
| 139 |
+
B, L = src.shape[0], self.cfg['model']['max_seq_len']
|
| 140 |
+
T = num_steps or self.scheduler.num_timesteps
|
| 141 |
+
mask_id = self.mask_token_id
|
| 142 |
+
pad_id = 1
|
| 143 |
+
|
| 144 |
+
x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
|
| 145 |
+
|
| 146 |
+
for step in range(T - 1, -1, -1):
|
| 147 |
+
t_tensor = torch.full((B,), step, dtype=torch.long, device=device)
|
| 148 |
+
hint = x0_est.clone()
|
| 149 |
+
|
| 150 |
+
logits, _ = self.forward(src, x0_est, t_tensor, x0_hint=hint)
|
| 151 |
+
|
| 152 |
+
# Repetition penalty
|
| 153 |
+
if repetition_penalty != 1.0:
|
| 154 |
+
for b in range(B):
|
| 155 |
+
for tok in set(x0_est[b].tolist()):
|
| 156 |
+
if tok > pad_id:
|
| 157 |
+
logits[b, :, tok] /= repetition_penalty
|
| 158 |
+
|
| 159 |
+
# Diversity penalty (suppress common tokens)
|
| 160 |
+
if diversity_penalty > 0.0:
|
| 161 |
+
logits = logits - diversity_penalty * logits.mean(dim=1, keepdim=True)
|
| 162 |
+
|
| 163 |
+
# Temperature + top-k sampling
|
| 164 |
+
logits = logits / max(temperature, 1e-8)
|
| 165 |
+
if top_k > 0:
|
| 166 |
+
vals, _ = torch.topk(logits, top_k, dim=-1)
|
| 167 |
+
logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
|
| 168 |
+
|
| 169 |
+
probs = torch.softmax(logits, dim=-1)
|
| 170 |
+
# Only update positions that are still masked
|
| 171 |
+
still = (x0_est == mask_id)
|
| 172 |
+
sample = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(B, L)
|
| 173 |
+
x0_est = torch.where(still, sample, x0_est)
|
| 174 |
+
|
| 175 |
+
return x0_est
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# ============================================================
|
| 179 |
+
# 2. Baseline Encoder-Decoder Model (unchanged)
|
| 180 |
+
# ============================================================
|
| 181 |
+
class BaselineEncoderDecoder(nn.Module):
|
| 182 |
+
def __init__(self, cfg):
|
| 183 |
+
super().__init__()
|
| 184 |
+
self.cfg = cfg
|
| 185 |
+
self.src_embed = SanskritEmbeddings(cfg['model']['vocab_size'], cfg['model']['d_model'],
|
| 186 |
+
cfg['model']['max_seq_len'])
|
| 187 |
+
self.tgt_embed = SanskritEmbeddings(cfg['model']['vocab_size'], cfg['model']['d_model'],
|
| 188 |
+
cfg['model']['max_seq_len'])
|
| 189 |
+
self.encoder_blocks = nn.ModuleList([
|
| 190 |
+
EncoderBlock(cfg['model']['d_model'], cfg['model']['n_heads'],
|
| 191 |
+
cfg['model']['d_ff'], cfg['model']['dropout'])
|
| 192 |
+
for _ in range(cfg['model']['n_layers'])
|
| 193 |
+
])
|
| 194 |
+
self.decoder_blocks = nn.ModuleList([
|
| 195 |
+
DecoderBlock(cfg['model']['d_model'], cfg['model']['n_heads'],
|
| 196 |
+
cfg['model']['d_ff'], cfg['model']['dropout'])
|
| 197 |
+
for _ in range(cfg['model']['n_layers'])
|
| 198 |
+
])
|
| 199 |
+
self.head = nn.Linear(cfg['model']['d_model'], cfg['model']['vocab_size'])
|
| 200 |
+
self.head.weight = self.tgt_embed.token_embedding.weight
|
| 201 |
+
|
| 202 |
+
def forward(self, src, tgt):
|
| 203 |
+
src_pad_mask, tgt_pad_mask = (src == 1), (tgt == 1)
|
| 204 |
+
memory = self.src_embed(src)
|
| 205 |
+
for block in self.encoder_blocks:
|
| 206 |
+
memory = block(memory, pad_mask=src_pad_mask)
|
| 207 |
+
x = self.tgt_embed(tgt)
|
| 208 |
+
for block in self.decoder_blocks:
|
| 209 |
+
x = block(x, memory, tgt_pad_mask=tgt_pad_mask)
|
| 210 |
+
return self.head(x)
|
| 211 |
+
|
| 212 |
+
@torch.no_grad()
|
| 213 |
+
def generate(self, src, max_len=80, start_token_id=2):
|
| 214 |
+
batch_size, device = src.size(0), src.device
|
| 215 |
+
src_pad_mask = (src == 1)
|
| 216 |
+
memory = self.src_embed(src)
|
| 217 |
+
for block in self.encoder_blocks:
|
| 218 |
+
memory = block(memory, pad_mask=src_pad_mask)
|
| 219 |
+
ys = torch.ones(batch_size, 1, dtype=torch.long, device=device) * start_token_id
|
| 220 |
+
for _ in range(max_len):
|
| 221 |
+
x = self.tgt_embed(ys)
|
| 222 |
+
for block in self.decoder_blocks:
|
| 223 |
+
x = block(x, memory, tgt_pad_mask=None)
|
| 224 |
+
logits = self.head(x)
|
| 225 |
+
next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
|
| 226 |
+
ys = torch.cat([ys, next_token], dim=1)
|
| 227 |
+
return ys[:, 1:]
|
model/sanskrit_model.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
sanskrit_model.py โ Fixed
|
| 3 |
+
===========================
|
| 4 |
+
Added inference_mode parameter to forward() so reverse_process.py can
|
| 5 |
+
pass inference_mode=True without a TypeError.
|
| 6 |
+
|
| 7 |
+
The wrapper introspects each inner model's signature and only passes
|
| 8 |
+
kwargs that model actually accepts โ safe across all four architectures.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import inspect
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SanskritModel(nn.Module):
|
| 17 |
+
def __init__(self, cfg):
|
| 18 |
+
super().__init__()
|
| 19 |
+
model_type = cfg['model_type']
|
| 20 |
+
|
| 21 |
+
if model_type == 'd3pm_cross_attention':
|
| 22 |
+
from model.d3pm_model_cross_attention import D3PMCrossAttention
|
| 23 |
+
self.model = D3PMCrossAttention(cfg)
|
| 24 |
+
|
| 25 |
+
elif model_type == 'd3pm_encoder_decoder':
|
| 26 |
+
from model.d3pm_model_encoder_decoder import D3PMEncoderDecoder
|
| 27 |
+
self.model = D3PMEncoderDecoder(cfg)
|
| 28 |
+
|
| 29 |
+
elif model_type == 'baseline_cross_attention':
|
| 30 |
+
from model.d3pm_model_cross_attention import BaselineCrossAttention
|
| 31 |
+
self.model = BaselineCrossAttention(cfg)
|
| 32 |
+
|
| 33 |
+
elif model_type == 'baseline_encoder_decoder':
|
| 34 |
+
from model.d3pm_model_encoder_decoder import BaselineEncoderDecoder
|
| 35 |
+
self.model = BaselineEncoderDecoder(cfg)
|
| 36 |
+
|
| 37 |
+
else:
|
| 38 |
+
raise ValueError(f"Unknown model_type: {model_type}")
|
| 39 |
+
|
| 40 |
+
def forward(self, input_ids, target_ids, t, x0_hint=None, inference_mode=False):
|
| 41 |
+
"""
|
| 42 |
+
Forward pass. Introspects the inner model's signature so only
|
| 43 |
+
supported kwargs are passed โ works with all four architectures.
|
| 44 |
+
"""
|
| 45 |
+
sig = inspect.signature(self.model.forward).parameters
|
| 46 |
+
kwargs = {}
|
| 47 |
+
if 'x0_hint' in sig:
|
| 48 |
+
kwargs['x0_hint'] = x0_hint
|
| 49 |
+
if 'inference_mode' in sig:
|
| 50 |
+
kwargs['inference_mode'] = inference_mode
|
| 51 |
+
|
| 52 |
+
if 't' in sig:
|
| 53 |
+
return self.model(input_ids, target_ids, t, **kwargs)
|
| 54 |
+
else:
|
| 55 |
+
return self.model(input_ids, target_ids, **kwargs)
|
| 56 |
+
|
| 57 |
+
@torch.no_grad()
|
| 58 |
+
def generate(self, src, **kwargs):
|
| 59 |
+
sig = inspect.signature(self.model.generate).parameters
|
| 60 |
+
filtered = {k: v for k, v in kwargs.items() if k in sig}
|
| 61 |
+
return self.model.generate(src, **filtered)
|
model/tokenizer.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
tokenizer.py โ Dual Tokenizer Fix
|
| 3 |
+
====================================
|
| 4 |
+
Two separate BPE tokenizers:
|
| 5 |
+
|
| 6 |
+
SanskritSourceTokenizer โ trained on quote_text (Roman/IAST script)
|
| 7 |
+
SanskritTargetTokenizer โ trained on quote_devanagari (Devanagari script)
|
| 8 |
+
|
| 9 |
+
WHY SEPARATE?
|
| 10 |
+
Roman Sanskrit and Devanagari are fundamentally different character sets.
|
| 11 |
+
Roman uses a-z + diacritics (~60 unique chars), Devanagari uses ฤ-เคน + matras
|
| 12 |
+
(~100+ unique chars). A shared BPE tokenizer wastes half its vocab on
|
| 13 |
+
character combos that never cross scripts, and forces the embedding table
|
| 14 |
+
to encode both scripts in one space โ confusing the model's cross-attention.
|
| 15 |
+
|
| 16 |
+
With separate tokenizers:
|
| 17 |
+
- src vocab captures Roman subwords cleanly (ฤ, ล, แนญ, แน etc.)
|
| 18 |
+
- tgt vocab captures Devanagari akshara clusters cleanly (เคเฅเคท, เคคเฅเคฐ, etc.)
|
| 19 |
+
- The model learns a true cross-script mapping in its cross-attention
|
| 20 |
+
|
| 21 |
+
SPECIAL TOKENS (same IDs in both):
|
| 22 |
+
[MASK] = 0 โ required by absorbing diffusion
|
| 23 |
+
[PAD] = 1
|
| 24 |
+
[UNK] = 2
|
| 25 |
+
[CLS] = 3
|
| 26 |
+
[SEP] = 4
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
from tokenizers import Tokenizer
|
| 30 |
+
from tokenizers.models import BPE
|
| 31 |
+
from tokenizers.trainers import BpeTrainer
|
| 32 |
+
from tokenizers.pre_tokenizers import Whitespace
|
| 33 |
+
from datasets import load_dataset
|
| 34 |
+
from pathlib import Path
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
SPECIAL_TOKENS = ["[MASK]", "[PAD]", "[UNK]", "[CLS]", "[SEP]"]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _build_bpe(texts, vocab_size):
|
| 41 |
+
"""Build a BPE tokenizer from an iterator of strings."""
|
| 42 |
+
tok = Tokenizer(BPE(unk_token="[UNK]"))
|
| 43 |
+
tok.pre_tokenizer = Whitespace()
|
| 44 |
+
trainer = BpeTrainer(
|
| 45 |
+
vocab_size=vocab_size,
|
| 46 |
+
special_tokens=SPECIAL_TOKENS, # [MASK] MUST be first โ id=0
|
| 47 |
+
min_frequency=2,
|
| 48 |
+
)
|
| 49 |
+
tok.train_from_iterator(texts, trainer)
|
| 50 |
+
return tok
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _validate(tok, name):
|
| 54 |
+
mask_id = tok.token_to_id("[MASK]")
|
| 55 |
+
pad_id = tok.token_to_id("[PAD]")
|
| 56 |
+
assert mask_id == 0, f"{name}: [MASK] must be id=0, got {mask_id}"
|
| 57 |
+
assert pad_id == 1, f"{name}: [PAD] must be id=1, got {pad_id}"
|
| 58 |
+
print(f"โ
{name}: [MASK]=0, [PAD]=1 confirmed. Vocab size={tok.get_vocab_size()}")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# โโ Source tokenizer (Roman/IAST Sanskrit) โโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 62 |
+
|
| 63 |
+
class SanskritSourceTokenizer:
|
| 64 |
+
"""
|
| 65 |
+
Tokenizer for quote_text โ Roman transliteration of Sanskrit.
|
| 66 |
+
Examples: "dharmo rakแนฃati rakแนฃitaแธฅ", "yatra nฤryastu pลซjyante"
|
| 67 |
+
"""
|
| 68 |
+
MODEL_PATH = "sanskrit_src_tokenizer.json"
|
| 69 |
+
|
| 70 |
+
def __init__(self, vocab_size=8000, max_len=80, n_train_samples=50000):
|
| 71 |
+
self.vocab_size = vocab_size
|
| 72 |
+
self.max_len = max_len
|
| 73 |
+
self.mask_token_id = 0
|
| 74 |
+
|
| 75 |
+
if Path(self.MODEL_PATH).exists():
|
| 76 |
+
print(f"๐ Loading source tokenizer from {self.MODEL_PATH} โฆ")
|
| 77 |
+
self.tokenizer = Tokenizer.from_file(self.MODEL_PATH)
|
| 78 |
+
else:
|
| 79 |
+
print("๐ Training source tokenizer on quote_text โฆ")
|
| 80 |
+
self._train(vocab_size, n_train_samples)
|
| 81 |
+
|
| 82 |
+
_validate(self.tokenizer, "SrcTokenizer")
|
| 83 |
+
|
| 84 |
+
def _train(self, vocab_size, n_samples):
|
| 85 |
+
dataset = load_dataset("paws/sanskrit-verses-gretil", split="train")
|
| 86 |
+
n = min(n_samples, len(dataset))
|
| 87 |
+
texts = [s["quote_text"] for s in dataset.select(range(n))
|
| 88 |
+
if s["quote_text"].strip()]
|
| 89 |
+
self.tokenizer = _build_bpe(texts, vocab_size)
|
| 90 |
+
self.tokenizer.save(self.MODEL_PATH)
|
| 91 |
+
print(f"โ
Source tokenizer trained on {len(texts)} Roman texts.")
|
| 92 |
+
|
| 93 |
+
def encode(self, text):
|
| 94 |
+
ids = self.tokenizer.encode(text).ids[:self.max_len]
|
| 95 |
+
pad = self.tokenizer.token_to_id("[PAD]")
|
| 96 |
+
ids += [pad] * max(0, self.max_len - len(ids))
|
| 97 |
+
return ids[:self.max_len]
|
| 98 |
+
|
| 99 |
+
def decode(self, ids):
|
| 100 |
+
clean = [i for i in ids if i > 4] # skip special tokens
|
| 101 |
+
return self.tokenizer.decode(clean)
|
| 102 |
+
|
| 103 |
+
def __len__(self):
|
| 104 |
+
return self.vocab_size
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# โโ Target tokenizer (Devanagari Sanskrit) โโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 108 |
+
|
| 109 |
+
class SanskritTargetTokenizer:
|
| 110 |
+
"""
|
| 111 |
+
Tokenizer for quote_devanagari โ Devanagari script.
|
| 112 |
+
Examples: "เคงเคฐเฅเคฎเฅ เคฐเคเฅเคทเคคเคฟ เคฐเคเฅเคทเคฟเคคเค", "เคฏเคคเฅเคฐ เคจเคพเคฐเฅเคฏเคธเฅเคคเฅ เคชเฅเคเฅเคฏเคจเฅเคคเฅ"
|
| 113 |
+
"""
|
| 114 |
+
MODEL_PATH = "sanskrit_tgt_tokenizer.json"
|
| 115 |
+
|
| 116 |
+
def __init__(self, vocab_size=8000, max_len=80, n_train_samples=50000):
|
| 117 |
+
self.vocab_size = vocab_size
|
| 118 |
+
self.max_len = max_len
|
| 119 |
+
self.mask_token_id = 0
|
| 120 |
+
|
| 121 |
+
if Path(self.MODEL_PATH).exists():
|
| 122 |
+
print(f"๐ Loading target tokenizer from {self.MODEL_PATH} โฆ")
|
| 123 |
+
self.tokenizer = Tokenizer.from_file(self.MODEL_PATH)
|
| 124 |
+
else:
|
| 125 |
+
print("๐ Training target tokenizer on quote_devanagari โฆ")
|
| 126 |
+
self._train(vocab_size, n_train_samples)
|
| 127 |
+
|
| 128 |
+
_validate(self.tokenizer, "TgtTokenizer")
|
| 129 |
+
|
| 130 |
+
def _train(self, vocab_size, n_samples):
|
| 131 |
+
dataset = load_dataset("paws/sanskrit-verses-gretil", split="train")
|
| 132 |
+
n = min(n_samples, len(dataset))
|
| 133 |
+
texts = [s["quote_devanagari"] for s in dataset.select(range(n))
|
| 134 |
+
if s["quote_devanagari"].strip()]
|
| 135 |
+
self.tokenizer = _build_bpe(texts, vocab_size)
|
| 136 |
+
self.tokenizer.save(self.MODEL_PATH)
|
| 137 |
+
print(f"โ
Target tokenizer trained on {len(texts)} Devanagari texts.")
|
| 138 |
+
|
| 139 |
+
def encode(self, text):
|
| 140 |
+
ids = self.tokenizer.encode(text).ids[:self.max_len]
|
| 141 |
+
pad = self.tokenizer.token_to_id("[PAD]")
|
| 142 |
+
ids += [pad] * max(0, self.max_len - len(ids))
|
| 143 |
+
return ids[:self.max_len]
|
| 144 |
+
|
| 145 |
+
def decode(self, ids):
|
| 146 |
+
clean = [i for i in ids if i > 4]
|
| 147 |
+
return self.tokenizer.decode(clean)
|
| 148 |
+
|
| 149 |
+
# Methods required by BERTScore
|
| 150 |
+
def build_inputs_with_special_tokens(self, token_ids):
|
| 151 |
+
return list(token_ids)
|
| 152 |
+
|
| 153 |
+
def get_vocab(self):
|
| 154 |
+
return {str(i): i for i in range(self.vocab_size)}
|
| 155 |
+
|
| 156 |
+
def convert_ids_to_tokens(self, ids):
|
| 157 |
+
return [str(i) for i in ids]
|
| 158 |
+
|
| 159 |
+
def __len__(self):
|
| 160 |
+
return self.vocab_size
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# โโ Legacy shared tokenizer (kept for backward compat) โโโโโโโโโโโโโโโ
|
| 164 |
+
|
| 165 |
+
class SanskritTokenizer:
|
| 166 |
+
"""
|
| 167 |
+
LEGACY: single shared tokenizer trained on BOTH scripts.
|
| 168 |
+
Still works but suboptimal โ use SanskritSourceTokenizer +
|
| 169 |
+
SanskritTargetTokenizer for the quote_text โ quote_devanagari task.
|
| 170 |
+
"""
|
| 171 |
+
MODEL_PATH = "sanskrit_tokenizer_m4pro.json"
|
| 172 |
+
|
| 173 |
+
def __init__(self, vocab_size=16000, max_len=80):
|
| 174 |
+
self.vocab_size = vocab_size
|
| 175 |
+
self.max_len = max_len
|
| 176 |
+
self.mask_token_id = 0
|
| 177 |
+
|
| 178 |
+
if Path(self.MODEL_PATH).exists():
|
| 179 |
+
print("๐ Loading shared tokenizer โฆ")
|
| 180 |
+
self.tokenizer = Tokenizer.from_file(self.MODEL_PATH)
|
| 181 |
+
else:
|
| 182 |
+
print("๐ Training shared tokenizer on both scripts โฆ")
|
| 183 |
+
self._train(vocab_size)
|
| 184 |
+
|
| 185 |
+
_validate(self.tokenizer, "SharedTokenizer")
|
| 186 |
+
|
| 187 |
+
def _train(self, vocab_size):
|
| 188 |
+
dataset = load_dataset("paws/sanskrit-verses-gretil", split="train")
|
| 189 |
+
n = min(50000, len(dataset))
|
| 190 |
+
texts = []
|
| 191 |
+
for s in dataset.select(range(n)):
|
| 192 |
+
if s["quote_text"].strip():
|
| 193 |
+
texts.append(s["quote_text"])
|
| 194 |
+
if s["quote_devanagari"].strip():
|
| 195 |
+
texts.append(s["quote_devanagari"])
|
| 196 |
+
self.tokenizer = _build_bpe(texts, vocab_size)
|
| 197 |
+
self.tokenizer.save(self.MODEL_PATH)
|
| 198 |
+
print(f"โ
Shared tokenizer trained ({len(texts)} texts).")
|
| 199 |
+
|
| 200 |
+
def encode(self, text):
|
| 201 |
+
ids = self.tokenizer.encode(text).ids[:self.max_len]
|
| 202 |
+
pad = self.tokenizer.token_to_id("[PAD]")
|
| 203 |
+
ids += [pad] * max(0, self.max_len - len(ids))
|
| 204 |
+
return ids[:self.max_len]
|
| 205 |
+
|
| 206 |
+
def decode(self, ids):
|
| 207 |
+
if ids and isinstance(ids[0], list):
|
| 208 |
+
raise TypeError("decode() got 2D list โ pass a 1D list.")
|
| 209 |
+
clean = [i for i in ids if i > 4]
|
| 210 |
+
return self.tokenizer.decode(clean)
|
| 211 |
+
|
| 212 |
+
def build_inputs_with_special_tokens(self, token_ids):
|
| 213 |
+
return list(token_ids)
|
| 214 |
+
|
| 215 |
+
def get_vocab(self):
|
| 216 |
+
return {str(i): i for i in range(self.vocab_size)}
|
| 217 |
+
|
| 218 |
+
def convert_ids_to_tokens(self, ids):
|
| 219 |
+
return [str(i) for i in ids]
|
| 220 |
+
|
| 221 |
+
def __len__(self):
|
| 222 |
+
return self.vocab_size
|
model/tokenizers.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
tokenizer.py โ FINAL
|
| 3 |
+
=====================
|
| 4 |
+
Uses the original sanskrit_tokenizer_m4pro.json โ the exact one the model
|
| 5 |
+
was trained with. Hard-coded absolute path as primary, with fallbacks.
|
| 6 |
+
|
| 7 |
+
This tokenizer has NO </w> end-of-word markers and NO decoder set.
|
| 8 |
+
decode() returns space-separated BPE pieces โ this is the format the
|
| 9 |
+
model was trained and evaluated on (BERTScore 0.71). Do NOT add a decoder
|
| 10 |
+
or retrain: that would break alignment with the checkpoint.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from tokenizers import Tokenizer
|
| 14 |
+
from tokenizers.models import BPE
|
| 15 |
+
from tokenizers.trainers import BpeTrainer
|
| 16 |
+
from tokenizers.pre_tokenizers import Whitespace
|
| 17 |
+
from datasets import load_dataset
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
import os
|
| 20 |
+
|
| 21 |
+
# Hard-coded absolute path โ update if you move the project
|
| 22 |
+
TOKENIZER_PATH = "/Users/bhsingh/Documents/Final_Paraphrase/sanskrit_tokenizer_m4pro.json"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def build_tokenizer(texts, vocab_size=16000):
|
| 26 |
+
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
|
| 27 |
+
tokenizer.pre_tokenizer = Whitespace()
|
| 28 |
+
trainer = BpeTrainer(
|
| 29 |
+
vocab_size=vocab_size,
|
| 30 |
+
special_tokens=["[MASK]", "[PAD]", "[UNK]", "[CLS]", "[SEP]"],
|
| 31 |
+
min_frequency=2,
|
| 32 |
+
)
|
| 33 |
+
tokenizer.train_from_iterator(texts, trainer)
|
| 34 |
+
return tokenizer
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class SanskritTokenizer:
|
| 38 |
+
def __init__(self, vocab_size=16000, max_len=80):
|
| 39 |
+
self.vocab_size = vocab_size
|
| 40 |
+
self.max_len = max_len
|
| 41 |
+
self.mask_token_id = 0
|
| 42 |
+
|
| 43 |
+
script_dir = Path(__file__).resolve().parent
|
| 44 |
+
candidates = [
|
| 45 |
+
os.environ.get("SANSKRIT_TOKENIZER_PATH", ""),
|
| 46 |
+
TOKENIZER_PATH,
|
| 47 |
+
str(script_dir.parent / "sanskrit_tokenizer_m4pro.json"),
|
| 48 |
+
str(script_dir / "sanskrit_tokenizer_m4pro.json"),
|
| 49 |
+
str(Path.cwd() / "sanskrit_tokenizer_m4pro.json"),
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
self.model_path = None
|
| 53 |
+
for c in candidates:
|
| 54 |
+
if c and Path(c).exists():
|
| 55 |
+
self.model_path = c
|
| 56 |
+
break
|
| 57 |
+
|
| 58 |
+
if self.model_path:
|
| 59 |
+
print(f"๐ Loading tokenizer from: {self.model_path}")
|
| 60 |
+
self.tokenizer = Tokenizer.from_file(self.model_path)
|
| 61 |
+
self._validate_mask_token()
|
| 62 |
+
else:
|
| 63 |
+
print(f"โ ๏ธ Tokenizer not found at any candidate path.")
|
| 64 |
+
print(f" Expected: {TOKENIZER_PATH}")
|
| 65 |
+
print(" Retraining โ WARNING: output will not match existing checkpoint!")
|
| 66 |
+
self.model_path = TOKENIZER_PATH
|
| 67 |
+
self._train_tokenizer()
|
| 68 |
+
|
| 69 |
+
def _validate_mask_token(self):
|
| 70 |
+
mask_id = self.tokenizer.token_to_id("[MASK]")
|
| 71 |
+
assert mask_id == 0, f"[MASK] must be ID 0, got {mask_id}"
|
| 72 |
+
print("โ
[MASK] token confirmed at ID=0")
|
| 73 |
+
|
| 74 |
+
def _train_tokenizer(self):
|
| 75 |
+
dataset = load_dataset("paws/sanskrit-verses-gretil", split="train")
|
| 76 |
+
texts = []
|
| 77 |
+
for sample in dataset.select(range(50000)):
|
| 78 |
+
texts.extend([sample["quote_text"], sample["quote_devanagari"]])
|
| 79 |
+
tokenizer = build_tokenizer(texts, self.vocab_size)
|
| 80 |
+
tokenizer.save(self.model_path)
|
| 81 |
+
self.tokenizer = tokenizer
|
| 82 |
+
self._validate_mask_token()
|
| 83 |
+
print(f"โ
Tokenizer saved to: {self.model_path}")
|
| 84 |
+
|
| 85 |
+
def encode(self, text):
|
| 86 |
+
encoded = self.tokenizer.encode(text)
|
| 87 |
+
token_ids = encoded.ids[:self.max_len]
|
| 88 |
+
pad_id = self.tokenizer.token_to_id("[PAD]")
|
| 89 |
+
if len(token_ids) < self.max_len:
|
| 90 |
+
token_ids += [pad_id] * (self.max_len - len(token_ids))
|
| 91 |
+
return token_ids[:self.max_len]
|
| 92 |
+
|
| 93 |
+
def decode(self, ids):
|
| 94 |
+
if isinstance(ids, list) and len(ids) > 0 and isinstance(ids[0], list):
|
| 95 |
+
raise TypeError("decode() expects 1D list of IDs, not 2D.")
|
| 96 |
+
# Filter special tokens: 0=MASK 1=PAD 2=UNK 3=CLS 4=SEP
|
| 97 |
+
clean = [i for i in ids if isinstance(i, int) and i > 4]
|
| 98 |
+
if not clean:
|
| 99 |
+
return ""
|
| 100 |
+
return self.tokenizer.decode(clean, skip_special_tokens=True).strip()
|
| 101 |
+
|
| 102 |
+
def build_inputs_with_special_tokens(self, token_ids):
|
| 103 |
+
return list(token_ids)
|
| 104 |
+
|
| 105 |
+
def get_vocab(self):
|
| 106 |
+
return {str(i): i for i in range(self.vocab_size)}
|
| 107 |
+
|
| 108 |
+
def convert_ids_to_tokens(self, ids):
|
| 109 |
+
return [str(i) for i in ids]
|
| 110 |
+
|
| 111 |
+
def __len__(self):
|
| 112 |
+
return self.vocab_size
|
model_settings.json
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "d3pm_encoder_decoder",
|
| 3 |
+
"include_negative_examples": false,
|
| 4 |
+
"num_steps": 4
|
| 5 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.2
|
| 2 |
+
numpy>=1.24
|
| 3 |
+
tqdm>=4.66
|
| 4 |
+
datasets>=2.19
|
| 5 |
+
tokenizers>=0.15
|
| 6 |
+
scikit-learn>=1.3
|
sanskrit_src_tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sanskrit_tgt_tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|