Re-upload project files (code + data)
Browse filesRe-uploaded all project files, including source code and datasets.
Ensured consistency between data, preprocessing, and model components.
- .gitattributes +2 -0
- analysis/AutoGPU.py +50 -0
- analysis/__init__.py +7 -0
- artifacts/highest_energy_digits.png +0 -0
- artifacts/llm_output_gpu0_seed1.json +20 -0
- artifacts/lowest_energy_digits.png +0 -0
- artifacts/median_energy_digits.png +0 -0
- artifacts/run.log +212 -0
- artifacts/samples_gpu0_seed1_perfect.png +3 -0
- artifacts/samples_gpu0_seed1_refined.png +3 -0
- artifacts/samples_gpu0_seed1_symbol.png +3 -0
- artifacts/samples_prof_gpu0_seed1_perfect.png +3 -0
- artifacts/samples_prof_gpu0_seed1_refined.png +3 -0
- artifacts/samples_prof_gpu0_seed1_symbol.png +3 -0
- artifacts/srtrbm_core_dynamics.pdf +0 -0
- artifacts/srtrbm_energy_data_vs_model_modern.pdf +0 -0
- artifacts/srtrbm_energy_diagnostics.pdf +0 -0
- artifacts/srtrbm_filters.png +3 -0
- artifacts/srtrbm_phase_diagram.pdf +0 -0
- correction/NO.py +75 -0
- correction/__init__.py +5 -0
- graphs/SrtrbmEnergy.py +177 -0
- graphs/SrtrbmMetrics.py +220 -0
- graphs/SrtrbmVisualization.py +116 -0
- graphs/__init__.py +29 -0
- llmeS/__init__.py +22 -0
- llmeS/client.py +92 -0
- llmeS/gateway.py +632 -0
- llmeS/hook.py +279 -0
- srtrbm_project_core.py +1600 -0
- stan.dgts +3 -0
- supplement/cluster.py +300 -0
- yaml/perception.yaml +53 -0
- zeta_mnist_hybrid.pt +3 -0
.gitattributes
CHANGED
|
@@ -39,3 +39,5 @@ artifacts/samples_prof_gpu0_seed1_perfect.png filter=lfs diff=lfs merge=lfs -tex
|
|
| 39 |
artifacts/samples_prof_gpu0_seed1_symbol.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
artifacts/srtrbm_filters.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
stan.dgts filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 39 |
artifacts/samples_prof_gpu0_seed1_symbol.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
artifacts/srtrbm_filters.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
stan.dgts filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
artifacts/samples_gpu0_seed1_refined.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
artifacts/samples_prof_gpu0_seed1_refined.png filter=lfs diff=lfs merge=lfs -text
|
analysis/AutoGPU.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pynvml import *
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class GPUEnergyTracker:
|
| 6 |
+
"""
|
| 7 |
+
GPU energy monitor using NVIDIA NVML.
|
| 8 |
+
|
| 9 |
+
Measures instantaneous power and integrates
|
| 10 |
+
it over time to estimate total GPU energy
|
| 11 |
+
consumption.
|
| 12 |
+
Units
|
| 13 |
+
Power : Watts
|
| 14 |
+
Energy : Joules
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, gpu_index=0):
|
| 18 |
+
nvmlInit()
|
| 19 |
+
|
| 20 |
+
self.handle = nvmlDeviceGetHandleByIndex(gpu_index)
|
| 21 |
+
|
| 22 |
+
self.energy_joules = 0.0
|
| 23 |
+
|
| 24 |
+
self.last_time = time.time()
|
| 25 |
+
|
| 26 |
+
self.last_power = self._read_power()
|
| 27 |
+
|
| 28 |
+
def _read_power(self):
|
| 29 |
+
power_mw = nvmlDeviceGetPowerUsage(self.handle)
|
| 30 |
+
|
| 31 |
+
return power_mw / 1000.0
|
| 32 |
+
|
| 33 |
+
def step(self):
|
| 34 |
+
current_time = time.time()
|
| 35 |
+
|
| 36 |
+
power = self._read_power()
|
| 37 |
+
|
| 38 |
+
dt = current_time - self.last_time
|
| 39 |
+
|
| 40 |
+
# trapezoidal integration
|
| 41 |
+
self.energy_joules += 0.5 * (self.last_power + power) * dt
|
| 42 |
+
|
| 43 |
+
self.last_power = power
|
| 44 |
+
self.last_time = current_time
|
| 45 |
+
|
| 46 |
+
def total_energy(self):
|
| 47 |
+
return self.energy_joules
|
| 48 |
+
|
| 49 |
+
def current_power(self):
|
| 50 |
+
return self.last_power
|
analysis/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from analysis.AutoGPU import (
|
| 2 |
+
GPUEnergyTracker
|
| 3 |
+
)
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"GPUEnergyTracker",
|
| 7 |
+
]
|
artifacts/highest_energy_digits.png
ADDED
|
artifacts/llm_output_gpu0_seed1.json
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"regime": "learning",
|
| 3 |
+
"phase": "ordered",
|
| 4 |
+
"failures": [],
|
| 5 |
+
"scores": {
|
| 6 |
+
"temperature": 0.46,
|
| 7 |
+
"gibbs": 0.43
|
| 8 |
+
},
|
| 9 |
+
"risk": {
|
| 10 |
+
"stagnation": 0.05,
|
| 11 |
+
"collapse": 0.0,
|
| 12 |
+
"over_ordering": 0.3
|
| 13 |
+
},
|
| 14 |
+
"actions": {},
|
| 15 |
+
"confidence": 0.2719435542821884,
|
| 16 |
+
"analysis": "High image_similarity (0.8597) and visually preserved digit structure indicate modes are retained and refinement matches the reference manifold. Learning_signal (0.328) and heating/positive trend slopes show continued parameter change, so the system is still learning. Entropy and diversity are low (entropy 0.3107 overall, current entropy 0.1565; diversity 0.2076), and mixing is slow (mcmc_tau_int 27.65), but per the strict rules these do not imply collapse when structural similarity is high. Beta is high (current beta 4.386) and outputs are structured → ordered phase. Overall classification: the system is in a learning regime with ordered phase. Epistemic note: confidence is limited by conflicting signals (very high similarity and active learning vs low diversity and poor mixing); therefore confidence is substantial but constrained by those contradictory indicators. [OVERRIDDEN: high structural similarity]",
|
| 17 |
+
"reason": "json|low_confidence_scaled",
|
| 18 |
+
"evidence": 0.2719435542821884,
|
| 19 |
+
"llm_conf_raw": 0.75
|
| 20 |
+
}
|
artifacts/lowest_energy_digits.png
ADDED
|
artifacts/median_energy_digits.png
ADDED
|
artifacts/run.log
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Training: 100%|██████████████████████████████████████| 400/400 [04:51<00:00, 1.37it/s, T=0.923, flip=0.055, beta=4.386]
|
| 2 |
+
|
| 3 |
+
[GPU 0] Running AIS...
|
| 4 |
+
✔ AIS finished in 96.37 seconds
|
| 5 |
+
coefficient (kappa): 784
|
| 6 |
+
coefficient (kappa): 784
|
| 7 |
+
|
| 8 |
+
[LLM] Start | B=1200 | T=0.9234
|
| 9 |
+
[Debug] ESS=2.87 | k=13 | deltaE_band=[0.00,2.00] | in_band=111 | selected=13
|
| 10 |
+
[Debug] d_model=0.9904, d_llm=1.0000
|
| 11 |
+
[Step] Sample=101 | REJECT | d_model=0.990 | d_total=2.280 | p=0.102
|
| 12 |
+
[Debug] d_model=1.0123, d_llm=0.0065
|
| 13 |
+
[Step] Sample=542 | REJECT | d_model=1.012 | d_total=1.021 | p=0.360
|
| 14 |
+
[Debug] d_model=1.0161, d_llm=0.7791
|
| 15 |
+
[Step] Sample=922 | REJECT | d_model=1.016 | d_total=2.021 | p=0.133
|
| 16 |
+
[Debug] d_model=1.0165, d_llm=-1.0241
|
| 17 |
+
[Step] Sample=677 | ACCEPT | d_model=1.017 | d_total=-0.305 | p=1.000
|
| 18 |
+
[Debug] d_model=0.9680, d_llm=-0.9760
|
| 19 |
+
[Step] Sample=1044 | ACCEPT | d_model=0.968 | d_total=-0.291 | p=1.000
|
| 20 |
+
[Debug] d_model=1.0358, d_llm=1.0429
|
| 21 |
+
[Step] Sample=1052 | ACCEPT | d_model=1.036 | d_total=2.381 | p=0.092
|
| 22 |
+
[Debug] d_model=1.0585, d_llm=-1.0636
|
| 23 |
+
[Step] Sample=15 | ACCEPT | d_model=1.059 | d_total=-0.314 | p=1.000
|
| 24 |
+
[Debug] d_model=1.0721, d_llm=0.0744
|
| 25 |
+
[Step] Sample=385 | REJECT | d_model=1.072 | d_total=1.168 | p=0.311
|
| 26 |
+
[Debug] d_model=1.0757, d_llm=1.0757
|
| 27 |
+
[Step] Sample=641 | REJECT | d_model=1.076 | d_total=2.463 | p=0.085
|
| 28 |
+
[Debug] d_model=1.0759, d_llm=-1.0733
|
| 29 |
+
[Step] Sample=404 | ACCEPT | d_model=1.076 | d_total=-0.309 | p=1.000
|
| 30 |
+
[Debug] d_model=1.0795, d_llm=-0.3774
|
| 31 |
+
[Step] Sample=813 | REJECT | d_model=1.079 | d_total=0.593 | p=0.553
|
| 32 |
+
[Debug] d_model=1.0859, d_llm=1.0780
|
| 33 |
+
[Step] Sample=636 | REJECT | d_model=1.086 | d_total=2.477 | p=0.084
|
| 34 |
+
[Debug] d_model=0.9026, d_llm=-0.8987
|
| 35 |
+
[Step] Sample=218 | ACCEPT | d_model=0.903 | d_total=-0.257 | p=1.000
|
| 36 |
+
|
| 37 |
+
[LLM Summary]
|
| 38 |
+
LLM used : 13/1200
|
| 39 |
+
Accepted : 6/13
|
| 40 |
+
T : 0.9234
|
| 41 |
+
[LLM Done]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
Physical Refinement Diagnostics: 0.5296437740325928
|
| 45 |
+
Binder cumulant (1): 0.6467613577842712
|
| 46 |
+
Chi improvement: -0.12803667783737183
|
| 47 |
+
Binder improvement: 0.0007802844047546387
|
| 48 |
+
|
| 49 |
+
[LLM] Start | B=1200 | T=0.9234
|
| 50 |
+
[Debug] ESS=1.88 | k=13 | deltaE_band=[0.00,2.00] | in_band=100 | selected=13
|
| 51 |
+
[Debug] d_model=0.9798, d_llm=0.9763
|
| 52 |
+
[Step] Sample=73 | REJECT | d_model=0.980 | d_total=2.239 | p=0.107
|
| 53 |
+
[Debug] d_model=1.0320, d_llm=-0.7650
|
| 54 |
+
[Step] Sample=749 | ACCEPT | d_model=1.032 | d_total=0.045 | p=0.956
|
| 55 |
+
[Debug] d_model=0.9327, d_llm=0.9304
|
| 56 |
+
[Step] Sample=758 | REJECT | d_model=0.933 | d_total=2.133 | p=0.118
|
| 57 |
+
[Debug] d_model=1.0745, d_llm=-0.8507
|
| 58 |
+
[Step] Sample=991 | ACCEPT | d_model=1.075 | d_total=-0.023 | p=1.000
|
| 59 |
+
[Debug] d_model=0.9232, d_llm=-0.9211
|
| 60 |
+
[Step] Sample=68 | ACCEPT | d_model=0.923 | d_total=-0.265 | p=1.000
|
| 61 |
+
[Debug] d_model=0.9205, d_llm=0.9206
|
| 62 |
+
[Step] Sample=213 | REJECT | d_model=0.921 | d_total=2.108 | p=0.121
|
| 63 |
+
[Debug] d_model=0.9111, d_llm=0.9135
|
| 64 |
+
[Step] Sample=970 | REJECT | d_model=0.911 | d_total=2.090 | p=0.124
|
| 65 |
+
[Debug] d_model=0.9092, d_llm=0.9139
|
| 66 |
+
[Step] Sample=76 | REJECT | d_model=0.909 | d_total=2.088 | p=0.124
|
| 67 |
+
[Debug] d_model=0.8912, d_llm=-0.8984
|
| 68 |
+
[Step] Sample=368 | ACCEPT | d_model=0.891 | d_total=-0.268 | p=1.000
|
| 69 |
+
[Debug] d_model=1.1118, d_llm=1.1166
|
| 70 |
+
[Step] Sample=875 | ACCEPT | d_model=1.112 | d_total=2.552 | p=0.078
|
| 71 |
+
[Debug] d_model=1.1135, d_llm=-1.1141
|
| 72 |
+
[Step] Sample=533 | ACCEPT | d_model=1.113 | d_total=-0.324 | p=1.000
|
| 73 |
+
[Debug] d_model=1.1337, d_llm=1.1295
|
| 74 |
+
[Step] Sample=597 | REJECT | d_model=1.134 | d_total=2.591 | p=0.075
|
| 75 |
+
[Debug] d_model=1.1339, d_llm=1.1250
|
| 76 |
+
[Step] Sample=696 | REJECT | d_model=1.134 | d_total=2.585 | p=0.075
|
| 77 |
+
|
| 78 |
+
[LLM Summary]
|
| 79 |
+
LLM used : 13/1200
|
| 80 |
+
Accepted : 6/13
|
| 81 |
+
T : 0.9234
|
| 82 |
+
[LLM Done]
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
Secondary magnetic susceptibility: 0.5303375720977783
|
| 86 |
+
Binder cumulant (2): 0.6477672457695007
|
| 87 |
+
Chi improvement: -0.12793707847595215
|
| 88 |
+
Binder improvement: 0.0005515813827514648
|
| 89 |
+
|
| 90 |
+
Running energy analysis ↓
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
Running diagnostics with the following calculations ↙
|
| 94 |
+
|
| 95 |
+
✔ Energy distribution finished in 149.96 seconds
|
| 96 |
+
✔ Energy landscape extremes finished in 0.06 seconds
|
| 97 |
+
✔ Phase diagram finished in 1.17 seconds
|
| 98 |
+
✔ RBM filters finished in 0.06 seconds
|
| 99 |
+
✔ Sample quality metrics finished in 12.30 seconds
|
| 100 |
+
|
| 101 |
+
Diagnostics completed in 164.88 seconds
|
| 102 |
+
|
| 103 |
+
GPU Energy Used : 41135.66 Joules
|
| 104 |
+
|
| 105 |
+
Device: cuda
|
| 106 |
+
, ↘
|
| 107 |
+
|
| 108 |
+
samples_gpu0_seed1_refined.png → Patches: torch.Size([1200, 1, 28, 28])
|
| 109 |
+
samples_gpu0_seed1_refined.png → Avg MSE: 0.1186484
|
| 110 |
+
Saved: samples_gpu0_seed1_perfect.png
|
| 111 |
+
|
| 112 |
+
Device: cuda
|
| 113 |
+
, ↘
|
| 114 |
+
|
| 115 |
+
samples_prof_gpu0_seed1_refined.png → Patches: torch.Size([1200, 1, 28, 28])
|
| 116 |
+
samples_prof_gpu0_seed1_refined.png → Avg MSE: 0.11937518
|
| 117 |
+
Saved: samples_prof_gpu0_seed1_perfect.png
|
| 118 |
+
|
| 119 |
+
DEBUG perfect_img_path: samples_prof_gpu0_seed1_perfect.png
|
| 120 |
+
[IMAGE LPIPS] 0.14028222858905792
|
| 121 |
+
|
| 122 |
+
[LLM RAW OUTPUT]
|
| 123 |
+
High image_similarity (0.86) and visibly preserved, sharp digit structure indicate the refinement preserves the reference manifold — a strong positive signal that modes are structurally captured. Diversity and entropy are low, and mixing (tau_int) is slow, but by the provided rules those alone do not imply failure. The system shows an active learning signal, positive temperature and beta trends, and heating — so weights appear to still be changing. Low entropy combined with high beta and high structural fidelity point to an ordered output regime rather than critical or disordered. Collapse risk is minimal because structural similarity and reconstruction quality are high, though low diversity raises a moderate risk of over-ordering (reduced micro-mode variety). Confidence is bounded by the evidence: similarity and learning flags are strong, but diversity/mixing provide limited contradictory information, so confidence is substantial but not maximal.
|
| 124 |
+
|
| 125 |
+
{
|
| 126 |
+
"analysis": "High image_similarity (0.8597) and visually preserved digit structure indicate modes are retained and refinement matches the reference manifold. Learning_signal (0.328) and heating/positive trend slopes show continued parameter change, so the system is still learning. Entropy and diversity are low (entropy 0.3107 overall, current entropy 0.1565; diversity 0.2076), and mixing is slow (mcmc_tau_int 27.65), but per the strict rules these do not imply collapse when structural similarity is high. Beta is high (current beta 4.386) and outputs are structured → ordered phase. Overall classification: the system is in a learning regime with ordered phase. Epistemic note: confidence is limited by conflicting signals (very high similarity and active learning vs low diversity and poor mixing); therefore confidence is substantial but constrained by those contradictory indicators.",
|
| 127 |
+
"regime": "learning",
|
| 128 |
+
"phase": "ordered",
|
| 129 |
+
"scores": {
|
| 130 |
+
"temperature": 0.92,
|
| 131 |
+
"gibbs": 0.86
|
| 132 |
+
},
|
| 133 |
+
"risk": {
|
| 134 |
+
"stagnation": 0.05,
|
| 135 |
+
"collapse": 0.02,
|
| 136 |
+
"over_ordering": 0.30
|
| 137 |
+
},
|
| 138 |
+
"confidence": 0.75
|
| 139 |
+
}
|
| 140 |
+
[SYS] EVIDENCE=0.272 | LLM=0.750 → FINAL=0.272
|
| 141 |
+
|
| 142 |
+
[LLM CLEAN RESULT]
|
| 143 |
+
{'regime': 'learning', 'phase': 'ordered', 'failures': [], 'scores': {'temperature': 0.46, 'gibbs': 0.43}, 'risk': {'stagnation': 0.05, 'collapse': 0.0, 'over_ordering': 0.3}, 'actions': {}, 'confidence': 0.2719435542821884, 'analysis': 'High image_similarity (0.8597) and visually preserved digit structure indicate modes are retained and refinement matches the reference manifold. Learning_signal (0.328) and heating/positive trend slopes show continued parameter change, so the system is still learning. Entropy and diversity are low (entropy 0.3107 overall, current entropy 0.1565; diversity 0.2076), and mixing is slow (mcmc_tau_int 27.65), but per the strict rules these do not imply collapse when structural similarity is high. Beta is high (current beta 4.386) and outputs are structured → ordered phase. Overall classification: the system is in a learning regime with ordered phase. Epistemic note: confidence is limited by conflicting signals (very high similarity and active learning vs low diversity and poor mixing); therefore confidence is substantial but constrained by those contradictory indicators. [OVERRIDDEN: high structural similarity]', 'reason': 'json|low_confidence_scaled', 'evidence': 0.2719435542821884, 'llm_conf_raw': 0.75}
|
| 144 |
+
[ANASIS] collapse=0.139 stagnation=0.320 healthy_align=0.540 → risk=0.406
|
| 145 |
+
[ANALYSIS SIGNAL] 0.406
|
| 146 |
+
[LLM OUTPUT SAVED] → llm_output_gpu0_seed1.json
|
| 147 |
+
|
| 148 |
+
[LLM RESULT]
|
| 149 |
+
|
| 150 |
+
{
|
| 151 |
+
"regime": "learning",
|
| 152 |
+
"phase": "ordered",
|
| 153 |
+
"failures": [],
|
| 154 |
+
"scores": {
|
| 155 |
+
"temperature": 0.46,
|
| 156 |
+
"gibbs": 0.43
|
| 157 |
+
},
|
| 158 |
+
"risk": {
|
| 159 |
+
"stagnation": 0.05,
|
| 160 |
+
"collapse": 0.0,
|
| 161 |
+
"over_ordering": 0.3
|
| 162 |
+
},
|
| 163 |
+
"actions": {},
|
| 164 |
+
"confidence": 0.2719435542821884,
|
| 165 |
+
"analysis": "High image_similarity (0.8597) and visually preserved digit structure
|
| 166 |
+
indicate modes are retained and refinement matches the reference manifold. Learning_signal
|
| 167 |
+
(0.328) and heating/positive trend slopes show continued parameter change, so the system
|
| 168 |
+
is still learning. Entropy and diversity are low (entropy 0.3107 overall, current entropy
|
| 169 |
+
0.1565; diversity 0.2076), and mixing is slow (mcmc_tau_int 27.65), but per the strict
|
| 170 |
+
rules these do not imply collapse when structural similarity is high. Beta is high
|
| 171 |
+
(current beta 4.386) and outputs are structured → ordered phase. Overall classification:
|
| 172 |
+
the system is in a learning regime with ordered phase. Epistemic note: confidence is
|
| 173 |
+
limited by conflicting signals (very high similarity and active learning vs low diversity
|
| 174 |
+
and poor mixing); therefore confidence is substantial but constrained by those
|
| 175 |
+
contradictory indicators. [OVERRIDDEN: high structural similarity]",
|
| 176 |
+
"reason": "json|low_confidence_scaled",
|
| 177 |
+
"evidence": 0.2719435542821884,
|
| 178 |
+
"llm_conf_raw": 0.75
|
| 179 |
+
}
|
| 180 |
+
Hybrid Thermodynamic RBM Final Results
|
| 181 |
+
Seed: 1 | GPU: 0
|
| 182 |
+
Final Temperature : 0.923444
|
| 183 |
+
Weight Norm : 68.468369
|
| 184 |
+
Spectral Gain : 74.144559
|
| 185 |
+
Train Log-Likelihood : -411.152252
|
| 186 |
+
Test Log-Likelihood : -410.739502
|
| 187 |
+
Train Pseudo-Likelihood : -0.678435
|
| 188 |
+
Test Pseudo-Likelihood : -0.678006
|
| 189 |
+
Reconstruction MSE : 0.015900
|
| 190 |
+
Reconstruction Accuracy : 0.980278
|
| 191 |
+
Mean Data Energy : -652.020264
|
| 192 |
+
Mean Model Energy : -632.642517
|
| 193 |
+
Energy Gap : 19.377747
|
| 194 |
+
GPU Energy Used : 41135.66 J
|
| 195 |
+
Log Partition (AIS) : 1063.172485
|
| 196 |
+
AIS ESS : 2483.00
|
| 197 |
+
AIS Log-Weight Variance : 1.972837
|
| 198 |
+
MCMC tau_int : 27.66
|
| 199 |
+
MCMC tau_std : 4.90
|
| 200 |
+
MCMC tau_max : 40.42
|
| 201 |
+
MCMC tau_min : 10.32
|
| 202 |
+
MCMC ESS : 359672.06
|
| 203 |
+
MCMC ESS / chain : 299.73
|
| 204 |
+
MCMC R-hat : 1.1702
|
| 205 |
+
MCMC ACF length : 25
|
| 206 |
+
Pixel Entropy : 0.310680
|
| 207 |
+
Sample Diversity : 0.207591
|
| 208 |
+
Mean Distribution L2 : 0.095660
|
| 209 |
+
Professional Samples File: samples_prof_gpu0_seed1.png
|
| 210 |
+
Generated Samples File : samples_gpu0_seed1.png
|
| 211 |
+
|
| 212 |
+
(srtrbm_env) gosachunt_uv0@GosachkPC:~/srtrbm_project$
|
artifacts/samples_gpu0_seed1_perfect.png
ADDED
|
Git LFS Details
|
artifacts/samples_gpu0_seed1_refined.png
ADDED
|
Git LFS Details
|
artifacts/samples_gpu0_seed1_symbol.png
ADDED
|
Git LFS Details
|
artifacts/samples_prof_gpu0_seed1_perfect.png
ADDED
|
Git LFS Details
|
artifacts/samples_prof_gpu0_seed1_refined.png
ADDED
|
Git LFS Details
|
artifacts/samples_prof_gpu0_seed1_symbol.png
ADDED
|
Git LFS Details
|
artifacts/srtrbm_core_dynamics.pdf
ADDED
|
Binary file (24 kB). View file
|
|
|
artifacts/srtrbm_energy_data_vs_model_modern.pdf
ADDED
|
Binary file (18.6 kB). View file
|
|
|
artifacts/srtrbm_energy_diagnostics.pdf
ADDED
|
Binary file (40.1 kB). View file
|
|
|
artifacts/srtrbm_filters.png
ADDED
|
Git LFS Details
|
artifacts/srtrbm_phase_diagram.pdf
ADDED
|
Binary file (37 kB). View file
|
|
|
correction/NO.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Refinement:
|
| 5 |
+
|
| 6 |
+
def __init__(self, model):
|
| 7 |
+
self.model = model
|
| 8 |
+
|
| 9 |
+
@torch.no_grad()
|
| 10 |
+
def mh_step(self, states, steps=6):
|
| 11 |
+
T = self.model.temperature()
|
| 12 |
+
current = states.clone()
|
| 13 |
+
|
| 14 |
+
for _ in range(steps):
|
| 15 |
+
proposal = self.model.gibbs_chain(current, T, steps=2)
|
| 16 |
+
|
| 17 |
+
F_cur = self.model.free_energy(current, T)
|
| 18 |
+
F_prop = self.model.free_energy(proposal, T)
|
| 19 |
+
|
| 20 |
+
delta = F_prop - F_cur
|
| 21 |
+
|
| 22 |
+
prob = torch.exp(torch.clamp(-delta / T, max=0))
|
| 23 |
+
accept = torch.rand_like(prob) < prob
|
| 24 |
+
|
| 25 |
+
accept = accept.view(-1, *([1] * (current.dim() - 1)))
|
| 26 |
+
current = torch.where(accept, proposal, current)
|
| 27 |
+
|
| 28 |
+
return current
|
| 29 |
+
|
| 30 |
+
@torch.no_grad()
|
| 31 |
+
def energy_guided_refine(self, states, steps=5):
|
| 32 |
+
T = self.model.temperature()
|
| 33 |
+
current = states.clone()
|
| 34 |
+
|
| 35 |
+
for _ in range(steps):
|
| 36 |
+
proposal = self.model.gibbs_chain(current, T, steps=2)
|
| 37 |
+
|
| 38 |
+
F_cur = self.model.free_energy(current, T)
|
| 39 |
+
F_prop = self.model.free_energy(proposal, T)
|
| 40 |
+
|
| 41 |
+
better = F_prop < F_cur
|
| 42 |
+
|
| 43 |
+
delta = F_prop - F_cur
|
| 44 |
+
|
| 45 |
+
prob = torch.exp(torch.clamp(-delta / T, max=0))
|
| 46 |
+
stochastic_accept = torch.rand_like(prob) < prob
|
| 47 |
+
|
| 48 |
+
accept = better | stochastic_accept
|
| 49 |
+
accept = accept.view(-1, *([1] * (current.dim() - 1)))
|
| 50 |
+
|
| 51 |
+
current = torch.where(accept, proposal, current)
|
| 52 |
+
|
| 53 |
+
return current
|
| 54 |
+
|
| 55 |
+
@torch.no_grad()
|
| 56 |
+
def soft_refine(self, states, steps=10):
|
| 57 |
+
T = self.model.temperature()
|
| 58 |
+
v = states.clone()
|
| 59 |
+
|
| 60 |
+
for _ in range(steps):
|
| 61 |
+
h_prob = torch.sigmoid((v @ self.model.W + self.model.hidden_bias) / T)
|
| 62 |
+
h = torch.bernoulli(h_prob)
|
| 63 |
+
|
| 64 |
+
v_prob = torch.sigmoid((h @ self.model.W.T + self.model.visible_bias) / T)
|
| 65 |
+
|
| 66 |
+
v = 0.7 * v + 0.3 * v_prob
|
| 67 |
+
return v
|
| 68 |
+
|
| 69 |
+
@torch.no_grad()
|
| 70 |
+
def myra_refine(self, states):
|
| 71 |
+
states = self.mh_step(states, steps=6)
|
| 72 |
+
states = self.energy_guided_refine(states, steps=5)
|
| 73 |
+
states = self.soft_refine(states, steps=8)
|
| 74 |
+
|
| 75 |
+
return states
|
correction/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from correction.NO import Refinement
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
"Refinement",
|
| 5 |
+
]
|
graphs/SrtrbmEnergy.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
from scipy.stats import gaussian_kde
|
| 5 |
+
from graphs.SrtrbmVisualization import save_digit_grid
|
| 6 |
+
|
| 7 |
+
# Modern plotting style
|
| 8 |
+
|
| 9 |
+
plt.style.use("seaborn-v0_8-whitegrid")
|
| 10 |
+
|
| 11 |
+
COLORS = {
|
| 12 |
+
"data": "#1f77b4", # deep blue
|
| 13 |
+
"model": "#ff7f0e", # orange
|
| 14 |
+
"mean_data": "#1f77b4",
|
| 15 |
+
"mean_model": "#ff7f0e"
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# Energy ranking
|
| 20 |
+
|
| 21 |
+
@torch.no_grad()
|
| 22 |
+
def compute_energy_ranking(model, data):
|
| 23 |
+
temperature = model.temperature()
|
| 24 |
+
|
| 25 |
+
energies = model.free_energy(data, temperature).detach()
|
| 26 |
+
|
| 27 |
+
idx = torch.argsort(energies)
|
| 28 |
+
|
| 29 |
+
return energies.cpu(), idx.cpu()
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Energy landscape visualization
|
| 33 |
+
|
| 34 |
+
@torch.no_grad()
|
| 35 |
+
def visualize_energy_extremes(model, data, k=100):
|
| 36 |
+
_, idx = compute_energy_ranking(model, data)
|
| 37 |
+
|
| 38 |
+
N = len(data)
|
| 39 |
+
|
| 40 |
+
best = data[idx[:k]]
|
| 41 |
+
mid = data[idx[N // 2 - k // 2: N // 2 + k // 2]]
|
| 42 |
+
worst = data[idx[-k:]]
|
| 43 |
+
|
| 44 |
+
save_digit_grid(best, "lowest_energy_digits.png")
|
| 45 |
+
save_digit_grid(mid, "median_energy_digits.png")
|
| 46 |
+
save_digit_grid(worst, "highest_energy_digits.png")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# Energy distribution analysis
|
| 50 |
+
|
| 51 |
+
@torch.no_grad()
|
| 52 |
+
def plot_data_vs_model_energy(model, data):
|
| 53 |
+
temperature = model.temperature()
|
| 54 |
+
|
| 55 |
+
# full dataset
|
| 56 |
+
n = int(data.shape[0])
|
| 57 |
+
|
| 58 |
+
# Data energy
|
| 59 |
+
|
| 60 |
+
F_data = model.free_energy(data[:n], temperature)
|
| 61 |
+
F_data = F_data.detach().cpu().numpy()
|
| 62 |
+
|
| 63 |
+
# Model sampling
|
| 64 |
+
|
| 65 |
+
samples = model.generate_ensemble_samples(
|
| 66 |
+
n_chains=n,
|
| 67 |
+
steps=2000,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
F_model = model.free_energy(samples, temperature)
|
| 71 |
+
F_model = F_model.detach().cpu().numpy()
|
| 72 |
+
|
| 73 |
+
# Histogram range
|
| 74 |
+
|
| 75 |
+
lo = min(F_data.min(), F_model.min())
|
| 76 |
+
hi = max(F_data.max(), F_model.max())
|
| 77 |
+
|
| 78 |
+
bins = np.linspace(lo, hi, 120)
|
| 79 |
+
|
| 80 |
+
# Kernel density estimation
|
| 81 |
+
|
| 82 |
+
kde_data = gaussian_kde(F_data)
|
| 83 |
+
kde_model = gaussian_kde(F_model)
|
| 84 |
+
|
| 85 |
+
x = np.linspace(lo, hi, 500)
|
| 86 |
+
|
| 87 |
+
# Plot
|
| 88 |
+
|
| 89 |
+
plt.figure(figsize=(7, 5))
|
| 90 |
+
|
| 91 |
+
plt.hist(
|
| 92 |
+
F_data,
|
| 93 |
+
bins=bins,
|
| 94 |
+
density=True,
|
| 95 |
+
alpha=0.35,
|
| 96 |
+
color=COLORS["data"],
|
| 97 |
+
label="Data"
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
plt.hist(
|
| 101 |
+
F_model,
|
| 102 |
+
bins=bins,
|
| 103 |
+
density=True,
|
| 104 |
+
alpha=0.35,
|
| 105 |
+
color=COLORS["model"],
|
| 106 |
+
label="Model"
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# KDE curves
|
| 110 |
+
|
| 111 |
+
plt.plot(
|
| 112 |
+
x,
|
| 113 |
+
kde_data(x),
|
| 114 |
+
color=COLORS["data"],
|
| 115 |
+
linewidth=2.5
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
plt.plot(
|
| 119 |
+
x,
|
| 120 |
+
kde_model(x),
|
| 121 |
+
color=COLORS["model"],
|
| 122 |
+
linewidth=2.5
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# mean lines
|
| 126 |
+
|
| 127 |
+
mean_data = F_data.mean()
|
| 128 |
+
mean_model = F_model.mean()
|
| 129 |
+
|
| 130 |
+
plt.axvline(
|
| 131 |
+
mean_data,
|
| 132 |
+
linestyle="--",
|
| 133 |
+
linewidth=2,
|
| 134 |
+
color=COLORS["mean_data"]
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
plt.axvline(
|
| 138 |
+
mean_model,
|
| 139 |
+
linestyle=":",
|
| 140 |
+
linewidth=2,
|
| 141 |
+
color=COLORS["mean_model"]
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
energy_gap = abs(mean_data - mean_model)
|
| 145 |
+
|
| 146 |
+
plt.xlabel("Free Energy")
|
| 147 |
+
|
| 148 |
+
plt.ylabel("Probability Density")
|
| 149 |
+
|
| 150 |
+
plt.title(
|
| 151 |
+
f"Free Energy Distribution (ΔF = {energy_gap:.2f})"
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
plt.legend(frameon=False)
|
| 155 |
+
|
| 156 |
+
plt.grid(alpha=0.25)
|
| 157 |
+
|
| 158 |
+
plt.tight_layout()
|
| 159 |
+
|
| 160 |
+
plt.savefig(
|
| 161 |
+
"srtrbm_energy_data_vs_model_modern.pdf",
|
| 162 |
+
bbox_inches="tight"
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
plt.close()
|
| 166 |
+
|
| 167 |
+
# Return statistics
|
| 168 |
+
|
| 169 |
+
return {
|
| 170 |
+
|
| 171 |
+
"mean_data_energy": float(mean_data),
|
| 172 |
+
|
| 173 |
+
"mean_model_energy": float(mean_model),
|
| 174 |
+
|
| 175 |
+
"energy_gap": float(energy_gap)
|
| 176 |
+
|
| 177 |
+
}
|
graphs/SrtrbmMetrics.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
|
| 5 |
+
# Modern plotting style
|
| 6 |
+
|
| 7 |
+
plt.style.use("seaborn-v0_8-whitegrid")
|
| 8 |
+
|
| 9 |
+
COLORS = {
|
| 10 |
+
"flip": "#1f77b4",
|
| 11 |
+
"smooth": "#d62728",
|
| 12 |
+
"sus": "#2ca02c"
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# Sample quality diagnostics
|
| 17 |
+
|
| 18 |
+
@torch.no_grad()
|
| 19 |
+
def sample_quality_metrics(model, real_data, n_samples=5000, diversity_pairs=3000):
|
| 20 |
+
samples = model.generate_ensemble_samples(
|
| 21 |
+
n_chains=n_samples,
|
| 22 |
+
steps=2000
|
| 23 |
+
).float()
|
| 24 |
+
|
| 25 |
+
# Pixel entropy
|
| 26 |
+
|
| 27 |
+
p = samples.mean(0)
|
| 28 |
+
|
| 29 |
+
entropy = -(p * torch.log(p.clamp(min=1e-8)) +
|
| 30 |
+
(1 - p) * torch.log((1 - p).clamp(min=1e-8))).mean()
|
| 31 |
+
|
| 32 |
+
# Diversity
|
| 33 |
+
|
| 34 |
+
flat = samples.view(n_samples, -1)
|
| 35 |
+
|
| 36 |
+
device = flat.device
|
| 37 |
+
|
| 38 |
+
idx1 = torch.randint(0, n_samples, (diversity_pairs,), device=device)
|
| 39 |
+
idx2 = torch.randint(0, n_samples, (diversity_pairs,), device=device)
|
| 40 |
+
|
| 41 |
+
dists = torch.abs(flat[idx1] - flat[idx2]).mean(1)
|
| 42 |
+
|
| 43 |
+
diversity = dists.mean().item()
|
| 44 |
+
|
| 45 |
+
# Mean distribution distance
|
| 46 |
+
|
| 47 |
+
real_mean = real_data.float().mean(0)
|
| 48 |
+
gen_mean = samples.mean(0)
|
| 49 |
+
|
| 50 |
+
mean_l2 = torch.norm(real_mean - gen_mean) / np.sqrt(real_mean.numel())
|
| 51 |
+
|
| 52 |
+
return {
|
| 53 |
+
"pixel_entropy": entropy.item(),
|
| 54 |
+
"diversity": diversity,
|
| 55 |
+
"mean_l2": mean_l2.item()
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Phase transition detection
|
| 60 |
+
|
| 61 |
+
def detect_critical_beta(beta, flip):
|
| 62 |
+
beta = np.array(beta)
|
| 63 |
+
flip = np.array(flip)
|
| 64 |
+
|
| 65 |
+
if len(beta) < 5:
|
| 66 |
+
return beta.mean()
|
| 67 |
+
|
| 68 |
+
order = np.argsort(beta)
|
| 69 |
+
|
| 70 |
+
beta = beta[order]
|
| 71 |
+
flip = flip[order]
|
| 72 |
+
|
| 73 |
+
window = 7
|
| 74 |
+
kernel = np.ones(window) / window
|
| 75 |
+
|
| 76 |
+
flip_smooth = np.convolve(flip, kernel, mode="same")
|
| 77 |
+
|
| 78 |
+
susceptibility = -np.gradient(flip_smooth, beta)
|
| 79 |
+
|
| 80 |
+
idx = np.argmax(susceptibility)
|
| 81 |
+
|
| 82 |
+
return beta[idx]
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# Phase diagram visualization
|
| 86 |
+
|
| 87 |
+
def plot_flip_beta(
|
| 88 |
+
model,
|
| 89 |
+
title,
|
| 90 |
+
filename,
|
| 91 |
+
fig_size=(10, 8),
|
| 92 |
+
density=True
|
| 93 |
+
):
|
| 94 |
+
beta = np.array(model.spectral_beta_hist)
|
| 95 |
+
flip = np.array(model.flip_hist)
|
| 96 |
+
|
| 97 |
+
if len(beta) < 5:
|
| 98 |
+
return
|
| 99 |
+
|
| 100 |
+
# Sort
|
| 101 |
+
|
| 102 |
+
order = np.argsort(beta)
|
| 103 |
+
|
| 104 |
+
beta = beta[order]
|
| 105 |
+
flip = flip[order]
|
| 106 |
+
|
| 107 |
+
# Smooth
|
| 108 |
+
|
| 109 |
+
window = 9
|
| 110 |
+
kernel = np.ones(window) / window
|
| 111 |
+
|
| 112 |
+
flip_smooth = np.convolve(flip, kernel, mode="same")
|
| 113 |
+
|
| 114 |
+
# Susceptibility
|
| 115 |
+
|
| 116 |
+
susceptibility = -np.gradient(flip_smooth, beta)
|
| 117 |
+
|
| 118 |
+
beta_c = beta[np.argmax(susceptibility)]
|
| 119 |
+
|
| 120 |
+
# Figure
|
| 121 |
+
|
| 122 |
+
fig, axes = plt.subplots(3, 1, figsize=fig_size)
|
| 123 |
+
|
| 124 |
+
# Panel 1 — Phase diagram
|
| 125 |
+
|
| 126 |
+
if density:
|
| 127 |
+
|
| 128 |
+
hb = axes[0].hexbin(
|
| 129 |
+
beta,
|
| 130 |
+
flip,
|
| 131 |
+
gridsize=40,
|
| 132 |
+
cmap="viridis",
|
| 133 |
+
mincnt=1
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
fig.colorbar(
|
| 137 |
+
hb,
|
| 138 |
+
ax=axes[0],
|
| 139 |
+
label="Density",
|
| 140 |
+
fraction=0.045,
|
| 141 |
+
pad=0.03
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
else:
|
| 145 |
+
|
| 146 |
+
axes[0].scatter(
|
| 147 |
+
beta,
|
| 148 |
+
flip,
|
| 149 |
+
s=20,
|
| 150 |
+
alpha=0.35,
|
| 151 |
+
color=COLORS["flip"]
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
axes[0].plot(
|
| 155 |
+
beta,
|
| 156 |
+
flip_smooth,
|
| 157 |
+
linewidth=2.5,
|
| 158 |
+
color=COLORS["smooth"],
|
| 159 |
+
label="Smoothed flip rate"
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
axes[0].axvline(
|
| 163 |
+
beta_c,
|
| 164 |
+
linestyle="--",
|
| 165 |
+
linewidth=2,
|
| 166 |
+
color="black",
|
| 167 |
+
label=f"β_c ≈ {beta_c:.2f}"
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
axes[0].set_xlabel(r"$\beta_{\mathrm{spectral}}$")
|
| 171 |
+
axes[0].set_ylabel("Flip Rate")
|
| 172 |
+
axes[0].set_title(title)
|
| 173 |
+
axes[0].legend()
|
| 174 |
+
|
| 175 |
+
# Panel 2 — Smoothed flip
|
| 176 |
+
|
| 177 |
+
axes[1].plot(
|
| 178 |
+
beta,
|
| 179 |
+
flip_smooth,
|
| 180 |
+
linewidth=2.5,
|
| 181 |
+
color=COLORS["smooth"]
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
axes[1].axvline(
|
| 185 |
+
beta_c,
|
| 186 |
+
linestyle="--",
|
| 187 |
+
color="black"
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
axes[1].set_ylabel("Smoothed Flip")
|
| 191 |
+
|
| 192 |
+
# Panel 3 — Susceptibility
|
| 193 |
+
|
| 194 |
+
axes[2].plot(
|
| 195 |
+
beta,
|
| 196 |
+
susceptibility,
|
| 197 |
+
linewidth=2.5,
|
| 198 |
+
color=COLORS["sus"]
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
axes[2].axvline(
|
| 202 |
+
beta_c,
|
| 203 |
+
linestyle="--",
|
| 204 |
+
color="black"
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
axes[2].set_xlabel(r"$\beta_{\mathrm{spectral}}$")
|
| 208 |
+
axes[2].set_ylabel("Susceptibility")
|
| 209 |
+
|
| 210 |
+
# Layout
|
| 211 |
+
|
| 212 |
+
fig.subplots_adjust(
|
| 213 |
+
left=0.12,
|
| 214 |
+
right=0.96,
|
| 215 |
+
hspace=0.35
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
plt.savefig(filename, bbox_inches="tight")
|
| 219 |
+
|
| 220 |
+
plt.close()
|
graphs/SrtrbmVisualization.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torchvision.utils as vutils
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# Save grid of digits
|
| 8 |
+
|
| 9 |
+
@torch.no_grad()
|
| 10 |
+
def save_digit_grid(data, filename, n_row=20):
|
| 11 |
+
imgs = data.reshape(-1, 1, 28, 28).detach().cpu()
|
| 12 |
+
|
| 13 |
+
grid = vutils.make_grid(
|
| 14 |
+
imgs,
|
| 15 |
+
nrow=n_row,
|
| 16 |
+
padding=2,
|
| 17 |
+
normalize=True
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
img = (grid * 255).clamp(0, 255).byte()
|
| 21 |
+
|
| 22 |
+
img = img.permute(1, 2, 0).numpy()
|
| 23 |
+
|
| 24 |
+
if img.shape[2] == 1:
|
| 25 |
+
img = img[:, :, 0]
|
| 26 |
+
|
| 27 |
+
Image.fromarray(img).save(filename)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# RBM filters visualization
|
| 31 |
+
|
| 32 |
+
@torch.no_grad()
|
| 33 |
+
def visualize_rbm_filters(
|
| 34 |
+
model,
|
| 35 |
+
filename="srtrbm_filters.png",
|
| 36 |
+
n_filters=256
|
| 37 |
+
):
|
| 38 |
+
W = model.W.detach().cpu()
|
| 39 |
+
|
| 40 |
+
n_filters = min(n_filters, W.shape[1])
|
| 41 |
+
|
| 42 |
+
filters = W[:, :n_filters].T
|
| 43 |
+
|
| 44 |
+
# normalize each filter
|
| 45 |
+
|
| 46 |
+
min_vals = filters.min(dim=1, keepdim=True)[0]
|
| 47 |
+
max_vals = filters.max(dim=1, keepdim=True)[0]
|
| 48 |
+
|
| 49 |
+
filters = (filters - min_vals) / (max_vals - min_vals + 1e-8)
|
| 50 |
+
|
| 51 |
+
filters = filters.reshape(-1, 1, 28, 28)
|
| 52 |
+
|
| 53 |
+
n_row = int(np.ceil(np.sqrt(n_filters)))
|
| 54 |
+
|
| 55 |
+
grid = vutils.make_grid(
|
| 56 |
+
filters,
|
| 57 |
+
nrow=n_row,
|
| 58 |
+
padding=2,
|
| 59 |
+
normalize=False
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
img = (grid * 255).clamp(0, 255).byte()
|
| 63 |
+
|
| 64 |
+
img = img.permute(1, 2, 0).numpy()
|
| 65 |
+
|
| 66 |
+
if img.shape[2] == 1:
|
| 67 |
+
img = img[:, :, 0]
|
| 68 |
+
|
| 69 |
+
Image.fromarray(img).save(filename)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# Fantasy particles (model samples)
|
| 73 |
+
|
| 74 |
+
@torch.no_grad()
|
| 75 |
+
def visualize_fantasy_particles(
|
| 76 |
+
model,
|
| 77 |
+
filename="fantasy_particles.png",
|
| 78 |
+
n_chains=400,
|
| 79 |
+
steps=2000
|
| 80 |
+
):
|
| 81 |
+
samples = model.generate_ensemble_samples(
|
| 82 |
+
n_chains=n_chains,
|
| 83 |
+
steps=steps
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
save_digit_grid(
|
| 87 |
+
samples,
|
| 88 |
+
filename,
|
| 89 |
+
n_row=int(np.sqrt(n_chains))
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# Training visual monitoring
|
| 94 |
+
|
| 95 |
+
@torch.no_grad()
|
| 96 |
+
def save_training_visuals(model, epoch):
|
| 97 |
+
samples = model.generate_ensemble_samples(
|
| 98 |
+
n_chains=400,
|
| 99 |
+
steps=3000
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
save_digit_grid(
|
| 103 |
+
samples,
|
| 104 |
+
f"samples_epoch_{epoch}.png",
|
| 105 |
+
n_row=20
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
visualize_rbm_filters(
|
| 109 |
+
model,
|
| 110 |
+
f"filters_epoch_{epoch}.png"
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
visualize_fantasy_particles(
|
| 114 |
+
model,
|
| 115 |
+
f"fantasy_epoch_{epoch}.png"
|
| 116 |
+
)
|
graphs/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from graphs.SrtrbmEnergy import (
|
| 2 |
+
compute_energy_ranking,
|
| 3 |
+
visualize_energy_extremes,
|
| 4 |
+
plot_data_vs_model_energy
|
| 5 |
+
)
|
| 6 |
+
|
| 7 |
+
from graphs.SrtrbmMetrics import (
|
| 8 |
+
sample_quality_metrics,
|
| 9 |
+
plot_flip_beta,
|
| 10 |
+
detect_critical_beta
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
from graphs.SrtrbmVisualization import (
|
| 14 |
+
save_digit_grid,
|
| 15 |
+
visualize_rbm_filters,
|
| 16 |
+
visualize_fantasy_particles
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"compute_energy_ranking",
|
| 21 |
+
"visualize_energy_extremes",
|
| 22 |
+
"plot_data_vs_model_energy",
|
| 23 |
+
"sample_quality_metrics",
|
| 24 |
+
"plot_flip_beta",
|
| 25 |
+
"detect_critical_beta",
|
| 26 |
+
"save_digit_grid",
|
| 27 |
+
"visualize_rbm_filters",
|
| 28 |
+
"visualize_fantasy_particles",
|
| 29 |
+
]
|
llmeS/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from llmeS.client import SafeBookClient
|
| 2 |
+
|
| 3 |
+
from llmeS.gateway import Evaluate
|
| 4 |
+
from llmeS.gateway import ANASIS
|
| 5 |
+
|
| 6 |
+
from llmeS.hook import find_connected_components_fast
|
| 7 |
+
|
| 8 |
+
from llmeS.hook import LLMEnergy
|
| 9 |
+
from llmeS.hook import LIES_gpu
|
| 10 |
+
|
| 11 |
+
from llmeS.hook import to_sparse_gpu
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"SafeBookClient",
|
| 16 |
+
"Evaluate",
|
| 17 |
+
"ANASIS",
|
| 18 |
+
"find_connected_components_fast",
|
| 19 |
+
"LLMEnergy",
|
| 20 |
+
"to_sparse_gpu",
|
| 21 |
+
"LIES_gpu",
|
| 22 |
+
]
|
llmeS/client.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GLOBAL LLM CLIENT (FULLY AGNOSTIC + SAFE)
|
| 2 |
+
#
|
| 3 |
+
# - Works with ANY backend
|
| 4 |
+
# - Never crashes the system
|
| 5 |
+
# - Normalizes output
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import time
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BaseLLM:
|
| 12 |
+
"""
|
| 13 |
+
User must implement this.
|
| 14 |
+
|
| 15 |
+
Expected:
|
| 16 |
+
generate(prompt, images=None) -> str | dict | None
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, api_key=None):
|
| 20 |
+
self.api_key = api_key
|
| 21 |
+
|
| 22 |
+
def generate(self, prompt, images=None):
|
| 23 |
+
raise NotImplementedError
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class SafeBookClient:
|
| 27 |
+
|
| 28 |
+
def __init__(self):
|
| 29 |
+
|
| 30 |
+
api_key = os.getenv("LLM_API_KEY")
|
| 31 |
+
|
| 32 |
+
self.enabled = True
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
self.backend = BaseLLM(api_key=api_key)
|
| 36 |
+
|
| 37 |
+
except Exception as e:
|
| 38 |
+
print("[LLM DISABLED]", str(e))
|
| 39 |
+
self.backend = None
|
| 40 |
+
self.enabled = False
|
| 41 |
+
|
| 42 |
+
def _normalize(self, result):
|
| 43 |
+
"""
|
| 44 |
+
Normalize ANY backend output → string
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
if result is None:
|
| 48 |
+
return None
|
| 49 |
+
|
| 50 |
+
# string
|
| 51 |
+
if isinstance(result, str):
|
| 52 |
+
return result
|
| 53 |
+
|
| 54 |
+
# dict with text
|
| 55 |
+
if isinstance(result, dict):
|
| 56 |
+
if "text" in result:
|
| 57 |
+
return result["text"]
|
| 58 |
+
|
| 59 |
+
if "output" in result:
|
| 60 |
+
return result["output"]
|
| 61 |
+
|
| 62 |
+
# fallback
|
| 63 |
+
try:
|
| 64 |
+
return str(result)
|
| 65 |
+
except Exception:
|
| 66 |
+
return None
|
| 67 |
+
|
| 68 |
+
def generate(self, prompt, images=None):
|
| 69 |
+
|
| 70 |
+
if not self.enabled or self.backend is None:
|
| 71 |
+
return None
|
| 72 |
+
|
| 73 |
+
max_retries = 3
|
| 74 |
+
delay = 1.5
|
| 75 |
+
|
| 76 |
+
for attempt in range(max_retries):
|
| 77 |
+
try:
|
| 78 |
+
result = self.backend.generate(prompt, images=images)
|
| 79 |
+
return self._normalize(result)
|
| 80 |
+
|
| 81 |
+
except Exception as e:
|
| 82 |
+
if attempt == max_retries - 1:
|
| 83 |
+
print("[LLM ERROR]", str(e))
|
| 84 |
+
return None
|
| 85 |
+
|
| 86 |
+
time.sleep(delay * (attempt + 1))
|
| 87 |
+
|
| 88 |
+
return None
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# GLOBAL INSTANCE
|
| 92 |
+
client = SafeBookClient()
|
llmeS/gateway.py
ADDED
|
@@ -0,0 +1,632 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LLM GATEWAY (PROVIDER-AGNOSTIC INTERFACE)
|
| 2 |
+
#
|
| 3 |
+
# This module provides a high-level evaluation interface that integrates
|
| 4 |
+
# thermodynamic metrics with LLM-based semantic interpretation.
|
| 5 |
+
#
|
| 6 |
+
# IMPORTANT DESIGN PRINCIPLE:
|
| 7 |
+
# The system is strictly LLM-provider agnostic.
|
| 8 |
+
# No specific API (OpenAI, Anthropic, Gemini, etc.) is assumed.
|
| 9 |
+
#
|
| 10 |
+
# All LLM interaction is routed through a global client:
|
| 11 |
+
#
|
| 12 |
+
# raw = client.generate(prompt, images=[...])
|
| 13 |
+
#
|
| 14 |
+
# The client is expected to return raw text output only.
|
| 15 |
+
# Any provider-specific formatting, API calls, or response schemas
|
| 16 |
+
# must be handled inside the client backend implementation.
|
| 17 |
+
#
|
| 18 |
+
# This module does NOT depend on:
|
| 19 |
+
# - response objects
|
| 20 |
+
# - structured API outputs
|
| 21 |
+
# - provider-specific message formats
|
| 22 |
+
#
|
| 23 |
+
# Instead, it assumes:
|
| 24 |
+
# - raw string output from an LLM
|
| 25 |
+
# - JSON content embedded within that output
|
| 26 |
+
#
|
| 27 |
+
# The gateway is responsible for:
|
| 28 |
+
# 1. Constructing structured prompts from system metrics
|
| 29 |
+
# 2. Passing inputs (text + images) to the LLM client
|
| 30 |
+
# 3. Extracting and validating JSON from raw responses
|
| 31 |
+
# 4. Enforcing epistemic and structural constraints
|
| 32 |
+
#
|
| 33 |
+
# If no LLM backend is provided:
|
| 34 |
+
# - The system will gracefully degrade
|
| 35 |
+
# - A fallback result will be returned
|
| 36 |
+
#
|
| 37 |
+
# This design ensures full decoupling between:
|
| 38 |
+
# - thermodynamic learning (SR-TRBM)
|
| 39 |
+
# - semantic interpretation (LLM)
|
| 40 |
+
#
|
| 41 |
+
# Result:
|
| 42 |
+
# The system remains stable, extensible, and backend-independent.
|
| 43 |
+
|
| 44 |
+
import torch.nn.functional as F
|
| 45 |
+
from PIL import Image
|
| 46 |
+
import numpy as np
|
| 47 |
+
import warnings
|
| 48 |
+
import base64
|
| 49 |
+
import lpips
|
| 50 |
+
import torch
|
| 51 |
+
import os
|
| 52 |
+
import json
|
| 53 |
+
import math
|
| 54 |
+
import copy
|
| 55 |
+
import yaml
|
| 56 |
+
|
| 57 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 58 |
+
|
| 59 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 60 |
+
|
| 61 |
+
lpips_model = lpips.LPIPS(net='vgg').to(device)
|
| 62 |
+
lpips_model.eval()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def encode_image(path):
|
| 66 |
+
with open(path, "rb") as f:
|
| 67 |
+
return "data:image/png;base64," + base64.b64encode(f.read()).decode("utf-8")
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def compute_lpips(path1, path2):
|
| 71 |
+
try:
|
| 72 |
+
img1 = Image.open(path1).convert("RGB").resize((256, 256))
|
| 73 |
+
img2 = Image.open(path2).convert("RGB").resize((256, 256))
|
| 74 |
+
|
| 75 |
+
img1 = np.array(img1).astype("float32") / 255.0
|
| 76 |
+
img2 = np.array(img2).astype("float32") / 255.0
|
| 77 |
+
|
| 78 |
+
img1 = torch.tensor(img1).permute(2, 0, 1).unsqueeze(0) * 2 - 1
|
| 79 |
+
img2 = torch.tensor(img2).permute(2, 0, 1).unsqueeze(0) * 2 - 1
|
| 80 |
+
|
| 81 |
+
img1 = img1.to(device)
|
| 82 |
+
img2 = img2.to(device)
|
| 83 |
+
|
| 84 |
+
with torch.no_grad():
|
| 85 |
+
dist = lpips_model(img1, img2)
|
| 86 |
+
|
| 87 |
+
return float(dist.item())
|
| 88 |
+
|
| 89 |
+
except Exception as eta:
|
| 90 |
+
print("[LPIPS ERROR]", str(eta))
|
| 91 |
+
return 1.0
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def preprocess(x):
|
| 95 |
+
if x.dim() == 2:
|
| 96 |
+
x = x.unsqueeze(0)
|
| 97 |
+
if x.shape[0] == 1:
|
| 98 |
+
x = x.repeat(3, 1, 1)
|
| 99 |
+
|
| 100 |
+
x = F.interpolate(x.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False)
|
| 101 |
+
|
| 102 |
+
x = x * 2 - 1
|
| 103 |
+
|
| 104 |
+
return x
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def compute_lpips_diversity(samples, k=1000):
|
| 108 |
+
N = samples.shape[0]
|
| 109 |
+
total = 0.0
|
| 110 |
+
|
| 111 |
+
for _ in range(k):
|
| 112 |
+
|
| 113 |
+
while True:
|
| 114 |
+
inna = np.random.randint(0, N)
|
| 115 |
+
j = np.random.randint(0, N)
|
| 116 |
+
if inna != j:
|
| 117 |
+
break
|
| 118 |
+
|
| 119 |
+
img1 = preprocess(samples[inna]).to(device)
|
| 120 |
+
img2 = preprocess(samples[j]).to(device)
|
| 121 |
+
|
| 122 |
+
with torch.no_grad():
|
| 123 |
+
d = lpips_model(img1, img2)
|
| 124 |
+
|
| 125 |
+
total += d.item()
|
| 126 |
+
|
| 127 |
+
return float(total / (k + 1e-8))
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def load_core_principles(path="yaml/perception.yaml"):
|
| 131 |
+
try:
|
| 132 |
+
with open(path, "r") as f:
|
| 133 |
+
book = yaml.safe_load(f)
|
| 134 |
+
except FileNotFoundError:
|
| 135 |
+
print("[YAML NOT FOUND]", path)
|
| 136 |
+
return ""
|
| 137 |
+
except yaml.YAMLError as e1:
|
| 138 |
+
print("[YAML PARSE ERROR]", str(e1))
|
| 139 |
+
return ""
|
| 140 |
+
|
| 141 |
+
core = book.get("core", {})
|
| 142 |
+
|
| 143 |
+
blocks = []
|
| 144 |
+
|
| 145 |
+
for _, val in core.items():
|
| 146 |
+
title = val.get("title", "")
|
| 147 |
+
desc = val.get("description", "")
|
| 148 |
+
rules = val.get("rules", [])
|
| 149 |
+
|
| 150 |
+
block = []
|
| 151 |
+
|
| 152 |
+
if title:
|
| 153 |
+
block.append(f"{title}:")
|
| 154 |
+
|
| 155 |
+
if desc:
|
| 156 |
+
block.append(desc.strip())
|
| 157 |
+
|
| 158 |
+
for r in rules:
|
| 159 |
+
block.append(f"- {r}")
|
| 160 |
+
|
| 161 |
+
if val.get("priority") == "hard_constraint":
|
| 162 |
+
block.append("This principle is a HARD CONSTRAINT and ought to override weaker signals.")
|
| 163 |
+
|
| 164 |
+
blocks.append("\n".join(block))
|
| 165 |
+
|
| 166 |
+
return "\n\n".join(blocks)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def _empty_result(reason, analysis=""):
|
| 170 |
+
return {
|
| 171 |
+
"regime": "unknown",
|
| 172 |
+
"phase": "unknown",
|
| 173 |
+
"failures": [],
|
| 174 |
+
"actions": {},
|
| 175 |
+
"confidence": 0.0,
|
| 176 |
+
"analysis": analysis.strip(),
|
| 177 |
+
"reason": reason
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def extract_json(raw):
|
| 182 |
+
raw = (raw or "").strip()
|
| 183 |
+
|
| 184 |
+
start = raw.find("{")
|
| 185 |
+
end = raw.rfind("}")
|
| 186 |
+
|
| 187 |
+
if start == -1 or end == -1:
|
| 188 |
+
return _empty_result("no_json", raw)
|
| 189 |
+
|
| 190 |
+
json_part = raw[start:end + 1]
|
| 191 |
+
|
| 192 |
+
try:
|
| 193 |
+
data = json.loads(json_part)
|
| 194 |
+
|
| 195 |
+
except json.JSONDecodeError as e2:
|
| 196 |
+
print("[JSON DECODE ERROR]", str(e2))
|
| 197 |
+
print("[RAW OUTPUT]", raw)
|
| 198 |
+
return _empty_result("json_decode_error", raw)
|
| 199 |
+
|
| 200 |
+
except TypeError as e2:
|
| 201 |
+
print("[TYPE ERROR]", str(e2))
|
| 202 |
+
return _empty_result("type_error", raw)
|
| 203 |
+
|
| 204 |
+
scores = data.get("scores") or {
|
| 205 |
+
"temperature": 0.0,
|
| 206 |
+
"gibbs": 0.0
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
return {
|
| 210 |
+
"regime": data.get("regime", "unknown"),
|
| 211 |
+
"phase": data.get("phase", "unknown"),
|
| 212 |
+
"failures": [] if data.get("failure", "none") == "none" else [data.get("failure")],
|
| 213 |
+
"scores": scores,
|
| 214 |
+
"risk": data.get("risk", {}),
|
| 215 |
+
"actions": {},
|
| 216 |
+
"confidence": float(data.get("confidence", 0.0)),
|
| 217 |
+
"analysis": data.get("analysis", ""),
|
| 218 |
+
"reason": "json"
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def extract_json_safe(raw):
|
| 223 |
+
raw = (raw or "").strip()
|
| 224 |
+
|
| 225 |
+
try:
|
| 226 |
+
return json.loads(raw)
|
| 227 |
+
except json.JSONDecodeError:
|
| 228 |
+
return extract_json(raw)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def load_object_context(perfect_image_path, metrics):
|
| 232 |
+
if not perfect_image_path:
|
| 233 |
+
return
|
| 234 |
+
|
| 235 |
+
json_path = perfect_image_path.replace(".png", "_objects.json")
|
| 236 |
+
|
| 237 |
+
if os.path.exists(json_path):
|
| 238 |
+
try:
|
| 239 |
+
with open(json_path, "r") as f:
|
| 240 |
+
data = json.load(f)
|
| 241 |
+
|
| 242 |
+
metrics["objects"] = data.get("objects", {})
|
| 243 |
+
metrics["domain"] = data.get("domain", {})
|
| 244 |
+
|
| 245 |
+
except (json.JSONDecodeError, OSError) as e3:
|
| 246 |
+
print("[OBJECT LOAD ERROR]", str(e3))
|
| 247 |
+
|
| 248 |
+
metrics["objects"] = {}
|
| 249 |
+
metrics["domain"] = {}
|
| 250 |
+
else:
|
| 251 |
+
metrics["objects"] = {}
|
| 252 |
+
metrics["domain"] = {}
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def Evaluate(metrics_input, refined_img_path, perfect_img_path, client=None):
|
| 256 |
+
print("DEBUG perfect_img_path:", perfect_img_path)
|
| 257 |
+
|
| 258 |
+
if client is None:
|
| 259 |
+
return _empty_result("no_client")
|
| 260 |
+
|
| 261 |
+
metrics_nomadic = copy.deepcopy(metrics_input)
|
| 262 |
+
|
| 263 |
+
if "samples" in metrics_input:
|
| 264 |
+
metrics_nomadic["lpips_diversity"] = compute_lpips_diversity(metrics_input["samples"])
|
| 265 |
+
metrics_nomadic["diversity"] = min(1.0, metrics_nomadic["lpips_diversity"])
|
| 266 |
+
|
| 267 |
+
if refined_img_path and perfect_img_path:
|
| 268 |
+
lpips_score = compute_lpips(refined_img_path, perfect_img_path)
|
| 269 |
+
metrics_nomadic["image_similarity"] = 1.0 - lpips_score
|
| 270 |
+
|
| 271 |
+
print("[IMAGE LPIPS]", lpips_score)
|
| 272 |
+
|
| 273 |
+
metrics_nomadic = enforce_hierarchy(metrics_nomadic)
|
| 274 |
+
load_object_context(perfect_img_path, metrics_nomadic)
|
| 275 |
+
|
| 276 |
+
metrics_str = json.dumps(metrics_nomadic, indent=2)
|
| 277 |
+
|
| 278 |
+
CORE_BLOCK = load_core_principles("yaml/perception.yaml")
|
| 279 |
+
|
| 280 |
+
prompt_main = f"""
|
| 281 |
+
You are analyzing a thermodynamic learning system.
|
| 282 |
+
|
| 283 |
+
{CORE_BLOCK}
|
| 284 |
+
|
| 285 |
+
{metrics_str}
|
| 286 |
+
|
| 287 |
+
CRITICAL INTERPRETATION RULES:
|
| 288 |
+
|
| 289 |
+
- High image_similarity indicates that generated samples match the reference structure.
|
| 290 |
+
- This is a strong POSITIVE signal of successful learning.
|
| 291 |
+
|
| 292 |
+
- Low diversity alone is NOT a failure condition.
|
| 293 |
+
- Converged systems naturally produce similar outputs.
|
| 294 |
+
|
| 295 |
+
- Poor mixing (high tau_int) does NOT imply collapse.
|
| 296 |
+
|
| 297 |
+
STRICT CONSTRAINT:
|
| 298 |
+
|
| 299 |
+
- Do NOT classify mode collapse if image_similarity is high.
|
| 300 |
+
|
| 301 |
+
Even if:
|
| 302 |
+
- diversity is low
|
| 303 |
+
- entropy is low
|
| 304 |
+
- mixing is slow
|
| 305 |
+
|
| 306 |
+
PRIORITY RULE:
|
| 307 |
+
|
| 308 |
+
- Structural similarity OVERRIDES diversity-based signals.
|
| 309 |
+
|
| 310 |
+
- High reconstruction quality + high similarity = HEALTHY system.
|
| 311 |
+
|
| 312 |
+
- If image_similarity is high AND reconstruction quality is high: classify the system as "stable" or "ordered", not as failure.
|
| 313 |
+
|
| 314 |
+
---
|
| 315 |
+
|
| 316 |
+
Analyze visuals and metrics together.
|
| 317 |
+
|
| 318 |
+
IMAGE INTERPRETATION:
|
| 319 |
+
|
| 320 |
+
- High image_similarity → structures match (GOOD)
|
| 321 |
+
- Low image_similarity → structures differ
|
| 322 |
+
|
| 323 |
+
---
|
| 324 |
+
|
| 325 |
+
IMPORTANT DEFINITIONS:
|
| 326 |
+
|
| 327 |
+
- A "mode" is NOT a digit class.
|
| 328 |
+
- Each digit contains many stylistic variations (micro-modes).
|
| 329 |
+
- Structural differences define modes, not visual similarity.
|
| 330 |
+
|
| 331 |
+
---
|
| 332 |
+
|
| 333 |
+
COLLAPSE CRITERIA:
|
| 334 |
+
|
| 335 |
+
Only classify as "mode_collapse" if ALL of the following hold:
|
| 336 |
+
- repetitive structure
|
| 337 |
+
- low diversity
|
| 338 |
+
- structural degradation (low similarity)
|
| 339 |
+
- poor mixing
|
| 340 |
+
|
| 341 |
+
Do NOT classify collapse if:
|
| 342 |
+
- structure is preserved (high image_similarity)
|
| 343 |
+
- reconstruction quality is high
|
| 344 |
+
|
| 345 |
+
---
|
| 346 |
+
|
| 347 |
+
ADDITIONAL CLASSIFICATION TASK:
|
| 348 |
+
|
| 349 |
+
Classify the system into the following categories:
|
| 350 |
+
|
| 351 |
+
- regime: one of ["learning", "stable", "stagnant", "converged"]
|
| 352 |
+
- phase: one of ["ordered", "critical", "disordered"]
|
| 353 |
+
|
| 354 |
+
Guidelines:
|
| 355 |
+
|
| 356 |
+
- ordered: low entropy, high beta, structured outputs
|
| 357 |
+
- disordered: high entropy, weak structure
|
| 358 |
+
- critical: balance between structure and diversity
|
| 359 |
+
|
| 360 |
+
- learning: weights are still changing (delta_w noticeable)
|
| 361 |
+
- stable: learning slowed but system remains healthy
|
| 362 |
+
- stagnant: little to no progress
|
| 363 |
+
- converged: learning has effectively finished
|
| 364 |
+
|
| 365 |
+
Notes:
|
| 366 |
+
|
| 367 |
+
- Choose the closest matching category based on the evidence.
|
| 368 |
+
- Avoid leaving these fields undefined.
|
| 369 |
+
- When uncertain, select the most plausible category rather than abstaining.
|
| 370 |
+
|
| 371 |
+
---
|
| 372 |
+
|
| 373 |
+
OUTPUT INSTRUCTIONS:
|
| 374 |
+
|
| 375 |
+
- First explain your reasoning briefly.
|
| 376 |
+
- Then output a valid JSON block.
|
| 377 |
+
|
| 378 |
+
JSON ought to be the LAST part of your response.
|
| 379 |
+
|
| 380 |
+
RULES:
|
| 381 |
+
|
| 382 |
+
- Scores ought to be between 0 and 1
|
| 383 |
+
- Do NOT output discrete actions
|
| 384 |
+
- Higher score = stronger recommendation
|
| 385 |
+
"""
|
| 386 |
+
|
| 387 |
+
prompt_format = """
|
| 388 |
+
FORMAT:
|
| 389 |
+
|
| 390 |
+
<reasoning>
|
| 391 |
+
|
| 392 |
+
{
|
| 393 |
+
"analysis": "...",
|
| 394 |
+
|
| 395 |
+
"regime": "...",
|
| 396 |
+
"phase": "...",
|
| 397 |
+
|
| 398 |
+
"scores": {
|
| 399 |
+
"temperature": 0.0,
|
| 400 |
+
"gibbs": 0.0
|
| 401 |
+
},
|
| 402 |
+
|
| 403 |
+
"risk": {
|
| 404 |
+
"stagnation": 0.0,
|
| 405 |
+
"collapse": 0.0,
|
| 406 |
+
"over_ordering": 0.0
|
| 407 |
+
},
|
| 408 |
+
|
| 409 |
+
"confidence": 0.0
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
EPISTEMIC INTERPRETATION (MANDATORY):
|
| 413 |
+
|
| 414 |
+
- You ought to interpret confidence as bounded by evidence.
|
| 415 |
+
- Confidence is not an independent belief; it is a function of evidence.
|
| 416 |
+
|
| 417 |
+
- If your internal belief is high but evidence is limited, you ought to explain the discrepancy.
|
| 418 |
+
|
| 419 |
+
- You ought to explicitly reflect this relationship in your analysis:
|
| 420 |
+
|
| 421 |
+
- High belief + low evidence → constrained confidence
|
| 422 |
+
- Low evidence → weak epistemic justification
|
| 423 |
+
|
| 424 |
+
- Do NOT just output a number. You ought to justify confidence in terms of evidence.
|
| 425 |
+
|
| 426 |
+
FORMAL RULE:
|
| 427 |
+
confidence ≤ evidence
|
| 428 |
+
|
| 429 |
+
INTERPRETATION RULE:
|
| 430 |
+
confidence represents only what is supported by evidence, not what is intuitively likely.
|
| 431 |
+
|
| 432 |
+
STRICT OUTPUT RULES:
|
| 433 |
+
- Output ONLY one JSON object
|
| 434 |
+
- Do NOT include text after JSON
|
| 435 |
+
- Do NOT include markdown (no ```json)
|
| 436 |
+
- JSON ought to be directly parseable by json.loads()
|
| 437 |
+
"""
|
| 438 |
+
|
| 439 |
+
prompt = prompt_main + prompt_format
|
| 440 |
+
|
| 441 |
+
try:
|
| 442 |
+
img1 = encode_image(refined_img_path)
|
| 443 |
+
img2 = encode_image(perfect_img_path)
|
| 444 |
+
|
| 445 |
+
raw = client.generate(
|
| 446 |
+
prompt,
|
| 447 |
+
images=[img1, img2]
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
if raw is None or not str(raw).strip():
|
| 451 |
+
return _empty_result("empty_response")
|
| 452 |
+
|
| 453 |
+
if len(raw) < 2000:
|
| 454 |
+
print("\n[LLM RAW OUTPUT]\n", raw)
|
| 455 |
+
else:
|
| 456 |
+
print("\n[LLM RAW OUTPUT] (truncated)\n", raw[:2000])
|
| 457 |
+
|
| 458 |
+
result = extract_json_safe(raw)
|
| 459 |
+
result = validate_llm_output(result, metrics_nomadic)
|
| 460 |
+
|
| 461 |
+
if not result.get("analysis") or len(result.get("analysis", "").strip()) < 5:
|
| 462 |
+
result["analysis"] = "No meaningful analysis provided by LLM"
|
| 463 |
+
|
| 464 |
+
llm_conf = result.get("confidence", 0.0)
|
| 465 |
+
evidence = compute_evidence(metrics_nomadic)
|
| 466 |
+
|
| 467 |
+
final_conf = min(llm_conf, evidence)
|
| 468 |
+
|
| 469 |
+
sim = metrics_nomadic.get("image_similarity", 0.0)
|
| 470 |
+
|
| 471 |
+
if evidence < 0.1 and sim < 0.7:
|
| 472 |
+
final_conf = 0.0
|
| 473 |
+
result["reason"] = "no_evidence"
|
| 474 |
+
|
| 475 |
+
final_conf = max(0.0, min(1.0, final_conf))
|
| 476 |
+
result["confidence"] = final_conf
|
| 477 |
+
result["evidence"] = evidence
|
| 478 |
+
result["llm_conf_raw"] = llm_conf
|
| 479 |
+
|
| 480 |
+
print(f"[SYS] EVIDENCE={evidence:.3f} | LLM={llm_conf:.3f} → FINAL={final_conf:.3f}")
|
| 481 |
+
|
| 482 |
+
if result["confidence"] < 0.5:
|
| 483 |
+
for k in result.get("scores", {}):
|
| 484 |
+
result["scores"][k] = max(0.0, min(1.0, result["scores"][k] * 0.5))
|
| 485 |
+
if result.get("reason"):
|
| 486 |
+
result["reason"] += "|low_confidence_scaled"
|
| 487 |
+
else:
|
| 488 |
+
result["reason"] = "low_confidence_scaled"
|
| 489 |
+
|
| 490 |
+
print("\n[LLM CLEAN RESULT]\n", result)
|
| 491 |
+
|
| 492 |
+
if result.get("phase") == "unknown":
|
| 493 |
+
print("[WARNING] phase unresolved after parsing")
|
| 494 |
+
|
| 495 |
+
if result.get("regime") == "unknown":
|
| 496 |
+
print("[WARNING] regime unresolved after parsing")
|
| 497 |
+
|
| 498 |
+
return result
|
| 499 |
+
|
| 500 |
+
except (AttributeError, KeyError) as e4:
|
| 501 |
+
print("\n[RESPONSE FORMAT ERROR]", str(e4))
|
| 502 |
+
return _empty_result("response_format_error")
|
| 503 |
+
|
| 504 |
+
except RuntimeError as e4:
|
| 505 |
+
print("\n[RUNTIME ERROR]", str(e4))
|
| 506 |
+
return _empty_result("runtime_error")
|
| 507 |
+
|
| 508 |
+
except OSError as e4:
|
| 509 |
+
print("\n[IO ERROR]", str(e4))
|
| 510 |
+
return _empty_result("io_error")
|
| 511 |
+
|
| 512 |
+
except json.JSONDecodeError as e4:
|
| 513 |
+
print("\n[JSON ERROR]", str(e4))
|
| 514 |
+
return _empty_result("json_error")
|
| 515 |
+
|
| 516 |
+
except Exception as e4:
|
| 517 |
+
print("\n[UNEXPECTED ERROR]", type(e4).__name__, str(e4))
|
| 518 |
+
raise
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
def compute_evidence(metrics):
|
| 522 |
+
signals = [
|
| 523 |
+
metrics.get("image_similarity", 0.0),
|
| 524 |
+
1 - metrics.get("std", 1.0),
|
| 525 |
+
metrics.get("flip_rate", 0.0),
|
| 526 |
+
metrics.get("diversity", 0.5),
|
| 527 |
+
]
|
| 528 |
+
|
| 529 |
+
delta_w = metrics.get("delta_w", 0.0)
|
| 530 |
+
signals.append(min(1.0, delta_w * 1000))
|
| 531 |
+
|
| 532 |
+
return sum(signals) / len(signals)
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
def enforce_hierarchy(metrics):
|
| 536 |
+
sim = metrics.get("image_similarity", 0.0)
|
| 537 |
+
|
| 538 |
+
if sim > 0.8:
|
| 539 |
+
if "diversity" in metrics:
|
| 540 |
+
metrics["diversity"] *= 0.3
|
| 541 |
+
|
| 542 |
+
return metrics
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
def validate_llm_output(result, metrics):
|
| 546 |
+
sim = metrics.get("image_similarity", 0.0)
|
| 547 |
+
|
| 548 |
+
if sim > 0.8:
|
| 549 |
+
analysis = str(result.get("analysis", "")).lower()
|
| 550 |
+
|
| 551 |
+
if "collapse" in analysis and "no collapse" not in analysis:
|
| 552 |
+
result.setdefault("risk", {})
|
| 553 |
+
result["risk"]["collapse"] = 0.0
|
| 554 |
+
result["analysis"] += " [OVERRIDDEN: high structural similarity]"
|
| 555 |
+
|
| 556 |
+
return result
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def _cosine(a, b):
|
| 560 |
+
dot = sum(x * y for x, y in zip(a, b))
|
| 561 |
+
na = math.sqrt(sum(x * x for x in a))
|
| 562 |
+
nb = math.sqrt(sum(x * x for x in b))
|
| 563 |
+
|
| 564 |
+
return dot / (na * nb + 1e-8)
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
def _embed(text):
|
| 568 |
+
text_log = text.lower()
|
| 569 |
+
|
| 570 |
+
return [
|
| 571 |
+
text_log.count("collapse") + text_log.count("failed"),
|
| 572 |
+
text_log.count("stagnation") + text_log.count("plateau"),
|
| 573 |
+
text_log.count("no learning") + text_log.count("stopped"),
|
| 574 |
+
text_log.count("learning"),
|
| 575 |
+
text_log.count("improving"),
|
| 576 |
+
text_log.count("stable"),
|
| 577 |
+
text_log.count("healthy"),
|
| 578 |
+
text_log.count("blur") + text_log.count("noisy"),
|
| 579 |
+
text_log.count("diversity")
|
| 580 |
+
]
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
def ANASIS(text, metrics=None):
|
| 584 |
+
if not text:
|
| 585 |
+
return 0.5
|
| 586 |
+
|
| 587 |
+
text_vec = _embed(text.lower())
|
| 588 |
+
|
| 589 |
+
ref_collapse = _embed("model collapse failure unstable diverging system")
|
| 590 |
+
ref_stagnation = _embed("learning stagnation plateau no progress stopped system")
|
| 591 |
+
ref_healthy = _embed("stable improving healthy well-trained diverse system")
|
| 592 |
+
|
| 593 |
+
sim_collapse = _cosine(text_vec, ref_collapse)
|
| 594 |
+
sim_stagnation = _cosine(text_vec, ref_stagnation)
|
| 595 |
+
sim_healthy = _cosine(text_vec, ref_healthy)
|
| 596 |
+
|
| 597 |
+
risk_signal = max(sim_collapse, sim_stagnation)
|
| 598 |
+
|
| 599 |
+
diversity_term = 0.0
|
| 600 |
+
|
| 601 |
+
if metrics is not None:
|
| 602 |
+
d = metrics.get("diversity", 0.5)
|
| 603 |
+
diversity_term = (1 - d) ** 2
|
| 604 |
+
|
| 605 |
+
energy = (
|
| 606 |
+
0.5 * risk_signal +
|
| 607 |
+
0.3 * (1 - sim_healthy) +
|
| 608 |
+
0.2 * diversity_term
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
energy = max(0.0, min(1.0, energy))
|
| 612 |
+
if sim_healthy == 0:
|
| 613 |
+
consistency = abs(1.0 - energy)
|
| 614 |
+
else:
|
| 615 |
+
consistency = abs(1 - (sim_healthy + energy))
|
| 616 |
+
|
| 617 |
+
consistency = max(0.0, min(1.0, consistency))
|
| 618 |
+
|
| 619 |
+
energy = max(0.0, min(1.0, energy))
|
| 620 |
+
|
| 621 |
+
final_score = energy - 0.1 * consistency
|
| 622 |
+
|
| 623 |
+
final_score = max(0.0, min(1.0, final_score))
|
| 624 |
+
|
| 625 |
+
print(
|
| 626 |
+
f"[ANASIS] collapse={sim_collapse:.3f} "
|
| 627 |
+
f"stagnation={sim_stagnation:.3f} "
|
| 628 |
+
f"healthy_align={consistency:.3f} "
|
| 629 |
+
f"→ risk={final_score:.3f}"
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
return final_score
|
llmeS/hook.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LLM ENERGY MODULE (PROVIDER-AGNOSTIC SEMANTIC ENERGY)
|
| 2 |
+
#
|
| 3 |
+
# This module defines an LLM-based energy estimator that interprets
|
| 4 |
+
# structured visual representations as probabilistic signals.
|
| 5 |
+
#
|
| 6 |
+
# DESIGN PRINCIPLE:
|
| 7 |
+
# The LLM is treated as a black-box semantic energy oracle.
|
| 8 |
+
# No assumptions are made about the underlying provider.
|
| 9 |
+
#
|
| 10 |
+
# All interaction is performed through a global client:
|
| 11 |
+
#
|
| 12 |
+
# raw = client.generate(prompt)
|
| 13 |
+
#
|
| 14 |
+
# The client MUST return raw text output.
|
| 15 |
+
# Any API-specific logic is handled externally.
|
| 16 |
+
#
|
| 17 |
+
# This module is responsible for:
|
| 18 |
+
# - constructing structured prompts from sparse representations
|
| 19 |
+
# - extracting probability distributions from raw LLM output
|
| 20 |
+
# - validating and normalizing probability vectors
|
| 21 |
+
#
|
| 22 |
+
# FAILURE MODE:
|
| 23 |
+
# If the LLM backend fails or returns invalid output:
|
| 24 |
+
# - the function returns (None, error_code)
|
| 25 |
+
# - the main system continues without interruption
|
| 26 |
+
#
|
| 27 |
+
# ARCHITECTURAL ROLE:
|
| 28 |
+
# This module provides a semantic energy term that augments
|
| 29 |
+
# the thermodynamic model with external probabilistic signals.
|
| 30 |
+
|
| 31 |
+
import json
|
| 32 |
+
import torch
|
| 33 |
+
import numpy as np
|
| 34 |
+
import hashlib
|
| 35 |
+
import torch.nn.functional as F
|
| 36 |
+
|
| 37 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def extract_json_block(text: str):
|
| 41 |
+
stack = []
|
| 42 |
+
start_idx = None
|
| 43 |
+
for into, ch in enumerate(text):
|
| 44 |
+
if ch == "{":
|
| 45 |
+
if start_idx is None:
|
| 46 |
+
start_idx = into
|
| 47 |
+
stack.append(ch)
|
| 48 |
+
elif ch == "}":
|
| 49 |
+
if stack:
|
| 50 |
+
stack.pop()
|
| 51 |
+
if not stack and start_idx is not None:
|
| 52 |
+
return text[start_idx:into + 1]
|
| 53 |
+
return None
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def upscale_to_64(grid_2d):
|
| 57 |
+
if grid_2d.dim() == 2:
|
| 58 |
+
grid_2d = grid_2d.unsqueeze(0).unsqueeze(0)
|
| 59 |
+
upscaled = F.interpolate(grid_2d, size=(64, 64), mode="bilinear", align_corners=False)
|
| 60 |
+
return upscaled.squeeze(0).squeeze(0)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _component_direction_pca(coordinates_np):
|
| 64 |
+
if len(coordinates_np) < 3:
|
| 65 |
+
return [0.0, 0.0]
|
| 66 |
+
|
| 67 |
+
cov = np.cov(coordinates_np.T)
|
| 68 |
+
eigenvalues, eigenvectors = np.linalg.eig(cov)
|
| 69 |
+
main_vec = eigenvectors[:, np.argmax(eigenvalues)]
|
| 70 |
+
|
| 71 |
+
norm = np.linalg.norm(main_vec) + 1e-12
|
| 72 |
+
main_vec = main_vec / norm
|
| 73 |
+
|
| 74 |
+
if main_vec[0] < 0:
|
| 75 |
+
main_vec = -main_vec
|
| 76 |
+
|
| 77 |
+
return [float(main_vec[0]), float(main_vec[1])]
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _component_features(coordinates_list):
|
| 81 |
+
coordinate = np.array(coordinates_list)
|
| 82 |
+
|
| 83 |
+
if coordinate.size == 0:
|
| 84 |
+
return {
|
| 85 |
+
"size": 0,
|
| 86 |
+
"bbox": [0, 0, 0, 0],
|
| 87 |
+
"center": [0.0, 0.0],
|
| 88 |
+
"direction": [0.0, 0.0],
|
| 89 |
+
"aspect_ratio": 1.0
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
ys = coordinate[:, 0]
|
| 93 |
+
xs = coordinate[:, 1]
|
| 94 |
+
|
| 95 |
+
y1, y2 = int(ys.min()), int(ys.max())
|
| 96 |
+
x1, x2 = int(xs.min()), int(xs.max())
|
| 97 |
+
|
| 98 |
+
center = [float(ys.mean()), float(xs.mean())]
|
| 99 |
+
direction = _component_direction_pca(coordinate)
|
| 100 |
+
|
| 101 |
+
aspect_ratio = (y2 - y1 + 1) / (x2 - x1 + 1 + 1e-6)
|
| 102 |
+
|
| 103 |
+
return {
|
| 104 |
+
"size": int(len(coordinates_list)),
|
| 105 |
+
"bbox": [y1, x1, y2, x2],
|
| 106 |
+
"center": center,
|
| 107 |
+
"direction": direction,
|
| 108 |
+
"aspect_ratio": float(aspect_ratio)
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def find_connected_components_fast(grid_2d, threshold=0.15):
|
| 113 |
+
from scipy.ndimage import label
|
| 114 |
+
|
| 115 |
+
binary = (grid_2d > threshold).detach().cpu().numpy()
|
| 116 |
+
labeled, num = label(binary)
|
| 117 |
+
|
| 118 |
+
components = []
|
| 119 |
+
for into in range(1, num + 1):
|
| 120 |
+
coordinates_plus = np.argwhere(labeled == into)
|
| 121 |
+
|
| 122 |
+
components.append(coordinates_plus.tolist())
|
| 123 |
+
|
| 124 |
+
return components
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def to_sparse_gpu(grid_2d, threshold=0.15):
|
| 128 |
+
grid_2d = grid_2d.to(device)
|
| 129 |
+
mask = grid_2d > threshold
|
| 130 |
+
|
| 131 |
+
coordinate_grow = torch.nonzero(mask, as_tuple=False)
|
| 132 |
+
|
| 133 |
+
values = grid_2d[mask]
|
| 134 |
+
|
| 135 |
+
coordinates_cpu = coordinate_grow.cpu().numpy()
|
| 136 |
+
values_cpu = values.cpu().numpy()
|
| 137 |
+
|
| 138 |
+
active_pixels = [
|
| 139 |
+
[int(r), int(clm), round(float(v), 3)]
|
| 140 |
+
for (r, clm), v in zip(coordinates_cpu, values_cpu)
|
| 141 |
+
]
|
| 142 |
+
|
| 143 |
+
if len(active_pixels) > 800:
|
| 144 |
+
step = max(1, len(active_pixels) // 800)
|
| 145 |
+
active_pixels = active_pixels[::step]
|
| 146 |
+
|
| 147 |
+
raw_components = find_connected_components_fast(grid_2d, threshold)
|
| 148 |
+
|
| 149 |
+
comp_objs = []
|
| 150 |
+
for comp in raw_components:
|
| 151 |
+
feats = _component_features(comp)
|
| 152 |
+
|
| 153 |
+
step = max(1, len(comp) // 50)
|
| 154 |
+
simplified = comp[::step]
|
| 155 |
+
|
| 156 |
+
comp_objs.append({
|
| 157 |
+
"points": simplified,
|
| 158 |
+
"size": feats["size"],
|
| 159 |
+
"bbox": feats["bbox"],
|
| 160 |
+
"center": feats["center"],
|
| 161 |
+
"direction": feats["direction"],
|
| 162 |
+
"aspect_ratio": feats["aspect_ratio"]
|
| 163 |
+
})
|
| 164 |
+
|
| 165 |
+
return json.dumps({
|
| 166 |
+
"size": list(grid_2d.shape),
|
| 167 |
+
"pixels": active_pixels,
|
| 168 |
+
"components": comp_objs
|
| 169 |
+
})
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def hash_repr(s: str):
|
| 173 |
+
return hashlib.md5(s.encode()).hexdigest()
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
_cache = {}
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def LLMEnergy(image_repr: str, client=None):
|
| 180 |
+
if client is None:
|
| 181 |
+
return None, "no_client"
|
| 182 |
+
|
| 183 |
+
key = hash_repr(image_repr)
|
| 184 |
+
if key in _cache:
|
| 185 |
+
return _cache[key], None
|
| 186 |
+
|
| 187 |
+
prompt = f"""
|
| 188 |
+
You MUST follow:
|
| 189 |
+
1. Use components to infer topology
|
| 190 |
+
2. Use pixels to infer intensity
|
| 191 |
+
3. Use component directions to infer stroke flow
|
| 192 |
+
4. Use aspect ratios to refine structural interpretation
|
| 193 |
+
5. Output calibrated probabilities
|
| 194 |
+
|
| 195 |
+
Return ONLY JSON:
|
| 196 |
+
{{"probs": [p0, p1, ..., p9]}}
|
| 197 |
+
|
| 198 |
+
Image:
|
| 199 |
+
{image_repr}
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
try:
|
| 203 |
+
raw_output = client.generate(prompt)
|
| 204 |
+
|
| 205 |
+
if raw_output is None or not str(raw_output).strip():
|
| 206 |
+
return None, "empty_response"
|
| 207 |
+
|
| 208 |
+
raw_output = raw_output.strip()
|
| 209 |
+
|
| 210 |
+
except Exception as e:
|
| 211 |
+
return None, f"client_error: {e}"
|
| 212 |
+
|
| 213 |
+
try:
|
| 214 |
+
parsed_json = json.loads(raw_output)
|
| 215 |
+
except json.JSONDecodeError:
|
| 216 |
+
json_block = extract_json_block(raw_output)
|
| 217 |
+
if json_block is None:
|
| 218 |
+
return None, "json_not_found"
|
| 219 |
+
try:
|
| 220 |
+
parsed_json = json.loads(json_block)
|
| 221 |
+
except json.JSONDecodeError:
|
| 222 |
+
return None, "json_parse_error"
|
| 223 |
+
|
| 224 |
+
probs = parsed_json.get("probs")
|
| 225 |
+
|
| 226 |
+
if not isinstance(probs, list) or len(probs) != 10:
|
| 227 |
+
return None, "invalid_probs"
|
| 228 |
+
|
| 229 |
+
try:
|
| 230 |
+
prob_vector = np.array(probs, dtype=np.float32)
|
| 231 |
+
except ValueError:
|
| 232 |
+
return None, "non_numeric_probs"
|
| 233 |
+
|
| 234 |
+
total = prob_vector.sum()
|
| 235 |
+
if total <= 0:
|
| 236 |
+
return None, "invalid_distribution"
|
| 237 |
+
|
| 238 |
+
prob_vector /= total
|
| 239 |
+
|
| 240 |
+
result = {"probs": prob_vector.tolist()}
|
| 241 |
+
_cache[key] = result
|
| 242 |
+
|
| 243 |
+
return result, None
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def LIES_gpu(prob_vector, alpha=5.0, beta=2.0, gamma=2.0):
|
| 247 |
+
prob = torch.tensor(prob_vector, device=device)
|
| 248 |
+
|
| 249 |
+
eps = 1e-12
|
| 250 |
+
entropy = -torch.sum(prob * torch.log(prob + eps))
|
| 251 |
+
confidence = torch.max(prob)
|
| 252 |
+
|
| 253 |
+
top2 = torch.topk(prob, 2).values
|
| 254 |
+
margin_out = top2[0] - top2[1]
|
| 255 |
+
|
| 256 |
+
energy = alpha * entropy - beta * confidence - gamma * margin_out
|
| 257 |
+
return energy.item()
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def process_digit(grid_2d, client=None):
|
| 261 |
+
grid_2d = grid_2d.to(device)
|
| 262 |
+
|
| 263 |
+
if grid_2d.numel() == 784:
|
| 264 |
+
grid_2d = grid_2d.view(28, 28)
|
| 265 |
+
grid_2d = upscale_to_64(grid_2d)
|
| 266 |
+
elif grid_2d.numel() == 4096:
|
| 267 |
+
grid_2d = grid_2d.view(64, 64)
|
| 268 |
+
else:
|
| 269 |
+
raise ValueError("Unexpected input size")
|
| 270 |
+
|
| 271 |
+
image_repr = to_sparse_gpu(grid_2d, threshold=0.15)
|
| 272 |
+
|
| 273 |
+
result, err = LLMEnergy(image_repr, client)
|
| 274 |
+
if err:
|
| 275 |
+
return None, err
|
| 276 |
+
|
| 277 |
+
energy = LIES_gpu(result["probs"])
|
| 278 |
+
|
| 279 |
+
return {"probs": result["probs"], "energy": energy}, None
|
srtrbm_project_core.py
ADDED
|
@@ -0,0 +1,1600 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Self-Regulated Thermodynamic RBM (SR-TRBM)
|
| 2 |
+
#
|
| 3 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
#
|
| 5 |
+
# Copyright © 2026 Görkem Can Süleymanoğlu
|
| 6 |
+
#
|
| 7 |
+
# This implementation realizes a thermodynamically self-regulated
|
| 8 |
+
# energy-based model operating in a finite-time stochastic regime.
|
| 9 |
+
# The framework extends classical Restricted Boltzmann Machines by
|
| 10 |
+
# introducing endogenous control of sampling dynamics and bounded
|
| 11 |
+
# external energy corrections.
|
| 12 |
+
#
|
| 13 |
+
# Training is performed using Persistent Contrastive Divergence
|
| 14 |
+
# (PCD-k), where the negative phase is approximated by short-run
|
| 15 |
+
# Gibbs chains. To use conductance collapse and Gibbs freezing,
|
| 16 |
+
# the model employs a hybrid temperature mechanism:
|
| 17 |
+
#
|
| 18 |
+
# T = T_micro + T_macro
|
| 19 |
+
#
|
| 20 |
+
# The microscopic component follows a feedback control law based
|
| 21 |
+
# on the empirical flip rate r_t:
|
| 22 |
+
#
|
| 23 |
+
# λ_{t+1} = λ_t - η (r_t - c_t)
|
| 24 |
+
#
|
| 25 |
+
# where c_t is an exponentially smoothed reference, inducing
|
| 26 |
+
# closed-loop stabilization of stochastic transition intensity.
|
| 27 |
+
#
|
| 28 |
+
# The model incorporates a localized LLM-guided refinement mechanism
|
| 29 |
+
# for proposals within an uncertainty band (ΔE ∈ [0, 2]).
|
| 30 |
+
#
|
| 31 |
+
# LLM contributions are normalized via a running variance estimate:
|
| 32 |
+
# σ² ← EMA(ΔE_model²), ΔE_llm ← ΔE_llm / σ
|
| 33 |
+
#
|
| 34 |
+
# Designed for CUDA-enabled GPU execution with fixed seeds.
|
| 35 |
+
# Average runtime per seed: ~20–30 minutes on RTX-class GPUs.
|
| 36 |
+
# AIS estimation and multichain Gibbs sampling are dominant costs.
|
| 37 |
+
#
|
| 38 |
+
# Thermal monitoring is recommended (e.g., nvidia-smi).
|
| 39 |
+
# Multi-seed runs should be executed sequentially.
|
| 40 |
+
#
|
| 41 |
+
# Full license text is provided in the root LICENSE file.
|
| 42 |
+
|
| 43 |
+
from graphs import SrtrbmVisualization
|
| 44 |
+
from graphs import SrtrbmEnergy
|
| 45 |
+
from graphs import SrtrbmMetrics
|
| 46 |
+
|
| 47 |
+
from analysis.AutoGPU import GPUEnergyTracker
|
| 48 |
+
from correction.NO import Refinement
|
| 49 |
+
|
| 50 |
+
import torch
|
| 51 |
+
import torch.nn.functional as F
|
| 52 |
+
import torch.multiprocessing as mp
|
| 53 |
+
import torchvision.utils as vutils
|
| 54 |
+
from tqdm import tqdm
|
| 55 |
+
from PIL import Image
|
| 56 |
+
import numpy as np
|
| 57 |
+
import subprocess
|
| 58 |
+
import threading
|
| 59 |
+
import textwrap
|
| 60 |
+
import json
|
| 61 |
+
import math
|
| 62 |
+
import time
|
| 63 |
+
import sys
|
| 64 |
+
|
| 65 |
+
import matplotlib
|
| 66 |
+
|
| 67 |
+
matplotlib.use("Agg")
|
| 68 |
+
|
| 69 |
+
import matplotlib as mpl
|
| 70 |
+
|
| 71 |
+
import matplotlib.pyplot as plt
|
| 72 |
+
|
| 73 |
+
from sklearn.datasets import fetch_openml
|
| 74 |
+
|
| 75 |
+
from llmeS.hook import LLMEnergy, to_sparse_gpu, LIES_gpu
|
| 76 |
+
from llmeS.gateway import Evaluate, ANASIS
|
| 77 |
+
from llmeS.client import SafeBookClient
|
| 78 |
+
|
| 79 |
+
mpl.rcParams.update({
|
| 80 |
+
"font.family": "serif",
|
| 81 |
+
"font.serif": ["Times New Roman", "Times", "DejaVu Serif"],
|
| 82 |
+
"font.size": 11,
|
| 83 |
+
"axes.titlesize": 12,
|
| 84 |
+
"axes.labelsize": 11,
|
| 85 |
+
"legend.fontsize": 9,
|
| 86 |
+
"xtick.labelsize": 9,
|
| 87 |
+
"ytick.labelsize": 9,
|
| 88 |
+
"figure.dpi": 300,
|
| 89 |
+
"savefig.dpi": 300,
|
| 90 |
+
"pdf.fonttype": 42,
|
| 91 |
+
"ps.fonttype": 42
|
| 92 |
+
})
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# Data Model
|
| 96 |
+
|
| 97 |
+
def load_mnist(device_exact):
|
| 98 |
+
dataSet = fetch_openml("mnist_784", version=1, cache=True)
|
| 99 |
+
X_prime = dataSet.data.to_numpy(dtype="float32") / 255.0
|
| 100 |
+
X_prime = (X_prime > -0.0).astype("float32")
|
| 101 |
+
|
| 102 |
+
return torch.tensor(X_prime).to(device_exact)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def susceptibility(samples):
|
| 106 |
+
spins = 2 * samples - 1
|
| 107 |
+
|
| 108 |
+
magnetization = spins.mean(dim=1)
|
| 109 |
+
|
| 110 |
+
chi = magnetization.var(unbiased=False)
|
| 111 |
+
|
| 112 |
+
return chi.item()
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def binder_cumulant(samples):
|
| 116 |
+
spins = 2 * samples - 1
|
| 117 |
+
|
| 118 |
+
spins = spins.reshape(spins.size(0), -1)
|
| 119 |
+
|
| 120 |
+
m = spins.mean(dim=1)
|
| 121 |
+
|
| 122 |
+
m2 = torch.mean(m ** 2)
|
| 123 |
+
m4 = torch.mean(m ** 4)
|
| 124 |
+
|
| 125 |
+
U = 1 - m4 / (3 * (m2 ** 2) + 1e-12)
|
| 126 |
+
|
| 127 |
+
return U.item()
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# Model Point
|
| 131 |
+
|
| 132 |
+
class HybridThermodynamicRBM:
|
| 133 |
+
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
n_visible,
|
| 137 |
+
n_hidden,
|
| 138 |
+
device_type="cuda",
|
| 139 |
+
learning_rate=5e-4,
|
| 140 |
+
epochs=400,
|
| 141 |
+
gibbs_steps=1,
|
| 142 |
+
batch_size=128,
|
| 143 |
+
lambda_gain=0.01,
|
| 144 |
+
flip_smoothing=0.01,
|
| 145 |
+
energy_temp_scale=0.001,
|
| 146 |
+
weight_decay=1e-4,
|
| 147 |
+
fixed_temperature=None
|
| 148 |
+
):
|
| 149 |
+
self.v_bias = None
|
| 150 |
+
self.h_bias = None
|
| 151 |
+
|
| 152 |
+
self.analysis_history = []
|
| 153 |
+
|
| 154 |
+
self.abort_counter = getattr(self, "abort_counter", 0)
|
| 155 |
+
|
| 156 |
+
self.device = torch.device(device_type)
|
| 157 |
+
|
| 158 |
+
self.n_visible = n_visible
|
| 159 |
+
self.n_hidden = n_hidden
|
| 160 |
+
self.lr = learning_rate
|
| 161 |
+
self.epochs = epochs
|
| 162 |
+
self.gibbs_steps = gibbs_steps
|
| 163 |
+
self.batch_size = batch_size
|
| 164 |
+
|
| 165 |
+
# Thermodynamic parameters
|
| 166 |
+
|
| 167 |
+
self.lambda_gain = lambda_gain
|
| 168 |
+
self.flip_smoothing = flip_smoothing
|
| 169 |
+
self.energy_temp_scale = energy_temp_scale
|
| 170 |
+
self.weight_decay = weight_decay
|
| 171 |
+
|
| 172 |
+
self.fixed_temperature = fixed_temperature
|
| 173 |
+
|
| 174 |
+
self.command = 1.0
|
| 175 |
+
self.bias_decay = 0.0
|
| 176 |
+
self.energy_count = 0
|
| 177 |
+
self.lambda_llm = 1.29
|
| 178 |
+
|
| 179 |
+
self.llm_scale = 1.0
|
| 180 |
+
|
| 181 |
+
self.llm_scale_initialized = False
|
| 182 |
+
|
| 183 |
+
# Model parameters
|
| 184 |
+
|
| 185 |
+
self.coefficient_zeta = 0.05
|
| 186 |
+
|
| 187 |
+
self.W = torch.randn(n_visible, n_hidden, device=self.device) * self.coefficient_zeta
|
| 188 |
+
|
| 189 |
+
self.visible_bias = torch.zeros(n_visible, device=self.device)
|
| 190 |
+
self.hidden_bias = torch.zeros(n_hidden, device=self.device)
|
| 191 |
+
|
| 192 |
+
# Thermodynamic states
|
| 193 |
+
|
| 194 |
+
self.log_temperature = torch.tensor(0.0, device=self.device)
|
| 195 |
+
self.flip_reference = torch.tensor(0.0, device=self.device)
|
| 196 |
+
self.energy_avg = torch.tensor(0.0, device=self.device)
|
| 197 |
+
|
| 198 |
+
self.persistent_sampling = None
|
| 199 |
+
|
| 200 |
+
# Monitoring histories
|
| 201 |
+
|
| 202 |
+
self.flip_hist = []
|
| 203 |
+
self.c_hist = []
|
| 204 |
+
self.weight_norm_hist = []
|
| 205 |
+
self.beta_eff_hist = []
|
| 206 |
+
self.temp_hist = []
|
| 207 |
+
self.delta_w_hist = []
|
| 208 |
+
self.persistent_div_hist = []
|
| 209 |
+
self.T_micro_hist = []
|
| 210 |
+
self.T_macro_hist = []
|
| 211 |
+
self.F_data_hist = []
|
| 212 |
+
self.F_model_hist = []
|
| 213 |
+
self.F_gap_hist = []
|
| 214 |
+
self.true_beta_hist = []
|
| 215 |
+
self.spectral_beta_hist = []
|
| 216 |
+
|
| 217 |
+
# Temperature (Hybrid Rule)
|
| 218 |
+
|
| 219 |
+
def temperature(self):
|
| 220 |
+
if self.fixed_temperature is not None:
|
| 221 |
+
return torch.tensor(self.fixed_temperature, device=self.device, dtype=torch.float32)
|
| 222 |
+
|
| 223 |
+
micro = torch.exp(self.log_temperature)
|
| 224 |
+
macro = self.energy_temp_scale * self.energy_avg
|
| 225 |
+
|
| 226 |
+
T = micro + macro
|
| 227 |
+
|
| 228 |
+
return torch.clamp(T, min=1e-6, max=1e3)
|
| 229 |
+
|
| 230 |
+
# Energy
|
| 231 |
+
|
| 232 |
+
def raw_energy(self, v, h):
|
| 233 |
+
return - (v @ self.W * h).sum(1) \
|
| 234 |
+
- (v * self.visible_bias).sum(1) \
|
| 235 |
+
- (h * self.hidden_bias).sum(1)
|
| 236 |
+
|
| 237 |
+
def free_energy(self, v, T):
|
| 238 |
+
activation = (v @ self.W + self.hidden_bias) / T
|
| 239 |
+
return -(v * self.visible_bias).sum(1) / T \
|
| 240 |
+
- F.softplus(activation).sum(1)
|
| 241 |
+
|
| 242 |
+
@torch.no_grad()
|
| 243 |
+
def reconstruction_accuracy(self, data):
|
| 244 |
+
|
| 245 |
+
recon_prob = self.reconstruct(data)
|
| 246 |
+
recon_bin = (recon_prob > 0.5).float()
|
| 247 |
+
|
| 248 |
+
correct = (recon_bin == data).float().sum()
|
| 249 |
+
total = data.numel()
|
| 250 |
+
|
| 251 |
+
return (correct / total).item()
|
| 252 |
+
|
| 253 |
+
# Gibbs Sampling
|
| 254 |
+
|
| 255 |
+
@torch.no_grad()
|
| 256 |
+
def gibbs_chain(self, v, T, steps=None):
|
| 257 |
+
|
| 258 |
+
steps = steps if steps is not None else self.gibbs_steps
|
| 259 |
+
|
| 260 |
+
for _ in range(steps):
|
| 261 |
+
h = torch.bernoulli(torch.sigmoid((v @ self.W + self.hidden_bias) / T))
|
| 262 |
+
v = torch.bernoulli(torch.sigmoid((h @ self.W.T + self.visible_bias) / T))
|
| 263 |
+
|
| 264 |
+
return v
|
| 265 |
+
|
| 266 |
+
# Training
|
| 267 |
+
|
| 268 |
+
def train(self, data, energy_tracker=None):
|
| 269 |
+
|
| 270 |
+
N = data.shape[0]
|
| 271 |
+
|
| 272 |
+
persistent_v = torch.bernoulli(
|
| 273 |
+
torch.full((self.batch_size, self.n_visible), 0.5, device=self.device)
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
epoch_bar = tqdm(range(self.epochs), desc="Training", leave=True)
|
| 277 |
+
|
| 278 |
+
for _ in epoch_bar:
|
| 279 |
+
|
| 280 |
+
v_neg_all = []
|
| 281 |
+
W_before = self.W.clone()
|
| 282 |
+
|
| 283 |
+
perm = torch.randperm(N, device=self.device)
|
| 284 |
+
data = data[perm]
|
| 285 |
+
|
| 286 |
+
flip_rates = []
|
| 287 |
+
energy_gaps = []
|
| 288 |
+
F_data_batches = []
|
| 289 |
+
F_model_batches = []
|
| 290 |
+
|
| 291 |
+
for algebraic in range(0, N, self.batch_size):
|
| 292 |
+
|
| 293 |
+
T = self.temperature()
|
| 294 |
+
|
| 295 |
+
v_pos = data[algebraic:algebraic + self.batch_size]
|
| 296 |
+
if v_pos.shape[0] != self.batch_size:
|
| 297 |
+
continue
|
| 298 |
+
|
| 299 |
+
# Negative phase (Persistent Contrastive Divergence - PCD-k)
|
| 300 |
+
|
| 301 |
+
v_prev = persistent_v
|
| 302 |
+
|
| 303 |
+
v_neg = self.gibbs_chain(v_prev, T, steps=self.gibbs_steps)
|
| 304 |
+
|
| 305 |
+
persistent_v = v_neg
|
| 306 |
+
|
| 307 |
+
flip = (v_neg != v_prev).float().mean()
|
| 308 |
+
flip_rates.append(flip)
|
| 309 |
+
|
| 310 |
+
# Expectations
|
| 311 |
+
|
| 312 |
+
h_pos = torch.sigmoid((v_pos @ self.W + self.hidden_bias) / T)
|
| 313 |
+
h_neg = torch.sigmoid((v_neg @ self.W + self.hidden_bias) / T)
|
| 314 |
+
|
| 315 |
+
dW = (v_pos.T @ h_pos - v_neg.T @ h_neg) / self.batch_size
|
| 316 |
+
dW -= self.weight_decay * self.W
|
| 317 |
+
|
| 318 |
+
self.W += self.lr * dW
|
| 319 |
+
|
| 320 |
+
# ℓ2 regularization on biases (bias decay) b_{t+1} = (1 - ηλ) b_t + η g_t
|
| 321 |
+
|
| 322 |
+
db_v = (v_pos - v_neg).mean(0) - self.bias_decay * self.visible_bias
|
| 323 |
+
db_h = (h_pos - h_neg).mean(0) - self.bias_decay * self.hidden_bias
|
| 324 |
+
|
| 325 |
+
self.visible_bias += self.lr * db_v
|
| 326 |
+
self.hidden_bias += self.lr * db_h
|
| 327 |
+
|
| 328 |
+
# Energy gap
|
| 329 |
+
|
| 330 |
+
F_data_batch = self.free_energy(v_pos, T).mean()
|
| 331 |
+
F_model_batch = self.free_energy(v_neg, T).mean()
|
| 332 |
+
|
| 333 |
+
energy_gaps.append(F_data_batch - F_model_batch)
|
| 334 |
+
|
| 335 |
+
F_data_batches.append(F_data_batch)
|
| 336 |
+
F_model_batches.append(F_model_batch)
|
| 337 |
+
|
| 338 |
+
v_neg_all.append(v_neg)
|
| 339 |
+
|
| 340 |
+
if energy_tracker is not None:
|
| 341 |
+
energy_tracker.step()
|
| 342 |
+
|
| 343 |
+
# Epoch statistics
|
| 344 |
+
|
| 345 |
+
flip_epoch = torch.stack(flip_rates).mean()
|
| 346 |
+
energy_gap_epoch = torch.stack(energy_gaps).mean()
|
| 347 |
+
|
| 348 |
+
F_data_epoch = torch.stack(F_data_batches).mean().item()
|
| 349 |
+
F_model_epoch = torch.stack(F_model_batches).mean().item()
|
| 350 |
+
|
| 351 |
+
# Microscopic control
|
| 352 |
+
|
| 353 |
+
self.flip_reference = (
|
| 354 |
+
(1 - self.flip_smoothing) * self.flip_reference
|
| 355 |
+
+ self.flip_smoothing * flip_epoch
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
error = flip_epoch - self.flip_reference
|
| 359 |
+
|
| 360 |
+
self.log_temperature = self.command * (
|
| 361 |
+
self.log_temperature -
|
| 362 |
+
self.lambda_gain * error
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
# Macroscopic control (Exact Cesàro average)
|
| 366 |
+
|
| 367 |
+
self.energy_count += 1
|
| 368 |
+
|
| 369 |
+
self.energy_avg = self.energy_avg + (
|
| 370 |
+
energy_gap_epoch.detach() - self.energy_avg
|
| 371 |
+
) / self.energy_count
|
| 372 |
+
|
| 373 |
+
# Diagnostics
|
| 374 |
+
|
| 375 |
+
current_T = self.temperature()
|
| 376 |
+
|
| 377 |
+
T_micro = torch.exp(self.log_temperature).item()
|
| 378 |
+
|
| 379 |
+
T_macro = (self.energy_temp_scale * self.energy_avg).item()
|
| 380 |
+
|
| 381 |
+
weight_norm = torch.norm(self.W).item()
|
| 382 |
+
|
| 383 |
+
with torch.no_grad():
|
| 384 |
+
u_vec = torch.randn(self.n_visible, 1, device=self.device)
|
| 385 |
+
u_vec = u_vec / (torch.norm(u_vec) + 1e-8)
|
| 386 |
+
|
| 387 |
+
for _ in range(self.gibbs_steps * 2):
|
| 388 |
+
v_vec = torch.matmul(self.W.T, u_vec)
|
| 389 |
+
v_vec = v_vec / (torch.norm(v_vec) + 1e-8)
|
| 390 |
+
u_vec = torch.matmul(self.W, v_vec)
|
| 391 |
+
u_vec = u_vec / (torch.norm(u_vec) + 1e-8)
|
| 392 |
+
|
| 393 |
+
spectral_norm = torch.norm(torch.matmul(self.W, v_vec)).item()
|
| 394 |
+
|
| 395 |
+
v_neg_all = torch.cat(v_neg_all, dim=0)
|
| 396 |
+
|
| 397 |
+
activations = (v_neg_all @ self.W + self.hidden_bias) / current_T
|
| 398 |
+
beta_eff = activations.std(dim=1).mean().item()
|
| 399 |
+
|
| 400 |
+
spectral_beta = spectral_norm / current_T.item()
|
| 401 |
+
|
| 402 |
+
true_beta = 1.0 / current_T.item()
|
| 403 |
+
|
| 404 |
+
delta_w = torch.norm(self.W - W_before).item()
|
| 405 |
+
|
| 406 |
+
# Store histories
|
| 407 |
+
|
| 408 |
+
self.flip_hist.append(flip_epoch.item())
|
| 409 |
+
self.c_hist.append(self.flip_reference.item())
|
| 410 |
+
self.weight_norm_hist.append(weight_norm)
|
| 411 |
+
self.beta_eff_hist.append(beta_eff)
|
| 412 |
+
self.temp_hist.append(current_T.item())
|
| 413 |
+
self.delta_w_hist.append(delta_w)
|
| 414 |
+
self.T_micro_hist.append(T_micro)
|
| 415 |
+
self.T_macro_hist.append(T_macro)
|
| 416 |
+
self.F_data_hist.append(F_data_epoch)
|
| 417 |
+
self.F_model_hist.append(F_model_epoch)
|
| 418 |
+
self.true_beta_hist.append(true_beta)
|
| 419 |
+
self.spectral_beta_hist.append(spectral_beta)
|
| 420 |
+
self.F_gap_hist.append(F_data_epoch - F_model_epoch)
|
| 421 |
+
|
| 422 |
+
# Progress bar
|
| 423 |
+
|
| 424 |
+
epoch_bar.set_postfix({
|
| 425 |
+
"T": f"{current_T.item():.3f}",
|
| 426 |
+
"flip": f"{flip_epoch.item():.3f}",
|
| 427 |
+
"beta": f"{beta_eff:.3f}"
|
| 428 |
+
})
|
| 429 |
+
|
| 430 |
+
persistent_div = torch.var(persistent_v.float())
|
| 431 |
+
|
| 432 |
+
self.persistent_div_hist.append(persistent_div.item())
|
| 433 |
+
|
| 434 |
+
return None
|
| 435 |
+
|
| 436 |
+
# Sampling
|
| 437 |
+
|
| 438 |
+
@torch.no_grad()
|
| 439 |
+
def generate_ensemble_samples(
|
| 440 |
+
self,
|
| 441 |
+
n_chains=1200,
|
| 442 |
+
steps=10000,
|
| 443 |
+
energy_tracker=None,
|
| 444 |
+
):
|
| 445 |
+
"""
|
| 446 |
+
Multichain Gibbs sampler with STABLE temperature (debug version).
|
| 447 |
+
"""
|
| 448 |
+
|
| 449 |
+
device = self.device
|
| 450 |
+
T_base = self.temperature().item()
|
| 451 |
+
|
| 452 |
+
# Persistent init
|
| 453 |
+
if (
|
| 454 |
+
hasattr(self, "persistent_sampling")
|
| 455 |
+
and self.persistent_sampling is not None
|
| 456 |
+
and self.persistent_sampling.shape == (n_chains, self.n_visible)
|
| 457 |
+
):
|
| 458 |
+
v = self.persistent_sampling
|
| 459 |
+
else:
|
| 460 |
+
v = torch.bernoulli(
|
| 461 |
+
torch.rand(n_chains, self.n_visible, device=device)
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
for _ in range(steps):
|
| 465 |
+
|
| 466 |
+
T = T_base
|
| 467 |
+
|
| 468 |
+
# h sample
|
| 469 |
+
h = torch.bernoulli(
|
| 470 |
+
torch.sigmoid((v @ self.W + self.hidden_bias) / T)
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
# additional Gibbs steps
|
| 474 |
+
for _ in range(self.gibbs_steps * 1):
|
| 475 |
+
v = torch.bernoulli(
|
| 476 |
+
torch.sigmoid((h @ self.W.T + self.visible_bias) / T)
|
| 477 |
+
)
|
| 478 |
+
h = torch.bernoulli(
|
| 479 |
+
torch.sigmoid((v @ self.W + self.hidden_bias) / T)
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
# final v update
|
| 483 |
+
v = torch.bernoulli(
|
| 484 |
+
torch.sigmoid((h @ self.W.T + self.visible_bias) / T)
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
if energy_tracker is not None:
|
| 488 |
+
energy_tracker.step()
|
| 489 |
+
|
| 490 |
+
return v
|
| 491 |
+
|
| 492 |
+
@torch.no_grad()
|
| 493 |
+
def llm_uncertainty_refinement(self, state, client=None):
|
| 494 |
+
"""
|
| 495 |
+
LLM-guided refinement with physically consistent energy coupling.
|
| 496 |
+
|
| 497 |
+
The refinement operates on locally uncertain proposals determined by the model.
|
| 498 |
+
Only proposals with ΔE_model in the interval [0, 2] are considered, which
|
| 499 |
+
corresponds to the uncertainty region of the energy landscape.
|
| 500 |
+
|
| 501 |
+
The language model contributes as a bounded perturbation to the energy.
|
| 502 |
+
Its influence is normalized and constrained relative to the model energy,
|
| 503 |
+
ensuring that it cannot dominate the underlying distribution but can still
|
| 504 |
+
provide meaningful local corrections.
|
| 505 |
+
|
| 506 |
+
The same bit-flip is used during both selection and proposal stages to
|
| 507 |
+
preserve consistency in the transition dynamics.
|
| 508 |
+
"""
|
| 509 |
+
|
| 510 |
+
batch_size, dim = state.shape
|
| 511 |
+
temperature = self.temperature()
|
| 512 |
+
|
| 513 |
+
debug = getattr(self, "debug", True)
|
| 514 |
+
|
| 515 |
+
if debug:
|
| 516 |
+
print(f"\n[LLM] Start | B={batch_size} | T={temperature:.4f}", flush=True)
|
| 517 |
+
|
| 518 |
+
refined_state = state.clone()
|
| 519 |
+
|
| 520 |
+
llm_cache = {}
|
| 521 |
+
|
| 522 |
+
def get_llm_energy(repr_str):
|
| 523 |
+
if repr_str in llm_cache:
|
| 524 |
+
return llm_cache[repr_str]
|
| 525 |
+
|
| 526 |
+
consequence, error = LLMEnergy(repr_str, client)
|
| 527 |
+
llm_cache[repr_str] = (consequence, error)
|
| 528 |
+
|
| 529 |
+
return consequence, error
|
| 530 |
+
|
| 531 |
+
# Generate single-bit proposals
|
| 532 |
+
flip_indices = torch.randint(0, dim, (batch_size,), device=state.device)
|
| 533 |
+
|
| 534 |
+
state_prop = state.clone()
|
| 535 |
+
state_prop[torch.arange(batch_size), flip_indices] = \
|
| 536 |
+
1 - state_prop[torch.arange(batch_size), flip_indices]
|
| 537 |
+
|
| 538 |
+
# Signed energy difference (no absolute value)
|
| 539 |
+
E_cur = self.free_energy(state, temperature)
|
| 540 |
+
E_prop = self.free_energy(state_prop, temperature)
|
| 541 |
+
delta_E = E_prop - E_cur
|
| 542 |
+
|
| 543 |
+
# Adaptive selection size using effective sample size (ESS)
|
| 544 |
+
weights = torch.softmax(-E_cur / (temperature + 1e-8), dim=0)
|
| 545 |
+
|
| 546 |
+
ess = 1.0 / torch.sum(weights ** 2)
|
| 547 |
+
ess_ratio = ess.item() / batch_size
|
| 548 |
+
|
| 549 |
+
k = int(batch_size * (1.0 - ess_ratio))
|
| 550 |
+
k = min(batch_size // 4, max(1, int(k / 90))) # The number 90 is an arbitrary choice.
|
| 551 |
+
|
| 552 |
+
# Select uncertain samples where ΔE ∈ [0, 2]
|
| 553 |
+
|
| 554 |
+
low, high = 0.0, 2.0
|
| 555 |
+
|
| 556 |
+
mask = (delta_E >= low) & (delta_E <= high)
|
| 557 |
+
band_indices = torch.nonzero(mask).squeeze(1)
|
| 558 |
+
|
| 559 |
+
center = 1.0
|
| 560 |
+
|
| 561 |
+
if band_indices.numel() >= k:
|
| 562 |
+
band_delta = delta_E[band_indices]
|
| 563 |
+
_, order = torch.sort(torch.abs(band_delta - center))
|
| 564 |
+
selected_indices = band_indices[order[:k]]
|
| 565 |
+
else:
|
| 566 |
+
all_dist = torch.abs(delta_E - center)
|
| 567 |
+
_, order = torch.sort(all_dist)
|
| 568 |
+
selected_indices = order[:k]
|
| 569 |
+
|
| 570 |
+
selected_flips = flip_indices[selected_indices]
|
| 571 |
+
|
| 572 |
+
if debug:
|
| 573 |
+
print(
|
| 574 |
+
f"[Debug] ESS={ess.item():.2f} | k={k} | "
|
| 575 |
+
f"deltaE_band=[{low:.2f},{high:.2f}] | "
|
| 576 |
+
f"in_band={band_indices.numel()} | selected={selected_indices.numel()}"
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
accepted_count = 0
|
| 580 |
+
llm_usage_count = 0
|
| 581 |
+
|
| 582 |
+
for integer, ideal in enumerate(selected_indices):
|
| 583 |
+
|
| 584 |
+
current_sample = refined_state[ideal]
|
| 585 |
+
proposed_sample = current_sample.clone()
|
| 586 |
+
|
| 587 |
+
flip_idx = selected_flips[integer].item()
|
| 588 |
+
proposed_sample[flip_idx] = 1 - proposed_sample[flip_idx]
|
| 589 |
+
|
| 590 |
+
E_model_cur = self.free_energy(current_sample.unsqueeze(0), temperature)[0]
|
| 591 |
+
E_model_prop = self.free_energy(proposed_sample.unsqueeze(0), temperature)[0]
|
| 592 |
+
|
| 593 |
+
delta_model = E_model_prop - E_model_cur
|
| 594 |
+
delta_total = delta_model
|
| 595 |
+
|
| 596 |
+
if client is not None:
|
| 597 |
+
|
| 598 |
+
current_repr = to_sparse_gpu(current_sample.view(28, 28))
|
| 599 |
+
proposed_repr = to_sparse_gpu(proposed_sample.view(28, 28))
|
| 600 |
+
|
| 601 |
+
current_llm, err_cur = get_llm_energy(current_repr)
|
| 602 |
+
proposed_llm, err_prop = get_llm_energy(proposed_repr)
|
| 603 |
+
|
| 604 |
+
if not (err_cur or err_prop or current_llm is None or proposed_llm is None):
|
| 605 |
+
|
| 606 |
+
llm_usage_count += 1
|
| 607 |
+
|
| 608 |
+
probs_cur = current_llm["probs"]
|
| 609 |
+
probs_prop = proposed_llm["probs"]
|
| 610 |
+
|
| 611 |
+
E_llm_cur = LIES_gpu(probs_cur)
|
| 612 |
+
E_llm_prop = LIES_gpu(probs_prop)
|
| 613 |
+
|
| 614 |
+
delta_llm = torch.tensor(
|
| 615 |
+
E_llm_prop - E_llm_cur,
|
| 616 |
+
device=temperature.device
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
# Normalize relative to model energy scale
|
| 620 |
+
scale_est = torch.abs(delta_model.detach()) + 1e-6
|
| 621 |
+
|
| 622 |
+
if not self.llm_scale_initialized:
|
| 623 |
+
self.llm_scale = scale_est ** 2
|
| 624 |
+
self.llm_scale_initialized = True
|
| 625 |
+
else:
|
| 626 |
+
self.llm_scale = 0.97 * self.llm_scale + 0.03 * (delta_model ** 2)
|
| 627 |
+
|
| 628 |
+
scale = torch.sqrt(self.llm_scale + 1e-6)
|
| 629 |
+
|
| 630 |
+
delta_llm = delta_llm / scale
|
| 631 |
+
|
| 632 |
+
# consistent bound
|
| 633 |
+
bound = torch.abs(delta_model.detach()) / scale + 1e-6
|
| 634 |
+
delta_llm = torch.clamp(delta_llm, -bound, bound)
|
| 635 |
+
|
| 636 |
+
delta_total = delta_model + self.lambda_llm * delta_llm
|
| 637 |
+
|
| 638 |
+
if debug:
|
| 639 |
+
print(
|
| 640 |
+
f"[Debug] d_model={delta_model.item():.4f}, "
|
| 641 |
+
f"d_llm={delta_llm.item():.4f}"
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
# Numerical stability for exponentiation
|
| 645 |
+
delta_total = torch.clamp(delta_total, -10.0, 10.0)
|
| 646 |
+
|
| 647 |
+
accept_prob = torch.minimum(
|
| 648 |
+
torch.tensor(1.0, device=delta_total.device),
|
| 649 |
+
torch.exp(-delta_total)
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
rand_val = torch.rand(1, device=temperature.device)
|
| 653 |
+
|
| 654 |
+
if rand_val < accept_prob:
|
| 655 |
+
refined_state[ideal] = proposed_sample
|
| 656 |
+
accepted_count += 1
|
| 657 |
+
accepted = True
|
| 658 |
+
else:
|
| 659 |
+
accepted = False
|
| 660 |
+
|
| 661 |
+
if debug:
|
| 662 |
+
status = "ACCEPT" if accepted else "REJECT"
|
| 663 |
+
print(
|
| 664 |
+
f"[Step] Sample={ideal} | {status} | "
|
| 665 |
+
f"d_model={delta_model.item():.3f} | "
|
| 666 |
+
f"d_total={delta_total.item():.3f} | "
|
| 667 |
+
f"p={accept_prob.item():.3f}"
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
if debug:
|
| 671 |
+
print("\n[LLM Summary]")
|
| 672 |
+
print(f"LLM used : {llm_usage_count}/{batch_size}")
|
| 673 |
+
print(f"Accepted : {accepted_count}/{k}")
|
| 674 |
+
print(f"T : {temperature:.4f}")
|
| 675 |
+
print("[LLM Done]\n", flush=True)
|
| 676 |
+
|
| 677 |
+
return refined_state
|
| 678 |
+
|
| 679 |
+
@torch.no_grad()
|
| 680 |
+
def save_ensemble_samples(
|
| 681 |
+
self,
|
| 682 |
+
filename="samples_ensemble.png",
|
| 683 |
+
n_display=100,
|
| 684 |
+
steps=10000,
|
| 685 |
+
energy_tracker=None,
|
| 686 |
+
client=None
|
| 687 |
+
):
|
| 688 |
+
samples = self.generate_ensemble_samples(
|
| 689 |
+
n_chains=n_display,
|
| 690 |
+
steps=steps,
|
| 691 |
+
energy_tracker=energy_tracker
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
samples_original = samples.clone().detach()
|
| 695 |
+
|
| 696 |
+
samples = self.llm_uncertainty_refinement(samples, client=client)
|
| 697 |
+
|
| 698 |
+
refiner = Refinement(self)
|
| 699 |
+
|
| 700 |
+
samples = refiner.myra_refine(samples)
|
| 701 |
+
|
| 702 |
+
samples = samples.view(-1, 1, 28, 28)
|
| 703 |
+
original_tmp = samples_original.view(-1, 1, 28, 28)
|
| 704 |
+
|
| 705 |
+
chi_final = susceptibility(samples)
|
| 706 |
+
chi_raw = susceptibility(original_tmp)
|
| 707 |
+
|
| 708 |
+
binder_final = binder_cumulant(samples)
|
| 709 |
+
binder_raw = binder_cumulant(original_tmp)
|
| 710 |
+
|
| 711 |
+
print("\nPhysical Refinement Diagnostics:", chi_final)
|
| 712 |
+
|
| 713 |
+
print("Binder cumulant (1):", binder_final)
|
| 714 |
+
print("Chi improvement:", chi_final - chi_raw)
|
| 715 |
+
|
| 716 |
+
print("Binder improvement:", binder_final - binder_raw)
|
| 717 |
+
|
| 718 |
+
grid = vutils.make_grid(
|
| 719 |
+
samples,
|
| 720 |
+
nrow=50,
|
| 721 |
+
padding=2
|
| 722 |
+
)
|
| 723 |
+
|
| 724 |
+
grid = (grid * 255).clamp(0, 255).byte()
|
| 725 |
+
nd_arr = grid.permute(1, 2, 0).cpu().numpy()
|
| 726 |
+
|
| 727 |
+
Image.fromarray(nd_arr).save(
|
| 728 |
+
filename.replace(".png", "_refined.png")
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
grid_0 = vutils.make_grid(
|
| 732 |
+
original_tmp,
|
| 733 |
+
nrow=50,
|
| 734 |
+
padding=2
|
| 735 |
+
)
|
| 736 |
+
|
| 737 |
+
grid_0 = (grid_0 * 255).clamp(0, 255).byte()
|
| 738 |
+
nd_arr1 = grid_0.permute(1, 2, 0).cpu().numpy()
|
| 739 |
+
|
| 740 |
+
Image.fromarray(nd_arr1).save(
|
| 741 |
+
filename.replace(".png", "_symbol.png")
|
| 742 |
+
)
|
| 743 |
+
|
| 744 |
+
@torch.no_grad()
|
| 745 |
+
def ensemble_diagnostics(
|
| 746 |
+
self,
|
| 747 |
+
n_chains=1200,
|
| 748 |
+
steps=10000,
|
| 749 |
+
burn_in_ratio=0.2,
|
| 750 |
+
thinning=1,
|
| 751 |
+
):
|
| 752 |
+
device = self.device
|
| 753 |
+
temperature = self.temperature()
|
| 754 |
+
|
| 755 |
+
# Initialize chains
|
| 756 |
+
visible = torch.bernoulli(
|
| 757 |
+
torch.full((n_chains, self.n_visible), 0.5, device=device)
|
| 758 |
+
)
|
| 759 |
+
|
| 760 |
+
energy_trace = []
|
| 761 |
+
|
| 762 |
+
# MCMC sampling
|
| 763 |
+
for step in range(steps):
|
| 764 |
+
hidden = torch.bernoulli(
|
| 765 |
+
torch.sigmoid((visible @ self.W + self.hidden_bias) / temperature)
|
| 766 |
+
)
|
| 767 |
+
visible = torch.bernoulli(
|
| 768 |
+
torch.sigmoid((hidden @ self.W.T + self.visible_bias) / temperature)
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
energy_trace.append(self.free_energy(visible, temperature).detach())
|
| 772 |
+
|
| 773 |
+
energy_trace = torch.stack(energy_trace) # (time, chains)
|
| 774 |
+
|
| 775 |
+
# Burn-in removal
|
| 776 |
+
burn = int(burn_in_ratio * steps)
|
| 777 |
+
if burn >= steps - 10:
|
| 778 |
+
burn = 0
|
| 779 |
+
|
| 780 |
+
energy_trace = energy_trace[burn:]
|
| 781 |
+
|
| 782 |
+
# Thinning
|
| 783 |
+
energy_trace = energy_trace[::thinning]
|
| 784 |
+
|
| 785 |
+
n_steps, n_chains = energy_trace.shape
|
| 786 |
+
|
| 787 |
+
# Not enough data safeguard
|
| 788 |
+
max_lag = n_steps // 2
|
| 789 |
+
if max_lag < 2:
|
| 790 |
+
return {
|
| 791 |
+
"tau_int": float("nan"),
|
| 792 |
+
"tau_std": float("nan"),
|
| 793 |
+
"tau_max": float("nan"),
|
| 794 |
+
"tau_min": float("nan"),
|
| 795 |
+
"ess": float("nan"),
|
| 796 |
+
"r_hat": float("nan"),
|
| 797 |
+
"n_eff_per_chain": float("nan"),
|
| 798 |
+
"autocorr_len": 0
|
| 799 |
+
}
|
| 800 |
+
|
| 801 |
+
# Center per chain
|
| 802 |
+
centered = energy_trace - energy_trace.mean(dim=0, keepdim=True)
|
| 803 |
+
var = torch.var(centered, dim=0, unbiased=True)
|
| 804 |
+
|
| 805 |
+
# Degenerate case
|
| 806 |
+
if torch.mean(var) < 1e-12:
|
| 807 |
+
return {
|
| 808 |
+
"tau_int": float("inf"),
|
| 809 |
+
"tau_std": float("nan"),
|
| 810 |
+
"tau_max": float("inf"),
|
| 811 |
+
"tau_min": float("inf"),
|
| 812 |
+
"ess": 0.0,
|
| 813 |
+
"r_hat": float("inf"),
|
| 814 |
+
"n_eff_per_chain": 0.0,
|
| 815 |
+
"autocorr_len": 0
|
| 816 |
+
}
|
| 817 |
+
|
| 818 |
+
# Lag construction (corrected: includes dense small lags)
|
| 819 |
+
small_lags = torch.arange(1, min(20, max_lag + 1), device=device)
|
| 820 |
+
|
| 821 |
+
log_lags = torch.logspace(
|
| 822 |
+
0, math.log10(max_lag), steps=10, device=device
|
| 823 |
+
).long()
|
| 824 |
+
|
| 825 |
+
lag_values = torch.unique(torch.cat([small_lags, log_lags])).sort()[0]
|
| 826 |
+
|
| 827 |
+
samples_per_lag = 10000
|
| 828 |
+
acf_list = []
|
| 829 |
+
|
| 830 |
+
# Stochastic ACF estimation
|
| 831 |
+
for lag in lag_values:
|
| 832 |
+
if lag == 0:
|
| 833 |
+
acf_list.append(torch.ones(n_chains, device=device))
|
| 834 |
+
continue
|
| 835 |
+
|
| 836 |
+
if lag >= n_steps:
|
| 837 |
+
break
|
| 838 |
+
|
| 839 |
+
idealization = torch.randint(0, n_steps - lag, (samples_per_lag,), device=device)
|
| 840 |
+
|
| 841 |
+
x_t = centered[idealization]
|
| 842 |
+
x_t_lag = centered[idealization + lag]
|
| 843 |
+
|
| 844 |
+
rho = (x_t * x_t_lag).mean(dim=0) / (var + 1e-12)
|
| 845 |
+
|
| 846 |
+
acf_list.append(rho)
|
| 847 |
+
|
| 848 |
+
acf_tensor = torch.stack(acf_list)
|
| 849 |
+
|
| 850 |
+
# Integrated autocorrelation time (IPS)
|
| 851 |
+
tau_per_chain = torch.ones(n_chains, device=device)
|
| 852 |
+
active = torch.ones(n_chains, dtype=torch.bool, device=device)
|
| 853 |
+
|
| 854 |
+
for k in range(1, acf_tensor.shape[0]):
|
| 855 |
+
rho = acf_tensor[k]
|
| 856 |
+
|
| 857 |
+
positive = rho > 0
|
| 858 |
+
update_mask = active & positive
|
| 859 |
+
|
| 860 |
+
tau_per_chain[update_mask] += 2.0 * rho[update_mask]
|
| 861 |
+
|
| 862 |
+
active = active & positive
|
| 863 |
+
|
| 864 |
+
if not active.any():
|
| 865 |
+
break
|
| 866 |
+
|
| 867 |
+
tau_mean = tau_per_chain.mean().item()
|
| 868 |
+
tau_std = tau_per_chain.std().item()
|
| 869 |
+
tau_max = tau_per_chain.max().item()
|
| 870 |
+
tau_min = tau_per_chain.min().item()
|
| 871 |
+
|
| 872 |
+
# Effective sample size
|
| 873 |
+
ess_per_chain = n_steps / (tau_per_chain + 1e-8)
|
| 874 |
+
total_ess = torch.sum(ess_per_chain).item()
|
| 875 |
+
|
| 876 |
+
# R-hat (split chains)
|
| 877 |
+
half = n_steps // 2
|
| 878 |
+
|
| 879 |
+
if half < 10:
|
| 880 |
+
r_hat = float("nan")
|
| 881 |
+
else:
|
| 882 |
+
first_half = energy_trace[:half]
|
| 883 |
+
second_half = energy_trace[half:2 * half]
|
| 884 |
+
|
| 885 |
+
split = torch.cat([first_half, second_half], dim=1)
|
| 886 |
+
|
| 887 |
+
n_half = split.shape[0]
|
| 888 |
+
|
| 889 |
+
chain_means = split.mean(dim=0)
|
| 890 |
+
|
| 891 |
+
B = n_half * torch.var(chain_means, unbiased=True)
|
| 892 |
+
W = torch.mean(torch.var(split, dim=0, unbiased=True))
|
| 893 |
+
|
| 894 |
+
var_hat = (1 - 1 / n_half) * W + (1 / n_half) * B
|
| 895 |
+
|
| 896 |
+
r_hat = torch.sqrt(var_hat / (W + 1e-12)).item()
|
| 897 |
+
|
| 898 |
+
return {
|
| 899 |
+
"tau_int": tau_mean,
|
| 900 |
+
"tau_std": tau_std,
|
| 901 |
+
"tau_max": tau_max,
|
| 902 |
+
"tau_min": tau_min,
|
| 903 |
+
"ess": float(total_ess),
|
| 904 |
+
"r_hat": float(r_hat),
|
| 905 |
+
"n_eff_per_chain": float(ess_per_chain.mean().item()),
|
| 906 |
+
"autocorr_len": int(acf_tensor.shape[0])
|
| 907 |
+
}
|
| 908 |
+
|
| 909 |
+
@torch.no_grad()
|
| 910 |
+
def save_professional_samples(
|
| 911 |
+
self,
|
| 912 |
+
filename="samples_professional.png",
|
| 913 |
+
n_display=100,
|
| 914 |
+
steps=10000,
|
| 915 |
+
energy_tracker=None,
|
| 916 |
+
client=None
|
| 917 |
+
):
|
| 918 |
+
samples = self.generate_ensemble_samples(
|
| 919 |
+
n_chains=n_display,
|
| 920 |
+
steps=steps,
|
| 921 |
+
energy_tracker=energy_tracker
|
| 922 |
+
)
|
| 923 |
+
|
| 924 |
+
samples_original = samples.clone().detach()
|
| 925 |
+
|
| 926 |
+
samples = self.llm_uncertainty_refinement(samples, client=client)
|
| 927 |
+
|
| 928 |
+
refiner = Refinement(self)
|
| 929 |
+
|
| 930 |
+
samples = refiner.myra_refine(samples)
|
| 931 |
+
|
| 932 |
+
samples = samples.view(-1, 1, 28, 28)
|
| 933 |
+
original_tmp = samples_original.view(-1, 1, 28, 28)
|
| 934 |
+
|
| 935 |
+
diagnostics = self.ensemble_diagnostics()
|
| 936 |
+
|
| 937 |
+
chi_final = susceptibility(samples)
|
| 938 |
+
chi_raw = susceptibility(original_tmp)
|
| 939 |
+
|
| 940 |
+
binder_final = binder_cumulant(samples)
|
| 941 |
+
binder_raw = binder_cumulant(original_tmp)
|
| 942 |
+
|
| 943 |
+
print("\nSecondary magnetic susceptibility:", chi_final)
|
| 944 |
+
|
| 945 |
+
print("Binder cumulant (2):", binder_final)
|
| 946 |
+
print("Chi improvement:", chi_final - chi_raw)
|
| 947 |
+
|
| 948 |
+
print("Binder improvement:", binder_final - binder_raw)
|
| 949 |
+
|
| 950 |
+
grid = vutils.make_grid(
|
| 951 |
+
samples,
|
| 952 |
+
nrow=50,
|
| 953 |
+
padding=2
|
| 954 |
+
)
|
| 955 |
+
|
| 956 |
+
grid = (grid * 255).clamp(0, 255).byte()
|
| 957 |
+
nd_arr = grid.permute(1, 2, 0).cpu().numpy()
|
| 958 |
+
|
| 959 |
+
Image.fromarray(nd_arr).save(
|
| 960 |
+
filename.replace(".png", "_refined.png")
|
| 961 |
+
)
|
| 962 |
+
|
| 963 |
+
grid1 = vutils.make_grid(
|
| 964 |
+
original_tmp,
|
| 965 |
+
nrow=50,
|
| 966 |
+
padding=2
|
| 967 |
+
)
|
| 968 |
+
|
| 969 |
+
grid1 = (grid1 * 255).clamp(0, 255).byte()
|
| 970 |
+
nd_arr2 = grid1.permute(1, 2, 0).cpu().numpy()
|
| 971 |
+
|
| 972 |
+
Image.fromarray(nd_arr2).save(
|
| 973 |
+
filename.replace(".png", "_symbol.png")
|
| 974 |
+
)
|
| 975 |
+
|
| 976 |
+
return diagnostics
|
| 977 |
+
|
| 978 |
+
# Ais Calculation
|
| 979 |
+
|
| 980 |
+
@torch.no_grad()
|
| 981 |
+
def ais_log_partition(self, n_runs=1000, n_intermediate=2000, energy_tracker=None):
|
| 982 |
+
|
| 983 |
+
device = self.device
|
| 984 |
+
T = self.temperature().item()
|
| 985 |
+
|
| 986 |
+
W = self.W
|
| 987 |
+
bv = self.visible_bias
|
| 988 |
+
bh = self.hidden_bias
|
| 989 |
+
|
| 990 |
+
nv, nh = self.n_visible, self.n_hidden
|
| 991 |
+
|
| 992 |
+
betas = torch.linspace(0.0, 1.0, n_intermediate, device=device) ** 2
|
| 993 |
+
|
| 994 |
+
# Base model: uniform Bernoulli(0.5)
|
| 995 |
+
|
| 996 |
+
logZ0 = (nv + nh) * math.log(2.0)
|
| 997 |
+
|
| 998 |
+
v = torch.bernoulli(torch.full((n_runs, nv), 0.5, device=device))
|
| 999 |
+
h = torch.bernoulli(torch.full((n_runs, nh), 0.5, device=device))
|
| 1000 |
+
|
| 1001 |
+
log_weights = torch.zeros(n_runs, device=device)
|
| 1002 |
+
|
| 1003 |
+
for k in range(1, n_intermediate):
|
| 1004 |
+
beta_prev = betas[k - 1]
|
| 1005 |
+
beta_curr = betas[k]
|
| 1006 |
+
|
| 1007 |
+
energy = (- (v @ W * h).sum(1)
|
| 1008 |
+
- (v * bv).sum(1)
|
| 1009 |
+
- (h * bh).sum(1)
|
| 1010 |
+
) / T
|
| 1011 |
+
# SINGLE temperature scaling point
|
| 1012 |
+
|
| 1013 |
+
log_weights += (beta_curr - beta_prev) * energy
|
| 1014 |
+
|
| 1015 |
+
h_prob = torch.sigmoid(beta_curr * (v @ W + bh) / T)
|
| 1016 |
+
h = torch.bernoulli(h_prob)
|
| 1017 |
+
|
| 1018 |
+
v_prob = torch.sigmoid(beta_curr * (h @ W.T + bv) / T)
|
| 1019 |
+
v = torch.bernoulli(v_prob)
|
| 1020 |
+
|
| 1021 |
+
if energy_tracker is not None:
|
| 1022 |
+
energy_tracker.step()
|
| 1023 |
+
|
| 1024 |
+
logZ = logZ0 + torch.logsumexp(log_weights, dim=0) - math.log(n_runs)
|
| 1025 |
+
|
| 1026 |
+
log_w = log_weights - torch.max(log_weights).detach()
|
| 1027 |
+
|
| 1028 |
+
w = torch.exp(log_w)
|
| 1029 |
+
w = w / torch.sum(w)
|
| 1030 |
+
ess = 1.0 / torch.sum(w ** 2)
|
| 1031 |
+
|
| 1032 |
+
log_weight_var = torch.var(log_weights)
|
| 1033 |
+
|
| 1034 |
+
return logZ.item(), log_weight_var.item(), ess.item()
|
| 1035 |
+
|
| 1036 |
+
# Log-Likelihood
|
| 1037 |
+
|
| 1038 |
+
@torch.no_grad()
|
| 1039 |
+
def log_likelihood(self, data, log_Z_2ox):
|
| 1040 |
+
T = self.temperature()
|
| 1041 |
+
Fv = self.free_energy(data, T)
|
| 1042 |
+
|
| 1043 |
+
return (-Fv - log_Z_2ox).mean().item()
|
| 1044 |
+
|
| 1045 |
+
@torch.no_grad()
|
| 1046 |
+
def pseudo_likelihood(self, data):
|
| 1047 |
+
T = self.temperature()
|
| 1048 |
+
W = self.W
|
| 1049 |
+
b = self.visible_bias
|
| 1050 |
+
h_bias = self.hidden_bias
|
| 1051 |
+
|
| 1052 |
+
N, D = data.shape
|
| 1053 |
+
|
| 1054 |
+
# hidden pre-activation
|
| 1055 |
+
wx = (data @ W + h_bias) / T # (N, H)
|
| 1056 |
+
|
| 1057 |
+
log_probs = []
|
| 1058 |
+
|
| 1059 |
+
for interaction in range(D):
|
| 1060 |
+
v_i = data[:, interaction] # (N,)
|
| 1061 |
+
W_i = W[interaction] # (H,)
|
| 1062 |
+
|
| 1063 |
+
# flip effect
|
| 1064 |
+
delta = (1 - 2 * v_i).unsqueeze(1) * W_i / T
|
| 1065 |
+
wx_flip = wx + delta
|
| 1066 |
+
|
| 1067 |
+
# free energy difference
|
| 1068 |
+
term = torch.sum(
|
| 1069 |
+
F.softplus(wx_flip) - F.softplus(wx),
|
| 1070 |
+
dim=1
|
| 1071 |
+
)
|
| 1072 |
+
|
| 1073 |
+
logits = (b[interaction] / T) + term
|
| 1074 |
+
|
| 1075 |
+
# log P(v_i | rest)
|
| 1076 |
+
log_prob_i = -F.binary_cross_entropy_with_logits(
|
| 1077 |
+
logits,
|
| 1078 |
+
v_i,
|
| 1079 |
+
reduction='none'
|
| 1080 |
+
)
|
| 1081 |
+
|
| 1082 |
+
log_probs.append(log_prob_i)
|
| 1083 |
+
|
| 1084 |
+
# average over dimensions
|
| 1085 |
+
log_prob = torch.stack(log_probs, dim=1).mean(dim=1)
|
| 1086 |
+
|
| 1087 |
+
print("coefficient (kappa):", D)
|
| 1088 |
+
|
| 1089 |
+
# average over batch
|
| 1090 |
+
return log_prob.mean().item()
|
| 1091 |
+
|
| 1092 |
+
@torch.no_grad()
|
| 1093 |
+
def reconstruct(self, v):
|
| 1094 |
+
|
| 1095 |
+
T = self.temperature()
|
| 1096 |
+
|
| 1097 |
+
h = torch.sigmoid((v @ self.W + self.hidden_bias) / T)
|
| 1098 |
+
v_recon = torch.sigmoid((h @ self.W.T + self.visible_bias) / T)
|
| 1099 |
+
|
| 1100 |
+
return v_recon
|
| 1101 |
+
|
| 1102 |
+
|
| 1103 |
+
# Multi-GPU Worker
|
| 1104 |
+
|
| 1105 |
+
def worker(Gpu_Id, seed_round, consequences):
|
| 1106 |
+
client = SafeBookClient()
|
| 1107 |
+
torch.cuda.set_device(Gpu_Id)
|
| 1108 |
+
torch.manual_seed(seed_round)
|
| 1109 |
+
torch.cuda.manual_seed_all(seed_round)
|
| 1110 |
+
energy_tracker = GPUEnergyTracker(Gpu_Id)
|
| 1111 |
+
|
| 1112 |
+
device_warm = torch.device(f"cuda:{Gpu_Id}")
|
| 1113 |
+
|
| 1114 |
+
# Data
|
| 1115 |
+
|
| 1116 |
+
data = load_mnist(device_warm)
|
| 1117 |
+
|
| 1118 |
+
train_data_um = data[:60000]
|
| 1119 |
+
test_data_um = data[60000:]
|
| 1120 |
+
|
| 1121 |
+
last_model = HybridThermodynamicRBM(
|
| 1122 |
+
n_visible=784,
|
| 1123 |
+
n_hidden=512,
|
| 1124 |
+
device_type=f"cuda:{Gpu_Id}",
|
| 1125 |
+
fixed_temperature=None
|
| 1126 |
+
)
|
| 1127 |
+
|
| 1128 |
+
train_result = last_model.train(
|
| 1129 |
+
train_data_um,
|
| 1130 |
+
energy_tracker=energy_tracker
|
| 1131 |
+
)
|
| 1132 |
+
|
| 1133 |
+
if train_result == "ABORT":
|
| 1134 |
+
print("[WORKER] Training aborted early")
|
| 1135 |
+
return
|
| 1136 |
+
|
| 1137 |
+
def dots_spinner(stop_event_choose, elastic_label):
|
| 1138 |
+
"""
|
| 1139 |
+
Lightweight terminal spinner used to indicate that a diagnostic
|
| 1140 |
+
computation is currently running.
|
| 1141 |
+
"""
|
| 1142 |
+
frames = [" ", ". ", ".. ", "..."]
|
| 1143 |
+
|
| 1144 |
+
integer = 0
|
| 1145 |
+
|
| 1146 |
+
while not stop_event_choose.is_set():
|
| 1147 |
+
sys.stdout.write(f"\r▶ {elastic_label}{frames[integer % len(frames)]}")
|
| 1148 |
+
sys.stdout.flush()
|
| 1149 |
+
time.sleep(0.4)
|
| 1150 |
+
|
| 1151 |
+
integer += 1
|
| 1152 |
+
|
| 1153 |
+
print(f"\n[GPU {Gpu_Id}] Running AIS...")
|
| 1154 |
+
|
| 1155 |
+
stop_event = threading.Event()
|
| 1156 |
+
|
| 1157 |
+
spinner_thread = threading.Thread(target=dots_spinner, args=(stop_event, "AIS"))
|
| 1158 |
+
spinner_thread.start()
|
| 1159 |
+
|
| 1160 |
+
ais_start = time.time()
|
| 1161 |
+
|
| 1162 |
+
try:
|
| 1163 |
+
log_Z_3ox, ais_var, ais_ess = last_model.ais_log_partition(
|
| 1164 |
+
n_runs=8000,
|
| 1165 |
+
n_intermediate=12000,
|
| 1166 |
+
energy_tracker=energy_tracker
|
| 1167 |
+
)
|
| 1168 |
+
except Exception as emerald:
|
| 1169 |
+
print(f"\n✖ AIS failed: {emerald}")
|
| 1170 |
+
|
| 1171 |
+
log_Z_3ox, ais_var, ais_ess = float("nan"), float("nan"), float("nan")
|
| 1172 |
+
|
| 1173 |
+
ais_duration = time.time() - ais_start
|
| 1174 |
+
|
| 1175 |
+
stop_event.set()
|
| 1176 |
+
spinner_thread.join()
|
| 1177 |
+
|
| 1178 |
+
sys.stdout.write(f"\r✔ AIS finished in {ais_duration:.2f} seconds\n")
|
| 1179 |
+
|
| 1180 |
+
train_ll = last_model.log_likelihood(train_data_um, log_Z_3ox)
|
| 1181 |
+
Test_ll = last_model.log_likelihood(test_data_um, log_Z_3ox)
|
| 1182 |
+
Train_pl = last_model.pseudo_likelihood(train_data_um)
|
| 1183 |
+
Test_pl = last_model.pseudo_likelihood(test_data_um)
|
| 1184 |
+
|
| 1185 |
+
# Reconstruction
|
| 1186 |
+
|
| 1187 |
+
reconstruction = last_model.reconstruct(test_data_um)
|
| 1188 |
+
|
| 1189 |
+
reconstruction_mse = torch.mean(
|
| 1190 |
+
(test_data_um - reconstruction) ** 2
|
| 1191 |
+
).item()
|
| 1192 |
+
|
| 1193 |
+
recon_acc = last_model.reconstruction_accuracy(test_data_um)
|
| 1194 |
+
|
| 1195 |
+
# Thermodynamic diagnostics
|
| 1196 |
+
|
| 1197 |
+
final_temperature = last_model.temperature().item()
|
| 1198 |
+
weight_norm = torch.norm(last_model.W).item()
|
| 1199 |
+
|
| 1200 |
+
gain_spectral = weight_norm / final_temperature
|
| 1201 |
+
|
| 1202 |
+
sample_filename = f"samples_gpu{Gpu_Id}_seed{seed_round}.png"
|
| 1203 |
+
prof_filename = f"samples_prof_gpu{Gpu_Id}_seed{seed_round}.png"
|
| 1204 |
+
|
| 1205 |
+
# Basit ensemble save
|
| 1206 |
+
|
| 1207 |
+
last_model.save_ensemble_samples(
|
| 1208 |
+
filename=sample_filename,
|
| 1209 |
+
n_display=1200,
|
| 1210 |
+
steps=10000,
|
| 1211 |
+
energy_tracker=energy_tracker,
|
| 1212 |
+
client=client
|
| 1213 |
+
)
|
| 1214 |
+
|
| 1215 |
+
# Professional ensemble + diagnostics
|
| 1216 |
+
|
| 1217 |
+
diagnostics = last_model.save_professional_samples(
|
| 1218 |
+
filename=prof_filename,
|
| 1219 |
+
n_display=1200,
|
| 1220 |
+
steps=10000,
|
| 1221 |
+
energy_tracker=energy_tracker,
|
| 1222 |
+
client=client
|
| 1223 |
+
)
|
| 1224 |
+
|
| 1225 |
+
print("\nRunning energy analysis ↓\n")
|
| 1226 |
+
|
| 1227 |
+
analysis_steps = [
|
| 1228 |
+
|
| 1229 |
+
("Energy distribution",
|
| 1230 |
+
lambda: SrtrbmEnergy.plot_data_vs_model_energy(last_model, train_data_um)),
|
| 1231 |
+
|
| 1232 |
+
("Energy landscape extremes",
|
| 1233 |
+
lambda: SrtrbmEnergy.visualize_energy_extremes(last_model, train_data_um)),
|
| 1234 |
+
|
| 1235 |
+
("Phase diagram",
|
| 1236 |
+
lambda: SrtrbmMetrics.plot_flip_beta(
|
| 1237 |
+
last_model,
|
| 1238 |
+
"SR-TRBM Phase Diagram",
|
| 1239 |
+
filename="srtrbm_phase_diagram.pdf"
|
| 1240 |
+
)),
|
| 1241 |
+
|
| 1242 |
+
("RBM filters",
|
| 1243 |
+
lambda: SrtrbmVisualization.visualize_rbm_filters(
|
| 1244 |
+
last_model,
|
| 1245 |
+
filename="srtrbm_filters.png",
|
| 1246 |
+
n_filters=256
|
| 1247 |
+
)),
|
| 1248 |
+
|
| 1249 |
+
("Sample quality metrics",
|
| 1250 |
+
lambda: SrtrbmMetrics.sample_quality_metrics(last_model, train_data_um)),
|
| 1251 |
+
]
|
| 1252 |
+
|
| 1253 |
+
quality = None
|
| 1254 |
+
energy_stats = None
|
| 1255 |
+
|
| 1256 |
+
print("\nRunning diagnostics with the following calculations ↙\n")
|
| 1257 |
+
|
| 1258 |
+
total_start = time.time()
|
| 1259 |
+
|
| 1260 |
+
for name, fn in analysis_steps:
|
| 1261 |
+
|
| 1262 |
+
stop_event = threading.Event()
|
| 1263 |
+
spinner_thread = threading.Thread(target=dots_spinner, args=(stop_event, name))
|
| 1264 |
+
spinner_thread.start()
|
| 1265 |
+
start = time.time()
|
| 1266 |
+
|
| 1267 |
+
try:
|
| 1268 |
+
result_cache = fn()
|
| 1269 |
+
except Exception as emerald:
|
| 1270 |
+
|
| 1271 |
+
print(f"\n✖ {name} failed:", emerald)
|
| 1272 |
+
|
| 1273 |
+
result_cache = None
|
| 1274 |
+
|
| 1275 |
+
duration = time.time() - start
|
| 1276 |
+
|
| 1277 |
+
stop_event.set()
|
| 1278 |
+
spinner_thread.join()
|
| 1279 |
+
|
| 1280 |
+
sys.stdout.write(f"\r✔ {name} finished in {duration:.2f} seconds\n")
|
| 1281 |
+
|
| 1282 |
+
if name == "Energy distribution" and result_cache is not None:
|
| 1283 |
+
energy_stats = result_cache
|
| 1284 |
+
if name == "Sample quality metrics":
|
| 1285 |
+
quality = result_cache
|
| 1286 |
+
|
| 1287 |
+
total_time = time.time() - total_start
|
| 1288 |
+
|
| 1289 |
+
print(f"\nDiagnostics completed in {total_time:.2f} seconds\n")
|
| 1290 |
+
|
| 1291 |
+
if quality is None:
|
| 1292 |
+
quality = {
|
| 1293 |
+
"pixel_entropy": float("nan"),
|
| 1294 |
+
"diversity": float("nan"),
|
| 1295 |
+
"mean_l2": float("nan")
|
| 1296 |
+
}
|
| 1297 |
+
|
| 1298 |
+
if energy_stats is None:
|
| 1299 |
+
energy_stats = {
|
| 1300 |
+
"mean_data_energy": float("nan"),
|
| 1301 |
+
"mean_model_energy": float("nan"),
|
| 1302 |
+
"energy_gap": float("nan")
|
| 1303 |
+
}
|
| 1304 |
+
|
| 1305 |
+
gpu_energy = energy_tracker.total_energy()
|
| 1306 |
+
|
| 1307 |
+
print(f"GPU Energy Used : {gpu_energy:.2f} Joules\n")
|
| 1308 |
+
|
| 1309 |
+
consequences.append({
|
| 1310 |
+
"seed": seed_round,
|
| 1311 |
+
"gpu": Gpu_Id,
|
| 1312 |
+
"gpu_energy": gpu_energy,
|
| 1313 |
+
"temperature": final_temperature,
|
| 1314 |
+
"weight_norm": weight_norm,
|
| 1315 |
+
"spectral_gain": gain_spectral,
|
| 1316 |
+
"train_log_likelihood": train_ll,
|
| 1317 |
+
"test_log_likelihood": Test_ll,
|
| 1318 |
+
"train_pseudo_likelihood": Train_pl,
|
| 1319 |
+
"test_pseudo_likelihood": Test_pl,
|
| 1320 |
+
"reconstruction_mse": reconstruction_mse,
|
| 1321 |
+
"reconstruction_accuracy": recon_acc,
|
| 1322 |
+
"mean_data_energy": energy_stats["mean_data_energy"],
|
| 1323 |
+
"mean_model_energy": energy_stats["mean_model_energy"],
|
| 1324 |
+
"energy_gap": energy_stats["energy_gap"],
|
| 1325 |
+
"logZ": log_Z_3ox,
|
| 1326 |
+
"ais_ess": ais_ess,
|
| 1327 |
+
"pixel_entropy": quality["pixel_entropy"],
|
| 1328 |
+
"diversity": quality["diversity"],
|
| 1329 |
+
"mean_l2": quality["mean_l2"],
|
| 1330 |
+
"ais_log_weight_variance": ais_var,
|
| 1331 |
+
"mcmc_tau_int": diagnostics["tau_int"],
|
| 1332 |
+
"mcmc_tau_std": diagnostics["tau_std"], # heterogeneity
|
| 1333 |
+
"mcmc_tau_max": diagnostics["tau_max"], # worst-case chain
|
| 1334 |
+
"mcmc_tau_min": diagnostics["tau_min"], # fastest chain
|
| 1335 |
+
"mcmc_ess": diagnostics["ess"],
|
| 1336 |
+
"mcmc_ess_per_chain": diagnostics["n_eff_per_chain"],
|
| 1337 |
+
"mcmc_r_hat": diagnostics["r_hat"],
|
| 1338 |
+
"mcmc_acf_len": diagnostics["autocorr_len"],
|
| 1339 |
+
"prof_sample_file": prof_filename,
|
| 1340 |
+
"sample_file": sample_filename
|
| 1341 |
+
})
|
| 1342 |
+
|
| 1343 |
+
plt.figure(figsize=(10, 12))
|
| 1344 |
+
|
| 1345 |
+
# Flip rate
|
| 1346 |
+
|
| 1347 |
+
plt.subplot(3, 1, 1)
|
| 1348 |
+
plt.plot(last_model.flip_hist, linewidth=2, label="Flip rate")
|
| 1349 |
+
plt.plot(last_model.c_hist, linestyle="--", linewidth=2, label="Adaptive reference")
|
| 1350 |
+
plt.title("Microscopic Flip Dynamics")
|
| 1351 |
+
plt.ylabel("Flip Rate")
|
| 1352 |
+
plt.grid(alpha=0.3)
|
| 1353 |
+
plt.legend()
|
| 1354 |
+
|
| 1355 |
+
# Temperature evolution
|
| 1356 |
+
|
| 1357 |
+
plt.subplot(3, 1, 2)
|
| 1358 |
+
plt.plot(last_model.temp_hist, linewidth=2)
|
| 1359 |
+
plt.title(r"Global Temperature $T$")
|
| 1360 |
+
plt.ylabel("Temperature")
|
| 1361 |
+
plt.grid(alpha=0.3)
|
| 1362 |
+
|
| 1363 |
+
# Micro vs Macro
|
| 1364 |
+
|
| 1365 |
+
plt.subplot(3, 1, 3)
|
| 1366 |
+
plt.plot(last_model.T_micro_hist, linewidth=2, label=r"$T_{micro}$")
|
| 1367 |
+
plt.plot(last_model.T_macro_hist, linewidth=2, label=r"$T_{macro}$")
|
| 1368 |
+
plt.title("Micro–Macro Temperature Decomposition")
|
| 1369 |
+
plt.xlabel("Epoch")
|
| 1370 |
+
plt.ylabel("Temperature Components")
|
| 1371 |
+
plt.grid(alpha=0.3)
|
| 1372 |
+
plt.legend()
|
| 1373 |
+
|
| 1374 |
+
plt.tight_layout()
|
| 1375 |
+
plt.savefig("srtrbm_core_dynamics.pdf", bbox_inches="tight")
|
| 1376 |
+
plt.close()
|
| 1377 |
+
|
| 1378 |
+
# Energy & Spectral Diagnostics
|
| 1379 |
+
|
| 1380 |
+
plt.figure(figsize=(10, 14))
|
| 1381 |
+
|
| 1382 |
+
# Weight norm
|
| 1383 |
+
|
| 1384 |
+
plt.subplot(4, 1, 1)
|
| 1385 |
+
plt.plot(last_model.weight_norm_hist, linewidth=2)
|
| 1386 |
+
plt.title("Weight Norm Evolution")
|
| 1387 |
+
plt.ylabel(r"$||W||$")
|
| 1388 |
+
plt.grid(alpha=0.3)
|
| 1389 |
+
|
| 1390 |
+
# Effective beta
|
| 1391 |
+
|
| 1392 |
+
plt.subplot(4, 1, 2)
|
| 1393 |
+
plt.plot(last_model.beta_eff_hist, linewidth=2)
|
| 1394 |
+
plt.title(r"Effective Inverse Temperature $\beta_{\mathrm{eff}}$")
|
| 1395 |
+
plt.ylabel(r"$\beta_{\mathrm{eff}}$")
|
| 1396 |
+
plt.grid(alpha=0.3)
|
| 1397 |
+
|
| 1398 |
+
# Free energy comparison
|
| 1399 |
+
|
| 1400 |
+
plt.subplot(4, 1, 3)
|
| 1401 |
+
plt.plot(last_model.F_data_hist, linewidth=2, label=r"$F_{data}$")
|
| 1402 |
+
plt.plot(last_model.F_model_hist, linewidth=2, label=r"$F_{model}$")
|
| 1403 |
+
plt.title("Free Energy: Data vs Model")
|
| 1404 |
+
plt.ylabel("Free Energy")
|
| 1405 |
+
plt.grid(alpha=0.3)
|
| 1406 |
+
plt.legend()
|
| 1407 |
+
|
| 1408 |
+
# Spectral beta
|
| 1409 |
+
|
| 1410 |
+
plt.subplot(4, 1, 4)
|
| 1411 |
+
plt.plot(last_model.spectral_beta_hist)
|
| 1412 |
+
|
| 1413 |
+
plt.axhline(1.0, linestyle="--")
|
| 1414 |
+
|
| 1415 |
+
plt.title(r"Spectral Inverse Temperature $\beta_{spectral}$")
|
| 1416 |
+
plt.xlabel("Epoch")
|
| 1417 |
+
plt.ylabel(r"$\beta_{spectral}$")
|
| 1418 |
+
plt.grid(alpha=0.3)
|
| 1419 |
+
|
| 1420 |
+
plt.tight_layout()
|
| 1421 |
+
plt.savefig("srtrbm_energy_diagnostics.pdf", bbox_inches="tight")
|
| 1422 |
+
plt.close()
|
| 1423 |
+
|
| 1424 |
+
sample_refined = sample_filename.replace(".png", "_refined.png")
|
| 1425 |
+
prof_refined = prof_filename.replace(".png", "_refined.png")
|
| 1426 |
+
|
| 1427 |
+
samples_for_phase = last_model.generate_ensemble_samples(n_chains=1200, steps=10000)
|
| 1428 |
+
samples_for_phase = samples_for_phase.view(-1, 1, 28, 28)
|
| 1429 |
+
|
| 1430 |
+
chi = susceptibility(samples_for_phase)
|
| 1431 |
+
|
| 1432 |
+
llm_metrics = {
|
| 1433 |
+
"temperature": float(final_temperature),
|
| 1434 |
+
"spectral_gain": float(gain_spectral),
|
| 1435 |
+
"energy_gap": float(energy_stats["energy_gap"]),
|
| 1436 |
+
|
| 1437 |
+
"quality": {
|
| 1438 |
+
"diversity": float(quality["diversity"]),
|
| 1439 |
+
"entropy": float(quality["pixel_entropy"])
|
| 1440 |
+
},
|
| 1441 |
+
|
| 1442 |
+
"sampling": {
|
| 1443 |
+
"mcmc_tau_int": float(diagnostics["tau_int"])
|
| 1444 |
+
},
|
| 1445 |
+
|
| 1446 |
+
"trend": {
|
| 1447 |
+
"temp_slope": float(last_model.temp_hist[-1] - last_model.temp_hist[0]),
|
| 1448 |
+
"beta_slope": float(last_model.beta_eff_hist[-1] - last_model.beta_eff_hist[0])
|
| 1449 |
+
},
|
| 1450 |
+
|
| 1451 |
+
"phase": {
|
| 1452 |
+
"susceptibility": float(chi)
|
| 1453 |
+
},
|
| 1454 |
+
|
| 1455 |
+
"history": {
|
| 1456 |
+
"gap_trend": float(last_model.F_gap_hist[-1] - last_model.F_gap_hist[0]),
|
| 1457 |
+
"entropy_trend": float(last_model.persistent_div_hist[-1] - last_model.persistent_div_hist[0]),
|
| 1458 |
+
"temp_trend": float(last_model.temp_hist[-1] - last_model.temp_hist[0]),
|
| 1459 |
+
"beta_trend": float(last_model.beta_eff_hist[-1] - last_model.beta_eff_hist[0]),
|
| 1460 |
+
|
| 1461 |
+
"gap_std": float(np.std(last_model.F_gap_hist)),
|
| 1462 |
+
"entropy_std": float(np.std(last_model.persistent_div_hist)),
|
| 1463 |
+
"temp_std": float(np.std(last_model.temp_hist)),
|
| 1464 |
+
|
| 1465 |
+
"current": {
|
| 1466 |
+
"entropy": float(last_model.persistent_div_hist[-1]),
|
| 1467 |
+
"temperature": float(last_model.temp_hist[-1]),
|
| 1468 |
+
"beta": float(last_model.beta_eff_hist[-1]),
|
| 1469 |
+
"gap": float(last_model.F_gap_hist[-1])
|
| 1470 |
+
},
|
| 1471 |
+
|
| 1472 |
+
"learning_signal": float(last_model.delta_w_hist[-1]),
|
| 1473 |
+
|
| 1474 |
+
"stagnation": bool(last_model.delta_w_hist[-1] < 1e-4),
|
| 1475 |
+
"learning_active": bool(last_model.delta_w_hist[-1] > 1e-4),
|
| 1476 |
+
|
| 1477 |
+
"cooling": bool(last_model.temp_hist[-1] < last_model.temp_hist[0]),
|
| 1478 |
+
"heating": bool(last_model.temp_hist[-1] > last_model.temp_hist[0]),
|
| 1479 |
+
|
| 1480 |
+
"trend_signature": [
|
| 1481 |
+
float(np.sign(last_model.F_gap_hist[-1] - last_model.F_gap_hist[0])),
|
| 1482 |
+
float(np.sign(last_model.persistent_div_hist[-1] - last_model.persistent_div_hist[0])),
|
| 1483 |
+
float(np.sign(last_model.temp_hist[-1] - last_model.temp_hist[0]))
|
| 1484 |
+
]
|
| 1485 |
+
}
|
| 1486 |
+
}
|
| 1487 |
+
|
| 1488 |
+
subprocess.run([
|
| 1489 |
+
"python3",
|
| 1490 |
+
"supplement/cluster.py",
|
| 1491 |
+
sample_refined,
|
| 1492 |
+
sample_refined.replace("_refined.png", "_perfect.png")
|
| 1493 |
+
])
|
| 1494 |
+
|
| 1495 |
+
subprocess.run([
|
| 1496 |
+
"python3",
|
| 1497 |
+
"supplement/cluster.py",
|
| 1498 |
+
prof_refined,
|
| 1499 |
+
prof_refined.replace("_refined.png", "_perfect.png")
|
| 1500 |
+
])
|
| 1501 |
+
|
| 1502 |
+
prof_perfect = prof_refined.replace("_refined.png", "_perfect.png")
|
| 1503 |
+
|
| 1504 |
+
output_json = f"llm_output_gpu{Gpu_Id}_seed{seed_round}.json"
|
| 1505 |
+
|
| 1506 |
+
consequence = Evaluate(
|
| 1507 |
+
llm_metrics,
|
| 1508 |
+
prof_refined,
|
| 1509 |
+
prof_perfect,
|
| 1510 |
+
client=client
|
| 1511 |
+
)
|
| 1512 |
+
|
| 1513 |
+
analysis_signal = ANASIS(consequence.get("analysis", ""))
|
| 1514 |
+
|
| 1515 |
+
print(f"[ANALYSIS SIGNAL] {analysis_signal:.3f}")
|
| 1516 |
+
|
| 1517 |
+
with open(output_json, "w", encoding="utf-8") as f:
|
| 1518 |
+
json.dump(consequence, f, indent=2, ensure_ascii=False)
|
| 1519 |
+
|
| 1520 |
+
print(f"[LLM OUTPUT SAVED] → {output_json}")
|
| 1521 |
+
|
| 1522 |
+
print("\n[LLM RESULT]\n")
|
| 1523 |
+
|
| 1524 |
+
json_text = json.dumps(consequence, indent=2, ensure_ascii=False)
|
| 1525 |
+
|
| 1526 |
+
for line in json_text.split("\n"):
|
| 1527 |
+
wrapped_lines = textwrap.wrap(line, width=90) or [""]
|
| 1528 |
+
|
| 1529 |
+
for word_line in wrapped_lines:
|
| 1530 |
+
print(word_line)
|
| 1531 |
+
sys.stdout.flush()
|
| 1532 |
+
time.sleep(0.005)
|
| 1533 |
+
|
| 1534 |
+
|
| 1535 |
+
# Main Section
|
| 1536 |
+
|
| 1537 |
+
if __name__ == "__main__":
|
| 1538 |
+
mp.set_start_method("spawn", force=True)
|
| 1539 |
+
available_gpus = torch.cuda.device_count()
|
| 1540 |
+
|
| 1541 |
+
seeds = [1]
|
| 1542 |
+
|
| 1543 |
+
if available_gpus < len(seeds):
|
| 1544 |
+
raise RuntimeError("Not enough GPUs available.")
|
| 1545 |
+
|
| 1546 |
+
manager = mp.Manager()
|
| 1547 |
+
results = manager.list()
|
| 1548 |
+
|
| 1549 |
+
processes = []
|
| 1550 |
+
|
| 1551 |
+
for gpu_id, seed in enumerate(seeds):
|
| 1552 |
+
p = mp.Process(
|
| 1553 |
+
target=worker,
|
| 1554 |
+
args=(gpu_id, seed, results)
|
| 1555 |
+
)
|
| 1556 |
+
|
| 1557 |
+
p.start()
|
| 1558 |
+
|
| 1559 |
+
processes.append(p)
|
| 1560 |
+
|
| 1561 |
+
for p in processes:
|
| 1562 |
+
p.join()
|
| 1563 |
+
|
| 1564 |
+
print("Hybrid Thermodynamic RBM Final Results")
|
| 1565 |
+
|
| 1566 |
+
for result in list(results):
|
| 1567 |
+
print(f"Seed: {result['seed']} | GPU: {result['gpu']}")
|
| 1568 |
+
print(f"Final Temperature : {result['temperature']:.6f}")
|
| 1569 |
+
print(f"Weight Norm : {result['weight_norm']:.6f}")
|
| 1570 |
+
print(f"Spectral Gain : {result['spectral_gain']:.6f}")
|
| 1571 |
+
print(f"Train Log-Likelihood : {result['train_log_likelihood']:.6f}")
|
| 1572 |
+
print(f"Test Log-Likelihood : {result['test_log_likelihood']:.6f}")
|
| 1573 |
+
print(f"Train Pseudo-Likelihood : {result['train_pseudo_likelihood']:.6f}")
|
| 1574 |
+
print(f"Test Pseudo-Likelihood : {result['test_pseudo_likelihood']:.6f}")
|
| 1575 |
+
print(f"Reconstruction MSE : {result['reconstruction_mse']:.6f}")
|
| 1576 |
+
print(f"Reconstruction Accuracy : {result['reconstruction_accuracy']:.6f}")
|
| 1577 |
+
print(f"Mean Data Energy : {result['mean_data_energy']:.6f}")
|
| 1578 |
+
print(f"Mean Model Energy : {result['mean_model_energy']:.6f}")
|
| 1579 |
+
print(f"Energy Gap : {result['energy_gap']:.6f}")
|
| 1580 |
+
print(f"GPU Energy Used : {result['gpu_energy']:.2f} J")
|
| 1581 |
+
print(f"Log Partition (AIS) : {result['logZ']:.6f}")
|
| 1582 |
+
print(f"AIS ESS : {result['ais_ess']:.2f}")
|
| 1583 |
+
print(f"AIS Log-Weight Variance : {result['ais_log_weight_variance']:.6f}")
|
| 1584 |
+
print(f"MCMC tau_int : {result['mcmc_tau_int']:.2f}")
|
| 1585 |
+
print(f"MCMC tau_std : {result['mcmc_tau_std']:.2f}")
|
| 1586 |
+
print(f"MCMC tau_max : {result['mcmc_tau_max']:.2f}")
|
| 1587 |
+
print(f"MCMC tau_min : {result['mcmc_tau_min']:.2f}")
|
| 1588 |
+
print(f"MCMC ESS : {result['mcmc_ess']:.2f}")
|
| 1589 |
+
print(f"MCMC ESS / chain : {result['mcmc_ess_per_chain']:.2f}")
|
| 1590 |
+
print(f"MCMC R-hat : {result['mcmc_r_hat']:.4f}")
|
| 1591 |
+
print(f"MCMC ACF length : {result['mcmc_acf_len']}")
|
| 1592 |
+
print(f"Pixel Entropy : {result['pixel_entropy']:.6f}")
|
| 1593 |
+
print(f"Sample Diversity : {result['diversity']:.6f}")
|
| 1594 |
+
print(f"Mean Distribution L2 : {result['mean_l2']:.6f}")
|
| 1595 |
+
print(f"Professional Samples File: {result['prof_sample_file']}")
|
| 1596 |
+
print(f"Generated Samples File : {result['sample_file']}")
|
| 1597 |
+
|
| 1598 |
+
print()
|
| 1599 |
+
|
| 1600 |
+
# MYRA: A Feedback-Controlled Thermodynamic RBM for Hybrid Intelligence
|
stan.dgts
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7e46c191e61e65b2c38514ccbd3d25519528ed02c7a30598e7b95171674bbe3f
|
| 3 |
+
size 188641741
|
supplement/cluster.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import cv2
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 11 |
+
print("Device:", device)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ResidualBlock(nn.Module):
|
| 15 |
+
def __init__(self, in_ch, out_ch, stride=1):
|
| 16 |
+
super().__init__()
|
| 17 |
+
|
| 18 |
+
self.first_convolution = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
|
| 19 |
+
self.first_normalization = nn.BatchNorm2d(out_ch)
|
| 20 |
+
|
| 21 |
+
self.second_convolution = nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False)
|
| 22 |
+
self.second_normalization = nn.BatchNorm2d(out_ch)
|
| 23 |
+
|
| 24 |
+
if in_ch != out_ch or stride != 1:
|
| 25 |
+
self.residual_projection = nn.Sequential(
|
| 26 |
+
nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False),
|
| 27 |
+
nn.BatchNorm2d(out_ch)
|
| 28 |
+
)
|
| 29 |
+
else:
|
| 30 |
+
self.residual_projection = nn.Identity()
|
| 31 |
+
|
| 32 |
+
def forward(self, xerox):
|
| 33 |
+
identity = self.residual_projection(xerox)
|
| 34 |
+
|
| 35 |
+
out = self.first_convolution(xerox)
|
| 36 |
+
out = self.first_normalization(out)
|
| 37 |
+
out = F.relu(out)
|
| 38 |
+
|
| 39 |
+
out = self.second_convolution(out)
|
| 40 |
+
out = self.second_normalization(out)
|
| 41 |
+
|
| 42 |
+
out += identity
|
| 43 |
+
out = F.relu(out)
|
| 44 |
+
|
| 45 |
+
return out
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class MNISTEmbeddingModel(nn.Module):
|
| 49 |
+
def __init__(self):
|
| 50 |
+
super().__init__()
|
| 51 |
+
|
| 52 |
+
self.first_residual_stage = ResidualBlock(1, 32, stride=1)
|
| 53 |
+
self.second_residual_stage = ResidualBlock(32, 64, stride=2)
|
| 54 |
+
self.third_residual_stage = ResidualBlock(64, 128, stride=2)
|
| 55 |
+
|
| 56 |
+
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
|
| 57 |
+
|
| 58 |
+
self.embedding_projection_layer = nn.Linear(128, 256)
|
| 59 |
+
self.classification_layer = nn.Linear(256, 10, bias=False)
|
| 60 |
+
|
| 61 |
+
def forward(self, xcode):
|
| 62 |
+
xcode = self.first_residual_stage(xcode)
|
| 63 |
+
xcode = self.second_residual_stage(xcode)
|
| 64 |
+
xcode = self.third_residual_stage(xcode)
|
| 65 |
+
|
| 66 |
+
xcode = self.global_pool(xcode)
|
| 67 |
+
xcode = torch.flatten(xcode, 1)
|
| 68 |
+
|
| 69 |
+
embedding = F.normalize(self.embedding_projection_layer(xcode), dim=1)
|
| 70 |
+
|
| 71 |
+
W = F.normalize(self.classification_layer.weight, dim=1)
|
| 72 |
+
logit = embedding @ W.T
|
| 73 |
+
|
| 74 |
+
return logit
|
| 75 |
+
|
| 76 |
+
def get_embedding(self, x):
|
| 77 |
+
x = self.first_residual_stage(x)
|
| 78 |
+
x = self.second_residual_stage(x)
|
| 79 |
+
x = self.third_residual_stage(x)
|
| 80 |
+
|
| 81 |
+
x = self.global_pool(x)
|
| 82 |
+
x = torch.flatten(x, 1)
|
| 83 |
+
|
| 84 |
+
embedding = F.normalize(self.embedding_projection_layer(x), dim=1)
|
| 85 |
+
return embedding
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
model = MNISTEmbeddingModel().to(device)
|
| 89 |
+
|
| 90 |
+
ckpt = torch.load("zeta_mnist_hybrid.pt", map_location=device, weights_only=False)
|
| 91 |
+
|
| 92 |
+
if "model_state" in ckpt:
|
| 93 |
+
raw_state = ckpt["model_state"]
|
| 94 |
+
else:
|
| 95 |
+
raw_state = ckpt
|
| 96 |
+
|
| 97 |
+
mapped = {}
|
| 98 |
+
|
| 99 |
+
for k, v in raw_state.items():
|
| 100 |
+
nk = k
|
| 101 |
+
|
| 102 |
+
nk = nk.replace("block1", "first_residual_stage")
|
| 103 |
+
nk = nk.replace("block2", "second_residual_stage")
|
| 104 |
+
nk = nk.replace("block3", "third_residual_stage")
|
| 105 |
+
|
| 106 |
+
nk = nk.replace("conv1", "first_convolution")
|
| 107 |
+
nk = nk.replace("conv2", "second_convolution")
|
| 108 |
+
|
| 109 |
+
nk = nk.replace("bn1", "first_normalization")
|
| 110 |
+
nk = nk.replace("bn2", "second_normalization")
|
| 111 |
+
|
| 112 |
+
nk = nk.replace("shortcut", "residual_projection")
|
| 113 |
+
|
| 114 |
+
nk = nk.replace("embedding", "embedding_projection_layer")
|
| 115 |
+
nk = nk.replace("fc", "classification_layer")
|
| 116 |
+
|
| 117 |
+
nk = nk.replace("fc1", "embedding_projection_layer")
|
| 118 |
+
nk = nk.replace("fc2", "classification_layer")
|
| 119 |
+
|
| 120 |
+
mapped[nk] = v
|
| 121 |
+
|
| 122 |
+
model.load_state_dict(mapped, strict=True)
|
| 123 |
+
model.eval()
|
| 124 |
+
|
| 125 |
+
data = torch.load("stan.dgts", map_location="cpu")
|
| 126 |
+
images = data["images"].float()
|
| 127 |
+
labels = data["labels"]
|
| 128 |
+
|
| 129 |
+
if images.max() > 1:
|
| 130 |
+
images /= 255.0
|
| 131 |
+
|
| 132 |
+
reference_bank = []
|
| 133 |
+
|
| 134 |
+
SAMPLES_PER_CLASS = 6
|
| 135 |
+
|
| 136 |
+
for d in range(10):
|
| 137 |
+
cls = images[labels == d]
|
| 138 |
+
|
| 139 |
+
center = cls.mean(dim=0, keepdim=True)
|
| 140 |
+
dists = ((cls - center) ** 2).mean(dim=(1, 2))
|
| 141 |
+
|
| 142 |
+
best = torch.argsort(dists)[:SAMPLES_PER_CLASS]
|
| 143 |
+
reference_bank.append(cls[best].to(device))
|
| 144 |
+
|
| 145 |
+
print(", ↘\n")
|
| 146 |
+
|
| 147 |
+
reference_embeddings = []
|
| 148 |
+
|
| 149 |
+
with torch.no_grad():
|
| 150 |
+
for d in range(10):
|
| 151 |
+
emb = model.get_embedding(reference_bank[d].unsqueeze(1))
|
| 152 |
+
reference_embeddings.append(emb)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def normalize_digit(patchX):
|
| 156 |
+
patchX = torch.tensor(patchX).float()
|
| 157 |
+
|
| 158 |
+
threshold_omega = patchX.mean()
|
| 159 |
+
maskY = patchX > threshold_omega
|
| 160 |
+
|
| 161 |
+
if maskY.sum() == 0:
|
| 162 |
+
return patchX.numpy()
|
| 163 |
+
|
| 164 |
+
coordinates0 = maskY.nonzero()
|
| 165 |
+
y0, x0 = coordinates0.min(dim=0).values
|
| 166 |
+
y1, x1 = coordinates0.max(dim=0).values + 1
|
| 167 |
+
|
| 168 |
+
digits = patchX[y0:y1, x0:x1]
|
| 169 |
+
|
| 170 |
+
digits = digits.unsqueeze(0).unsqueeze(0)
|
| 171 |
+
digits = F.interpolate(digits, size=(20, 20), mode='bilinear', align_corners=False)
|
| 172 |
+
digits = digits.squeeze()
|
| 173 |
+
|
| 174 |
+
canvasX = torch.zeros(28, 28)
|
| 175 |
+
canvasX[4:24, 4:24] = digits
|
| 176 |
+
|
| 177 |
+
return canvasX.numpy()
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def process_image(input_paths, output_paths):
|
| 181 |
+
image = Image.open(input_paths).convert("L")
|
| 182 |
+
image = np.array(image).astype(np.float32) / 255.0
|
| 183 |
+
|
| 184 |
+
patch_size = 29
|
| 185 |
+
step = 30
|
| 186 |
+
|
| 187 |
+
patches = []
|
| 188 |
+
coordinates = []
|
| 189 |
+
|
| 190 |
+
for y1 in range(0, image.shape[0] - patch_size + 1, step):
|
| 191 |
+
for x1 in range(0, image.shape[1] - patch_size + 1, step):
|
| 192 |
+
|
| 193 |
+
patch = image[y1:y1 + 28, x1:x1 + 28]
|
| 194 |
+
|
| 195 |
+
if patch.mean() < 0.005:
|
| 196 |
+
continue
|
| 197 |
+
|
| 198 |
+
patch = normalize_digit(patch)
|
| 199 |
+
|
| 200 |
+
patches.append(patch)
|
| 201 |
+
coordinates.append((y1, x1))
|
| 202 |
+
|
| 203 |
+
if len(patches) == 0:
|
| 204 |
+
print(f"{input_paths} → No valid patches")
|
| 205 |
+
return
|
| 206 |
+
|
| 207 |
+
patches = torch.from_numpy(np.stack(patches)).unsqueeze(1).to(device)
|
| 208 |
+
print(f"{input_paths} → Patches:", patches.shape)
|
| 209 |
+
|
| 210 |
+
with torch.no_grad():
|
| 211 |
+
logits_variable = model(patches)
|
| 212 |
+
|
| 213 |
+
probs = torch.softmax(logits_variable, dim=1)
|
| 214 |
+
conf, predictions = torch.max(probs, dim=1)
|
| 215 |
+
|
| 216 |
+
canvas = np.zeros(image.shape, dtype=np.float32)
|
| 217 |
+
mse_values = []
|
| 218 |
+
|
| 219 |
+
for integer, (y1, x1) in enumerate(coordinates):
|
| 220 |
+
|
| 221 |
+
top2 = torch.topk(probs[integer], 2)
|
| 222 |
+
|
| 223 |
+
if conf[integer] < 0.8 or (top2.values[0] - top2.values[1]) < 0.2:
|
| 224 |
+
patch_tensor = patches[integer].unsqueeze(0)
|
| 225 |
+
|
| 226 |
+
with torch.no_grad():
|
| 227 |
+
emb_patch = model.get_embedding(patch_tensor)
|
| 228 |
+
|
| 229 |
+
sims = []
|
| 230 |
+
for domino in range(10):
|
| 231 |
+
emb_bank = reference_embeddings[domino]
|
| 232 |
+
sims_matrix = torch.matmul(emb_patch, emb_bank.T)
|
| 233 |
+
sim = torch.logsumexp(sims_matrix / 0.1, dim=1).item()
|
| 234 |
+
sims.append(sim)
|
| 235 |
+
|
| 236 |
+
digit = int(np.argmax(sims))
|
| 237 |
+
else:
|
| 238 |
+
digit = predictions[integer].item()
|
| 239 |
+
|
| 240 |
+
bank = reference_bank[digit]
|
| 241 |
+
patch = patches[integer] # (1,28,28)
|
| 242 |
+
|
| 243 |
+
distance = ((patch - bank) ** 2).mean(dim=(1, 2))
|
| 244 |
+
|
| 245 |
+
best_idx = torch.argmin(distance)
|
| 246 |
+
|
| 247 |
+
best_img = bank[best_idx].cpu().numpy()
|
| 248 |
+
|
| 249 |
+
_, bw = cv2.threshold(best_img, 0.2, 1.0, cv2.THRESH_BINARY)
|
| 250 |
+
|
| 251 |
+
num_labels, elastic, stats, _ = cv2.connectedComponentsWithStats((bw * 255).astype(np.uint8))
|
| 252 |
+
|
| 253 |
+
clean = np.zeros_like(best_img)
|
| 254 |
+
|
| 255 |
+
for into in range(1, num_labels):
|
| 256 |
+
if stats[into, cv2.CC_STAT_AREA] > 20:
|
| 257 |
+
clean[elastic == into] = 1.0
|
| 258 |
+
|
| 259 |
+
best_img = clean
|
| 260 |
+
|
| 261 |
+
with torch.no_grad():
|
| 262 |
+
emb_patch = model.get_embedding(patch.unsqueeze(0))
|
| 263 |
+
|
| 264 |
+
emb_best = model.get_embedding(
|
| 265 |
+
torch.from_numpy(best_img).to(device).unsqueeze(0).unsqueeze(0)
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
sim = torch.matmul(emb_patch, emb_best.T).item()
|
| 269 |
+
|
| 270 |
+
alpha = (sim + 1) / 2
|
| 271 |
+
|
| 272 |
+
original = patch.squeeze(0).cpu().numpy()
|
| 273 |
+
|
| 274 |
+
blended = alpha * best_img + (1 - alpha) * original
|
| 275 |
+
|
| 276 |
+
canvas[y1:y1 + 28, x1:x1 + 28] = blended
|
| 277 |
+
|
| 278 |
+
original = patch.squeeze(0).cpu().numpy()
|
| 279 |
+
|
| 280 |
+
mse = ((original - best_img) ** 2).mean()
|
| 281 |
+
mse_values.append(mse)
|
| 282 |
+
|
| 283 |
+
print(f"{input_paths} → Avg MSE:", sum(mse_values) / len(mse_values))
|
| 284 |
+
|
| 285 |
+
output = (canvas * 255).astype(np.uint8)
|
| 286 |
+
Image.fromarray(output).save(output_paths)
|
| 287 |
+
|
| 288 |
+
print(f"Saved: {output_paths}\n")
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
if __name__ == "__main__":
|
| 292 |
+
if len(sys.argv) != 3:
|
| 293 |
+
print("Usage: python3 cluster.py <input> <output>")
|
| 294 |
+
sys.exit(1)
|
| 295 |
+
|
| 296 |
+
input_path = sys.argv[1]
|
| 297 |
+
|
| 298 |
+
output_path = sys.argv[2]
|
| 299 |
+
|
| 300 |
+
process_image(input_path, output_path)
|
yaml/perception.yaml
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
core:
|
| 2 |
+
|
| 3 |
+
image_semantics:
|
| 4 |
+
title: "IMAGE SEMANTICS (X vs Y)"
|
| 5 |
+
|
| 6 |
+
description: >
|
| 7 |
+
X represents the refined model output, reflecting what the system has learned.
|
| 8 |
+
Y represents the perfect reference, defining the target structural manifold.
|
| 9 |
+
|
| 10 |
+
rules:
|
| 11 |
+
- "X = refined output (learned structural expression)."
|
| 12 |
+
- "Y = perfect reference (ground-truth structural manifold)."
|
| 13 |
+
- "Y defines the expected diversity and structural space."
|
| 14 |
+
- "Evaluation depends on whether distinct structural configurations in Y are preserved as distinct representations in X."
|
| 15 |
+
|
| 16 |
+
structural_mode_principle:
|
| 17 |
+
priority: "hard_constraint"
|
| 18 |
+
title: "STRUCTURAL MODE PRINCIPLE"
|
| 19 |
+
|
| 20 |
+
description: >
|
| 21 |
+
Modes are defined by probabilistic structural equivalence, not superficial visual similarity.
|
| 22 |
+
|
| 23 |
+
rules:
|
| 24 |
+
- "Each sample consists of local structural elements (strokes, curves, segments, junctions)."
|
| 25 |
+
- "Two samples belong to the same mode if their elements align with high probability across the structure."
|
| 26 |
+
- "Consistent structural differences imply different modes."
|
| 27 |
+
- "Minor noise or insignificant variations do not define new modes."
|
| 28 |
+
- "Visual similarity is a weak signal and must not override structural or statistical evidence."
|
| 29 |
+
- "Mode identity depends on the distribution and organization of structural elements."
|
| 30 |
+
|
| 31 |
+
x_y_reasoning:
|
| 32 |
+
title: "STRUCTURAL COVERAGE REASONING"
|
| 33 |
+
|
| 34 |
+
description: >
|
| 35 |
+
Determine whether the learned representation (X) sufficiently captures the structural diversity defined by the reference manifold (Y).
|
| 36 |
+
|
| 37 |
+
rules:
|
| 38 |
+
- "If structures in X align with those in Y across diverse regions, modes are captured."
|
| 39 |
+
- "If multiple distinct structures in Y collapse into fewer representations in X, mode collapse or under-representation is likely."
|
| 40 |
+
- "If structural distinctions in Y are preserved in X, multiple modes are present."
|
| 41 |
+
- "If diversity in X is significantly lower than in Y, this indicates loss of modes."
|
| 42 |
+
|
| 43 |
+
refined_projection_principle:
|
| 44 |
+
title: "REFINED AS STRUCTURAL PROJECTION"
|
| 45 |
+
|
| 46 |
+
description: >
|
| 47 |
+
Refinement acts as a projection that reveals the true structural organization of samples.
|
| 48 |
+
|
| 49 |
+
rules:
|
| 50 |
+
- "Refinement exposes fine-grained structural variation."
|
| 51 |
+
- "Samples that appear similar before refinement may differ structurally after refinement."
|
| 52 |
+
- "If samples converge to similar structures after refinement, they likely belong to the same mode."
|
| 53 |
+
- "If they remain structurally distinct, they represent different modes."
|
zeta_mnist_hybrid.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:81f10f7049d6b7f0721eed7f0d16b0af15f741f6ea363a6baf30d3a62ae26d14
|
| 3 |
+
size 1370347
|