Spaces:
Running
Running
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +2 -0
- README.md +5 -0
- analysis/quality_classifier.py +885 -0
- analysis/run_analysis.py +1245 -0
- analysis/step_ablation.py +640 -0
- analysis_outputs/outputs_all_models_20260325/T16/task1_encoder_cost.png +0 -0
- analysis_outputs/outputs_all_models_20260325/T16/task1_kv_cache.txt +15 -0
- analysis_outputs/outputs_all_models_20260325/T16/task1_speedup.png +0 -0
- analysis_outputs/outputs_all_models_20260325/T16/task1_time_comparison.png +0 -0
- analysis_outputs/outputs_all_models_20260325/T16/task2_all_layers_t0.png +0 -0
- analysis_outputs/outputs_all_models_20260325/T16/task2_attn_evolution.png +0 -0
- analysis_outputs/outputs_all_models_20260325/T16/task2_attn_t0.png +0 -0
- analysis_outputs/outputs_all_models_20260325/T16/task2_attn_t15.png +0 -0
- analysis_outputs/outputs_all_models_20260325/T16/task2_report.txt +35 -0
- analysis_outputs/outputs_all_models_20260325/T16/task2_semantic_drift.png +0 -0
- analysis_outputs/outputs_all_models_20260325/T16/task2_source_alignment.png +0 -0
- analysis_outputs/outputs_all_models_20260325/T16/task2_tfidf_vs_attention.png +0 -0
- analysis_outputs/outputs_all_models_20260325/T16/task3_concept_space.png +0 -0
- analysis_outputs/outputs_all_models_20260325/T16/task3_diversity_curve.png +0 -0
- analysis_outputs/outputs_all_models_20260325/T16/task3_diversity_direction.npy +3 -0
- analysis_outputs/outputs_all_models_20260325/T16/task3_pca_explained_variance.png +0 -0
- analysis_outputs/outputs_all_models_20260325/T16/task3_report.txt +21 -0
- analysis_outputs/outputs_all_models_20260325/T16/task4_3d.png +0 -0
- analysis_outputs/outputs_all_models_20260325/T16/task4_raw_results.json +8 -0
- analysis_outputs/outputs_all_models_20260325/T16/task4_report.txt +14 -0
- analysis_outputs/outputs_all_models_20260325/T16/task5_guidance_results.json +44 -0
- analysis_outputs/outputs_all_models_20260325/T16/task5_quality_classifier.pt +3 -0
- analysis_outputs/outputs_all_models_20260325/T16/task5_quality_data.npz +3 -0
- analysis_outputs/outputs_all_models_20260325/T16/task5_quality_diversity_tradeoff.png +0 -0
- analysis_outputs/outputs_all_models_20260325/T16/task5_report.txt +15 -0
- analysis_outputs/outputs_all_models_20260325/T32/task1_encoder_cost.png +0 -0
- analysis_outputs/outputs_all_models_20260325/T32/task1_kv_cache.txt +15 -0
- analysis_outputs/outputs_all_models_20260325/T32/task1_speedup.png +0 -0
- analysis_outputs/outputs_all_models_20260325/T32/task1_time_comparison.png +0 -0
- analysis_outputs/outputs_all_models_20260325/T32/task2_all_layers_t0.png +0 -0
- analysis_outputs/outputs_all_models_20260325/T32/task2_attn_evolution.png +0 -0
- analysis_outputs/outputs_all_models_20260325/T32/task2_attn_t0.png +0 -0
- analysis_outputs/outputs_all_models_20260325/T32/task2_attn_t31.png +0 -0
- analysis_outputs/outputs_all_models_20260325/T32/task2_report.txt +35 -0
- analysis_outputs/outputs_all_models_20260325/T32/task2_semantic_drift.png +0 -0
- analysis_outputs/outputs_all_models_20260325/T32/task2_source_alignment.png +0 -0
- analysis_outputs/outputs_all_models_20260325/T32/task2_tfidf_vs_attention.png +0 -0
- analysis_outputs/outputs_all_models_20260325/T32/task3_concept_space.png +0 -0
- analysis_outputs/outputs_all_models_20260325/T32/task3_diversity_curve.png +0 -0
- analysis_outputs/outputs_all_models_20260325/T32/task3_diversity_direction.npy +3 -0
- analysis_outputs/outputs_all_models_20260325/T32/task3_pca_explained_variance.png +0 -0
- analysis_outputs/outputs_all_models_20260325/T32/task3_report.txt +21 -0
- analysis_outputs/outputs_all_models_20260325/T32/task4_3d.png +0 -0
- analysis_outputs/outputs_all_models_20260325/T32/task4_raw_results.json +8 -0
- analysis_outputs/outputs_all_models_20260325/T32/task4_report.txt +14 -0
.gitattributes
CHANGED
|
@@ -36,3 +36,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 36 |
analysis_outputs/T16/task3_concept_space.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
analysis_outputs/T4/task3_concept_space.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
analysis_outputs/T8/task3_concept_space.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 36 |
analysis_outputs/T16/task3_concept_space.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
analysis_outputs/T4/task3_concept_space.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
analysis_outputs/T8/task3_concept_space.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
analysis_outputs/outputs_all_models_20260325/T32/task5_quality_diversity_tradeoff.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
analysis_outputs/outputs_all_models_20260325/T64/task5_quality_diversity_tradeoff.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -20,8 +20,13 @@ Set these Space variables in **Settings → Variables and secrets**:
|
|
| 20 |
- `HF_CHECKPOINT_REPO` = `<your-username>/sanskrit-d3pm`
|
| 21 |
- `HF_CHECKPOINT_FILE` = `best_model.pt`
|
| 22 |
- `HF_CHECKPOINT_LABEL` = `main-model` (optional)
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
The app will download checkpoint from your model repo and load it at runtime.
|
|
|
|
|
|
|
| 25 |
|
| 26 |
### Optional MLflow Tracking in Space
|
| 27 |
|
|
|
|
| 20 |
- `HF_CHECKPOINT_REPO` = `<your-username>/sanskrit-d3pm`
|
| 21 |
- `HF_CHECKPOINT_FILE` = `best_model.pt`
|
| 22 |
- `HF_CHECKPOINT_LABEL` = `main-model` (optional)
|
| 23 |
+
- `HF_DEFAULT_MODEL_TYPE` = `d3pm_cross_attention` or `d3pm_encoder_decoder`
|
| 24 |
+
- `HF_DEFAULT_INCLUDE_NEG` = `true` or `false`
|
| 25 |
+
- `HF_DEFAULT_NUM_STEPS` = checkpoint diffusion steps, for example `4`, `8`, `16`
|
| 26 |
|
| 27 |
The app will download checkpoint from your model repo and load it at runtime.
|
| 28 |
+
If the model repo contains `model_settings.json`, the Space will use it
|
| 29 |
+
automatically and these variables become optional overrides.
|
| 30 |
|
| 31 |
### Optional MLflow Tracking in Space
|
| 32 |
|
analysis/quality_classifier.py
ADDED
|
@@ -0,0 +1,885 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# """
|
| 2 |
+
# analysis/quality_classifier.py
|
| 3 |
+
# ================================
|
| 4 |
+
# Task 5: Classifier-Free Guidance for Paraphrase Quality Control
|
| 5 |
+
#
|
| 6 |
+
# Two steps — only Step 2 requires training a SMALL model (not the main D3PM):
|
| 7 |
+
#
|
| 8 |
+
# STEP 1 — Collect training data (no training):
|
| 9 |
+
# Run existing model on val set, record (hidden_state, CER) pairs.
|
| 10 |
+
# Hidden states come from model.model._last_hidden after forward_cached().
|
| 11 |
+
# CER score = quality label (lower CER = higher quality).
|
| 12 |
+
#
|
| 13 |
+
# STEP 2 — Train quality classifier:
|
| 14 |
+
# Small 2-layer MLP: d_model → 64 → 1
|
| 15 |
+
# Input: pooled decoder hidden state [B, d_model]
|
| 16 |
+
# Output: predicted quality score in [0, 1] (1 = high quality)
|
| 17 |
+
# Loss: MSE against normalized CER labels
|
| 18 |
+
# Training time: ~5-10 minutes on CPU for 10k examples
|
| 19 |
+
#
|
| 20 |
+
# STEP 3 — Guided inference (no retraining):
|
| 21 |
+
# At each diffusion step, use classifier gradient to shift logits:
|
| 22 |
+
# guided_logits = logits + λ * ∂(quality_score)/∂(logits)
|
| 23 |
+
# Higher λ → model biased toward high-quality outputs
|
| 24 |
+
# λ=0 → standard generation (no guidance)
|
| 25 |
+
#
|
| 26 |
+
# Key: main D3PM model is FROZEN throughout. Only the 10k-param classifier trains.
|
| 27 |
+
# """
|
| 28 |
+
#
|
| 29 |
+
# import torch
|
| 30 |
+
# import torch.nn as nn
|
| 31 |
+
# import torch.nn.functional as F
|
| 32 |
+
# import numpy as np
|
| 33 |
+
# import os
|
| 34 |
+
# import json
|
| 35 |
+
# from typing import List, Dict, Optional, Tuple
|
| 36 |
+
#
|
| 37 |
+
#
|
| 38 |
+
# # ── Quality classifier architecture ──────────────────────────────────
|
| 39 |
+
#
|
| 40 |
+
# class QualityClassifier(nn.Module):
|
| 41 |
+
# """
|
| 42 |
+
# Lightweight MLP that predicts transliteration quality from decoder
|
| 43 |
+
# hidden states.
|
| 44 |
+
#
|
| 45 |
+
# Architecture:
|
| 46 |
+
# d_model → 128 → 64 → 1 → Sigmoid
|
| 47 |
+
#
|
| 48 |
+
# Input: mean-pooled decoder hidden state [B, d_model]
|
| 49 |
+
# Output: quality score [B, 1] ∈ [0, 1] (1 = high quality)
|
| 50 |
+
#
|
| 51 |
+
# ~10k parameters. Trains in minutes on CPU.
|
| 52 |
+
# """
|
| 53 |
+
# def __init__(self, d_model: int):
|
| 54 |
+
# super().__init__()
|
| 55 |
+
# self.net = nn.Sequential(
|
| 56 |
+
# nn.Linear(d_model, 128),
|
| 57 |
+
# nn.ReLU(),
|
| 58 |
+
# nn.Dropout(0.1),
|
| 59 |
+
# nn.Linear(128, 64),
|
| 60 |
+
# nn.ReLU(),
|
| 61 |
+
# nn.Linear(64, 1),
|
| 62 |
+
# nn.Sigmoid(),
|
| 63 |
+
# )
|
| 64 |
+
# self.d_model = d_model
|
| 65 |
+
#
|
| 66 |
+
# def forward(self, hidden: torch.Tensor) -> torch.Tensor:
|
| 67 |
+
# """
|
| 68 |
+
# Args:
|
| 69 |
+
# hidden : [B, tgt_len, d_model] OR [B, d_model] (already pooled)
|
| 70 |
+
#
|
| 71 |
+
# Returns:
|
| 72 |
+
# score : [B, 1] quality score in [0, 1]
|
| 73 |
+
# """
|
| 74 |
+
# if hidden.dim() == 3:
|
| 75 |
+
# # Pool over sequence length
|
| 76 |
+
# hidden = hidden.mean(dim=1) # [B, d_model]
|
| 77 |
+
# return self.net(hidden) # [B, 1]
|
| 78 |
+
#
|
| 79 |
+
#
|
| 80 |
+
# # ── Training data collection ──────────────────────────────────────────
|
| 81 |
+
#
|
| 82 |
+
# @torch.no_grad()
|
| 83 |
+
# def collect_quality_data(
|
| 84 |
+
# model,
|
| 85 |
+
# src_list: List[torch.Tensor],
|
| 86 |
+
# ref_list: List[str],
|
| 87 |
+
# tgt_tokenizer,
|
| 88 |
+
# t_capture: int = 0,
|
| 89 |
+
# temperature: float = 0.8,
|
| 90 |
+
# top_k: int = 40,
|
| 91 |
+
# max_samples: int = 5000,
|
| 92 |
+
# ) -> Tuple[np.ndarray, np.ndarray]:
|
| 93 |
+
# """
|
| 94 |
+
# Collect (hidden_state, quality_score) pairs for classifier training.
|
| 95 |
+
#
|
| 96 |
+
# For each sample:
|
| 97 |
+
# 1. Run generate_cached() on src
|
| 98 |
+
# 2. Capture decoder hidden state at t=t_capture
|
| 99 |
+
# 3. Compute CER between output and reference
|
| 100 |
+
# 4. Quality = 1 - CER (normalize to [0,1])
|
| 101 |
+
#
|
| 102 |
+
# Args:
|
| 103 |
+
# model : SanskritModel
|
| 104 |
+
# src_list : list of [1, src_len] tensors
|
| 105 |
+
# ref_list : list of reference Devanagari strings
|
| 106 |
+
# tgt_tokenizer : SanskritTargetTokenizer
|
| 107 |
+
# t_capture : which step to capture hidden states (0 = final)
|
| 108 |
+
# max_samples : cap number of training examples
|
| 109 |
+
#
|
| 110 |
+
# Returns:
|
| 111 |
+
# hidden_matrix : np.ndarray [N, d_model]
|
| 112 |
+
# quality_scores: np.ndarray [N] values in [0, 1]
|
| 113 |
+
# """
|
| 114 |
+
# inner = model.model
|
| 115 |
+
# T = inner.scheduler.num_timesteps
|
| 116 |
+
# device = next(inner.parameters()).device
|
| 117 |
+
#
|
| 118 |
+
# hidden_list = []
|
| 119 |
+
# quality_list = []
|
| 120 |
+
# n = min(len(src_list), max_samples)
|
| 121 |
+
#
|
| 122 |
+
# def cer(pred, ref):
|
| 123 |
+
# if not ref:
|
| 124 |
+
# return 1.0
|
| 125 |
+
# def ed(s1, s2):
|
| 126 |
+
# m, n = len(s1), len(s2)
|
| 127 |
+
# dp = list(range(n + 1))
|
| 128 |
+
# for i in range(1, m + 1):
|
| 129 |
+
# prev, dp[0] = dp[0], i
|
| 130 |
+
# for j in range(1, n + 1):
|
| 131 |
+
# temp = dp[j]
|
| 132 |
+
# dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1])
|
| 133 |
+
# prev = temp
|
| 134 |
+
# return dp[n]
|
| 135 |
+
# return ed(pred, ref) / max(len(ref), 1)
|
| 136 |
+
#
|
| 137 |
+
# print(f"Collecting quality data from {n} examples...")
|
| 138 |
+
# for i, (src, ref) in enumerate(zip(src_list[:n], ref_list[:n])):
|
| 139 |
+
# if i % 200 == 0:
|
| 140 |
+
# print(f" {i}/{n}")
|
| 141 |
+
#
|
| 142 |
+
# if src.dim() == 1:
|
| 143 |
+
# src = src.unsqueeze(0)
|
| 144 |
+
# src = src.to(device)
|
| 145 |
+
#
|
| 146 |
+
# B = src.shape[0]
|
| 147 |
+
# tgt_len = inner.max_seq_len
|
| 148 |
+
# mask_id = inner.mask_token_id
|
| 149 |
+
#
|
| 150 |
+
# memory, src_pad_mask = inner.encode_source(src)
|
| 151 |
+
# x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
|
| 152 |
+
# hint = None
|
| 153 |
+
# h_cap = None
|
| 154 |
+
#
|
| 155 |
+
# for t_val in range(T - 1, -1, -1):
|
| 156 |
+
# t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 157 |
+
# is_last = (t_val == 0)
|
| 158 |
+
#
|
| 159 |
+
# logits, _ = inner.forward_cached(
|
| 160 |
+
# memory, src_pad_mask, x0_est, t,
|
| 161 |
+
# x0_hint=hint, inference_mode=True,
|
| 162 |
+
# )
|
| 163 |
+
#
|
| 164 |
+
# if t_val == t_capture and hasattr(inner, '_last_hidden'):
|
| 165 |
+
# h_cap = inner._last_hidden[0].mean(dim=0).detach().cpu() # [d_model]
|
| 166 |
+
#
|
| 167 |
+
# logits = logits / max(temperature, 1e-8)
|
| 168 |
+
# if top_k > 0:
|
| 169 |
+
# V = logits.shape[-1]
|
| 170 |
+
# if top_k < V:
|
| 171 |
+
# vals, _ = torch.topk(logits, top_k, dim=-1)
|
| 172 |
+
# logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
|
| 173 |
+
#
|
| 174 |
+
# probs = F.softmax(logits, dim=-1)
|
| 175 |
+
# x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
|
| 176 |
+
# hint = x0_est
|
| 177 |
+
#
|
| 178 |
+
# if h_cap is None:
|
| 179 |
+
# continue
|
| 180 |
+
#
|
| 181 |
+
# ids = [x for x in x0_est[0].tolist() if x > 4]
|
| 182 |
+
# pred = tgt_tokenizer.decode(ids).strip()
|
| 183 |
+
# q = max(0.0, 1.0 - cer(pred, ref)) # quality = 1 - CER
|
| 184 |
+
#
|
| 185 |
+
# hidden_list.append(h_cap.numpy())
|
| 186 |
+
# quality_list.append(q)
|
| 187 |
+
#
|
| 188 |
+
# print(f"Collected {len(hidden_list)} quality examples.")
|
| 189 |
+
# print(f"Quality stats: mean={np.mean(quality_list):.3f} "
|
| 190 |
+
# f"min={np.min(quality_list):.3f} max={np.max(quality_list):.3f}")
|
| 191 |
+
#
|
| 192 |
+
# return np.stack(hidden_list), np.array(quality_list, dtype=np.float32)
|
| 193 |
+
#
|
| 194 |
+
#
|
| 195 |
+
# def _sample(probs):
|
| 196 |
+
# B, L, V = probs.shape
|
| 197 |
+
# flat = probs.view(B * L, V).clamp(min=1e-9)
|
| 198 |
+
# flat = flat / flat.sum(dim=-1, keepdim=True)
|
| 199 |
+
# return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
|
| 200 |
+
#
|
| 201 |
+
#
|
| 202 |
+
# # ── Training ──────────────────────────────────────────────────────────
|
| 203 |
+
#
|
| 204 |
+
# def train_quality_classifier(
|
| 205 |
+
# hidden_matrix: np.ndarray,
|
| 206 |
+
# quality_scores: np.ndarray,
|
| 207 |
+
# d_model: int,
|
| 208 |
+
# epochs: int = 30,
|
| 209 |
+
# batch_size: int = 64,
|
| 210 |
+
# lr: float = 1e-3,
|
| 211 |
+
# val_frac: float = 0.1,
|
| 212 |
+
# save_path: Optional[str] = None,
|
| 213 |
+
# ) -> QualityClassifier:
|
| 214 |
+
# """
|
| 215 |
+
# Train QualityClassifier on collected (hidden, quality) pairs.
|
| 216 |
+
#
|
| 217 |
+
# Args:
|
| 218 |
+
# hidden_matrix : [N, d_model] from collect_quality_data()
|
| 219 |
+
# quality_scores : [N] quality labels in [0, 1]
|
| 220 |
+
# d_model : hidden dimension
|
| 221 |
+
# epochs : training epochs
|
| 222 |
+
# save_path : if given, save trained classifier weights here
|
| 223 |
+
#
|
| 224 |
+
# Returns:
|
| 225 |
+
# trained QualityClassifier
|
| 226 |
+
# """
|
| 227 |
+
# device = torch.device("cpu") # classifier is tiny, CPU is fine
|
| 228 |
+
#
|
| 229 |
+
# X = torch.tensor(hidden_matrix, dtype=torch.float32)
|
| 230 |
+
# y = torch.tensor(quality_scores, dtype=torch.float32).unsqueeze(-1)
|
| 231 |
+
#
|
| 232 |
+
# N = len(X)
|
| 233 |
+
# n_val = max(1, int(N * val_frac))
|
| 234 |
+
# idx = torch.randperm(N)
|
| 235 |
+
# val_idx = idx[:n_val]
|
| 236 |
+
# train_idx = idx[n_val:]
|
| 237 |
+
#
|
| 238 |
+
# X_train, y_train = X[train_idx], y[train_idx]
|
| 239 |
+
# X_val, y_val = X[val_idx], y[val_idx]
|
| 240 |
+
#
|
| 241 |
+
# clf = QualityClassifier(d_model).to(device)
|
| 242 |
+
# optimizer = torch.optim.Adam(clf.parameters(), lr=lr)
|
| 243 |
+
#
|
| 244 |
+
# print(f"\nTraining QualityClassifier: {sum(p.numel() for p in clf.parameters())} params")
|
| 245 |
+
# print(f"Train: {len(X_train)} Val: {len(X_val)}")
|
| 246 |
+
#
|
| 247 |
+
# best_val_loss = float('inf')
|
| 248 |
+
# best_state = None
|
| 249 |
+
#
|
| 250 |
+
# for epoch in range(epochs):
|
| 251 |
+
# clf.train()
|
| 252 |
+
# perm = torch.randperm(len(X_train))
|
| 253 |
+
# train_loss = 0.0
|
| 254 |
+
# n_batches = 0
|
| 255 |
+
#
|
| 256 |
+
# for start in range(0, len(X_train), batch_size):
|
| 257 |
+
# batch_idx = perm[start:start + batch_size]
|
| 258 |
+
# xb, yb = X_train[batch_idx], y_train[batch_idx]
|
| 259 |
+
# pred = clf(xb)
|
| 260 |
+
# loss = F.mse_loss(pred, yb)
|
| 261 |
+
# optimizer.zero_grad()
|
| 262 |
+
# loss.backward()
|
| 263 |
+
# optimizer.step()
|
| 264 |
+
# train_loss += loss.item()
|
| 265 |
+
# n_batches += 1
|
| 266 |
+
#
|
| 267 |
+
# clf.eval()
|
| 268 |
+
# with torch.no_grad():
|
| 269 |
+
# val_pred = clf(X_val)
|
| 270 |
+
# val_loss = F.mse_loss(val_pred, y_val).item()
|
| 271 |
+
#
|
| 272 |
+
# if epoch % 5 == 0 or epoch == epochs - 1:
|
| 273 |
+
# print(f" Ep {epoch+1:3d} train={train_loss/n_batches:.4f} val={val_loss:.4f}")
|
| 274 |
+
#
|
| 275 |
+
# if val_loss < best_val_loss:
|
| 276 |
+
# best_val_loss = val_loss
|
| 277 |
+
# best_state = {k: v.clone() for k, v in clf.state_dict().items()}
|
| 278 |
+
#
|
| 279 |
+
# if best_state:
|
| 280 |
+
# clf.load_state_dict(best_state)
|
| 281 |
+
# print(f" Best val loss: {best_val_loss:.4f}")
|
| 282 |
+
#
|
| 283 |
+
# if save_path:
|
| 284 |
+
# os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
|
| 285 |
+
# torch.save(clf.state_dict(), save_path)
|
| 286 |
+
# print(f" Classifier saved: {save_path}")
|
| 287 |
+
#
|
| 288 |
+
# return clf
|
| 289 |
+
#
|
| 290 |
+
#
|
| 291 |
+
# # ── Guided inference ──────────────────────────────────────────────────
|
| 292 |
+
#
|
| 293 |
+
# def generate_guided(
|
| 294 |
+
# model,
|
| 295 |
+
# src: torch.Tensor,
|
| 296 |
+
# classifier: QualityClassifier,
|
| 297 |
+
# guidance_scale: float = 1.0,
|
| 298 |
+
# temperature: float = 0.8,
|
| 299 |
+
# top_k: int = 40,
|
| 300 |
+
# ) -> torch.Tensor:
|
| 301 |
+
# """
|
| 302 |
+
# Classifier-guided generation.
|
| 303 |
+
#
|
| 304 |
+
# At each diffusion step:
|
| 305 |
+
# 1. Run forward_cached() → logits, hidden states
|
| 306 |
+
# 2. Compute classifier gradient: ∂(quality_score) / ∂(hidden)
|
| 307 |
+
# 3. Project gradient back to logit space (approximate)
|
| 308 |
+
# 4. guided_logits = logits + λ * gradient_signal
|
| 309 |
+
# 5. Sample from guided_logits
|
| 310 |
+
#
|
| 311 |
+
# guidance_scale λ:
|
| 312 |
+
# 0.0 → no guidance (standard generation)
|
| 313 |
+
# 0.5 → weak guidance
|
| 314 |
+
# 1.0 → moderate guidance (recommended starting point)
|
| 315 |
+
# 2.0 → strong guidance (may reduce diversity)
|
| 316 |
+
# 3.0 → very strong (may collapse to repetitive output)
|
| 317 |
+
#
|
| 318 |
+
# Args:
|
| 319 |
+
# model : SanskritModel (frozen)
|
| 320 |
+
# src : [1, src_len] IAST token ids
|
| 321 |
+
# classifier : trained QualityClassifier
|
| 322 |
+
# guidance_scale : λ — guidance strength
|
| 323 |
+
#
|
| 324 |
+
# Returns:
|
| 325 |
+
# x0_est : [1, tgt_len] generated token ids
|
| 326 |
+
# """
|
| 327 |
+
# inner = model.model
|
| 328 |
+
# T = inner.scheduler.num_timesteps
|
| 329 |
+
# device = next(inner.parameters()).device
|
| 330 |
+
# clf_device = next(classifier.parameters()).device
|
| 331 |
+
#
|
| 332 |
+
# if src.dim() == 1:
|
| 333 |
+
# src = src.unsqueeze(0)
|
| 334 |
+
# src = src.to(device)
|
| 335 |
+
#
|
| 336 |
+
# B = src.shape[0]
|
| 337 |
+
# tgt_len = inner.max_seq_len
|
| 338 |
+
# mask_id = inner.mask_token_id
|
| 339 |
+
#
|
| 340 |
+
# memory, src_pad_mask = inner.encode_source(src)
|
| 341 |
+
# x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
|
| 342 |
+
# hint = None
|
| 343 |
+
#
|
| 344 |
+
# inner.eval()
|
| 345 |
+
# classifier.eval()
|
| 346 |
+
#
|
| 347 |
+
# for t_val in range(T - 1, -1, -1):
|
| 348 |
+
# t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 349 |
+
# is_last = (t_val == 0)
|
| 350 |
+
#
|
| 351 |
+
# if guidance_scale > 0.0:
|
| 352 |
+
# # Need gradients for classifier guidance
|
| 353 |
+
# with torch.enable_grad():
|
| 354 |
+
# # Run forward_cached and get hidden states
|
| 355 |
+
# PAD = 1
|
| 356 |
+
# if t_val > 0:
|
| 357 |
+
# _, x_t_ids = inner.forward_process.q_sample(x0_est, t)
|
| 358 |
+
# else:
|
| 359 |
+
# x_t_ids = x0_est
|
| 360 |
+
#
|
| 361 |
+
# x = inner.tgt_embed(x_t_ids)
|
| 362 |
+
# t_norm = t.float() / T
|
| 363 |
+
# t_emb = inner.time_mlp(t_norm.unsqueeze(-1))
|
| 364 |
+
# x = x + t_emb.unsqueeze(1)
|
| 365 |
+
#
|
| 366 |
+
# if hint is not None:
|
| 367 |
+
# hint_emb = inner.tgt_embed(hint)
|
| 368 |
+
# gate = inner.hint_gate(x)
|
| 369 |
+
# x = x + gate * hint_emb
|
| 370 |
+
#
|
| 371 |
+
# for block in inner.decoder_blocks:
|
| 372 |
+
# x = block(x, memory, tgt_pad_mask=None, src_pad_mask=src_pad_mask)
|
| 373 |
+
#
|
| 374 |
+
# # hidden: [B, tgt_len, d_model] — detach from graph for clf
|
| 375 |
+
# hidden = x.detach().requires_grad_(True).to(clf_device)
|
| 376 |
+
#
|
| 377 |
+
# # Classifier quality score
|
| 378 |
+
# quality = classifier(hidden) # [B, 1]
|
| 379 |
+
# quality.sum().backward()
|
| 380 |
+
#
|
| 381 |
+
# # Gradient of quality w.r.t. hidden: [B, tgt_len, d_model]
|
| 382 |
+
# grad = hidden.grad.to(device) # [B, tgt_len, d_model]
|
| 383 |
+
#
|
| 384 |
+
# # Project gradient to logit space via output head weight
|
| 385 |
+
# # logit_grad ≈ grad @ head.weight [B, tgt_len, tgt_vocab]
|
| 386 |
+
# logit_grad = grad @ inner.head.weight.T
|
| 387 |
+
#
|
| 388 |
+
# # Compute standard logits (no gradient needed)
|
| 389 |
+
# with torch.no_grad():
|
| 390 |
+
# logits = inner.head(x)
|
| 391 |
+
#
|
| 392 |
+
# # Apply guidance
|
| 393 |
+
# logits = logits + guidance_scale * logit_grad
|
| 394 |
+
#
|
| 395 |
+
# else:
|
| 396 |
+
# with torch.no_grad():
|
| 397 |
+
# logits, _ = inner.forward_cached(
|
| 398 |
+
# memory, src_pad_mask, x0_est, t,
|
| 399 |
+
# x0_hint=hint, inference_mode=True,
|
| 400 |
+
# )
|
| 401 |
+
#
|
| 402 |
+
# with torch.no_grad():
|
| 403 |
+
# logits = logits / max(temperature, 1e-8)
|
| 404 |
+
# if top_k > 0:
|
| 405 |
+
# V = logits.shape[-1]
|
| 406 |
+
# if top_k < V:
|
| 407 |
+
# vals, _ = torch.topk(logits, top_k, dim=-1)
|
| 408 |
+
# logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
|
| 409 |
+
#
|
| 410 |
+
# probs = F.softmax(logits, dim=-1)
|
| 411 |
+
# x0_est = torch.argmax(probs, dim=-1) if is_last else _sample_no_grad(probs)
|
| 412 |
+
# hint = x0_est
|
| 413 |
+
#
|
| 414 |
+
# return x0_est
|
| 415 |
+
#
|
| 416 |
+
#
|
| 417 |
+
# def _sample_no_grad(probs):
|
| 418 |
+
# B, L, V = probs.shape
|
| 419 |
+
# flat = probs.view(B * L, V).clamp(min=1e-9)
|
| 420 |
+
# flat = flat / flat.sum(dim=-1, keepdim=True)
|
| 421 |
+
# return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
|
| 422 |
+
#
|
| 423 |
+
#
|
| 424 |
+
# # ── Guidance scale sweep ──────────────────────────────────────────────
|
| 425 |
+
#
|
| 426 |
+
# def sweep_guidance_scales(
|
| 427 |
+
# model,
|
| 428 |
+
# classifier: QualityClassifier,
|
| 429 |
+
# src_list: List[torch.Tensor],
|
| 430 |
+
# ref_list: List[str],
|
| 431 |
+
# tgt_tokenizer,
|
| 432 |
+
# scales: List[float] = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0],
|
| 433 |
+
# n_samples: int = 50,
|
| 434 |
+
# device: torch.device = None,
|
| 435 |
+
# output_dir: str = "analysis/outputs",
|
| 436 |
+
# ) -> Dict:
|
| 437 |
+
# """
|
| 438 |
+
# Evaluate CER at each guidance scale.
|
| 439 |
+
# Produces quality-diversity tradeoff plot.
|
| 440 |
+
# """
|
| 441 |
+
# def cer(pred, ref):
|
| 442 |
+
# if not ref:
|
| 443 |
+
# return 1.0
|
| 444 |
+
# def ed(s1, s2):
|
| 445 |
+
# m, n = len(s1), len(s2)
|
| 446 |
+
# dp = list(range(n + 1))
|
| 447 |
+
# for i in range(1, m + 1):
|
| 448 |
+
# prev, dp[0] = dp[0], i
|
| 449 |
+
# for j in range(1, n + 1):
|
| 450 |
+
# temp = dp[j]
|
| 451 |
+
# dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1])
|
| 452 |
+
# prev = temp
|
| 453 |
+
# return dp[n]
|
| 454 |
+
# return ed(pred, ref) / max(len(ref), 1)
|
| 455 |
+
#
|
| 456 |
+
# device = device or next(model.parameters()).device
|
| 457 |
+
# results = {}
|
| 458 |
+
# n = min(n_samples, len(src_list))
|
| 459 |
+
#
|
| 460 |
+
# print("\nGuidance scale sweep...")
|
| 461 |
+
# for scale in scales:
|
| 462 |
+
# cer_list = []
|
| 463 |
+
# output_set = []
|
| 464 |
+
# for src, ref in zip(src_list[:n], ref_list[:n]):
|
| 465 |
+
# if src.dim() == 1:
|
| 466 |
+
# src = src.unsqueeze(0)
|
| 467 |
+
# out = generate_guided(model, src.to(device), classifier,
|
| 468 |
+
# guidance_scale=scale)
|
| 469 |
+
# ids = [x for x in out[0].tolist() if x > 4]
|
| 470 |
+
# pred = tgt_tokenizer.decode(ids).strip()
|
| 471 |
+
# cer_list.append(cer(pred, ref))
|
| 472 |
+
# output_set.append(pred)
|
| 473 |
+
#
|
| 474 |
+
# mean_cer = float(np.mean(cer_list))
|
| 475 |
+
#
|
| 476 |
+
# # Self-diversity: unique outputs / total (proxy for diversity)
|
| 477 |
+
# unique_frac = len(set(output_set)) / max(len(output_set), 1)
|
| 478 |
+
#
|
| 479 |
+
# results[scale] = {"mean_cer": mean_cer, "diversity": unique_frac}
|
| 480 |
+
# print(f" λ={scale:.1f} CER={mean_cer:.4f} diversity={unique_frac:.3f}")
|
| 481 |
+
#
|
| 482 |
+
# # Plot
|
| 483 |
+
# os.makedirs(output_dir, exist_ok=True)
|
| 484 |
+
# try:
|
| 485 |
+
# import matplotlib.pyplot as plt
|
| 486 |
+
# fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
|
| 487 |
+
#
|
| 488 |
+
# sc_list = sorted(results.keys())
|
| 489 |
+
# cers = [results[s]["mean_cer"] for s in sc_list]
|
| 490 |
+
# diversities = [results[s]["diversity"] for s in sc_list]
|
| 491 |
+
#
|
| 492 |
+
# ax1.plot(sc_list, cers, 'o-', color='coral', linewidth=1.8, markersize=7)
|
| 493 |
+
# ax1.set_xlabel("Guidance scale λ", fontsize=10)
|
| 494 |
+
# ax1.set_ylabel("CER (↓ better)", fontsize=10)
|
| 495 |
+
# ax1.set_title("Quality vs guidance scale", fontsize=10)
|
| 496 |
+
#
|
| 497 |
+
# ax2.plot(sc_list, diversities, 'o-', color='steelblue', linewidth=1.8, markersize=7)
|
| 498 |
+
# ax2.set_xlabel("Guidance scale λ", fontsize=10)
|
| 499 |
+
# ax2.set_ylabel("Output diversity (unique fraction)", fontsize=10)
|
| 500 |
+
# ax2.set_title("Diversity vs guidance scale", fontsize=10)
|
| 501 |
+
#
|
| 502 |
+
# plt.suptitle("Quality-Diversity Tradeoff (Guidance Scale Sweep)", fontsize=11)
|
| 503 |
+
# plt.tight_layout()
|
| 504 |
+
# path = os.path.join(output_dir, "guidance_scale_sweep.png")
|
| 505 |
+
# plt.savefig(path, dpi=150, bbox_inches='tight')
|
| 506 |
+
# plt.close()
|
| 507 |
+
# print(f" Saved: {path}")
|
| 508 |
+
# except ImportError:
|
| 509 |
+
# pass
|
| 510 |
+
#
|
| 511 |
+
# with open(os.path.join(output_dir, "guidance_results.json"), "w") as f:
|
| 512 |
+
# json.dump({str(k): v for k, v in results.items()}, f, indent=2)
|
| 513 |
+
#
|
| 514 |
+
# return results
|
| 515 |
+
import os
|
| 516 |
+
import json
|
| 517 |
+
import torch
|
| 518 |
+
import torch.nn as nn
|
| 519 |
+
import torch.nn.functional as F
|
| 520 |
+
import numpy as np
|
| 521 |
+
from typing import List, Dict
|
| 522 |
+
from itertools import combinations
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
class QualityClassifier(nn.Module):
|
| 526 |
+
def __init__(self, d_model: int):
|
| 527 |
+
super().__init__()
|
| 528 |
+
self.net = nn.Sequential(
|
| 529 |
+
nn.Linear(d_model, 128),
|
| 530 |
+
nn.ReLU(),
|
| 531 |
+
nn.Dropout(0.1),
|
| 532 |
+
nn.Linear(128, 64),
|
| 533 |
+
nn.ReLU(),
|
| 534 |
+
nn.Linear(64, 1),
|
| 535 |
+
nn.Sigmoid(),
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
def forward(self, hidden):
|
| 539 |
+
if hidden.dim() == 3:
|
| 540 |
+
hidden = hidden.mean(dim=1)
|
| 541 |
+
return self.net(hidden)
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
def _cer(pred: str, ref: str) -> float:
|
| 545 |
+
m, n = len(pred), len(ref)
|
| 546 |
+
if m == 0 and n == 0:
|
| 547 |
+
return 0.0
|
| 548 |
+
dp = list(range(n + 1))
|
| 549 |
+
for i in range(1, m + 1):
|
| 550 |
+
prev, dp[0] = dp[0], i
|
| 551 |
+
for j in range(1, n + 1):
|
| 552 |
+
tmp = dp[j]
|
| 553 |
+
dp[j] = prev if pred[i - 1] == ref[j - 1] else 1 + min(prev, dp[j], dp[j - 1])
|
| 554 |
+
prev = tmp
|
| 555 |
+
return float(dp[n]) / max(1, m, n)
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
def _sample(probs: torch.Tensor) -> torch.Tensor:
|
| 559 |
+
B, L, V = probs.shape
|
| 560 |
+
flat = probs.reshape(B * L, V).clamp(min=1e-9)
|
| 561 |
+
flat = flat / flat.sum(dim=-1, keepdim=True)
|
| 562 |
+
return torch.multinomial(flat, 1).squeeze(-1).reshape(B, L)
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
@torch.no_grad()
|
| 566 |
+
def _decode_pred(tgt_tokenizer, out_ids: torch.Tensor) -> str:
|
| 567 |
+
ids = [x for x in out_ids[0].tolist() if x > 4]
|
| 568 |
+
return tgt_tokenizer.decode(ids).strip()
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
def _tokenize_ws(text: str) -> list[str]:
|
| 572 |
+
return [t for t in text.split() if t]
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
def _distinct_n(outputs: List[str], n: int = 2) -> float:
|
| 576 |
+
ngrams = []
|
| 577 |
+
for s in outputs:
|
| 578 |
+
toks = _tokenize_ws(s)
|
| 579 |
+
if len(toks) < n:
|
| 580 |
+
continue
|
| 581 |
+
ngrams.extend([tuple(toks[i:i+n]) for i in range(len(toks) - n + 1)])
|
| 582 |
+
if not ngrams:
|
| 583 |
+
return 0.0
|
| 584 |
+
return float(len(set(ngrams)) / max(1, len(ngrams)))
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
def _self_bleu(outputs: List[str], max_pairs: int = 64) -> float:
|
| 588 |
+
try:
|
| 589 |
+
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
| 590 |
+
except Exception:
|
| 591 |
+
return 0.0
|
| 592 |
+
toks = [_tokenize_ws(s) for s in outputs if s.strip()]
|
| 593 |
+
if len(toks) < 2:
|
| 594 |
+
return 0.0
|
| 595 |
+
smooth = SmoothingFunction().method1
|
| 596 |
+
pairs = list(combinations(range(len(toks)), 2))
|
| 597 |
+
if len(pairs) > max_pairs:
|
| 598 |
+
idx = np.linspace(0, len(pairs) - 1, max_pairs, dtype=int)
|
| 599 |
+
pairs = [pairs[i] for i in idx]
|
| 600 |
+
vals = []
|
| 601 |
+
for i, j in pairs:
|
| 602 |
+
ref = [toks[j]]
|
| 603 |
+
hyp = toks[i]
|
| 604 |
+
if not hyp:
|
| 605 |
+
continue
|
| 606 |
+
vals.append(float(sentence_bleu(ref, hyp, smoothing_function=smooth)))
|
| 607 |
+
return float(np.mean(vals)) if vals else 0.0
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
@torch.no_grad()
|
| 611 |
+
def collect_quality_data(
|
| 612 |
+
model,
|
| 613 |
+
src_list: List[torch.Tensor],
|
| 614 |
+
ref_list: List[str],
|
| 615 |
+
tgt_tokenizer,
|
| 616 |
+
t_capture: int = 0,
|
| 617 |
+
max_samples: int = 1000,
|
| 618 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 619 |
+
inner = model.model
|
| 620 |
+
device = next(inner.parameters()).device
|
| 621 |
+
inner.eval()
|
| 622 |
+
|
| 623 |
+
hidden_rows = []
|
| 624 |
+
quality_rows = []
|
| 625 |
+
|
| 626 |
+
n = min(max_samples, len(src_list), len(ref_list))
|
| 627 |
+
print(f"Collecting quality data from {n} examples...")
|
| 628 |
+
for i, (src, ref) in enumerate(zip(src_list[:n], ref_list[:n])):
|
| 629 |
+
if src.dim() == 1:
|
| 630 |
+
src = src.unsqueeze(0)
|
| 631 |
+
src = src.to(device)
|
| 632 |
+
|
| 633 |
+
out = inner.generate_cached(src) if hasattr(inner, "generate_cached") else inner.generate(src)
|
| 634 |
+
pred = _decode_pred(tgt_tokenizer, out)
|
| 635 |
+
cer_q = 1.0 - _cer(pred, ref)
|
| 636 |
+
toks = [t for t in pred.split() if t]
|
| 637 |
+
uniq = len(set(toks)) / max(1, len(toks))
|
| 638 |
+
len_ratio = min(1.0, len(toks) / max(1, len(ref.split())))
|
| 639 |
+
# Blend quality target to avoid all-zero collapse on weak checkpoints.
|
| 640 |
+
quality = 0.70 * cer_q + 0.20 * uniq + 0.10 * len_ratio
|
| 641 |
+
|
| 642 |
+
memory, src_pad = inner.encode_source(src)
|
| 643 |
+
t = torch.full((1,), int(t_capture), dtype=torch.long, device=device)
|
| 644 |
+
_ = inner.forward_cached(memory, src_pad, out, t, x0_hint=out, inference_mode=True)
|
| 645 |
+
hidden = getattr(inner, "_last_hidden", None)
|
| 646 |
+
if hidden is None:
|
| 647 |
+
continue
|
| 648 |
+
hidden_rows.append(hidden[0].mean(dim=0).detach().cpu().numpy())
|
| 649 |
+
quality_rows.append(float(np.clip(quality, 0.0, 1.0)))
|
| 650 |
+
if i % 200 == 0:
|
| 651 |
+
print(f" {i}/{n}")
|
| 652 |
+
|
| 653 |
+
if not hidden_rows:
|
| 654 |
+
raise RuntimeError("No hidden states collected for quality classifier.")
|
| 655 |
+
hidden_arr = np.asarray(hidden_rows, dtype=np.float32)
|
| 656 |
+
quality_arr = np.asarray(quality_rows, dtype=np.float32)
|
| 657 |
+
print(f"Collected {hidden_arr.shape[0]} quality examples.")
|
| 658 |
+
return hidden_arr, quality_arr
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
def train_quality_classifier(
|
| 662 |
+
hidden: np.ndarray,
|
| 663 |
+
quality: np.ndarray,
|
| 664 |
+
d_model: int,
|
| 665 |
+
epochs: int = 30,
|
| 666 |
+
batch_size: int = 64,
|
| 667 |
+
lr: float = 1e-3,
|
| 668 |
+
save_path: str | None = None,
|
| 669 |
+
):
|
| 670 |
+
device = torch.device("cpu")
|
| 671 |
+
clf = QualityClassifier(d_model).to(device)
|
| 672 |
+
|
| 673 |
+
x = torch.tensor(hidden, dtype=torch.float32, device=device)
|
| 674 |
+
q = quality.astype(np.float32)
|
| 675 |
+
# Standardize target for better gradients when raw spread is tiny.
|
| 676 |
+
q_mu = float(np.mean(q))
|
| 677 |
+
q_sd = float(np.std(q))
|
| 678 |
+
if q_sd < 1e-4:
|
| 679 |
+
q = q + np.random.normal(0.0, 1e-3, size=q.shape).astype(np.float32)
|
| 680 |
+
q_mu = float(np.mean(q))
|
| 681 |
+
q_sd = float(np.std(q))
|
| 682 |
+
q = np.clip((q - q_mu) / max(q_sd, 1e-6), -3.0, 3.0)
|
| 683 |
+
y = torch.tensor(q, dtype=torch.float32, device=device).unsqueeze(-1)
|
| 684 |
+
|
| 685 |
+
idx = torch.randperm(x.shape[0])
|
| 686 |
+
split = int(0.9 * x.shape[0])
|
| 687 |
+
tr, va = idx[:split], idx[split:]
|
| 688 |
+
|
| 689 |
+
x_tr, y_tr = x[tr], y[tr]
|
| 690 |
+
x_va, y_va = x[va], y[va]
|
| 691 |
+
|
| 692 |
+
opt = torch.optim.Adam(clf.parameters(), lr=lr)
|
| 693 |
+
loss_fn = nn.MSELoss()
|
| 694 |
+
best_val = float("inf")
|
| 695 |
+
best_state = None
|
| 696 |
+
|
| 697 |
+
print(f"\nTraining QualityClassifier: {sum(p.numel() for p in clf.parameters())} params")
|
| 698 |
+
print(f"Train: {x_tr.shape[0]} Val: {x_va.shape[0]}")
|
| 699 |
+
for ep in range(1, epochs + 1):
|
| 700 |
+
clf.train()
|
| 701 |
+
ep_losses = []
|
| 702 |
+
for i in range(0, x_tr.shape[0], batch_size):
|
| 703 |
+
xb = x_tr[i : i + batch_size]
|
| 704 |
+
yb = y_tr[i : i + batch_size]
|
| 705 |
+
pred = clf(xb)
|
| 706 |
+
loss = loss_fn(pred, yb)
|
| 707 |
+
opt.zero_grad(set_to_none=True)
|
| 708 |
+
loss.backward()
|
| 709 |
+
opt.step()
|
| 710 |
+
ep_losses.append(float(loss.item()))
|
| 711 |
+
tr_loss = float(np.mean(ep_losses)) if ep_losses else 0.0
|
| 712 |
+
|
| 713 |
+
clf.eval()
|
| 714 |
+
with torch.no_grad():
|
| 715 |
+
va_loss = float(loss_fn(clf(x_va), y_va).item()) if x_va.shape[0] else tr_loss
|
| 716 |
+
if va_loss < best_val:
|
| 717 |
+
best_val = va_loss
|
| 718 |
+
best_state = {k: v.detach().cpu().clone() for k, v in clf.state_dict().items()}
|
| 719 |
+
if ep == 1 or ep % 5 == 0 or ep == epochs:
|
| 720 |
+
print(f" Ep {ep:>3d} train={tr_loss:.4f} val={va_loss:.4f}")
|
| 721 |
+
|
| 722 |
+
if best_state is not None:
|
| 723 |
+
clf.load_state_dict(best_state)
|
| 724 |
+
clf.eval()
|
| 725 |
+
print(f" Best val loss: {best_val:.4f}")
|
| 726 |
+
|
| 727 |
+
if save_path:
|
| 728 |
+
torch.save(clf.state_dict(), save_path)
|
| 729 |
+
print(f" Classifier saved: {save_path}")
|
| 730 |
+
return clf
|
| 731 |
+
|
| 732 |
+
|
| 733 |
+
def generate_guided(
|
| 734 |
+
model,
|
| 735 |
+
src: torch.Tensor,
|
| 736 |
+
classifier: QualityClassifier,
|
| 737 |
+
guidance_scale: float = 1.0,
|
| 738 |
+
temperature: float = 0.8,
|
| 739 |
+
top_k: int = 40,
|
| 740 |
+
):
|
| 741 |
+
inner = model.model
|
| 742 |
+
T = inner.scheduler.num_timesteps
|
| 743 |
+
device = next(inner.parameters()).device
|
| 744 |
+
if src.dim() == 1:
|
| 745 |
+
src = src.unsqueeze(0)
|
| 746 |
+
src = src.to(device)
|
| 747 |
+
B = src.shape[0]
|
| 748 |
+
tgt_len = inner.max_seq_len
|
| 749 |
+
mask_id = inner.mask_token_id
|
| 750 |
+
|
| 751 |
+
memory, src_pad_mask = inner.encode_source(src)
|
| 752 |
+
x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
|
| 753 |
+
hint = None
|
| 754 |
+
|
| 755 |
+
inner.eval()
|
| 756 |
+
classifier.eval()
|
| 757 |
+
|
| 758 |
+
for t_val in range(T - 1, -1, -1):
|
| 759 |
+
t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 760 |
+
is_last = t_val == 0
|
| 761 |
+
|
| 762 |
+
with torch.no_grad():
|
| 763 |
+
logits, _ = inner.forward_cached(memory, src_pad_mask, x0_est, t, x0_hint=hint, inference_mode=True)
|
| 764 |
+
hidden = getattr(inner, "_last_hidden", None)
|
| 765 |
+
|
| 766 |
+
if guidance_scale > 0.0 and hidden is not None:
|
| 767 |
+
hidden_leaf = hidden.detach().requires_grad_(True)
|
| 768 |
+
q = classifier(hidden_leaf).sum()
|
| 769 |
+
grad = torch.autograd.grad(q, hidden_leaf, retain_graph=False, create_graph=False)[0]
|
| 770 |
+
grad = grad / (grad.norm(dim=-1, keepdim=True) + 1e-6)
|
| 771 |
+
logit_grad = torch.matmul(grad, inner.head.weight.T)
|
| 772 |
+
logits = logits + (1.5 * guidance_scale) * torch.clamp(logit_grad, -6.0, 6.0)
|
| 773 |
+
|
| 774 |
+
logits = logits / max(float(temperature), 1e-8)
|
| 775 |
+
if top_k > 0 and top_k < logits.shape[-1]:
|
| 776 |
+
vals, _ = torch.topk(logits, int(top_k), dim=-1)
|
| 777 |
+
logits = logits.masked_fill(logits < vals[..., -1:], float("-inf"))
|
| 778 |
+
|
| 779 |
+
probs = F.softmax(logits, dim=-1)
|
| 780 |
+
x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
|
| 781 |
+
hint = x0_est
|
| 782 |
+
return x0_est
|
| 783 |
+
|
| 784 |
+
|
| 785 |
+
def sweep_guidance_scales(
|
| 786 |
+
model,
|
| 787 |
+
classifier: QualityClassifier,
|
| 788 |
+
src_list: List[torch.Tensor],
|
| 789 |
+
ref_list: List[str],
|
| 790 |
+
tgt_tokenizer,
|
| 791 |
+
scales: List[float] = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0],
|
| 792 |
+
n_samples: int = 50,
|
| 793 |
+
device=None,
|
| 794 |
+
output_dir: str = "analysis/outputs",
|
| 795 |
+
) -> Dict:
|
| 796 |
+
device = device or next(model.parameters()).device
|
| 797 |
+
n = min(n_samples, len(src_list), len(ref_list))
|
| 798 |
+
results = {}
|
| 799 |
+
print("\nGuidance scale sweep...")
|
| 800 |
+
for scale in scales:
|
| 801 |
+
cer_vals = []
|
| 802 |
+
outputs = []
|
| 803 |
+
for src, ref in zip(src_list[:n], ref_list[:n]):
|
| 804 |
+
# Higher λ gets slightly sharper decoding and stronger signal.
|
| 805 |
+
temp = max(0.55, 0.85 - 0.08 * float(scale))
|
| 806 |
+
k = max(12, int(40 - 4 * float(scale)))
|
| 807 |
+
out = generate_guided(
|
| 808 |
+
model, src.to(device), classifier,
|
| 809 |
+
guidance_scale=float(scale), temperature=temp, top_k=k
|
| 810 |
+
)
|
| 811 |
+
pred = _decode_pred(tgt_tokenizer, out)
|
| 812 |
+
cer_vals.append(_cer(pred, ref))
|
| 813 |
+
outputs.append(pred)
|
| 814 |
+
mean_cer = float(np.mean(cer_vals)) if cer_vals else 1.0
|
| 815 |
+
sent_unique = float(len(set(outputs)) / max(1, len(outputs)))
|
| 816 |
+
distinct2 = _distinct_n(outputs, n=2)
|
| 817 |
+
self_bleu = _self_bleu(outputs)
|
| 818 |
+
self_bleu_div = 1.0 - self_bleu
|
| 819 |
+
diversity = float(0.5 * distinct2 + 0.5 * self_bleu_div)
|
| 820 |
+
results[float(scale)] = {
|
| 821 |
+
"mean_cer": mean_cer,
|
| 822 |
+
"diversity": diversity,
|
| 823 |
+
"sent_unique": sent_unique,
|
| 824 |
+
"distinct2": distinct2,
|
| 825 |
+
"self_bleu": self_bleu,
|
| 826 |
+
}
|
| 827 |
+
print(
|
| 828 |
+
f" λ={float(scale):.1f} CER={mean_cer:.4f} "
|
| 829 |
+
f"div={diversity:.3f} d2={distinct2:.3f} sBLEU={self_bleu:.3f}"
|
| 830 |
+
)
|
| 831 |
+
|
| 832 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 833 |
+
try:
|
| 834 |
+
import matplotlib.pyplot as plt
|
| 835 |
+
xs = sorted(results.keys())
|
| 836 |
+
ys_c = [results[x]["mean_cer"] for x in xs]
|
| 837 |
+
ys_d = [results[x]["diversity"] for x in xs]
|
| 838 |
+
ys_d2 = [results[x]["distinct2"] for x in xs]
|
| 839 |
+
fig, ax = plt.subplots(1, 3, figsize=(13, 4))
|
| 840 |
+
ax[0].plot(xs, ys_c, marker="o")
|
| 841 |
+
ax[0].set_xlabel("Guidance scale λ")
|
| 842 |
+
ax[0].set_ylabel("CER (lower is better)")
|
| 843 |
+
ax[0].set_title("Quality vs Guidance")
|
| 844 |
+
ax[1].plot(xs, ys_d, marker="o")
|
| 845 |
+
ax[1].set_xlabel("Guidance scale λ")
|
| 846 |
+
ax[1].set_ylabel("Composite diversity")
|
| 847 |
+
ax[1].set_title("Diversity vs Guidance")
|
| 848 |
+
ax[2].plot(xs, ys_d2, marker="o")
|
| 849 |
+
ax[2].set_xlabel("Guidance scale λ")
|
| 850 |
+
ax[2].set_ylabel("Distinct-2")
|
| 851 |
+
ax[2].set_title("Distinct-2 vs Guidance")
|
| 852 |
+
plt.tight_layout()
|
| 853 |
+
plt.savefig(os.path.join(output_dir, "task5_quality_diversity_tradeoff.png"), dpi=150, bbox_inches="tight")
|
| 854 |
+
plt.close()
|
| 855 |
+
except Exception:
|
| 856 |
+
pass
|
| 857 |
+
|
| 858 |
+
with open(os.path.join(output_dir, "task5_guidance_results.json"), "w", encoding="utf-8") as f:
|
| 859 |
+
json.dump({str(k): v for k, v in results.items()}, f, indent=2)
|
| 860 |
+
return results
|
| 861 |
+
|
| 862 |
+
|
| 863 |
+
def sweep_guidance(
|
| 864 |
+
model,
|
| 865 |
+
classifier,
|
| 866 |
+
src_list,
|
| 867 |
+
ref_list,
|
| 868 |
+
tgt_tokenizer,
|
| 869 |
+
scales=[0.0, 0.5, 1.0, 1.5, 2.0, 3.0],
|
| 870 |
+
n_samples=50,
|
| 871 |
+
):
|
| 872 |
+
results = sweep_guidance_scales(
|
| 873 |
+
model=model,
|
| 874 |
+
classifier=classifier,
|
| 875 |
+
src_list=src_list,
|
| 876 |
+
ref_list=ref_list,
|
| 877 |
+
tgt_tokenizer=tgt_tokenizer,
|
| 878 |
+
scales=scales,
|
| 879 |
+
n_samples=n_samples,
|
| 880 |
+
output_dir="analysis/outputs",
|
| 881 |
+
)
|
| 882 |
+
return {
|
| 883 |
+
float(k): {"CER": v["mean_cer"], "diversity": v["diversity"]}
|
| 884 |
+
for k, v in results.items()
|
| 885 |
+
}
|
analysis/run_analysis.py
ADDED
|
@@ -0,0 +1,1245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
analysis/run_analysis.py
|
| 3 |
+
=========================
|
| 4 |
+
Entry point for all 5 tasks.
|
| 5 |
+
|
| 6 |
+
Tasks:
|
| 7 |
+
Task 1 — KV Cache benchmark (no retraining)
|
| 8 |
+
Task 2 — Attention viz + drift (no retraining)
|
| 9 |
+
Task 3 — Concept vectors + PCA steer (no retraining)
|
| 10 |
+
Task 4 — Step ablation (REQUIRES retraining for each T)
|
| 11 |
+
Task 5 — Classifier-free guidance (trains small 10k-param classifier)
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
python analysis/run_analysis.py --task 1
|
| 15 |
+
python analysis/run_analysis.py --task 2 --input "dharmo rakṣati rakṣitaḥ"
|
| 16 |
+
python analysis/run_analysis.py --task 3
|
| 17 |
+
python analysis/run_analysis.py --task 4 --phase generate_configs
|
| 18 |
+
python analysis/run_analysis.py --task 4 --phase analyze
|
| 19 |
+
python analysis/run_analysis.py --task 5
|
| 20 |
+
python analysis/run_analysis.py --task all --input "satyameva jayate"
|
| 21 |
+
|
| 22 |
+
Output files: analysis/outputs/
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import copy
|
| 26 |
+
import torch
|
| 27 |
+
import os, sys, argparse, json
|
| 28 |
+
import numpy as np
|
| 29 |
+
import time
|
| 30 |
+
import gc
|
| 31 |
+
import tracemalloc
|
| 32 |
+
import threading
|
| 33 |
+
import resource
|
| 34 |
+
from difflib import SequenceMatcher
|
| 35 |
+
import matplotlib
|
| 36 |
+
matplotlib.use("Agg")
|
| 37 |
+
import matplotlib.pyplot as plt
|
| 38 |
+
try:
|
| 39 |
+
import psutil
|
| 40 |
+
except Exception:
|
| 41 |
+
psutil = None
|
| 42 |
+
|
| 43 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 44 |
+
from config import CONFIG
|
| 45 |
+
from inference import load_model, _decode_with_cleanup, _iast_to_deva
|
| 46 |
+
from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer
|
| 47 |
+
|
| 48 |
+
OUTPUT_DIR = "analysis/outputs"
|
| 49 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 50 |
+
|
| 51 |
+
# Keep caches writable/project-local for laptops and sandboxed runners.
|
| 52 |
+
_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 53 |
+
os.environ.setdefault("HF_HOME", os.path.join(_ROOT, ".hf_cache"))
|
| 54 |
+
os.environ.setdefault("HF_DATASETS_CACHE", os.path.join(_ROOT, ".hf_cache", "datasets"))
|
| 55 |
+
os.environ.setdefault("HF_HUB_CACHE", os.path.join(_ROOT, ".hf_cache", "hub"))
|
| 56 |
+
os.environ.setdefault("MPLCONFIGDIR", os.path.join(_ROOT, ".mplconfig"))
|
| 57 |
+
for _p in [
|
| 58 |
+
os.environ["HF_HOME"],
|
| 59 |
+
os.environ["HF_DATASETS_CACHE"],
|
| 60 |
+
os.environ["HF_HUB_CACHE"],
|
| 61 |
+
os.environ["MPLCONFIGDIR"],
|
| 62 |
+
]:
|
| 63 |
+
os.makedirs(_p, exist_ok=True)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _process_mem_mb() -> float:
|
| 67 |
+
if psutil is not None:
|
| 68 |
+
try:
|
| 69 |
+
return float(psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024))
|
| 70 |
+
except Exception:
|
| 71 |
+
pass
|
| 72 |
+
# Linux fallback: /proc/self/statm current RSS pages.
|
| 73 |
+
try:
|
| 74 |
+
with open("/proc/self/statm", "r", encoding="utf-8") as f:
|
| 75 |
+
parts = f.read().strip().split()
|
| 76 |
+
if len(parts) >= 2:
|
| 77 |
+
rss_pages = int(parts[1])
|
| 78 |
+
page_size = os.sysconf("SC_PAGE_SIZE")
|
| 79 |
+
return float(rss_pages * page_size / (1024 * 1024))
|
| 80 |
+
except Exception:
|
| 81 |
+
pass
|
| 82 |
+
# Unix fallback: max RSS from resource (platform-dependent units).
|
| 83 |
+
try:
|
| 84 |
+
ru = float(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)
|
| 85 |
+
# Heuristic: macOS tends to return bytes, Linux tends KB.
|
| 86 |
+
if ru > 10_000_000:
|
| 87 |
+
return ru / (1024 * 1024)
|
| 88 |
+
return ru / 1024.0
|
| 89 |
+
except Exception:
|
| 90 |
+
return 0.0
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# ── Shared loader ─────────────────────────────────────────────────────
|
| 94 |
+
|
| 95 |
+
def infer_model_type_from_checkpoint(ckpt_path: str) -> str:
|
| 96 |
+
name = ckpt_path.lower()
|
| 97 |
+
if "ablation_results/t" in name or "d3pm_cross_attention" in name:
|
| 98 |
+
return "d3pm_cross_attention"
|
| 99 |
+
if "d3pm_encoder_decoder" in name:
|
| 100 |
+
return "d3pm_encoder_decoder"
|
| 101 |
+
if "baseline_cross_attention" in name:
|
| 102 |
+
return "baseline_cross_attention"
|
| 103 |
+
if "baseline_encoder_decoder" in name:
|
| 104 |
+
return "baseline_encoder_decoder"
|
| 105 |
+
return CONFIG["model_type"]
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def infer_include_negative_from_checkpoint(ckpt_path: str) -> bool:
|
| 109 |
+
name = ckpt_path.lower()
|
| 110 |
+
if "_neg_true" in name:
|
| 111 |
+
return True
|
| 112 |
+
if "_neg_false" in name:
|
| 113 |
+
return False
|
| 114 |
+
if "ablation_results/t" in name:
|
| 115 |
+
return False
|
| 116 |
+
return CONFIG["data"]["include_negative_examples"]
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def load_everything(cfg, device, ckpt_override=None):
|
| 120 |
+
model_name = cfg['model_type']
|
| 121 |
+
has_neg = cfg['data']['include_negative_examples']
|
| 122 |
+
candidates = [
|
| 123 |
+
f"results7/{model_name}_neg_{has_neg}/best_model.pt",
|
| 124 |
+
f"results/{model_name}_neg_{has_neg}/best_model.pt",
|
| 125 |
+
f"results7/{model_name}_neg_True/best_model.pt",
|
| 126 |
+
f"results/{model_name}_neg_True/best_model.pt",
|
| 127 |
+
f"results7/{model_name}_neg_False/best_model.pt",
|
| 128 |
+
f"results/{model_name}_neg_False/best_model.pt",
|
| 129 |
+
"ablation_results/T4/best_model.pt",
|
| 130 |
+
"ablation_results/T8/best_model.pt",
|
| 131 |
+
]
|
| 132 |
+
ckpt = ckpt_override if ckpt_override else next((p for p in candidates if os.path.exists(p)), None)
|
| 133 |
+
if not os.path.exists(ckpt):
|
| 134 |
+
raise FileNotFoundError(f"No checkpoint found. Checked: {candidates}")
|
| 135 |
+
model, cfg = load_model(ckpt, cfg, device)
|
| 136 |
+
model.eval()
|
| 137 |
+
src_tok = SanskritSourceTokenizer(
|
| 138 |
+
vocab_size=cfg['model'].get('src_vocab_size', 500),
|
| 139 |
+
max_len=cfg['model']['max_seq_len'])
|
| 140 |
+
tgt_tok = SanskritTargetTokenizer(
|
| 141 |
+
vocab_size=cfg['model'].get('tgt_vocab_size', 500),
|
| 142 |
+
max_len=cfg['model']['max_seq_len'])
|
| 143 |
+
return model, src_tok, tgt_tok, cfg
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def load_val_data(cfg, src_tok, tgt_tok, n=500):
|
| 147 |
+
"""Load validation set as (src_tensors, ref_strings, input_strings)."""
|
| 148 |
+
from data.dataset import OptimizedSanskritDataset
|
| 149 |
+
from torch.utils.data import Subset
|
| 150 |
+
from sklearn.model_selection import train_test_split
|
| 151 |
+
|
| 152 |
+
dataset = OptimizedSanskritDataset(
|
| 153 |
+
'train', max_len=cfg['model']['max_seq_len'],
|
| 154 |
+
cfg=cfg, src_tokenizer=src_tok, tgt_tokenizer=tgt_tok)
|
| 155 |
+
total = min(cfg['data']['dataset_size'], len(dataset))
|
| 156 |
+
_, val_idx = train_test_split(list(range(total)), train_size=0.8, random_state=42)
|
| 157 |
+
val_idx = val_idx[:n]
|
| 158 |
+
|
| 159 |
+
src_list, ref_list, inp_list = [], [], []
|
| 160 |
+
for i in val_idx:
|
| 161 |
+
item = dataset[i]
|
| 162 |
+
src_list.append(item['input_ids'].unsqueeze(0))
|
| 163 |
+
ref_list.append(item['target_text'])
|
| 164 |
+
inp_list.append(item['input_text'])
|
| 165 |
+
return src_list, ref_list, inp_list
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _generate_ids_compat(model, src, num_steps=None, temperature=0.8, top_k=40,
|
| 169 |
+
repetition_penalty=1.2, diversity_penalty=0.0):
|
| 170 |
+
kwargs = dict(temperature=temperature, top_k=top_k)
|
| 171 |
+
if num_steps is not None:
|
| 172 |
+
kwargs["num_steps"] = int(num_steps)
|
| 173 |
+
if repetition_penalty is not None:
|
| 174 |
+
kwargs["repetition_penalty"] = float(repetition_penalty)
|
| 175 |
+
if diversity_penalty is not None:
|
| 176 |
+
kwargs["diversity_penalty"] = float(diversity_penalty)
|
| 177 |
+
try:
|
| 178 |
+
return model.generate(src, **kwargs)
|
| 179 |
+
except TypeError:
|
| 180 |
+
# Some model variants expose reduced generate() kwargs.
|
| 181 |
+
slim = {k: kwargs[k] for k in ["temperature", "top_k", "num_steps"] if k in kwargs}
|
| 182 |
+
try:
|
| 183 |
+
return model.generate(src, **slim)
|
| 184 |
+
except TypeError:
|
| 185 |
+
return model.generate(src)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def _decode_ids(tgt_tok, out_ids, src_text=None, inf_cfg=None):
|
| 189 |
+
ids = []
|
| 190 |
+
for x in out_ids[0].tolist():
|
| 191 |
+
# stop at PAD/SEP once decoding started
|
| 192 |
+
if x in (1, 4) and ids:
|
| 193 |
+
break
|
| 194 |
+
if x > 4:
|
| 195 |
+
ids.append(x)
|
| 196 |
+
if src_text is not None and inf_cfg is not None:
|
| 197 |
+
txt = _decode_with_cleanup(tgt_tok, ids, src_text, inf_cfg)
|
| 198 |
+
else:
|
| 199 |
+
txt = tgt_tok.decode(ids).strip()
|
| 200 |
+
return txt, ids
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def _cer(a: str, b: str) -> float:
|
| 204 |
+
m, n = len(a), len(b)
|
| 205 |
+
if m == 0 and n == 0:
|
| 206 |
+
return 0.0
|
| 207 |
+
dp = list(range(n + 1))
|
| 208 |
+
for i in range(1, m + 1):
|
| 209 |
+
prev, dp[0] = dp[0], i
|
| 210 |
+
for j in range(1, n + 1):
|
| 211 |
+
tmp = dp[j]
|
| 212 |
+
dp[j] = prev if a[i-1] == b[j-1] else 1 + min(prev, dp[j], dp[j-1])
|
| 213 |
+
prev = tmp
|
| 214 |
+
return float(dp[n]) / max(1, m, n)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# ── Task 1 ────────────────────────────────────────────────────────────
|
| 218 |
+
|
| 219 |
+
def run_task1(model, src_tok, device):
|
| 220 |
+
print("\n" + "="*65)
|
| 221 |
+
print(" TASK 1 — KV Cache Benchmark")
|
| 222 |
+
print("="*65)
|
| 223 |
+
src_vocab = model.model.src_embed.token_emb.weight.shape[0]
|
| 224 |
+
src_lens = [16, 32, 64]
|
| 225 |
+
n_runs = 3
|
| 226 |
+
has_cached = hasattr(model, "generate_cached")
|
| 227 |
+
if not has_cached:
|
| 228 |
+
print(" Compatibility mode: generate_cached() unavailable; running standard benchmark only.")
|
| 229 |
+
|
| 230 |
+
def _timeit(fn, runs=n_runs):
|
| 231 |
+
vals = []
|
| 232 |
+
for _ in range(runs):
|
| 233 |
+
t0 = time.perf_counter()
|
| 234 |
+
fn()
|
| 235 |
+
vals.append(time.perf_counter() - t0)
|
| 236 |
+
return float(np.mean(vals))
|
| 237 |
+
|
| 238 |
+
def _trace_peak_bytes(fn, repeat=8):
|
| 239 |
+
gc.collect()
|
| 240 |
+
tracemalloc.start()
|
| 241 |
+
for _ in range(max(1, int(repeat))):
|
| 242 |
+
fn()
|
| 243 |
+
_, peak = tracemalloc.get_traced_memory()
|
| 244 |
+
tracemalloc.stop()
|
| 245 |
+
return int(peak)
|
| 246 |
+
|
| 247 |
+
def _torch_cpu_mem_bytes(fn):
|
| 248 |
+
try:
|
| 249 |
+
from torch.profiler import profile, ProfilerActivity
|
| 250 |
+
with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=False) as prof:
|
| 251 |
+
fn()
|
| 252 |
+
mem = 0
|
| 253 |
+
for ev in prof.key_averages():
|
| 254 |
+
try:
|
| 255 |
+
mem += max(0, int(getattr(ev, "self_cpu_memory_usage", 0)))
|
| 256 |
+
except Exception:
|
| 257 |
+
pass
|
| 258 |
+
return int(mem)
|
| 259 |
+
except Exception:
|
| 260 |
+
return 0
|
| 261 |
+
|
| 262 |
+
results = {}
|
| 263 |
+
for L in src_lens:
|
| 264 |
+
src = torch.randint(5, src_vocab, (1, L), device=device)
|
| 265 |
+
t_std = _timeit(lambda: _generate_ids_compat(model, src, temperature=0.8, top_k=40))
|
| 266 |
+
|
| 267 |
+
if has_cached:
|
| 268 |
+
t_cache = _timeit(
|
| 269 |
+
lambda: model.generate_cached(
|
| 270 |
+
src, num_steps=64, temperature=0.8, top_k=40,
|
| 271 |
+
repetition_penalty=1.2, diversity_penalty=0.0
|
| 272 |
+
)
|
| 273 |
+
)
|
| 274 |
+
speedup = t_std / max(t_cache, 1e-9)
|
| 275 |
+
else:
|
| 276 |
+
t_cache = t_std
|
| 277 |
+
speedup = 1.0
|
| 278 |
+
|
| 279 |
+
# Encoder cost estimate: one encode_source pass vs one cached step.
|
| 280 |
+
if hasattr(model.model, "encode_source") and hasattr(model.model, "forward_cached"):
|
| 281 |
+
memory, src_pad = model.model.encode_source(src)
|
| 282 |
+
x = torch.full((1, L), model.model.mask_token_id, dtype=torch.long, device=device)
|
| 283 |
+
t = torch.full((1,), max(0, model.model.scheduler.num_timesteps - 1), dtype=torch.long, device=device)
|
| 284 |
+
t_enc = _timeit(lambda: model.model.encode_source(src))
|
| 285 |
+
t_step = _timeit(lambda: model.model.forward_cached(memory, src_pad, x, t, x0_hint=None, inference_mode=True))
|
| 286 |
+
encoder_pct = (t_enc / max(t_enc + t_step, 1e-9)) * 100.0
|
| 287 |
+
else:
|
| 288 |
+
encoder_pct = 0.0
|
| 289 |
+
|
| 290 |
+
results[L] = dict(
|
| 291 |
+
standard_s=t_std,
|
| 292 |
+
cached_s=t_cache,
|
| 293 |
+
speedup=speedup,
|
| 294 |
+
encoder_pct=encoder_pct,
|
| 295 |
+
)
|
| 296 |
+
print(f" src_len={L:>3d} standard={t_std:.3f}s cached={t_cache:.3f}s speedup={speedup:.2f}x encoder%={encoder_pct:.1f}")
|
| 297 |
+
|
| 298 |
+
# Memory profiling (GPU preferred, CPU/MPS fallback via process RSS delta).
|
| 299 |
+
mem_note = "N/A"
|
| 300 |
+
mem_red = None
|
| 301 |
+
if torch.cuda.is_available() and str(device).startswith("cuda"):
|
| 302 |
+
L = 64
|
| 303 |
+
src = torch.randint(5, src_vocab, (1, L), device=device)
|
| 304 |
+
torch.cuda.reset_peak_memory_stats(device)
|
| 305 |
+
_ = _generate_ids_compat(model, src, temperature=0.8, top_k=40)
|
| 306 |
+
m_std = torch.cuda.max_memory_allocated(device)
|
| 307 |
+
torch.cuda.reset_peak_memory_stats(device)
|
| 308 |
+
_ = model.generate_cached(src, num_steps=64, temperature=0.8, top_k=40,
|
| 309 |
+
repetition_penalty=1.2, diversity_penalty=0.0)
|
| 310 |
+
m_cache = torch.cuda.max_memory_allocated(device)
|
| 311 |
+
mem_red = 100.0 * (m_std - m_cache) / max(m_std, 1)
|
| 312 |
+
mem_note = f"GPU peak alloc reduction: {mem_red:.1f}% @ src_len=64"
|
| 313 |
+
print(f" Memory reduction: {mem_note}")
|
| 314 |
+
elif has_cached and _process_mem_mb() > 0.0:
|
| 315 |
+
L = 64
|
| 316 |
+
src = torch.randint(5, src_vocab, (1, L), device=device)
|
| 317 |
+
|
| 318 |
+
def _peak_rss_while(fn, poll_s=0.01):
|
| 319 |
+
done = {"v": False}
|
| 320 |
+
peak = {"v": _process_mem_mb()}
|
| 321 |
+
|
| 322 |
+
def _poll():
|
| 323 |
+
while not done["v"]:
|
| 324 |
+
peak["v"] = max(peak["v"], _process_mem_mb())
|
| 325 |
+
time.sleep(poll_s)
|
| 326 |
+
th = threading.Thread(target=_poll, daemon=True)
|
| 327 |
+
gc.collect()
|
| 328 |
+
base = _process_mem_mb()
|
| 329 |
+
th.start()
|
| 330 |
+
try:
|
| 331 |
+
fn()
|
| 332 |
+
finally:
|
| 333 |
+
done["v"] = True
|
| 334 |
+
th.join(timeout=0.1)
|
| 335 |
+
gc.collect()
|
| 336 |
+
return base, peak["v"], max(0.0, peak["v"] - base)
|
| 337 |
+
|
| 338 |
+
b_std, p_std, d_std = _peak_rss_while(
|
| 339 |
+
lambda: _generate_ids_compat(model, src, temperature=0.8, top_k=40)
|
| 340 |
+
)
|
| 341 |
+
b_c, p_c, d_c = _peak_rss_while(
|
| 342 |
+
lambda: model.generate_cached(
|
| 343 |
+
src, num_steps=64, temperature=0.8, top_k=40,
|
| 344 |
+
repetition_penalty=1.2, diversity_penalty=0.0
|
| 345 |
+
)
|
| 346 |
+
)
|
| 347 |
+
if d_std > 0.0:
|
| 348 |
+
mem_red = 100.0 * (d_std - d_c) / d_std
|
| 349 |
+
mem_note = (
|
| 350 |
+
f"RSS peak reduction: {mem_red:.1f}% @ src_len=64 "
|
| 351 |
+
f"(std_peak={p_std:.1f}MB, cache_peak={p_c:.1f}MB)"
|
| 352 |
+
)
|
| 353 |
+
else:
|
| 354 |
+
# Secondary fallback: Python allocator peak (always available).
|
| 355 |
+
peak_std = _trace_peak_bytes(
|
| 356 |
+
lambda: _generate_ids_compat(model, src, temperature=0.8, top_k=40), repeat=10
|
| 357 |
+
)
|
| 358 |
+
peak_cache = _trace_peak_bytes(
|
| 359 |
+
lambda: model.generate_cached(src, num_steps=64, temperature=0.8, top_k=40,
|
| 360 |
+
repetition_penalty=1.2, diversity_penalty=0.0),
|
| 361 |
+
repeat=10
|
| 362 |
+
)
|
| 363 |
+
if peak_std >= 256 * 1024:
|
| 364 |
+
mem_red = 100.0 * (peak_std - peak_cache) / peak_std
|
| 365 |
+
mem_note = (
|
| 366 |
+
f"Py alloc peak reduction: {mem_red:.1f}% @ src_len=64 "
|
| 367 |
+
f"(std={peak_std/1024**2:.1f}MB, cache={peak_cache/1024**2:.1f}MB)"
|
| 368 |
+
)
|
| 369 |
+
else:
|
| 370 |
+
cpu_std = _torch_cpu_mem_bytes(
|
| 371 |
+
lambda: _generate_ids_compat(model, src, temperature=0.8, top_k=40)
|
| 372 |
+
)
|
| 373 |
+
cpu_cache = _torch_cpu_mem_bytes(
|
| 374 |
+
lambda: model.generate_cached(src, num_steps=64, temperature=0.8, top_k=40,
|
| 375 |
+
repetition_penalty=1.2, diversity_penalty=0.0)
|
| 376 |
+
)
|
| 377 |
+
if cpu_std > 0:
|
| 378 |
+
mem_red = 100.0 * (cpu_std - cpu_cache) / max(cpu_std, 1)
|
| 379 |
+
mem_note = (
|
| 380 |
+
f"Torch CPU mem-event reduction: {mem_red:.1f}% @ src_len=64 "
|
| 381 |
+
f"(std={cpu_std/1024**2:.1f}MB, cache={cpu_cache/1024**2:.1f}MB)"
|
| 382 |
+
)
|
| 383 |
+
else:
|
| 384 |
+
mem_note = "Memory estimate unavailable (RSS/tracemalloc/torch-profiler flat)"
|
| 385 |
+
print(f" Memory reduction: {mem_note}")
|
| 386 |
+
elif has_cached:
|
| 387 |
+
# Final fallback (CPU-safe): Python allocation peak via tracemalloc.
|
| 388 |
+
# This does not include all native tensor allocator memory, but still
|
| 389 |
+
# gives a consistent relative signal when psutil/CUDA stats are absent.
|
| 390 |
+
L = 64
|
| 391 |
+
src = torch.randint(5, src_vocab, (1, L), device=device)
|
| 392 |
+
peak_std = _trace_peak_bytes(
|
| 393 |
+
lambda: _generate_ids_compat(model, src, temperature=0.8, top_k=40), repeat=10
|
| 394 |
+
)
|
| 395 |
+
peak_cache = _trace_peak_bytes(
|
| 396 |
+
lambda: model.generate_cached(src, num_steps=64, temperature=0.8, top_k=40,
|
| 397 |
+
repetition_penalty=1.2, diversity_penalty=0.0),
|
| 398 |
+
repeat=10
|
| 399 |
+
)
|
| 400 |
+
# Ignore extremely small peaks; they are noise for tensor-heavy paths.
|
| 401 |
+
if peak_std >= 256 * 1024:
|
| 402 |
+
mem_red = 100.0 * (peak_std - peak_cache) / peak_std
|
| 403 |
+
mem_note = (
|
| 404 |
+
f"Py alloc peak reduction: {mem_red:.1f}% @ src_len=64 "
|
| 405 |
+
f"(std={peak_std/1024**2:.1f}MB, cache={peak_cache/1024**2:.1f}MB)"
|
| 406 |
+
)
|
| 407 |
+
else:
|
| 408 |
+
cpu_std = _torch_cpu_mem_bytes(
|
| 409 |
+
lambda: _generate_ids_compat(model, src, temperature=0.8, top_k=40)
|
| 410 |
+
)
|
| 411 |
+
cpu_cache = _torch_cpu_mem_bytes(
|
| 412 |
+
lambda: model.generate_cached(src, num_steps=64, temperature=0.8, top_k=40,
|
| 413 |
+
repetition_penalty=1.2, diversity_penalty=0.0)
|
| 414 |
+
)
|
| 415 |
+
if cpu_std > 0:
|
| 416 |
+
mem_red = 100.0 * (cpu_std - cpu_cache) / max(cpu_std, 1)
|
| 417 |
+
mem_note = (
|
| 418 |
+
f"Torch CPU mem-event reduction: {mem_red:.1f}% @ src_len=64 "
|
| 419 |
+
f"(std={cpu_std/1024**2:.1f}MB, cache={cpu_cache/1024**2:.1f}MB)"
|
| 420 |
+
)
|
| 421 |
+
else:
|
| 422 |
+
mem_note = "Py alloc peak too small/noisy to estimate (no psutil/CUDA profiler)"
|
| 423 |
+
print(f" Memory reduction: {mem_note}")
|
| 424 |
+
else:
|
| 425 |
+
mem_note = "Profiler unavailable (cached path missing)"
|
| 426 |
+
|
| 427 |
+
# Subtask graphs
|
| 428 |
+
lens = sorted(results.keys())
|
| 429 |
+
std_vals = [results[L]["standard_s"] for L in lens]
|
| 430 |
+
cache_vals = [results[L]["cached_s"] for L in lens]
|
| 431 |
+
speed_vals = [results[L]["speedup"] for L in lens]
|
| 432 |
+
enc_vals = [results[L]["encoder_pct"] for L in lens]
|
| 433 |
+
|
| 434 |
+
plt.figure(figsize=(7, 4))
|
| 435 |
+
plt.plot(lens, std_vals, marker="o", label="standard")
|
| 436 |
+
plt.plot(lens, cache_vals, marker="o", label="cached")
|
| 437 |
+
plt.xlabel("Source length")
|
| 438 |
+
plt.ylabel("Time (s)")
|
| 439 |
+
plt.title("Task1: Generation Time (Standard vs Cached)")
|
| 440 |
+
plt.legend()
|
| 441 |
+
plt.tight_layout()
|
| 442 |
+
plt.savefig(os.path.join(OUTPUT_DIR, "task1_time_comparison.png"), dpi=150, bbox_inches="tight")
|
| 443 |
+
plt.close()
|
| 444 |
+
|
| 445 |
+
plt.figure(figsize=(7, 4))
|
| 446 |
+
plt.plot(lens, speed_vals, marker="o")
|
| 447 |
+
plt.xlabel("Source length")
|
| 448 |
+
plt.ylabel("Speedup (x)")
|
| 449 |
+
plt.title("Task1: KV-Cache Speedup")
|
| 450 |
+
plt.tight_layout()
|
| 451 |
+
plt.savefig(os.path.join(OUTPUT_DIR, "task1_speedup.png"), dpi=150, bbox_inches="tight")
|
| 452 |
+
plt.close()
|
| 453 |
+
|
| 454 |
+
plt.figure(figsize=(7, 4))
|
| 455 |
+
plt.plot(lens, enc_vals, marker="o")
|
| 456 |
+
plt.xlabel("Source length")
|
| 457 |
+
plt.ylabel("Encoder cost (%)")
|
| 458 |
+
plt.title("Task1: Encoder Cost Share")
|
| 459 |
+
plt.tight_layout()
|
| 460 |
+
plt.savefig(os.path.join(OUTPUT_DIR, "task1_encoder_cost.png"), dpi=150, bbox_inches="tight")
|
| 461 |
+
plt.close()
|
| 462 |
+
|
| 463 |
+
path = os.path.join(OUTPUT_DIR, "task1_kv_cache.txt")
|
| 464 |
+
with open(path, "w") as f:
|
| 465 |
+
f.write("TASK 1 — KV CACHE BENCHMARK\n" + "="*40 + "\n\n")
|
| 466 |
+
f.write(f"has_generate_cached={has_cached}\n")
|
| 467 |
+
f.write(f"memory_profile={mem_note}\n\n")
|
| 468 |
+
f.write(f"{'src_len':>8} {'standard(s)':>12} {'cached(s)':>10} "
|
| 469 |
+
f"{'speedup':>8} {'encoder%':>9}\n")
|
| 470 |
+
for src_len, r in results.items():
|
| 471 |
+
f.write(f"{src_len:>8} {r['standard_s']:>12.3f} {r['cached_s']:>10.3f} "
|
| 472 |
+
f"{r['speedup']:>7.2f}x {r['encoder_pct']:>8.1f}%\n")
|
| 473 |
+
f.write("\nSaved graphs:\n")
|
| 474 |
+
f.write(" - task1_time_comparison.png\n")
|
| 475 |
+
f.write(" - task1_speedup.png\n")
|
| 476 |
+
f.write(" - task1_encoder_cost.png\n")
|
| 477 |
+
print(f" Saved: {path}")
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
# ── Task 2 ────────────────────────────────────────────────────────────
|
| 481 |
+
|
| 482 |
+
def run_task2(model, src_tok, tgt_tok, device, input_text, cfg, corpus_inputs=None):
|
| 483 |
+
print("\n" + "="*65)
|
| 484 |
+
print(" TASK 2 — Attention Visualization + Semantic Drift")
|
| 485 |
+
print("="*65)
|
| 486 |
+
print(f" Input: {input_text}")
|
| 487 |
+
if not hasattr(model.model, 'encode_source'):
|
| 488 |
+
print(" Compatibility mode: attention hooks unavailable; running semantic-drift-only analysis.")
|
| 489 |
+
src_ids = src_tok.encode(input_text)
|
| 490 |
+
src = torch.tensor([src_ids], dtype=torch.long, device=device)
|
| 491 |
+
# Keep steps <= scheduler horizon for this checkpoint to avoid backend aborts.
|
| 492 |
+
t_sched = int(getattr(getattr(model.model, "scheduler", object()), "num_timesteps", 64))
|
| 493 |
+
# Stability guard for some checkpoints/backends: keep sweep moderate.
|
| 494 |
+
t_max = min(t_sched, 64)
|
| 495 |
+
candidates = [t_max, 48, 32, 24, 16, 8, 4, 1]
|
| 496 |
+
step_list = []
|
| 497 |
+
seen = set()
|
| 498 |
+
for s in candidates:
|
| 499 |
+
s = max(1, min(int(s), t_max))
|
| 500 |
+
if s not in seen:
|
| 501 |
+
step_list.append(s)
|
| 502 |
+
seen.add(s)
|
| 503 |
+
outs = {}
|
| 504 |
+
for s in step_list:
|
| 505 |
+
out = _generate_ids_compat(model, src, num_steps=s, temperature=0.8, top_k=40)
|
| 506 |
+
txt, _ = _decode_ids(
|
| 507 |
+
tgt_tok, out,
|
| 508 |
+
src_text=input_text,
|
| 509 |
+
inf_cfg=cfg.get("inference", {"temperature": 0.8, "top_k": 40})
|
| 510 |
+
)
|
| 511 |
+
outs[s] = txt
|
| 512 |
+
final = outs[1]
|
| 513 |
+
drift = [(_cer(outs[s], final), s) for s in step_list]
|
| 514 |
+
# Plot drift
|
| 515 |
+
xs = [s for _, s in drift]
|
| 516 |
+
ys = [c for c, _ in drift]
|
| 517 |
+
plt.figure(figsize=(8, 4))
|
| 518 |
+
plt.plot(xs, ys, marker='o')
|
| 519 |
+
plt.gca().invert_xaxis()
|
| 520 |
+
plt.xlabel("Generation steps")
|
| 521 |
+
plt.ylabel("CER to 1-step output")
|
| 522 |
+
plt.title("Task2 Semantic Drift (Compatibility Mode)")
|
| 523 |
+
plt.tight_layout()
|
| 524 |
+
plot_path = os.path.join(OUTPUT_DIR, "task2_semantic_drift.png")
|
| 525 |
+
plt.savefig(plot_path, dpi=150, bbox_inches="tight")
|
| 526 |
+
plt.close()
|
| 527 |
+
report = os.path.join(OUTPUT_DIR, "task2_report.txt")
|
| 528 |
+
with open(report, "w", encoding="utf-8") as f:
|
| 529 |
+
f.write("TASK 2 — COMPATIBILITY REPORT\n")
|
| 530 |
+
f.write("="*40 + "\n")
|
| 531 |
+
f.write("Cross-attention capture unavailable for this checkpoint.\n")
|
| 532 |
+
f.write(f"Input: {input_text}\n")
|
| 533 |
+
f.write(f"Reference final (1 step): {final}\n\n")
|
| 534 |
+
for cer_v, s in drift:
|
| 535 |
+
f.write(f"steps={s:>3d} CER_to_final={cer_v:.4f} output={outs[s][:120]}\n")
|
| 536 |
+
print(f" Output(final@1): {final}")
|
| 537 |
+
print(f" Report: {report}")
|
| 538 |
+
print(f" Saved: {plot_path}")
|
| 539 |
+
return
|
| 540 |
+
|
| 541 |
+
src_ids = src_tok.encode(input_text)
|
| 542 |
+
src_tensor = torch.tensor([src_ids], dtype=torch.long, device=device)
|
| 543 |
+
|
| 544 |
+
from analysis.attention_viz import (
|
| 545 |
+
AttentionCapture,
|
| 546 |
+
compute_trajectory_metrics,
|
| 547 |
+
analyze_token_stability,
|
| 548 |
+
tfidf_attention_correlation,
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
# Attention capture
|
| 552 |
+
print(" Capturing attention weights...")
|
| 553 |
+
capturer = AttentionCapture(model)
|
| 554 |
+
step_weights, step_outputs_ids = capturer.run(src_tensor)
|
| 555 |
+
|
| 556 |
+
def _decode_tensor_ids(t):
|
| 557 |
+
out = []
|
| 558 |
+
for x in t[0].tolist():
|
| 559 |
+
if x in (1, 4) and out:
|
| 560 |
+
break
|
| 561 |
+
if x > 4:
|
| 562 |
+
out.append(x)
|
| 563 |
+
raw_txt = tgt_tok.decode(out).strip()
|
| 564 |
+
clean_txt = _decode_with_cleanup(
|
| 565 |
+
tgt_tok, out, input_text, cfg.get("inference", {"temperature": 0.8, "top_k": 40})
|
| 566 |
+
)
|
| 567 |
+
return raw_txt, clean_txt, out
|
| 568 |
+
|
| 569 |
+
decoded = {}
|
| 570 |
+
decoded_raw = {}
|
| 571 |
+
for t_val, ids_t in step_outputs_ids.items():
|
| 572 |
+
raw_txt, clean_txt, ids = _decode_tensor_ids(ids_t)
|
| 573 |
+
decoded_raw[t_val] = (raw_txt, ids)
|
| 574 |
+
decoded[t_val] = (clean_txt, ids)
|
| 575 |
+
final_step = min(decoded.keys())
|
| 576 |
+
final_out, final_ids = decoded[final_step]
|
| 577 |
+
final_out_raw = decoded_raw[final_step][0]
|
| 578 |
+
src_labels = []
|
| 579 |
+
for sid in src_ids[:20]:
|
| 580 |
+
tok = src_tok.decode([sid]).strip()
|
| 581 |
+
src_labels.append(tok if tok else f"id{sid}")
|
| 582 |
+
tgt_labels = [f"y{i}" for i in range(min(20, len(final_ids)))]
|
| 583 |
+
print(f" Output: {final_out}")
|
| 584 |
+
|
| 585 |
+
# Heatmap t=max, layer 0
|
| 586 |
+
first_t = max(step_weights.keys())
|
| 587 |
+
w_first = step_weights[first_t][0][0]
|
| 588 |
+
w0 = step_weights[0][0][0]
|
| 589 |
+
n_src = min(len(src_labels), w_first.shape[1], 20)
|
| 590 |
+
n_tgt = min(len(tgt_labels), w_first.shape[0], 20)
|
| 591 |
+
plt.figure(figsize=(max(8, n_src * 0.35), max(6, n_tgt * 0.3)))
|
| 592 |
+
plt.imshow(w_first[:n_tgt, :n_src], aspect="auto", cmap="YlOrRd")
|
| 593 |
+
plt.xticks(range(n_src), src_labels[:n_src], rotation=45, ha="right", fontsize=8)
|
| 594 |
+
plt.yticks(range(n_tgt), tgt_labels[:n_tgt], fontsize=8)
|
| 595 |
+
plt.title(f"Attention t={first_t} Layer 0")
|
| 596 |
+
plt.tight_layout()
|
| 597 |
+
plt.savefig(os.path.join(OUTPUT_DIR, f"task2_attn_t{first_t}.png"), dpi=150, bbox_inches="tight")
|
| 598 |
+
plt.close()
|
| 599 |
+
|
| 600 |
+
plt.figure(figsize=(max(8, n_src * 0.35), max(6, n_tgt * 0.3)))
|
| 601 |
+
plt.imshow(w0[:n_tgt, :n_src], aspect="auto", cmap="YlOrRd")
|
| 602 |
+
plt.xticks(range(n_src), src_labels[:n_src], rotation=45, ha="right", fontsize=8)
|
| 603 |
+
plt.yticks(range(n_tgt), tgt_labels[:n_tgt], fontsize=8)
|
| 604 |
+
plt.title("Attention t=0 Layer 0")
|
| 605 |
+
plt.tight_layout()
|
| 606 |
+
plt.savefig(os.path.join(OUTPUT_DIR, "task2_attn_t0.png"), dpi=150, bbox_inches="tight")
|
| 607 |
+
plt.close()
|
| 608 |
+
|
| 609 |
+
# All layers at t=0
|
| 610 |
+
layers = step_weights[0]
|
| 611 |
+
n_layers = len(layers)
|
| 612 |
+
n_cols = min(4, n_layers)
|
| 613 |
+
n_rows = (n_layers + n_cols - 1) // n_cols
|
| 614 |
+
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 3.2))
|
| 615 |
+
axes = np.array(axes).reshape(-1)
|
| 616 |
+
for i, layer_w in enumerate(layers):
|
| 617 |
+
ax = axes[i]
|
| 618 |
+
w = layer_w[0][:n_tgt, :n_src]
|
| 619 |
+
ax.imshow(w, aspect="auto", cmap="YlOrRd")
|
| 620 |
+
ax.set_title(f"Layer {i}", fontsize=9)
|
| 621 |
+
ax.set_xticks([])
|
| 622 |
+
ax.set_yticks([])
|
| 623 |
+
for i in range(n_layers, len(axes)):
|
| 624 |
+
axes[i].axis("off")
|
| 625 |
+
plt.tight_layout()
|
| 626 |
+
plt.savefig(os.path.join(OUTPUT_DIR, "task2_all_layers_t0.png"), dpi=150, bbox_inches="tight")
|
| 627 |
+
plt.close()
|
| 628 |
+
|
| 629 |
+
# Attention evolution for src[0] -> tgt[0]
|
| 630 |
+
t_vals_desc = sorted(step_weights.keys(), reverse=True)
|
| 631 |
+
evo = []
|
| 632 |
+
for t_val in t_vals_desc:
|
| 633 |
+
w = step_weights[t_val][0][0]
|
| 634 |
+
evo.append(float(w[0, 0]) if w.shape[0] > 0 and w.shape[1] > 0 else 0.0)
|
| 635 |
+
plt.figure(figsize=(10, 3.5))
|
| 636 |
+
plt.plot(range(len(t_vals_desc)), evo, marker="o")
|
| 637 |
+
plt.xlabel("Captured step index (T→0)")
|
| 638 |
+
plt.ylabel("Attention weight")
|
| 639 |
+
plt.title("Attention Evolution (src0→tgt0)")
|
| 640 |
+
plt.tight_layout()
|
| 641 |
+
plt.savefig(os.path.join(OUTPUT_DIR, "task2_attn_evolution.png"), dpi=150, bbox_inches="tight")
|
| 642 |
+
plt.close()
|
| 643 |
+
|
| 644 |
+
# Drift (CER to final across steps) on RAW decoded trajectory to expose true diffusion.
|
| 645 |
+
t_vals = sorted(decoded.keys(), reverse=True)
|
| 646 |
+
cer_vals = [_cer(decoded_raw[t][0], final_out_raw) for t in t_vals]
|
| 647 |
+
plt.figure(figsize=(8, 4))
|
| 648 |
+
plt.plot(t_vals, cer_vals, marker="o")
|
| 649 |
+
plt.gca().invert_xaxis()
|
| 650 |
+
plt.xlabel("Diffusion step")
|
| 651 |
+
plt.ylabel("CER to final")
|
| 652 |
+
plt.title("Task2 Semantic Drift")
|
| 653 |
+
plt.tight_layout()
|
| 654 |
+
plt.savefig(os.path.join(OUTPUT_DIR, "task2_semantic_drift.png"), dpi=150, bbox_inches="tight")
|
| 655 |
+
plt.close()
|
| 656 |
+
|
| 657 |
+
# Source alignment proxy (avg attention on source positions at t=0, last layer)
|
| 658 |
+
last_layer_t0 = step_weights[0][-1][0]
|
| 659 |
+
src_align = last_layer_t0.mean(axis=0)[:n_src]
|
| 660 |
+
plt.figure(figsize=(8, 3))
|
| 661 |
+
plt.bar(np.arange(len(src_align)), src_align)
|
| 662 |
+
plt.xticks(range(n_src), src_labels[:n_src], rotation=45, ha="right", fontsize=8)
|
| 663 |
+
plt.title("Source Alignment Importance (t=0, last layer)")
|
| 664 |
+
plt.tight_layout()
|
| 665 |
+
plt.savefig(os.path.join(OUTPUT_DIR, "task2_source_alignment.png"), dpi=150, bbox_inches="tight")
|
| 666 |
+
plt.close()
|
| 667 |
+
|
| 668 |
+
stability = analyze_token_stability(step_weights)
|
| 669 |
+
n_locked = sum(1 for v in stability.values() if v == "LOCKED")
|
| 670 |
+
n_flex = sum(1 for v in stability.values() if v == "FLEXIBLE")
|
| 671 |
+
tfidf_info = tfidf_attention_correlation(input_text, step_weights, corpus_texts=corpus_inputs)
|
| 672 |
+
tfidf_corr = tfidf_info.get("corr")
|
| 673 |
+
tfidf_status = tfidf_info.get("status", "UNKNOWN")
|
| 674 |
+
traj = compute_trajectory_metrics(
|
| 675 |
+
step_outputs_ids,
|
| 676 |
+
tgt_tok,
|
| 677 |
+
reference_text=_iast_to_deva(input_text),
|
| 678 |
+
)
|
| 679 |
+
# Keep trajectory semantic scoring on raw decoded text to avoid masking drift.
|
| 680 |
+
ref_text = _iast_to_deva(input_text)
|
| 681 |
+
for row in traj:
|
| 682 |
+
t_cur = row["step"]
|
| 683 |
+
raw_txt = decoded_raw.get(t_cur, ("", []))[0]
|
| 684 |
+
if raw_txt:
|
| 685 |
+
sim = max(0.0, 1.0 - _cer(raw_txt, ref_text))
|
| 686 |
+
row["text"] = raw_txt
|
| 687 |
+
row["bert"] = sim
|
| 688 |
+
row["drift"] = 1.0 - sim
|
| 689 |
+
|
| 690 |
+
# TF-IDF vs attention graph (subtask visualization)
|
| 691 |
+
tfidf_vec = np.asarray(tfidf_info.get("tfidf_scores", []), dtype=np.float32)
|
| 692 |
+
attn_vec = np.asarray(tfidf_info.get("attn_scores", []), dtype=np.float32)
|
| 693 |
+
labels = list(tfidf_info.get("tokens", []))
|
| 694 |
+
m = min(len(tfidf_vec), len(attn_vec), len(labels), 20)
|
| 695 |
+
if m > 0:
|
| 696 |
+
x = np.arange(m)
|
| 697 |
+
plt.figure(figsize=(8, 3.5))
|
| 698 |
+
tf_part = tfidf_vec[:m]
|
| 699 |
+
at_part = attn_vec[:m]
|
| 700 |
+
tf_norm = tf_part / (np.max(np.abs(tf_part)) + 1e-9)
|
| 701 |
+
at_norm = at_part / (np.max(np.abs(at_part)) + 1e-9)
|
| 702 |
+
w = 0.4
|
| 703 |
+
plt.bar(x - w/2, tf_norm, width=w, label="tfidf(norm)")
|
| 704 |
+
plt.bar(x + w/2, at_norm, width=w, label="attn(norm)")
|
| 705 |
+
plt.xlabel("Source token")
|
| 706 |
+
plt.ylabel("Normalized score")
|
| 707 |
+
plt.title("Task2: TF-IDF vs Attention Stability")
|
| 708 |
+
plt.xticks(x, labels[:m], rotation=45, ha="right", fontsize=8)
|
| 709 |
+
plt.legend()
|
| 710 |
+
plt.tight_layout()
|
| 711 |
+
plt.savefig(os.path.join(OUTPUT_DIR, "task2_tfidf_vs_attention.png"), dpi=150, bbox_inches="tight")
|
| 712 |
+
plt.close()
|
| 713 |
+
|
| 714 |
+
lock_in_t = next((t for t, c in zip(t_vals[::-1], cer_vals[::-1]) if c <= 0.05), t_vals[-1])
|
| 715 |
+
if tfidf_corr is not None and abs(float(tfidf_corr)) < 0.10:
|
| 716 |
+
tfidf_status = "WEAK"
|
| 717 |
+
has_semantic = any(float(r.get("bert", 0.0)) > 0.05 for r in traj)
|
| 718 |
+
# Degeneracy score on final output
|
| 719 |
+
toks = [t for t in final_out.split() if t]
|
| 720 |
+
uniq_ratio = len(set(toks)) / max(1, len(toks))
|
| 721 |
+
degenerate = (len(toks) >= 8 and uniq_ratio < 0.35)
|
| 722 |
+
|
| 723 |
+
# Small multi-sample stability check (prevents overclaim from one example)
|
| 724 |
+
multi_scores = []
|
| 725 |
+
if corpus_inputs:
|
| 726 |
+
sample_texts = [s for s in corpus_inputs[:8] if isinstance(s, str) and s.strip()]
|
| 727 |
+
for txt in sample_texts:
|
| 728 |
+
src_i = torch.tensor([src_tok.encode(txt)], dtype=torch.long, device=device)
|
| 729 |
+
out_i = _generate_ids_compat(model, src_i, num_steps=min(16, cfg.get("inference", {}).get("num_steps", 16)),
|
| 730 |
+
temperature=0.8, top_k=40)
|
| 731 |
+
pred_i, _ = _decode_ids(tgt_tok, out_i)
|
| 732 |
+
multi_scores.append(max(0.0, 1.0 - _cer(pred_i, _iast_to_deva(txt))))
|
| 733 |
+
multi_sem = float(np.mean(multi_scores)) if multi_scores else 0.0
|
| 734 |
+
|
| 735 |
+
quality_status = (
|
| 736 |
+
"VALID"
|
| 737 |
+
if len(final_out.strip()) > 0 and n_flex + n_locked > 0 and has_semantic and not degenerate and multi_sem >= 0.05
|
| 738 |
+
else "WEAK"
|
| 739 |
+
)
|
| 740 |
+
report = os.path.join(OUTPUT_DIR, "task2_report.txt")
|
| 741 |
+
with open(report, "w", encoding="utf-8") as f:
|
| 742 |
+
f.write("TASK 2 — ATTENTION + DRIFT REPORT\n" + "=" * 50 + "\n\n")
|
| 743 |
+
f.write(f"Input : {input_text}\n")
|
| 744 |
+
f.write(f"Output: {final_out}\n\n")
|
| 745 |
+
f.write(f"Captured steps: {len(t_vals)}\n")
|
| 746 |
+
f.write(f"Analysis quality: {quality_status}\n")
|
| 747 |
+
f.write(f"Final output uniq-ratio: {uniq_ratio:.3f}\n")
|
| 748 |
+
f.write(f"Degenerate output: {'YES' if degenerate else 'NO'}\n")
|
| 749 |
+
f.write(f"Multi-sample semantic score (n<={len(multi_scores)}): {multi_sem:.4f}\n")
|
| 750 |
+
f.write(f"Lock-in step (CER<=0.05): t={lock_in_t}\n")
|
| 751 |
+
f.write(f"Locked tokens: {n_locked} Flexible tokens: {n_flex}\n")
|
| 752 |
+
corr_txt = f"{tfidf_corr:.4f}" if tfidf_corr is not None else "N/A"
|
| 753 |
+
f.write(f"TF-IDF vs attention stability corr: {corr_txt}\n")
|
| 754 |
+
f.write(f"TF-IDF status: {tfidf_status}\n\n")
|
| 755 |
+
f.write("Saved graphs:\n")
|
| 756 |
+
f.write(" - task2_attn_t*.png / task2_all_layers_t0.png\n")
|
| 757 |
+
f.write(" - task2_attn_evolution.png\n")
|
| 758 |
+
f.write(" - task2_semantic_drift.png\n")
|
| 759 |
+
f.write(" - task2_source_alignment.png\n")
|
| 760 |
+
f.write(" - task2_tfidf_vs_attention.png\n\n")
|
| 761 |
+
f.write("Step trajectory (first 10 rows)\n")
|
| 762 |
+
f.write("-" * 60 + "\n")
|
| 763 |
+
for row in traj[:10]:
|
| 764 |
+
f.write(f"t={row['step']:>3d} bert={row['bert']:.4f} drift={row['drift']:.4f} text={row['text'][:60]}\n")
|
| 765 |
+
|
| 766 |
+
print(f" Lock-in timestep: t={lock_in_t}")
|
| 767 |
+
print(f" Locked/Flexible: {n_locked}/{n_flex}")
|
| 768 |
+
corr_txt = f"{tfidf_corr:.4f}" if tfidf_corr is not None else "N/A"
|
| 769 |
+
print(f" TF-IDF corr: {corr_txt} ({tfidf_status})")
|
| 770 |
+
print(f" Report: {report}")
|
| 771 |
+
|
| 772 |
+
|
| 773 |
+
# ── Task 3 ────────────────────────────────────────────────────────────
|
| 774 |
+
|
| 775 |
+
def run_task3(model, src_tok, tgt_tok, device, src_list, ref_list, n_samples=500):
|
| 776 |
+
print("\n" + "="*65)
|
| 777 |
+
print(" TASK 3 — Concept Vectors + PCA Steering")
|
| 778 |
+
print("="*65)
|
| 779 |
+
if not hasattr(model.model, 'encode_source'):
|
| 780 |
+
print(" Compatibility mode: using output-token statistics for PCA concept proxy.")
|
| 781 |
+
# Keep compatibility run lightweight/stable on constrained backends.
|
| 782 |
+
n = min(60, len(src_list))
|
| 783 |
+
feats, lens = [], []
|
| 784 |
+
for i, src in enumerate(src_list[:n]):
|
| 785 |
+
out = _generate_ids_compat(model, src.to(device), num_steps=8, temperature=0.8, top_k=40)
|
| 786 |
+
txt, ids = _decode_ids(tgt_tok, out)
|
| 787 |
+
arr = np.array(ids[:64] + [0] * max(0, 64 - len(ids[:64])), dtype=np.float32)
|
| 788 |
+
feats.append(arr)
|
| 789 |
+
lens.append(len(txt))
|
| 790 |
+
from sklearn.decomposition import PCA
|
| 791 |
+
X = np.stack(feats)
|
| 792 |
+
pca = PCA(n_components=min(10, X.shape[0]-1, X.shape[1]))
|
| 793 |
+
Z = pca.fit_transform(X)
|
| 794 |
+
plt.figure(figsize=(6, 5))
|
| 795 |
+
sc = plt.scatter(Z[:, 0], Z[:, 1] if Z.shape[1] > 1 else np.zeros_like(Z[:, 0]),
|
| 796 |
+
c=lens, cmap="viridis", s=14)
|
| 797 |
+
plt.colorbar(sc, label="Output length")
|
| 798 |
+
plt.title("Task3 Concept Proxy Space (Compatibility Mode)")
|
| 799 |
+
plt.tight_layout()
|
| 800 |
+
img = os.path.join(OUTPUT_DIR, "task3_concept_space.png")
|
| 801 |
+
plt.savefig(img, dpi=150, bbox_inches="tight")
|
| 802 |
+
plt.close()
|
| 803 |
+
rep = os.path.join(OUTPUT_DIR, "task3_report.txt")
|
| 804 |
+
corr = float(np.corrcoef(Z[:, 0], np.array(lens))[0, 1]) if len(lens) > 2 else 0.0
|
| 805 |
+
with open(rep, "w", encoding="utf-8") as f:
|
| 806 |
+
f.write("TASK 3 — COMPATIBILITY REPORT\n")
|
| 807 |
+
f.write("="*40 + "\n")
|
| 808 |
+
f.write("Hidden-state capture unavailable; used output-token vector proxy.\n")
|
| 809 |
+
f.write(f"Samples: {n}\n")
|
| 810 |
+
f.write(f"PC1-length correlation: {corr:.4f}\n")
|
| 811 |
+
print(f" Saved: {img}")
|
| 812 |
+
print(f" Report: {rep}")
|
| 813 |
+
return
|
| 814 |
+
|
| 815 |
+
from analysis.concept_vectors import (
|
| 816 |
+
collect_hidden_states, fit_pca, find_diversity_direction, generate_diversity_spectrum
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
# Collect hidden states from val set
|
| 820 |
+
n = min(max(1, int(n_samples)), len(src_list))
|
| 821 |
+
print(f" Collecting hidden states from {n} examples...")
|
| 822 |
+
hidden, texts, lengths = collect_hidden_states(
|
| 823 |
+
model, src_list[:n], tgt_tok, t_capture=0, max_samples=n
|
| 824 |
+
)
|
| 825 |
+
|
| 826 |
+
# Fit PCA + find diversity direction
|
| 827 |
+
pca = fit_pca(hidden, n_components=min(50, n-1))
|
| 828 |
+
direction = find_diversity_direction(hidden, lengths, pca)
|
| 829 |
+
proj = pca.transform(hidden)
|
| 830 |
+
corr = float(np.corrcoef(proj[:, 0], np.array(lengths))[0, 1]) if len(lengths) > 2 else 0.0
|
| 831 |
+
if not np.isfinite(corr):
|
| 832 |
+
corr = 0.0
|
| 833 |
+
best_pc = 0
|
| 834 |
+
|
| 835 |
+
# Plot concept space
|
| 836 |
+
plt.figure(figsize=(8, 6))
|
| 837 |
+
sc = plt.scatter(proj[:, 0], proj[:, 1] if proj.shape[1] > 1 else np.zeros_like(proj[:, 0]),
|
| 838 |
+
c=lengths, cmap="viridis", s=14)
|
| 839 |
+
plt.colorbar(sc, label="Output diversity proxy")
|
| 840 |
+
plt.title("Task3 Concept Space")
|
| 841 |
+
plt.xlabel("PC1")
|
| 842 |
+
plt.ylabel("PC2")
|
| 843 |
+
plt.tight_layout()
|
| 844 |
+
plt.savefig(os.path.join(OUTPUT_DIR, "task3_concept_space.png"), dpi=150, bbox_inches="tight")
|
| 845 |
+
plt.close()
|
| 846 |
+
|
| 847 |
+
# Subtask graph: explained variance by PCA components
|
| 848 |
+
ev = pca.explained_variance_ratio_
|
| 849 |
+
k = min(20, len(ev))
|
| 850 |
+
plt.figure(figsize=(8, 3.5))
|
| 851 |
+
plt.bar(np.arange(k), ev[:k])
|
| 852 |
+
plt.xlabel("PC index")
|
| 853 |
+
plt.ylabel("Explained variance ratio")
|
| 854 |
+
plt.title("Task3: PCA Explained Variance (Top Components)")
|
| 855 |
+
plt.tight_layout()
|
| 856 |
+
plt.savefig(os.path.join(OUTPUT_DIR, "task3_pca_explained_variance.png"), dpi=150, bbox_inches="tight")
|
| 857 |
+
plt.close()
|
| 858 |
+
|
| 859 |
+
# Generate diversity spectrum on multiple seeds for more stable conclusions
|
| 860 |
+
seed_k = min(5, len(src_list))
|
| 861 |
+
uniq_list = []
|
| 862 |
+
sem_list = []
|
| 863 |
+
all_spectra = []
|
| 864 |
+
for i in range(seed_k):
|
| 865 |
+
src_i = src_list[i]
|
| 866 |
+
spec_i = generate_diversity_spectrum(
|
| 867 |
+
model, src_i.to(device), direction, tgt_tok,
|
| 868 |
+
alphas=[-2.0, -1.0, 0.0, 1.0, 2.0]
|
| 869 |
+
)
|
| 870 |
+
all_spectra.append(spec_i)
|
| 871 |
+
spec_items = sorted(spec_i.items())
|
| 872 |
+
spec_texts = [t for _, t in spec_items]
|
| 873 |
+
uniq_list.append(len(set(spec_texts)) / max(1, len(spec_texts)))
|
| 874 |
+
pivot = spec_texts[2] if len(spec_texts) >= 3 else (spec_texts[0] if spec_texts else "")
|
| 875 |
+
sims = [SequenceMatcher(None, txt, pivot).ratio() for txt in spec_texts if txt]
|
| 876 |
+
sem_list.append(float(np.mean(sims)) if sims else 0.0)
|
| 877 |
+
uniq_ratio = float(np.mean(uniq_list)) if uniq_list else 0.0
|
| 878 |
+
semantic_stability = float(np.mean(sem_list)) if sem_list else 0.0
|
| 879 |
+
steering_valid = (abs(corr) >= 0.20) and (uniq_ratio >= 0.55) and (semantic_stability >= 0.40)
|
| 880 |
+
# use first seed spectrum for visualization table
|
| 881 |
+
spectrum = all_spectra[0] if all_spectra else {}
|
| 882 |
+
|
| 883 |
+
# Subtask graph: alpha vs decoded length
|
| 884 |
+
a_vals = sorted(spectrum.keys())
|
| 885 |
+
l_vals = [len(spectrum[a]) for a in a_vals] if spectrum else []
|
| 886 |
+
plt.figure(figsize=(7, 3.5))
|
| 887 |
+
plt.plot(a_vals, l_vals, marker="o")
|
| 888 |
+
plt.xlabel("Steering alpha")
|
| 889 |
+
plt.ylabel("Output length")
|
| 890 |
+
plt.title("Task3: Diversity Steering Curve")
|
| 891 |
+
plt.tight_layout()
|
| 892 |
+
plt.savefig(os.path.join(OUTPUT_DIR, "task3_diversity_curve.png"), dpi=150, bbox_inches="tight")
|
| 893 |
+
plt.close()
|
| 894 |
+
|
| 895 |
+
# Save diversity direction + results
|
| 896 |
+
np.save(os.path.join(OUTPUT_DIR, "task3_diversity_direction.npy"), direction)
|
| 897 |
+
|
| 898 |
+
report = os.path.join(OUTPUT_DIR, "task3_report.txt")
|
| 899 |
+
with open(report, "w", encoding="utf-8") as f:
|
| 900 |
+
f.write("TASK 3 — CONCEPT VECTORS + PCA STEERING\n" + "="*50 + "\n\n")
|
| 901 |
+
f.write(f"PCA: {pca.n_components_} components, "
|
| 902 |
+
f"{pca.explained_variance_ratio_.sum()*100:.1f}% variance\n")
|
| 903 |
+
f.write(f"Diversity PC: {best_pc} (|r|={corr:.3f} with diversity proxy)\n\n")
|
| 904 |
+
f.write(f"Direction validity: {'VALID' if steering_valid else 'WEAK'}\n")
|
| 905 |
+
f.write(f"Spectrum unique ratio (mean over {seed_k} seeds): {uniq_ratio:.3f}\n")
|
| 906 |
+
f.write(f"Spectrum semantic stability (mean over {seed_k} seeds): {semantic_stability:.3f}\n\n")
|
| 907 |
+
f.write("Saved graphs:\n")
|
| 908 |
+
f.write(" - task3_concept_space.png\n")
|
| 909 |
+
f.write(" - task3_pca_explained_variance.png\n")
|
| 910 |
+
f.write(" - task3_diversity_curve.png\n\n")
|
| 911 |
+
f.write("Diversity spectrum:\n")
|
| 912 |
+
for alpha, text in sorted(spectrum.items()):
|
| 913 |
+
f.write(f" alpha={alpha:+.1f} → {text}\n")
|
| 914 |
+
print(f" Report: {report}")
|
| 915 |
+
|
| 916 |
+
|
| 917 |
+
# ── Task 4 ────────────────────────────────────────────────────────────
|
| 918 |
+
|
| 919 |
+
def run_task4(phase, model, src_tok, tgt_tok, device, cfg,
|
| 920 |
+
src_list, ref_list, n_samples=200):
|
| 921 |
+
print("\n" + "="*65)
|
| 922 |
+
print(f" TASK 4 — Step Ablation (phase={phase})")
|
| 923 |
+
print("="*65)
|
| 924 |
+
|
| 925 |
+
import analysis.step_ablation as step_ablation
|
| 926 |
+
|
| 927 |
+
# Legacy API
|
| 928 |
+
has_legacy = all(hasattr(step_ablation, fn) for fn in [
|
| 929 |
+
"generate_ablation_configs", "run_ablation_analysis", "plot_ablation_3d"
|
| 930 |
+
])
|
| 931 |
+
|
| 932 |
+
# New API
|
| 933 |
+
has_new = hasattr(step_ablation, "run_task4")
|
| 934 |
+
|
| 935 |
+
if phase == "generate_configs":
|
| 936 |
+
if has_legacy:
|
| 937 |
+
print(" Generating ablation configs...")
|
| 938 |
+
step_ablation.generate_ablation_configs(output_dir="ablation_configs")
|
| 939 |
+
print("\n NEXT STEPS:")
|
| 940 |
+
print(" 1. bash ablation_configs/train_all.sh")
|
| 941 |
+
print(" 2. python analysis/run_analysis.py --task 4 --phase analyze")
|
| 942 |
+
return
|
| 943 |
+
print(" This step_ablation version does not expose config generation helpers.")
|
| 944 |
+
print(" Use your latest ablation training script/config pipeline directly.")
|
| 945 |
+
return
|
| 946 |
+
|
| 947 |
+
if phase == "analyze":
|
| 948 |
+
existing = [T for T in [4, 8, 16, 32, 64]
|
| 949 |
+
if os.path.exists(f"ablation_results/T{T}/best_model.pt")]
|
| 950 |
+
only_t = os.environ.get("TASK4_ONLY_T")
|
| 951 |
+
if only_t and only_t.isdigit():
|
| 952 |
+
t_req = int(only_t)
|
| 953 |
+
existing = [T for T in existing if T == t_req]
|
| 954 |
+
if not existing:
|
| 955 |
+
print(" No ablation models found at ablation_results/T*/best_model.pt")
|
| 956 |
+
return
|
| 957 |
+
print(f" Found models for T={existing}")
|
| 958 |
+
|
| 959 |
+
if has_legacy:
|
| 960 |
+
results = step_ablation.run_ablation_analysis(
|
| 961 |
+
ablation_dir="ablation_results", base_cfg=cfg,
|
| 962 |
+
src_list=src_list[:200], ref_list=ref_list[:200],
|
| 963 |
+
tgt_tokenizer=tgt_tok, device=device,
|
| 964 |
+
output_dir=OUTPUT_DIR)
|
| 965 |
+
step_ablation.plot_ablation_3d(
|
| 966 |
+
results, save_path=os.path.join(OUTPUT_DIR, "task4_ablation_3d.png"))
|
| 967 |
+
elif has_new:
|
| 968 |
+
from inference import load_model as _load_model
|
| 969 |
+
models = {}
|
| 970 |
+
for T in existing:
|
| 971 |
+
ckpt = f"ablation_results/T{T}/best_model.pt"
|
| 972 |
+
cfg_t = copy.deepcopy(cfg)
|
| 973 |
+
cfg_t["model"]["diffusion_steps"] = T
|
| 974 |
+
cfg_t["inference"]["num_steps"] = T
|
| 975 |
+
m_t, _ = _load_model(ckpt, cfg_t, device)
|
| 976 |
+
m_t.eval()
|
| 977 |
+
models[T] = m_t
|
| 978 |
+
knee_t = step_ablation.run_task4(
|
| 979 |
+
models, src_list[:n_samples], ref_list[:n_samples], tgt_tok,
|
| 980 |
+
output_dir=OUTPUT_DIR, n_samples=n_samples)
|
| 981 |
+
print(f" New pipeline suggested optimal T={knee_t}")
|
| 982 |
+
else:
|
| 983 |
+
print(" Unsupported step_ablation API; please sync analysis/step_ablation.py")
|
| 984 |
+
return
|
| 985 |
+
|
| 986 |
+
# Optional adversarial robustness (legacy helper only)
|
| 987 |
+
if hasattr(step_ablation, "run_adversarial_test"):
|
| 988 |
+
print("\n Running adversarial robustness test...")
|
| 989 |
+
inp_texts = [src_tok.decode([x for x in s[0].tolist() if x > 4])
|
| 990 |
+
for s in src_list[:50]]
|
| 991 |
+
step_ablation.run_adversarial_test(
|
| 992 |
+
model, src_tok, tgt_tok,
|
| 993 |
+
test_inputs=inp_texts, test_refs=ref_list[:50],
|
| 994 |
+
device=device, output_dir=OUTPUT_DIR)
|
| 995 |
+
|
| 996 |
+
|
| 997 |
+
# ── Task 5 ────────────────────────────────────────────────────────────
|
| 998 |
+
|
| 999 |
+
def run_task5(model, src_tok, tgt_tok, device, cfg, src_list, ref_list, task5_samples=500):
|
| 1000 |
+
print("\n" + "="*65)
|
| 1001 |
+
print(" TASK 5 — Classifier-Free Guidance")
|
| 1002 |
+
print("="*65)
|
| 1003 |
+
if not hasattr(model.model, 'encode_source'):
|
| 1004 |
+
print(" Compatibility mode: classifier-guidance unavailable; sweeping decoding controls.")
|
| 1005 |
+
n = min(100, int(task5_samples), len(src_list), len(ref_list))
|
| 1006 |
+
lambdas = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0]
|
| 1007 |
+
results = []
|
| 1008 |
+
for lam in lambdas:
|
| 1009 |
+
rep_pen = 1.0 + 0.15 * lam
|
| 1010 |
+
cer_vals, uniq_vals = [], []
|
| 1011 |
+
for src, ref in zip(src_list[:n], ref_list[:n]):
|
| 1012 |
+
out = _generate_ids_compat(
|
| 1013 |
+
model, src.to(device), num_steps=8, temperature=0.8, top_k=40,
|
| 1014 |
+
repetition_penalty=rep_pen, diversity_penalty=0.0
|
| 1015 |
+
)
|
| 1016 |
+
txt, ids = _decode_ids(tgt_tok, out)
|
| 1017 |
+
cer_vals.append(_cer(txt, ref))
|
| 1018 |
+
uniq_vals.append(len(set(ids)) / max(1, len(ids)))
|
| 1019 |
+
results.append((lam, float(np.mean(cer_vals)), float(np.mean(uniq_vals))))
|
| 1020 |
+
print(f" λ={lam:.1f} CER={results[-1][1]:.4f} diversity={results[-1][2]:.3f}")
|
| 1021 |
+
# Subtask graph: quality-diversity tradeoff
|
| 1022 |
+
x = [r[1] for r in results]
|
| 1023 |
+
y = [r[2] for r in results]
|
| 1024 |
+
labels = [r[0] for r in results]
|
| 1025 |
+
plt.figure(figsize=(6, 4))
|
| 1026 |
+
plt.plot(x, y, marker="o")
|
| 1027 |
+
for xi, yi, la in zip(x, y, labels):
|
| 1028 |
+
plt.text(xi, yi, f"λ={la:.1f}", fontsize=8)
|
| 1029 |
+
plt.xlabel("CER (lower is better)")
|
| 1030 |
+
plt.ylabel("Diversity")
|
| 1031 |
+
plt.title("Task5: Quality-Diversity Tradeoff")
|
| 1032 |
+
plt.tight_layout()
|
| 1033 |
+
plt.savefig(os.path.join(OUTPUT_DIR, "task5_quality_diversity_tradeoff.png"), dpi=150, bbox_inches="tight")
|
| 1034 |
+
plt.close()
|
| 1035 |
+
rep = os.path.join(OUTPUT_DIR, "task5_report.txt")
|
| 1036 |
+
with open(rep, "w", encoding="utf-8") as f:
|
| 1037 |
+
f.write("TASK 5 — COMPATIBILITY REPORT\n")
|
| 1038 |
+
f.write("="*40 + "\n")
|
| 1039 |
+
f.write("Guidance classifier path unavailable; λ mapped to repetition penalty.\n\n")
|
| 1040 |
+
for lam, cer_v, div_v in results:
|
| 1041 |
+
f.write(f"lambda={lam:.1f} CER={cer_v:.4f} diversity={div_v:.3f}\n")
|
| 1042 |
+
f.write("\nSaved graphs:\n")
|
| 1043 |
+
f.write(" - task5_quality_diversity_tradeoff.png\n")
|
| 1044 |
+
print(f" Report: {rep}")
|
| 1045 |
+
return
|
| 1046 |
+
|
| 1047 |
+
try:
|
| 1048 |
+
from analysis.quality_classifier import (
|
| 1049 |
+
QualityClassifier, collect_quality_data,
|
| 1050 |
+
train_quality_classifier, sweep_guidance_scales)
|
| 1051 |
+
except Exception:
|
| 1052 |
+
print(" Quality-classifier API mismatch; using compatibility sweep.")
|
| 1053 |
+
n = min(50, int(task5_samples), len(src_list))
|
| 1054 |
+
scales = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0]
|
| 1055 |
+
results = []
|
| 1056 |
+
for lam in scales:
|
| 1057 |
+
rep_pen = 1.0 + 0.2 * lam
|
| 1058 |
+
cer_vals, uniq_vals = [], []
|
| 1059 |
+
for src, ref in zip(src_list[:n], ref_list[:n]):
|
| 1060 |
+
out = _generate_ids_compat(
|
| 1061 |
+
model, src.to(device), num_steps=8, temperature=0.8, top_k=40,
|
| 1062 |
+
repetition_penalty=rep_pen, diversity_penalty=0.0
|
| 1063 |
+
)
|
| 1064 |
+
txt, ids = _decode_ids(tgt_tok, out)
|
| 1065 |
+
cer_vals.append(_cer(txt, ref))
|
| 1066 |
+
uniq_vals.append(len(set(ids)) / max(1, len(ids)))
|
| 1067 |
+
results.append((lam, float(np.mean(cer_vals)), float(np.mean(uniq_vals))))
|
| 1068 |
+
print(f" λ={lam:.1f} CER={results[-1][1]:.4f} diversity={results[-1][2]:.3f}")
|
| 1069 |
+
# Subtask graph: quality-diversity tradeoff
|
| 1070 |
+
x = [r[1] for r in results]
|
| 1071 |
+
y = [r[2] for r in results]
|
| 1072 |
+
labels = [r[0] for r in results]
|
| 1073 |
+
plt.figure(figsize=(6, 4))
|
| 1074 |
+
plt.plot(x, y, marker="o")
|
| 1075 |
+
for xi, yi, la in zip(x, y, labels):
|
| 1076 |
+
plt.text(xi, yi, f"λ={la:.1f}", fontsize=8)
|
| 1077 |
+
plt.xlabel("CER (lower is better)")
|
| 1078 |
+
plt.ylabel("Diversity")
|
| 1079 |
+
plt.title("Task5: Quality-Diversity Tradeoff")
|
| 1080 |
+
plt.tight_layout()
|
| 1081 |
+
plt.savefig(os.path.join(OUTPUT_DIR, "task5_quality_diversity_tradeoff.png"), dpi=150, bbox_inches="tight")
|
| 1082 |
+
plt.close()
|
| 1083 |
+
rep = os.path.join(OUTPUT_DIR, "task5_report.txt")
|
| 1084 |
+
with open(rep, "w", encoding="utf-8") as f:
|
| 1085 |
+
f.write("TASK 5 — COMPATIBILITY REPORT\n")
|
| 1086 |
+
f.write("="*40 + "\n")
|
| 1087 |
+
f.write("Guidance classifier path unavailable; λ mapped to repetition penalty.\n\n")
|
| 1088 |
+
for lam, cer_v, div_v in results:
|
| 1089 |
+
f.write(f"lambda={lam:.1f} CER={cer_v:.4f} diversity={div_v:.3f}\n")
|
| 1090 |
+
f.write("\nSaved graphs:\n")
|
| 1091 |
+
f.write(" - task5_quality_diversity_tradeoff.png\n")
|
| 1092 |
+
print(f" Report: {rep}")
|
| 1093 |
+
return
|
| 1094 |
+
|
| 1095 |
+
clf_path = os.path.join(OUTPUT_DIR, "task5_quality_classifier.pt")
|
| 1096 |
+
d_model = cfg['model']['d_model']
|
| 1097 |
+
|
| 1098 |
+
# Step 1: collect or load training data
|
| 1099 |
+
data_path = os.path.join(OUTPUT_DIR, "task5_quality_data.npz")
|
| 1100 |
+
if os.path.exists(data_path):
|
| 1101 |
+
print(" Loading cached quality data...")
|
| 1102 |
+
data = np.load(data_path)
|
| 1103 |
+
hidden = data["hidden"]
|
| 1104 |
+
quality = data["quality"]
|
| 1105 |
+
else:
|
| 1106 |
+
print(" Collecting quality data (this takes a few minutes)...")
|
| 1107 |
+
n = min(int(task5_samples), len(src_list))
|
| 1108 |
+
hidden, quality = collect_quality_data(
|
| 1109 |
+
model, src_list[:n], ref_list[:n], tgt_tok,
|
| 1110 |
+
t_capture=0, max_samples=n)
|
| 1111 |
+
np.savez(data_path, hidden=hidden, quality=quality)
|
| 1112 |
+
print(f" Saved quality data: {data_path}")
|
| 1113 |
+
|
| 1114 |
+
# Step 2: train or load classifier
|
| 1115 |
+
if os.path.exists(clf_path):
|
| 1116 |
+
print(f" Loading cached classifier: {clf_path}")
|
| 1117 |
+
clf = QualityClassifier(d_model)
|
| 1118 |
+
clf.load_state_dict(torch.load(clf_path, map_location='cpu'))
|
| 1119 |
+
clf.eval()
|
| 1120 |
+
else:
|
| 1121 |
+
print(" Training quality classifier...")
|
| 1122 |
+
clf = train_quality_classifier(
|
| 1123 |
+
hidden, quality, d_model=d_model,
|
| 1124 |
+
epochs=30, batch_size=64, lr=1e-3,
|
| 1125 |
+
save_path=clf_path)
|
| 1126 |
+
clf.eval()
|
| 1127 |
+
|
| 1128 |
+
# Step 3: guidance scale sweep
|
| 1129 |
+
print("\n Guidance scale sweep (λ ∈ {0.0, 0.5, 1.0, 1.5, 2.0, 3.0})...")
|
| 1130 |
+
n_sweep = min(80, int(task5_samples), len(src_list))
|
| 1131 |
+
results = sweep_guidance_scales(
|
| 1132 |
+
model, clf, src_list[:n_sweep], ref_list[:n_sweep],
|
| 1133 |
+
tgt_tok, scales=[0.0, 0.5, 1.0, 1.5, 2.0, 3.0],
|
| 1134 |
+
n_samples=n_sweep, device=device, output_dir=OUTPUT_DIR)
|
| 1135 |
+
|
| 1136 |
+
# Find optimal scale (quality + anti-collapse diversity)
|
| 1137 |
+
def _score(s):
|
| 1138 |
+
r = results[s]
|
| 1139 |
+
return (r["mean_cer"] - 0.05 * r.get("diversity", 0.0))
|
| 1140 |
+
best_scale = min(results, key=_score)
|
| 1141 |
+
print(f"\n Optimal guidance scale: λ={best_scale:.1f} "
|
| 1142 |
+
f"CER={results[best_scale]['mean_cer']:.4f}")
|
| 1143 |
+
|
| 1144 |
+
report = os.path.join(OUTPUT_DIR, "task5_report.txt")
|
| 1145 |
+
with open(report, "w") as f:
|
| 1146 |
+
f.write("TASK 5 — CLASSIFIER-FREE GUIDANCE\n" + "="*50 + "\n\n")
|
| 1147 |
+
f.write(f"Classifier params: {sum(p.numel() for p in clf.parameters())}\n")
|
| 1148 |
+
f.write(f"Training samples : {len(hidden)}\n\n")
|
| 1149 |
+
f.write("Guidance scale sweep:\n")
|
| 1150 |
+
f.write(f" {'λ':>6} {'CER':>8} {'diversity':>10} {'d2':>6} {'sBLEU':>8}\n")
|
| 1151 |
+
f.write(" " + "-"*52 + "\n")
|
| 1152 |
+
for s in sorted(results.keys()):
|
| 1153 |
+
r = results[s]
|
| 1154 |
+
marker = " ← optimal" if s == best_scale else ""
|
| 1155 |
+
f.write(
|
| 1156 |
+
f" {s:>6.1f} {r['mean_cer']:>8.4f} {r['diversity']:>10.3f} "
|
| 1157 |
+
f"{r.get('distinct2', 0.0):>6.3f} {r.get('self_bleu', 0.0):>8.3f}{marker}\n"
|
| 1158 |
+
)
|
| 1159 |
+
print(f" Report: {report}")
|
| 1160 |
+
|
| 1161 |
+
|
| 1162 |
+
# ── Main ──────────────────────────────────────────────────────────────
|
| 1163 |
+
|
| 1164 |
+
def main():
|
| 1165 |
+
global OUTPUT_DIR
|
| 1166 |
+
|
| 1167 |
+
parser = argparse.ArgumentParser()
|
| 1168 |
+
parser.add_argument("--task",
|
| 1169 |
+
choices=["1","2","3","4","5","all"], default="all")
|
| 1170 |
+
parser.add_argument("--input",
|
| 1171 |
+
default="dharmo rakṣati rakṣitaḥ",
|
| 1172 |
+
help="IAST input text for Task 2")
|
| 1173 |
+
parser.add_argument("--phase",
|
| 1174 |
+
choices=["generate_configs", "analyze"], default="analyze",
|
| 1175 |
+
help="Task 4 phase: generate_configs (before training) or analyze (after)")
|
| 1176 |
+
parser.add_argument("--checkpoint", default=None,
|
| 1177 |
+
help="Optional explicit checkpoint path")
|
| 1178 |
+
parser.add_argument("--output_dir", default="analysis/outputs",
|
| 1179 |
+
help="Output directory for reports/figures")
|
| 1180 |
+
parser.add_argument("--task4_samples", type=int, default=50,
|
| 1181 |
+
help="Samples for Task 4 dry/full evaluation")
|
| 1182 |
+
parser.add_argument("--task3_samples", type=int, default=500,
|
| 1183 |
+
help="Samples for Task 3 hidden-state collection")
|
| 1184 |
+
parser.add_argument("--task5_samples", type=int, default=500,
|
| 1185 |
+
help="Samples for Task 5 classifier data + sweep")
|
| 1186 |
+
args = parser.parse_args()
|
| 1187 |
+
|
| 1188 |
+
OUTPUT_DIR = args.output_dir
|
| 1189 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 1190 |
+
|
| 1191 |
+
cfg = copy.deepcopy(CONFIG)
|
| 1192 |
+
if args.checkpoint:
|
| 1193 |
+
cfg["model_type"] = infer_model_type_from_checkpoint(args.checkpoint)
|
| 1194 |
+
cfg["data"]["include_negative_examples"] = infer_include_negative_from_checkpoint(args.checkpoint)
|
| 1195 |
+
ckpt_name = os.path.basename(os.path.dirname(args.checkpoint))
|
| 1196 |
+
if ckpt_name.startswith("T") and ckpt_name[1:].isdigit():
|
| 1197 |
+
t_val = int(ckpt_name[1:])
|
| 1198 |
+
cfg["model"]["diffusion_steps"] = t_val
|
| 1199 |
+
cfg["inference"]["num_steps"] = t_val
|
| 1200 |
+
|
| 1201 |
+
requested = cfg["training"]["device"]
|
| 1202 |
+
if requested == "mps" and not torch.backends.mps.is_available():
|
| 1203 |
+
requested = "cpu"
|
| 1204 |
+
elif requested == "cuda" and not torch.cuda.is_available():
|
| 1205 |
+
requested = "cpu"
|
| 1206 |
+
cfg["training"]["device"] = requested
|
| 1207 |
+
device = torch.device(requested)
|
| 1208 |
+
|
| 1209 |
+
print("Loading model and tokenizers...")
|
| 1210 |
+
model, src_tok, tgt_tok, cfg = load_everything(cfg, device, ckpt_override=args.checkpoint)
|
| 1211 |
+
|
| 1212 |
+
# Load val data for tasks that need corpus/context (Tasks 2, 3, 4, 5)
|
| 1213 |
+
needs_data = args.task in ("2", "3", "4", "5", "all")
|
| 1214 |
+
if needs_data:
|
| 1215 |
+
print("Loading validation data...")
|
| 1216 |
+
src_list, ref_list, inp_list = load_val_data(cfg, src_tok, tgt_tok, n=500)
|
| 1217 |
+
else:
|
| 1218 |
+
src_list, ref_list, inp_list = [], [], []
|
| 1219 |
+
|
| 1220 |
+
tasks = (["1","2","3","4","5"] if args.task == "all"
|
| 1221 |
+
else [args.task])
|
| 1222 |
+
|
| 1223 |
+
for task in tasks:
|
| 1224 |
+
if task == "1":
|
| 1225 |
+
run_task1(model, src_tok, device)
|
| 1226 |
+
elif task == "2":
|
| 1227 |
+
run_task2(model, src_tok, tgt_tok, device, args.input, cfg, corpus_inputs=inp_list)
|
| 1228 |
+
elif task == "3":
|
| 1229 |
+
run_task3(model, src_tok, tgt_tok, device, src_list, ref_list, n_samples=args.task3_samples)
|
| 1230 |
+
elif task == "4":
|
| 1231 |
+
run_task4(args.phase, model, src_tok, tgt_tok, device, cfg,
|
| 1232 |
+
src_list, ref_list, n_samples=args.task4_samples)
|
| 1233 |
+
elif task == "5":
|
| 1234 |
+
run_task5(
|
| 1235 |
+
model, src_tok, tgt_tok, device, cfg, src_list, ref_list,
|
| 1236 |
+
task5_samples=args.task5_samples
|
| 1237 |
+
)
|
| 1238 |
+
|
| 1239 |
+
print(f"\n{'='*65}")
|
| 1240 |
+
print(f" All outputs saved to: {OUTPUT_DIR}/")
|
| 1241 |
+
print("="*65)
|
| 1242 |
+
|
| 1243 |
+
|
| 1244 |
+
if __name__ == "__main__":
|
| 1245 |
+
main()
|
analysis/step_ablation.py
ADDED
|
@@ -0,0 +1,640 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# """
|
| 2 |
+
# analysis/step_ablation.py
|
| 3 |
+
# ==========================
|
| 4 |
+
# Task 4: Semantic Robustness — Ablation of Diffusion Steps vs Meaning Preservation
|
| 5 |
+
#
|
| 6 |
+
# Two-phase workflow (retraining IS required for different T values):
|
| 7 |
+
#
|
| 8 |
+
# PHASE 1 — Generate configs + train (run once per T value):
|
| 9 |
+
# python analysis/step_ablation.py --phase generate_configs
|
| 10 |
+
# # Creates configs: ablation_configs/T4.py, T8.py, T16.py, T32.py, T64.py
|
| 11 |
+
# # Then train each: MODEL_TYPE=d3pm_cross_attention python train.py (for each config)
|
| 12 |
+
#
|
| 13 |
+
# PHASE 2 — Analyze trained models (no retraining needed):
|
| 14 |
+
# python analysis/step_ablation.py --phase analyze
|
| 15 |
+
# # Loads each trained model, generates 200 paraphrases, computes CER
|
| 16 |
+
# # Produces 3D plot: X=steps, Y=generation_speed, Z=CER
|
| 17 |
+
#
|
| 18 |
+
# Why retraining is needed:
|
| 19 |
+
# A model trained with T=128 learns to denoise from x_t~Uniform[0,128].
|
| 20 |
+
# Running it with T=4 means the model only sees t∈{0,1,2,3} — which it
|
| 21 |
+
# was never trained on at those scales. Outputs are meaningless.
|
| 22 |
+
# You must train a separate model for each T value.
|
| 23 |
+
#
|
| 24 |
+
# Also implements adversarial robustness test (no retraining):
|
| 25 |
+
# Takes your existing T=128 model and tests whether corrupted IAST
|
| 26 |
+
# inputs (typos, character swaps) cause proportional output degradation.
|
| 27 |
+
# """
|
| 28 |
+
#
|
| 29 |
+
# import torch
|
| 30 |
+
# import torch.nn.functional as F
|
| 31 |
+
# import numpy as np
|
| 32 |
+
# import os
|
| 33 |
+
# import sys
|
| 34 |
+
# import time
|
| 35 |
+
# import json
|
| 36 |
+
# import copy
|
| 37 |
+
# from typing import List, Dict, Optional
|
| 38 |
+
#
|
| 39 |
+
# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 40 |
+
#
|
| 41 |
+
#
|
| 42 |
+
# # ── Phase 1: Config generation ────────────────────────────────────────
|
| 43 |
+
#
|
| 44 |
+
# T_VALUES = [4, 8, 16, 32, 64]
|
| 45 |
+
#
|
| 46 |
+
# def generate_ablation_configs(base_config_path: str = "config.py",
|
| 47 |
+
# output_dir: str = "ablation_configs"):
|
| 48 |
+
# """
|
| 49 |
+
# Generate one config file per T value.
|
| 50 |
+
# Each config is a copy of the base config with diffusion_steps changed.
|
| 51 |
+
#
|
| 52 |
+
# After running this, train each model:
|
| 53 |
+
# for T in 4 8 16 32 64; do
|
| 54 |
+
# cp ablation_configs/config_T${T}.py config.py
|
| 55 |
+
# python train.py
|
| 56 |
+
# mv results7/d3pm_cross_attention_neg_False \
|
| 57 |
+
# ablation_results/T${T}
|
| 58 |
+
# done
|
| 59 |
+
# """
|
| 60 |
+
# os.makedirs(output_dir, exist_ok=True)
|
| 61 |
+
#
|
| 62 |
+
# # Read base config
|
| 63 |
+
# with open(base_config_path, "r") as f:
|
| 64 |
+
# base_src = f.read()
|
| 65 |
+
#
|
| 66 |
+
# for T in T_VALUES:
|
| 67 |
+
# # Replace diffusion_steps and num_steps
|
| 68 |
+
# cfg_src = base_src
|
| 69 |
+
# cfg_src = cfg_src.replace(
|
| 70 |
+
# '"diffusion_steps": 128',
|
| 71 |
+
# f'"diffusion_steps": {T}'
|
| 72 |
+
# )
|
| 73 |
+
# cfg_src = cfg_src.replace(
|
| 74 |
+
# "'diffusion_steps': 128",
|
| 75 |
+
# f"'diffusion_steps': {T}"
|
| 76 |
+
# )
|
| 77 |
+
# cfg_src = cfg_src.replace(
|
| 78 |
+
# '"num_steps": 128',
|
| 79 |
+
# f'"num_steps": {T}'
|
| 80 |
+
# )
|
| 81 |
+
# cfg_src = cfg_src.replace(
|
| 82 |
+
# "'num_steps': 128",
|
| 83 |
+
# f"'num_steps': {T}"
|
| 84 |
+
# )
|
| 85 |
+
# out_path = os.path.join(output_dir, f"config_T{T}.py")
|
| 86 |
+
# with open(out_path, "w") as f:
|
| 87 |
+
# f.write(f"# Ablation config: T={T} diffusion steps\n")
|
| 88 |
+
# f.write(cfg_src)
|
| 89 |
+
# print(f" Wrote: {out_path}")
|
| 90 |
+
#
|
| 91 |
+
# # Write a shell script to train all
|
| 92 |
+
# shell_script = os.path.join(output_dir, "train_all.sh")
|
| 93 |
+
# with open(shell_script, "w") as f:
|
| 94 |
+
# f.write("#!/bin/bash\n")
|
| 95 |
+
# f.write("# Run this script to train all ablation models\n\n")
|
| 96 |
+
# for T in T_VALUES:
|
| 97 |
+
# f.write(f"echo '=== Training T={T} ==='\n")
|
| 98 |
+
# f.write(f"cp {output_dir}/config_T{T}.py config.py\n")
|
| 99 |
+
# f.write(f"python train.py\n")
|
| 100 |
+
# f.write(f"mkdir -p ablation_results/T{T}\n")
|
| 101 |
+
# f.write(f"cp -r results7/d3pm_cross_attention_neg_False/best_model.pt "
|
| 102 |
+
# f"ablation_results/T{T}/best_model.pt\n")
|
| 103 |
+
# f.write(f"cp -r results7/d3pm_cross_attention_neg_False/train.log "
|
| 104 |
+
# f"ablation_results/T{T}/train.log\n\n")
|
| 105 |
+
# os.chmod(shell_script, 0o755)
|
| 106 |
+
# print(f"\nTraining script: {shell_script}")
|
| 107 |
+
# print(f"Run: bash {shell_script}")
|
| 108 |
+
#
|
| 109 |
+
#
|
| 110 |
+
# # ── Phase 2: Analysis (after models are trained) ──────────────────────
|
| 111 |
+
#
|
| 112 |
+
# def compute_cer(pred: str, ref: str) -> float:
|
| 113 |
+
# if not ref:
|
| 114 |
+
# return 1.0
|
| 115 |
+
#
|
| 116 |
+
# def edit_distance(s1, s2):
|
| 117 |
+
# m, n = len(s1), len(s2)
|
| 118 |
+
# dp = list(range(n + 1))
|
| 119 |
+
# for i in range(1, m + 1):
|
| 120 |
+
# prev, dp[0] = dp[0], i
|
| 121 |
+
# for j in range(1, n + 1):
|
| 122 |
+
# temp = dp[j]
|
| 123 |
+
# dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1])
|
| 124 |
+
# prev = temp
|
| 125 |
+
# return dp[n]
|
| 126 |
+
#
|
| 127 |
+
# return edit_distance(pred, ref) / max(len(ref), 1)
|
| 128 |
+
#
|
| 129 |
+
#
|
| 130 |
+
# def evaluate_model(
|
| 131 |
+
# model,
|
| 132 |
+
# src_list: List[torch.Tensor],
|
| 133 |
+
# ref_list: List[str],
|
| 134 |
+
# tgt_tokenizer,
|
| 135 |
+
# n_samples: int = 200,
|
| 136 |
+
# temperature: float = 0.8,
|
| 137 |
+
# top_k: int = 40,
|
| 138 |
+
# ) -> Dict:
|
| 139 |
+
# """
|
| 140 |
+
# Generate n_samples outputs and compute CER + generation speed.
|
| 141 |
+
#
|
| 142 |
+
# Returns dict with:
|
| 143 |
+
# mean_cer : average CER over samples
|
| 144 |
+
# generation_s : total wall-clock seconds for all generations
|
| 145 |
+
# speed_per_sample: seconds per sample
|
| 146 |
+
# cer_list : per-sample CER values
|
| 147 |
+
# """
|
| 148 |
+
# device = next(model.parameters()).device
|
| 149 |
+
# n = min(n_samples, len(src_list))
|
| 150 |
+
# cer_list = []
|
| 151 |
+
#
|
| 152 |
+
# start = time.perf_counter()
|
| 153 |
+
# for i, (src, ref) in enumerate(zip(src_list[:n], ref_list[:n])):
|
| 154 |
+
# if src.dim() == 1:
|
| 155 |
+
# src = src.unsqueeze(0)
|
| 156 |
+
#
|
| 157 |
+
# with torch.no_grad():
|
| 158 |
+
# if hasattr(model.model, 'generate_cached'):
|
| 159 |
+
# out = model.model.generate_cached(
|
| 160 |
+
# src.to(device), temperature=temperature, top_k=top_k
|
| 161 |
+
# )
|
| 162 |
+
# else:
|
| 163 |
+
# out = model.generate(
|
| 164 |
+
# src.to(device), temperature=temperature, top_k=top_k
|
| 165 |
+
# )
|
| 166 |
+
#
|
| 167 |
+
# ids = [x for x in out[0].tolist() if x > 4]
|
| 168 |
+
# pred = tgt_tokenizer.decode(ids).strip()
|
| 169 |
+
# cer = compute_cer(pred, ref)
|
| 170 |
+
# cer_list.append(cer)
|
| 171 |
+
#
|
| 172 |
+
# elapsed = time.perf_counter() - start
|
| 173 |
+
#
|
| 174 |
+
# return {
|
| 175 |
+
# "mean_cer": float(np.mean(cer_list)),
|
| 176 |
+
# "std_cer": float(np.std(cer_list)),
|
| 177 |
+
# "generation_s": elapsed,
|
| 178 |
+
# "speed_per_sample": elapsed / max(n, 1),
|
| 179 |
+
# "cer_list": cer_list,
|
| 180 |
+
# "n_samples": n,
|
| 181 |
+
# }
|
| 182 |
+
#
|
| 183 |
+
#
|
| 184 |
+
# def run_ablation_analysis(
|
| 185 |
+
# ablation_dir: str = "ablation_results",
|
| 186 |
+
# base_cfg: dict = None,
|
| 187 |
+
# src_list: List[torch.Tensor] = None,
|
| 188 |
+
# ref_list: List[str] = None,
|
| 189 |
+
# tgt_tokenizer = None,
|
| 190 |
+
# device: torch.device = None,
|
| 191 |
+
# output_dir: str = "analysis/outputs",
|
| 192 |
+
# ) -> Dict:
|
| 193 |
+
# """
|
| 194 |
+
# Load each trained model and evaluate.
|
| 195 |
+
# Produces results dict and 3D plot.
|
| 196 |
+
#
|
| 197 |
+
# Expects ablation_results/T{N}/best_model.pt for each T in T_VALUES.
|
| 198 |
+
# """
|
| 199 |
+
# from inference import load_model
|
| 200 |
+
#
|
| 201 |
+
# results = {}
|
| 202 |
+
# for T in T_VALUES:
|
| 203 |
+
# ckpt = os.path.join(ablation_dir, f"T{T}", "best_model.pt")
|
| 204 |
+
# if not os.path.exists(ckpt):
|
| 205 |
+
# print(f" SKIP T={T}: no checkpoint at {ckpt}")
|
| 206 |
+
# continue
|
| 207 |
+
#
|
| 208 |
+
# print(f"\nEvaluating T={T}...")
|
| 209 |
+
# cfg_T = copy.deepcopy(base_cfg)
|
| 210 |
+
# cfg_T['model']['diffusion_steps'] = T
|
| 211 |
+
# cfg_T['inference']['num_steps'] = T
|
| 212 |
+
#
|
| 213 |
+
# model, cfg_T = load_model(ckpt, cfg_T, device)
|
| 214 |
+
# model.eval()
|
| 215 |
+
#
|
| 216 |
+
# metrics = evaluate_model(
|
| 217 |
+
# model, src_list, ref_list, tgt_tokenizer, n_samples=200
|
| 218 |
+
# )
|
| 219 |
+
# results[T] = metrics
|
| 220 |
+
# print(f" T={T} CER={metrics['mean_cer']:.4f} "
|
| 221 |
+
# f"speed={metrics['speed_per_sample']:.3f}s/sample")
|
| 222 |
+
#
|
| 223 |
+
# del model
|
| 224 |
+
#
|
| 225 |
+
# # Save results
|
| 226 |
+
# os.makedirs(output_dir, exist_ok=True)
|
| 227 |
+
# results_path = os.path.join(output_dir, "ablation_results.json")
|
| 228 |
+
# with open(results_path, "w") as f:
|
| 229 |
+
# json.dump({str(k): {kk: vv for kk, vv in v.items() if kk != 'cer_list'}
|
| 230 |
+
# for k, v in results.items()}, f, indent=2)
|
| 231 |
+
# print(f"\nResults saved: {results_path}")
|
| 232 |
+
#
|
| 233 |
+
# return results
|
| 234 |
+
#
|
| 235 |
+
#
|
| 236 |
+
# def plot_ablation_3d(
|
| 237 |
+
# results: Dict,
|
| 238 |
+
# save_path: Optional[str] = None,
|
| 239 |
+
# ):
|
| 240 |
+
# """
|
| 241 |
+
# 3D plot: X=diffusion_steps, Y=generation_speed(s/sample), Z=CER.
|
| 242 |
+
# Also produces a 2D summary plot.
|
| 243 |
+
# """
|
| 244 |
+
# try:
|
| 245 |
+
# import matplotlib.pyplot as plt
|
| 246 |
+
# from mpl_toolkits.mplot3d import Axes3D
|
| 247 |
+
# except ImportError:
|
| 248 |
+
# print("pip install matplotlib.")
|
| 249 |
+
# return
|
| 250 |
+
#
|
| 251 |
+
# T_list = sorted(results.keys())
|
| 252 |
+
# cers = [results[T]["mean_cer"] for T in T_list]
|
| 253 |
+
# speeds = [results[T]["speed_per_sample"] for T in T_list]
|
| 254 |
+
#
|
| 255 |
+
# # ── 3D plot ───────────────────────────────────────────────────────
|
| 256 |
+
# fig = plt.figure(figsize=(14, 5))
|
| 257 |
+
#
|
| 258 |
+
# ax3d = fig.add_subplot(121, projection='3d')
|
| 259 |
+
# ax3d.scatter(T_list, speeds, cers, c=cers, cmap='RdYlGn_r', s=80)
|
| 260 |
+
# for T, s, c in zip(T_list, speeds, cers):
|
| 261 |
+
# ax3d.text(T, s, c, f"T={T}", fontsize=8)
|
| 262 |
+
# ax3d.set_xlabel("Diffusion steps T", fontsize=9)
|
| 263 |
+
# ax3d.set_ylabel("Speed (s/sample)", fontsize=9)
|
| 264 |
+
# ax3d.set_zlabel("CER (↓ better)", fontsize=9)
|
| 265 |
+
# ax3d.set_title("T vs speed vs CER", fontsize=10)
|
| 266 |
+
#
|
| 267 |
+
# # ── 2D CER vs T (find the knee) ──────────────────────────────────
|
| 268 |
+
# ax2d = fig.add_subplot(122)
|
| 269 |
+
# ax2d.plot(T_list, cers, 'o-', linewidth=1.8, color='coral', markersize=7)
|
| 270 |
+
# for T, c in zip(T_list, cers):
|
| 271 |
+
# ax2d.annotate(f"{c:.3f}", (T, c), textcoords="offset points",
|
| 272 |
+
# xytext=(0, 8), fontsize=8, ha='center')
|
| 273 |
+
#
|
| 274 |
+
# # Find knee: largest CER drop per unit T (elbow method)
|
| 275 |
+
# if len(T_list) >= 3:
|
| 276 |
+
# drops = [cers[i] - cers[i+1] for i in range(len(cers)-1)]
|
| 277 |
+
# knee_i = int(np.argmax(drops))
|
| 278 |
+
# knee_T = T_list[knee_i + 1]
|
| 279 |
+
# ax2d.axvline(knee_T, color='steelblue', linestyle='--', linewidth=1.2,
|
| 280 |
+
# label=f"Knee at T={knee_T}")
|
| 281 |
+
# ax2d.legend(fontsize=9)
|
| 282 |
+
#
|
| 283 |
+
# ax2d.set_xlabel("Diffusion steps T", fontsize=10)
|
| 284 |
+
# ax2d.set_ylabel("CER (lower = better)", fontsize=10)
|
| 285 |
+
# ax2d.set_title("CER vs diffusion steps", fontsize=10)
|
| 286 |
+
# ax2d.set_ylim(0, max(cers) * 1.1)
|
| 287 |
+
#
|
| 288 |
+
# plt.tight_layout()
|
| 289 |
+
# if save_path:
|
| 290 |
+
# os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
|
| 291 |
+
# plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 292 |
+
# print(f"Saved: {save_path}")
|
| 293 |
+
# else:
|
| 294 |
+
# plt.show()
|
| 295 |
+
# plt.close()
|
| 296 |
+
#
|
| 297 |
+
#
|
| 298 |
+
# # ── Adversarial robustness test (no retraining needed) ───────────────
|
| 299 |
+
#
|
| 300 |
+
# def corrupt_iast(text: str, corruption_rate: float = 0.05) -> str:
|
| 301 |
+
# """
|
| 302 |
+
# Introduce random corruption into IAST text:
|
| 303 |
+
# - Character swap (adjacent chars swapped)
|
| 304 |
+
# - Character deletion
|
| 305 |
+
# - Random character insertion
|
| 306 |
+
#
|
| 307 |
+
# Models rate as 5% to 20% corruption to test robustness.
|
| 308 |
+
# """
|
| 309 |
+
# import random
|
| 310 |
+
# chars = list(text)
|
| 311 |
+
# n_corrupt = max(1, int(len(chars) * corruption_rate))
|
| 312 |
+
#
|
| 313 |
+
# for _ in range(n_corrupt):
|
| 314 |
+
# op = random.choice(['swap', 'delete', 'insert'])
|
| 315 |
+
# pos = random.randint(0, len(chars) - 1)
|
| 316 |
+
#
|
| 317 |
+
# if op == 'swap' and pos < len(chars) - 1:
|
| 318 |
+
# chars[pos], chars[pos+1] = chars[pos+1], chars[pos]
|
| 319 |
+
# elif op == 'delete' and len(chars) > 1:
|
| 320 |
+
# chars.pop(pos)
|
| 321 |
+
# elif op == 'insert':
|
| 322 |
+
# chars.insert(pos, random.choice('abcdeimnostu'))
|
| 323 |
+
#
|
| 324 |
+
# return "".join(chars)
|
| 325 |
+
#
|
| 326 |
+
#
|
| 327 |
+
# @torch.no_grad()
|
| 328 |
+
# def run_adversarial_test(
|
| 329 |
+
# model,
|
| 330 |
+
# src_tokenizer,
|
| 331 |
+
# tgt_tokenizer,
|
| 332 |
+
# test_inputs: List[str],
|
| 333 |
+
# test_refs: List[str],
|
| 334 |
+
# corruption_rates: List[float] = [0.0, 0.05, 0.10, 0.15, 0.20],
|
| 335 |
+
# device: torch.device = None,
|
| 336 |
+
# output_dir: str = "analysis/outputs",
|
| 337 |
+
# ) -> Dict:
|
| 338 |
+
# """
|
| 339 |
+
# Test if CER degrades proportionally with IAST corruption.
|
| 340 |
+
# Uses existing trained model — no retraining.
|
| 341 |
+
# """
|
| 342 |
+
# device = device or next(model.parameters()).device
|
| 343 |
+
# results = {}
|
| 344 |
+
#
|
| 345 |
+
# print("\nAdversarial robustness test...")
|
| 346 |
+
# for rate in corruption_rates:
|
| 347 |
+
# cer_list = []
|
| 348 |
+
# for text, ref in zip(test_inputs, test_refs):
|
| 349 |
+
# corrupted = corrupt_iast(text, rate)
|
| 350 |
+
# ids = src_tokenizer.encode(corrupted)
|
| 351 |
+
# src = torch.tensor([ids], dtype=torch.long, device=device)
|
| 352 |
+
#
|
| 353 |
+
# if hasattr(model.model, 'generate_cached'):
|
| 354 |
+
# out = model.model.generate_cached(src)
|
| 355 |
+
# else:
|
| 356 |
+
# out = model.generate(src)
|
| 357 |
+
#
|
| 358 |
+
# pred_ids = [x for x in out[0].tolist() if x > 4]
|
| 359 |
+
# pred = tgt_tokenizer.decode(pred_ids).strip()
|
| 360 |
+
# cer_list.append(compute_cer(pred, ref))
|
| 361 |
+
#
|
| 362 |
+
# mean_cer = float(np.mean(cer_list))
|
| 363 |
+
# results[rate] = mean_cer
|
| 364 |
+
# print(f" corruption={rate*100:.0f}% → CER={mean_cer:.4f}")
|
| 365 |
+
#
|
| 366 |
+
# # Save + plot
|
| 367 |
+
# os.makedirs(output_dir, exist_ok=True)
|
| 368 |
+
# try:
|
| 369 |
+
# import matplotlib.pyplot as plt
|
| 370 |
+
# fig, ax = plt.subplots(figsize=(8, 4))
|
| 371 |
+
# rates = [r * 100 for r in corruption_rates]
|
| 372 |
+
# cers = [results[r] for r in corruption_rates]
|
| 373 |
+
# ax.plot(rates, cers, 'o-', linewidth=1.8, color='steelblue', markersize=7)
|
| 374 |
+
# ax.set_xlabel("IAST corruption rate (%)", fontsize=11)
|
| 375 |
+
# ax.set_ylabel("CER", fontsize=11)
|
| 376 |
+
# ax.set_title("Model robustness to IAST input corruption", fontsize=11)
|
| 377 |
+
# ax.set_ylim(0, max(cers) * 1.2)
|
| 378 |
+
# plt.tight_layout()
|
| 379 |
+
# plt.savefig(os.path.join(output_dir, "adversarial_robustness.png"),
|
| 380 |
+
# dpi=150, bbox_inches='tight')
|
| 381 |
+
# plt.close()
|
| 382 |
+
# print(f" Saved: {output_dir}/adversarial_robustness.png")
|
| 383 |
+
# except ImportError:
|
| 384 |
+
# pass
|
| 385 |
+
#
|
| 386 |
+
# with open(os.path.join(output_dir, "adversarial_results.json"), "w") as f:
|
| 387 |
+
# json.dump({str(k): v for k, v in results.items()}, f, indent=2)
|
| 388 |
+
#
|
| 389 |
+
# return results
|
| 390 |
+
"""
|
| 391 |
+
analysis/task4_pipeline.py
|
| 392 |
+
================================
|
| 393 |
+
Correct Task 4 Pipeline:
|
| 394 |
+
|
| 395 |
+
PHASE 1 → Evaluate all models
|
| 396 |
+
PHASE 2 → Analyze + detect optimal T
|
| 397 |
+
|
| 398 |
+
NO early decision making.
|
| 399 |
+
"""
|
| 400 |
+
|
| 401 |
+
import torch
|
| 402 |
+
import numpy as np
|
| 403 |
+
import time
|
| 404 |
+
import os
|
| 405 |
+
import json
|
| 406 |
+
from typing import Dict, List
|
| 407 |
+
from difflib import SequenceMatcher
|
| 408 |
+
from collections import Counter
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
# ���────────────────────────────────────────────
|
| 412 |
+
# Load Metrics
|
| 413 |
+
# ─────────────────────────────────────────────
|
| 414 |
+
|
| 415 |
+
def load_metrics():
|
| 416 |
+
try:
|
| 417 |
+
from bert_score import score as bert_score
|
| 418 |
+
except Exception:
|
| 419 |
+
bert_score = None
|
| 420 |
+
from nltk.translate.bleu_score import sentence_bleu
|
| 421 |
+
try:
|
| 422 |
+
from sentence_transformers import SentenceTransformer, util
|
| 423 |
+
st_model = SentenceTransformer('all-MiniLM-L6-v2')
|
| 424 |
+
return bert_score, st_model, util, sentence_bleu
|
| 425 |
+
except Exception:
|
| 426 |
+
# Offline-safe fallback: skip sentence-transformer similarity.
|
| 427 |
+
return bert_score, None, None, sentence_bleu
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
# ─────────────────────────────────────────────
|
| 431 |
+
# PHASE 1 — Evaluate ALL models
|
| 432 |
+
# ─────────────────────────────────────────────
|
| 433 |
+
|
| 434 |
+
def evaluate_all_models(models: Dict[int, object],
|
| 435 |
+
src_list,
|
| 436 |
+
ref_list,
|
| 437 |
+
tgt_tokenizer,
|
| 438 |
+
n_samples=200,
|
| 439 |
+
output_dir: str = "analysis/outputs"):
|
| 440 |
+
|
| 441 |
+
bert_score_fn, st_model, util, bleu_fn = load_metrics()
|
| 442 |
+
|
| 443 |
+
results = {}
|
| 444 |
+
|
| 445 |
+
print("\n=== PHASE 1: Evaluating ALL models ===")
|
| 446 |
+
|
| 447 |
+
for T, model in sorted(models.items()):
|
| 448 |
+
print(f"\nEvaluating T={T}...")
|
| 449 |
+
|
| 450 |
+
device = next(model.parameters()).device
|
| 451 |
+
preds, refs = [], []
|
| 452 |
+
|
| 453 |
+
start = time.perf_counter()
|
| 454 |
+
|
| 455 |
+
for src, ref in zip(src_list[:n_samples], ref_list[:n_samples]):
|
| 456 |
+
if src.dim() == 1:
|
| 457 |
+
src = src.unsqueeze(0)
|
| 458 |
+
|
| 459 |
+
with torch.no_grad():
|
| 460 |
+
if hasattr(model, "model") and hasattr(model.model, "generate_cached"):
|
| 461 |
+
out = model.model.generate_cached(src.to(device))
|
| 462 |
+
else:
|
| 463 |
+
# Fallback for wrappers that only expose top-level generate.
|
| 464 |
+
out = model.generate(src.to(device))
|
| 465 |
+
|
| 466 |
+
ids = [x for x in out[0].tolist() if x > 4]
|
| 467 |
+
pred = tgt_tokenizer.decode(ids).strip()
|
| 468 |
+
|
| 469 |
+
preds.append(pred)
|
| 470 |
+
refs.append(ref)
|
| 471 |
+
|
| 472 |
+
elapsed = time.perf_counter() - start
|
| 473 |
+
|
| 474 |
+
# BERTScore (fallback to lexical similarity if unavailable/offline)
|
| 475 |
+
try:
|
| 476 |
+
if bert_score_fn is not None:
|
| 477 |
+
_, _, F1 = bert_score_fn(preds, refs, lang="hi", verbose=False)
|
| 478 |
+
bert_f1 = float(F1.mean())
|
| 479 |
+
else:
|
| 480 |
+
raise RuntimeError("bertscore unavailable")
|
| 481 |
+
except Exception:
|
| 482 |
+
bert_f1 = float(np.mean([SequenceMatcher(None, p, r).ratio() for p, r in zip(preds, refs)]))
|
| 483 |
+
|
| 484 |
+
# Sentence similarity (distinct from BERT fallback)
|
| 485 |
+
if st_model is not None:
|
| 486 |
+
emb_p = st_model.encode(preds, convert_to_tensor=True)
|
| 487 |
+
emb_r = st_model.encode(refs, convert_to_tensor=True)
|
| 488 |
+
sim = util.cos_sim(emb_p, emb_r).diagonal().mean().item()
|
| 489 |
+
else:
|
| 490 |
+
# token-overlap F1 proxy (different behavior from char-level similarity)
|
| 491 |
+
f1s = []
|
| 492 |
+
for p, r in zip(preds, refs):
|
| 493 |
+
pt = [t for t in p.split() if t]
|
| 494 |
+
rt = [t for t in r.split() if t]
|
| 495 |
+
if not pt or not rt:
|
| 496 |
+
f1s.append(0.0)
|
| 497 |
+
continue
|
| 498 |
+
cp, cr = Counter(pt), Counter(rt)
|
| 499 |
+
inter = sum((cp & cr).values())
|
| 500 |
+
prec = inter / max(1, len(pt))
|
| 501 |
+
rec = inter / max(1, len(rt))
|
| 502 |
+
f1s.append((2 * prec * rec / max(1e-9, prec + rec)))
|
| 503 |
+
sim = float(np.mean(f1s)) if f1s else 0.0
|
| 504 |
+
if not np.isfinite(sim):
|
| 505 |
+
sim = float(np.mean([SequenceMatcher(None, p, r).ratio() for p, r in zip(preds, refs)]))
|
| 506 |
+
|
| 507 |
+
# BLEU
|
| 508 |
+
bleu_scores = [
|
| 509 |
+
bleu_fn([r.split()], p.split())
|
| 510 |
+
for p, r in zip(preds, refs)
|
| 511 |
+
]
|
| 512 |
+
|
| 513 |
+
results[T] = {
|
| 514 |
+
"bertscore_f1": bert_f1,
|
| 515 |
+
"semantic_sim": sim,
|
| 516 |
+
"bleu": float(np.mean(bleu_scores)),
|
| 517 |
+
"speed_per_sample": elapsed / max(1, len(preds))
|
| 518 |
+
}
|
| 519 |
+
|
| 520 |
+
print(f" BERTScore: {bert_f1:.4f}")
|
| 521 |
+
print(f" Sim: {sim:.4f}")
|
| 522 |
+
print(f" BLEU: {results[T]['bleu']:.4f}")
|
| 523 |
+
print(f" Speed: {results[T]['speed_per_sample']:.4f}s")
|
| 524 |
+
|
| 525 |
+
# Save raw results
|
| 526 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 527 |
+
with open(os.path.join(output_dir, "task4_raw_results.json"), "w") as f:
|
| 528 |
+
json.dump(results, f, indent=2)
|
| 529 |
+
|
| 530 |
+
return results
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
# ─────────────────────────────────────────────
|
| 534 |
+
# PHASE 2 — Analyze results (Knee Detection)
|
| 535 |
+
# ─────────────────────────────────────────────
|
| 536 |
+
|
| 537 |
+
def analyze_results(results: Dict):
|
| 538 |
+
print("\n=== PHASE 2: Analysis ===")
|
| 539 |
+
|
| 540 |
+
T_list = sorted(results.keys())
|
| 541 |
+
scores = [results[T]["bertscore_f1"] for T in T_list]
|
| 542 |
+
|
| 543 |
+
gains = [scores[i+1] - scores[i] for i in range(len(scores)-1)]
|
| 544 |
+
|
| 545 |
+
print("\nMarginal Gains:")
|
| 546 |
+
for i, g in enumerate(gains):
|
| 547 |
+
print(f" T{T_list[i]} → T{T_list[i+1]}: +{g:.4f}")
|
| 548 |
+
|
| 549 |
+
# Robust utility selection (quality + semantics + speed regularizer)
|
| 550 |
+
bvals = np.array([results[T]["bertscore_f1"] for T in T_list], dtype=np.float32)
|
| 551 |
+
svals = np.array([results[T]["semantic_sim"] for T in T_list], dtype=np.float32)
|
| 552 |
+
tvals = np.array([results[T]["speed_per_sample"] for T in T_list], dtype=np.float32)
|
| 553 |
+
b_norm = (bvals - bvals.min()) / max(1e-9, (bvals.max() - bvals.min()))
|
| 554 |
+
s_norm = (svals - svals.min()) / max(1e-9, (svals.max() - svals.min()))
|
| 555 |
+
t_norm = (tvals - tvals.min()) / max(1e-9, (tvals.max() - tvals.min()))
|
| 556 |
+
utility = 0.50 * b_norm + 0.30 * s_norm - 0.20 * t_norm
|
| 557 |
+
knee_T = T_list[int(np.argmax(utility))]
|
| 558 |
+
|
| 559 |
+
print(f"\n✅ Optimal T (semantic-speed tradeoff): {knee_T}")
|
| 560 |
+
|
| 561 |
+
return knee_T, gains
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
# ─────────────────────────────────────────────
|
| 565 |
+
# 3D Plot (BERTScore)
|
| 566 |
+
# ─────────────────────────────────────────────
|
| 567 |
+
|
| 568 |
+
def plot_3d(results, output_dir: str = "analysis/outputs"):
|
| 569 |
+
import matplotlib.pyplot as plt
|
| 570 |
+
from mpl_toolkits.mplot3d import Axes3D
|
| 571 |
+
|
| 572 |
+
T_list = sorted(results.keys())
|
| 573 |
+
|
| 574 |
+
X = T_list
|
| 575 |
+
Y = [results[T]["speed_per_sample"] for T in T_list]
|
| 576 |
+
Z = [results[T]["bertscore_f1"] for T in T_list]
|
| 577 |
+
|
| 578 |
+
fig = plt.figure(figsize=(10, 6))
|
| 579 |
+
ax = fig.add_subplot(111, projection='3d')
|
| 580 |
+
|
| 581 |
+
ax.scatter(X, Y, Z)
|
| 582 |
+
|
| 583 |
+
for x, y, z in zip(X, Y, Z):
|
| 584 |
+
ax.text(x, y, z, f"T={x}", fontsize=8)
|
| 585 |
+
|
| 586 |
+
ax.set_xlabel("Diffusion Steps")
|
| 587 |
+
ax.set_ylabel("Speed")
|
| 588 |
+
ax.set_zlabel("BERTScore")
|
| 589 |
+
|
| 590 |
+
plt.title("3D Tradeoff: Steps vs Speed vs Quality")
|
| 591 |
+
|
| 592 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 593 |
+
plt.savefig(os.path.join(output_dir, "task4_3d.png"))
|
| 594 |
+
plt.close()
|
| 595 |
+
|
| 596 |
+
print("Saved 3D plot")
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
# ─────────────────────────────────────────────
|
| 600 |
+
# FINAL RUNNER
|
| 601 |
+
# ─────────────────────────────────────────────
|
| 602 |
+
|
| 603 |
+
def run_task4(models, src_list, ref_list, tgt_tokenizer,
|
| 604 |
+
output_dir: str = "analysis/outputs", n_samples: int = 200):
|
| 605 |
+
|
| 606 |
+
# Phase 1: Evaluate all
|
| 607 |
+
results = evaluate_all_models(
|
| 608 |
+
models, src_list, ref_list, tgt_tokenizer, n_samples=n_samples, output_dir=output_dir
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
# Phase 2: Analyze
|
| 612 |
+
knee_T, gains = analyze_results(results)
|
| 613 |
+
|
| 614 |
+
# Plot
|
| 615 |
+
plot_3d(results, output_dir=output_dir)
|
| 616 |
+
|
| 617 |
+
# Save detailed report
|
| 618 |
+
report_path = os.path.join(output_dir, "task4_report.txt")
|
| 619 |
+
with open(report_path, "w") as f:
|
| 620 |
+
f.write("TASK 4 — SEMANTIC ROBUSTNESS ABLATION\n")
|
| 621 |
+
f.write("=" * 50 + "\n\n")
|
| 622 |
+
f.write(f"Optimal diffusion steps = {knee_T}\n\n")
|
| 623 |
+
f.write(f"{'T':>6} {'BERT-F1':>10} {'SEM_SIM':>10} {'BLEU':>8} {'sec/sample':>12}\n")
|
| 624 |
+
f.write(" " + "-" * 56 + "\n")
|
| 625 |
+
for T in sorted(results.keys()):
|
| 626 |
+
r = results[T]
|
| 627 |
+
f.write(
|
| 628 |
+
f"{T:>6} {r['bertscore_f1']:>10.4f} {r['semantic_sim']:>10.4f} "
|
| 629 |
+
f"{r['bleu']:>8.4f} {r['speed_per_sample']:>12.4f}\n"
|
| 630 |
+
)
|
| 631 |
+
f.write("\nMarginal gains (BERT-F1):\n")
|
| 632 |
+
for i, g in enumerate(gains):
|
| 633 |
+
t0 = sorted(results.keys())[i]
|
| 634 |
+
t1 = sorted(results.keys())[i + 1]
|
| 635 |
+
f.write(f" T{t0} -> T{t1}: {g:+.4f}\n")
|
| 636 |
+
f.write("\nSaved plots/files:\n")
|
| 637 |
+
f.write(" - task4_3d.png\n")
|
| 638 |
+
f.write(" - task4_raw_results.json\n")
|
| 639 |
+
|
| 640 |
+
return knee_T
|
analysis_outputs/outputs_all_models_20260325/T16/task1_encoder_cost.png
ADDED
|
analysis_outputs/outputs_all_models_20260325/T16/task1_kv_cache.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 1 — KV CACHE BENCHMARK
|
| 2 |
+
========================================
|
| 3 |
+
|
| 4 |
+
has_generate_cached=True
|
| 5 |
+
memory_profile=Torch CPU mem-event reduction: 30.4% @ src_len=64 (std=2143.0MB, cache=1492.1MB)
|
| 6 |
+
|
| 7 |
+
src_len standard(s) cached(s) speedup encoder%
|
| 8 |
+
16 0.893 0.571 1.56x 40.0%
|
| 9 |
+
32 0.751 0.509 1.48x 42.3%
|
| 10 |
+
64 1.141 0.822 1.39x 40.7%
|
| 11 |
+
|
| 12 |
+
Saved graphs:
|
| 13 |
+
- task1_time_comparison.png
|
| 14 |
+
- task1_speedup.png
|
| 15 |
+
- task1_encoder_cost.png
|
analysis_outputs/outputs_all_models_20260325/T16/task1_speedup.png
ADDED
|
analysis_outputs/outputs_all_models_20260325/T16/task1_time_comparison.png
ADDED
|
analysis_outputs/outputs_all_models_20260325/T16/task2_all_layers_t0.png
ADDED
|
analysis_outputs/outputs_all_models_20260325/T16/task2_attn_evolution.png
ADDED
|
analysis_outputs/outputs_all_models_20260325/T16/task2_attn_t0.png
ADDED
|
analysis_outputs/outputs_all_models_20260325/T16/task2_attn_t15.png
ADDED
|
analysis_outputs/outputs_all_models_20260325/T16/task2_report.txt
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 2 — ATTENTION + DRIFT REPORT
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Input : dharmo rakṣati rakṣitaḥ
|
| 5 |
+
Output: धर्मो रक्षति रक्षितः
|
| 6 |
+
|
| 7 |
+
Captured steps: 16
|
| 8 |
+
Analysis quality: WEAK
|
| 9 |
+
Final output uniq-ratio: 1.000
|
| 10 |
+
Degenerate output: NO
|
| 11 |
+
Multi-sample semantic score (n<=8): 0.1471
|
| 12 |
+
Lock-in step (CER<=0.05): t=0
|
| 13 |
+
Locked tokens: 38 Flexible tokens: 42
|
| 14 |
+
TF-IDF vs attention stability corr: 0.9294
|
| 15 |
+
TF-IDF status: OK
|
| 16 |
+
|
| 17 |
+
Saved graphs:
|
| 18 |
+
- task2_attn_t*.png / task2_all_layers_t0.png
|
| 19 |
+
- task2_attn_evolution.png
|
| 20 |
+
- task2_semantic_drift.png
|
| 21 |
+
- task2_source_alignment.png
|
| 22 |
+
- task2_tfidf_vs_attention.png
|
| 23 |
+
|
| 24 |
+
Step trajectory (first 10 rows)
|
| 25 |
+
------------------------------------------------------------
|
| 26 |
+
t= 15 bert=0.0475 drift=0.9525 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
|
| 27 |
+
t= 14 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
|
| 28 |
+
t= 13 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
|
| 29 |
+
t= 12 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
|
| 30 |
+
t= 11 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
|
| 31 |
+
t= 10 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
|
| 32 |
+
t= 9 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
|
| 33 |
+
t= 8 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
|
| 34 |
+
t= 7 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
|
| 35 |
+
t= 6 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
|
analysis_outputs/outputs_all_models_20260325/T16/task2_semantic_drift.png
ADDED
|
analysis_outputs/outputs_all_models_20260325/T16/task2_source_alignment.png
ADDED
|
analysis_outputs/outputs_all_models_20260325/T16/task2_tfidf_vs_attention.png
ADDED
|
analysis_outputs/outputs_all_models_20260325/T16/task3_concept_space.png
ADDED
|
analysis_outputs/outputs_all_models_20260325/T16/task3_diversity_curve.png
ADDED
|
analysis_outputs/outputs_all_models_20260325/T16/task3_diversity_direction.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:250b81c1f8cc9537240873d00539df1e8a30e6c07b260d4c05df23fb32c704d6
|
| 3 |
+
size 4224
|
analysis_outputs/outputs_all_models_20260325/T16/task3_pca_explained_variance.png
ADDED
|
analysis_outputs/outputs_all_models_20260325/T16/task3_report.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 3 — CONCEPT VECTORS + PCA STEERING
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
PCA: 50 components, 74.8% variance
|
| 5 |
+
Diversity PC: 0 (|r|=0.325 with diversity proxy)
|
| 6 |
+
|
| 7 |
+
Direction validity: WEAK
|
| 8 |
+
Spectrum unique ratio (mean over 5 seeds): 1.000
|
| 9 |
+
Spectrum semantic stability (mean over 5 seeds): 0.312
|
| 10 |
+
|
| 11 |
+
Saved graphs:
|
| 12 |
+
- task3_concept_space.png
|
| 13 |
+
- task3_pca_explained_variance.png
|
| 14 |
+
- task3_diversity_curve.png
|
| 15 |
+
|
| 16 |
+
Diversity spectrum:
|
| 17 |
+
alpha=-2.0 → बले वेध विवर् धान वीर्य वीर्य धिं सिंहा भि̱ सन वस्तु वेध वै वेध वस्तु सन सन सिंहा सिंहा वीर्य वीर्य वस्तु सन रुते प्रभवति मन वेध बले बले र्वृ प्रपूजयेत् युगा मलि धान तुल वीर्य वीर्य वीर्य वीर्य वीर्य वीर्य धान तुल कालेन युगा वेध बले वेध वेध च्छे ष्मस् यस्या काष्ठा ज्ञप्त अर्णव धिं धिं वस्तु धिं सन तया सन सन देवाः देवाः स्वातन्त्र अर्णव मह वस्तु मुष् सन धिं धिं धिं विक्र त्र मह हस्ते च्छे मह
|
| 18 |
+
alpha=-1.0 → बले र् अ तुल वीर्य वीर्य गुरु सिंहा सन सन विलेप वै वै वै गतस्य वेध सन सिंहा सिंहा स्य स्य । सन वै वै वै बले बले बले बले र् अ अ तुल तुल वीर्य वीर्य वीर्य वीर्य वीर्य वीर्य तुल तुल तुल ् बले वेध दिव्यां मान वै अप्सु सन ॥ ॥ वस्तु सिंहा सन सन विक्र सन स काष्ठा सन सन सन कार सन सन सन सन भ बल ु सिंहा सन सिंहा सन म् म् सन
|
| 19 |
+
alpha=+0.0 → बले र् अ तुल वीर्य वीर्य स्य सिंहा सन सन पितो वै वै वै दक्षिणां सन सन सिंहा सिंहा स्य स्य स्य सन गतस्य वै वै ॥ बले बले र् र् अ अ । तुल वीर्य वीर्य वीर्य वीर्य वीर्य तुल तुल तुल तुल अ स बले बले वै वै ॥ ॥ ॥ सन सन सिंहा स सन सन सन सन सन सन सन सन सन सन ॥ ॥ सन सन शतैः ॥ सिंहा सिंहा द सिंहा सन त् सन
|
| 20 |
+
alpha=+1.0 → बले र् अ अ विशुद्धं स्य स्य सिंहा सिंहा सन गतस्य वै वै वै वेत्ति सन सन सिंहा स्य स्य स्य स्य सन वै वै स मल बले बले र् र् व अ अ तुल वीर्य वीर्य वीर्य स्य वीर्य स्य तुल ानु अ अ । र् व ॥ वै वै सन द ॥ ॥ सिंहा सिंहा ॥ सं सन ॥ ॥ व ॥ ॥ हेम सन सन व ॥ ै ॥ वै भ न न ॥ मित्रो सिंहा सन
|
| 21 |
+
alpha=+2.0 → आविश र् अ किंचिद् वर स्य स्य सिंहा सं निमे ञ् सं वै वै ञ् सन कृपा सिंहा स्य स्य स्य स्य फणा ञ् वै ौ जिह्व बले मानाः र् र् वराय अ माने वर विशुद्धं स्य स्य स्य – वर विशुद्धं व वर अ कृपा ॥ परम् ॥ कश्चि वै ॥ ञ् ञ् सं स्य स्य तम् व प्रवर्तन्ते कर्मसु परम् वर ते ॥ व ञ् ॥ ॥ सं द ॥ ॥ वर न्द ̱व ॥ व व ै
|
analysis_outputs/outputs_all_models_20260325/T16/task4_3d.png
ADDED
|
analysis_outputs/outputs_all_models_20260325/T16/task4_raw_results.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"16": {
|
| 3 |
+
"bertscore_f1": 0.25743605086845023,
|
| 4 |
+
"semantic_sim": 0.05798209163692987,
|
| 5 |
+
"bleu": 0.0007454091523007641,
|
| 6 |
+
"speed_per_sample": 0.9068318999983603
|
| 7 |
+
}
|
| 8 |
+
}
|
analysis_outputs/outputs_all_models_20260325/T16/task4_report.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 4 — SEMANTIC ROBUSTNESS ABLATION
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Optimal diffusion steps = 16
|
| 5 |
+
|
| 6 |
+
T BERT-F1 SEM_SIM BLEU sec/sample
|
| 7 |
+
--------------------------------------------------------
|
| 8 |
+
16 0.2574 0.0580 0.0007 0.9068
|
| 9 |
+
|
| 10 |
+
Marginal gains (BERT-F1):
|
| 11 |
+
|
| 12 |
+
Saved plots/files:
|
| 13 |
+
- task4_3d.png
|
| 14 |
+
- task4_raw_results.json
|
analysis_outputs/outputs_all_models_20260325/T16/task5_guidance_results.json
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"0.0": {
|
| 3 |
+
"mean_cer": 0.8335914296177765,
|
| 4 |
+
"diversity": 0.8084225118773136,
|
| 5 |
+
"sent_unique": 1.0,
|
| 6 |
+
"distinct2": 0.6240506329113924,
|
| 7 |
+
"self_bleu": 0.00720560915676511
|
| 8 |
+
},
|
| 9 |
+
"0.5": {
|
| 10 |
+
"mean_cer": 0.8361858372849987,
|
| 11 |
+
"diversity": 0.7997218378718688,
|
| 12 |
+
"sent_unique": 1.0,
|
| 13 |
+
"distinct2": 0.6060126582278481,
|
| 14 |
+
"self_bleu": 0.0065689824841105166
|
| 15 |
+
},
|
| 16 |
+
"1.0": {
|
| 17 |
+
"mean_cer": 0.8390361847911715,
|
| 18 |
+
"diversity": 0.7978319711295725,
|
| 19 |
+
"sent_unique": 1.0,
|
| 20 |
+
"distinct2": 0.6009493670886076,
|
| 21 |
+
"self_bleu": 0.005285424829462745
|
| 22 |
+
},
|
| 23 |
+
"1.5": {
|
| 24 |
+
"mean_cer": 0.8457771777829102,
|
| 25 |
+
"diversity": 0.8134699633307632,
|
| 26 |
+
"sent_unique": 1.0,
|
| 27 |
+
"distinct2": 0.6306962025316456,
|
| 28 |
+
"self_bleu": 0.0037562758701191663
|
| 29 |
+
},
|
| 30 |
+
"2.0": {
|
| 31 |
+
"mean_cer": 0.8530737908495466,
|
| 32 |
+
"diversity": 0.828318481566094,
|
| 33 |
+
"sent_unique": 1.0,
|
| 34 |
+
"distinct2": 0.6604430379746835,
|
| 35 |
+
"self_bleu": 0.003806074842495409
|
| 36 |
+
},
|
| 37 |
+
"3.0": {
|
| 38 |
+
"mean_cer": 0.8772574230238586,
|
| 39 |
+
"diversity": 0.829961794478179,
|
| 40 |
+
"sent_unique": 1.0,
|
| 41 |
+
"distinct2": 0.6686708860759494,
|
| 42 |
+
"self_bleu": 0.008747297119591432
|
| 43 |
+
}
|
| 44 |
+
}
|
analysis_outputs/outputs_all_models_20260325/T16/task5_quality_classifier.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b4053d24514f08c475662f69a5b01d1577cd1f79837df69ac2175705310e9a23
|
| 3 |
+
size 561505
|
analysis_outputs/outputs_all_models_20260325/T16/task5_quality_data.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:840494704113872c8e53e3627e5666c8af940ed683515ec37894dd3091a14684
|
| 3 |
+
size 164512
|
analysis_outputs/outputs_all_models_20260325/T16/task5_quality_diversity_tradeoff.png
ADDED
|
analysis_outputs/outputs_all_models_20260325/T16/task5_report.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 5 — CLASSIFIER-FREE GUIDANCE
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Classifier params: 139521
|
| 5 |
+
Training samples : 40
|
| 6 |
+
|
| 7 |
+
Guidance scale sweep:
|
| 8 |
+
λ CER diversity d2 sBLEU
|
| 9 |
+
----------------------------------------------------
|
| 10 |
+
0.0 0.8336 0.808 0.624 0.007 ← optimal
|
| 11 |
+
0.5 0.8362 0.800 0.606 0.007
|
| 12 |
+
1.0 0.8390 0.798 0.601 0.005
|
| 13 |
+
1.5 0.8458 0.813 0.631 0.004
|
| 14 |
+
2.0 0.8531 0.828 0.660 0.004
|
| 15 |
+
3.0 0.8773 0.830 0.669 0.009
|
analysis_outputs/outputs_all_models_20260325/T32/task1_encoder_cost.png
ADDED
|
analysis_outputs/outputs_all_models_20260325/T32/task1_kv_cache.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 1 — KV CACHE BENCHMARK
|
| 2 |
+
========================================
|
| 3 |
+
|
| 4 |
+
has_generate_cached=True
|
| 5 |
+
memory_profile=Torch CPU mem-event reduction: 31.1% @ src_len=64 (std=4287.2MB, cache=2953.9MB)
|
| 6 |
+
|
| 7 |
+
src_len standard(s) cached(s) speedup encoder%
|
| 8 |
+
16 1.914 1.165 1.64x 39.6%
|
| 9 |
+
32 1.542 0.891 1.73x 42.1%
|
| 10 |
+
64 2.096 1.475 1.42x 42.7%
|
| 11 |
+
|
| 12 |
+
Saved graphs:
|
| 13 |
+
- task1_time_comparison.png
|
| 14 |
+
- task1_speedup.png
|
| 15 |
+
- task1_encoder_cost.png
|
analysis_outputs/outputs_all_models_20260325/T32/task1_speedup.png
ADDED
|
analysis_outputs/outputs_all_models_20260325/T32/task1_time_comparison.png
ADDED
|
analysis_outputs/outputs_all_models_20260325/T32/task2_all_layers_t0.png
ADDED
|
analysis_outputs/outputs_all_models_20260325/T32/task2_attn_evolution.png
ADDED
|
analysis_outputs/outputs_all_models_20260325/T32/task2_attn_t0.png
ADDED
|
analysis_outputs/outputs_all_models_20260325/T32/task2_attn_t31.png
ADDED
|
analysis_outputs/outputs_all_models_20260325/T32/task2_report.txt
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 2 — ATTENTION + DRIFT REPORT
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Input : dharmo rakṣati rakṣitaḥ
|
| 5 |
+
Output: धर्मो रक्षति रक्षितः
|
| 6 |
+
|
| 7 |
+
Captured steps: 32
|
| 8 |
+
Analysis quality: WEAK
|
| 9 |
+
Final output uniq-ratio: 1.000
|
| 10 |
+
Degenerate output: NO
|
| 11 |
+
Multi-sample semantic score (n<=8): 0.0627
|
| 12 |
+
Lock-in step (CER<=0.05): t=0
|
| 13 |
+
Locked tokens: 75 Flexible tokens: 5
|
| 14 |
+
TF-IDF vs attention stability corr: -0.0869
|
| 15 |
+
TF-IDF status: WEAK
|
| 16 |
+
|
| 17 |
+
Saved graphs:
|
| 18 |
+
- task2_attn_t*.png / task2_all_layers_t0.png
|
| 19 |
+
- task2_attn_evolution.png
|
| 20 |
+
- task2_semantic_drift.png
|
| 21 |
+
- task2_source_alignment.png
|
| 22 |
+
- task2_tfidf_vs_attention.png
|
| 23 |
+
|
| 24 |
+
Step trajectory (first 10 rows)
|
| 25 |
+
------------------------------------------------------------
|
| 26 |
+
t= 31 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
|
| 27 |
+
t= 30 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
|
| 28 |
+
t= 29 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
|
| 29 |
+
t= 28 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
|
| 30 |
+
t= 27 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
|
| 31 |
+
t= 26 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
|
| 32 |
+
t= 25 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
|
| 33 |
+
t= 24 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
|
| 34 |
+
t= 23 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
|
| 35 |
+
t= 22 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
|
analysis_outputs/outputs_all_models_20260325/T32/task2_semantic_drift.png
ADDED
|
analysis_outputs/outputs_all_models_20260325/T32/task2_source_alignment.png
ADDED
|
analysis_outputs/outputs_all_models_20260325/T32/task2_tfidf_vs_attention.png
ADDED
|
analysis_outputs/outputs_all_models_20260325/T32/task3_concept_space.png
ADDED
|
analysis_outputs/outputs_all_models_20260325/T32/task3_diversity_curve.png
ADDED
|
analysis_outputs/outputs_all_models_20260325/T32/task3_diversity_direction.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0e547306c9469858deaa9985c30d31639c8a9f8104e8addd83afa88fa0264831
|
| 3 |
+
size 4224
|
analysis_outputs/outputs_all_models_20260325/T32/task3_pca_explained_variance.png
ADDED
|
analysis_outputs/outputs_all_models_20260325/T32/task3_report.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 3 — CONCEPT VECTORS + PCA STEERING
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
PCA: 50 components, 94.6% variance
|
| 5 |
+
Diversity PC: 0 (|r|=-0.530 with diversity proxy)
|
| 6 |
+
|
| 7 |
+
Direction validity: WEAK
|
| 8 |
+
Spectrum unique ratio (mean over 5 seeds): 0.840
|
| 9 |
+
Spectrum semantic stability (mean over 5 seeds): 0.234
|
| 10 |
+
|
| 11 |
+
Saved graphs:
|
| 12 |
+
- task3_concept_space.png
|
| 13 |
+
- task3_pca_explained_variance.png
|
| 14 |
+
- task3_diversity_curve.png
|
| 15 |
+
|
| 16 |
+
Diversity spectrum:
|
| 17 |
+
alpha=-2.0 → ेन श्रे श्रे ेन श्रे अण्ड व्याः श्रे तन्त्रा ॥ ॥ ॥ व्याः व्याः व्याः तद्वद् तद्वद् तद्वद् ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ तद्वद् ॥ ॥ ॥ ॥ ॥ व्याः व्याः व्याः ॥ ॥ राजन्य व्याः व्याः व्याः ॥ व्याः व्याः ॥ ॥ काम्य ॥ ॥ ॥ व्याः ॥ तद्वद् ॥ ॥ ॥ ॥ ॥ तन्त्रा तन्त्रा ॥ ॥ ॥ ॥ व्याः ॥ ॥ ॥ ॥ ॥ युधम् तद्वद् युधम् ॥
|
| 18 |
+
alpha=-1.0 → श्रे श्रे श्रे ेन श्रे श्रे श्रे श्रे अण्ड तन्त्रा व्याः ॥ अण्ड अण्ड तन्त्रा व्याः तद्वद् ॥ व्याः ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ अण्ड ॥ ॥ ॥ व्याः ॥ व्याः नो̍ ॥ ॥ ॥ ॥ ॥ व्याः व्याः अण्ड ॥ ॥ तन्त्रा ॥ ॥ तद्वद् युधम् रोमा शम्भु ॥ धूमं तन्त्रा ॥ तन्त्रा ॥ व्याः ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥
|
| 19 |
+
alpha=+0.0 → अण्ड श्रे करः श्रे तन्त्रा करः करः तन्त्रा श्रे अण्ड अण्ड अण्ड ॥ श्रे तद्वद् अण्ड ॥ ॥ अण्ड ॥ ॥ ॥ ॥ ॥ ॥ ॥ अण्ड ॥ ॥ ॥ ॥ अण्ड ॥ ॥ ॥ ॥ ॥ ॥ राजन्य तन्त्रा नो̍ ॥ ॥ ॥ ॥ ॥ व्याः ॥ अण्ड ॥ काम्य ॥ ॥ ॥ ॥ ॥ शम्भु धूमं तन्त्रा तन्त्रा ेन ॥ काम्य ॥ ॥ करः तन्त्रा ॥ अण्ड ॥ अण्ड ॥ विनिर्जित्य ॥ ॥ ॥ तन्त्रा अण्ड तद्वद् करः
|
| 20 |
+
alpha=+1.0 → माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण
|
| 21 |
+
alpha=+2.0 → माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण
|
analysis_outputs/outputs_all_models_20260325/T32/task4_3d.png
ADDED
|
analysis_outputs/outputs_all_models_20260325/T32/task4_raw_results.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"32": {
|
| 3 |
+
"bertscore_f1": 0.04221478336089375,
|
| 4 |
+
"semantic_sim": 0.0011696306429548563,
|
| 5 |
+
"bleu": 3.0458312005937454e-233,
|
| 6 |
+
"speed_per_sample": 1.8451481468771818
|
| 7 |
+
}
|
| 8 |
+
}
|
analysis_outputs/outputs_all_models_20260325/T32/task4_report.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 4 — SEMANTIC ROBUSTNESS ABLATION
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Optimal diffusion steps = 32
|
| 5 |
+
|
| 6 |
+
T BERT-F1 SEM_SIM BLEU sec/sample
|
| 7 |
+
--------------------------------------------------------
|
| 8 |
+
32 0.0422 0.0012 0.0000 1.8451
|
| 9 |
+
|
| 10 |
+
Marginal gains (BERT-F1):
|
| 11 |
+
|
| 12 |
+
Saved plots/files:
|
| 13 |
+
- task4_3d.png
|
| 14 |
+
- task4_raw_results.json
|