sparse-cafm / src /util /moment_based.py
leharris3's picture
Minimal HF Space deployment with gradio 5.x fix
0917e8d
import torch
import numpy as np
from torch import Tensor
from scipy.stats import moment
def calc_moment_based_stats(samples: np.ndarray) -> dict:
"""
Calculate a panel of moment-based stats for SPM samples with shape [B, N, M].
- Ref: https://gwyddion.net/documentation/user-guide-en/statistical-analysis.html#stat-quantities
"Moment based quantities are expressed using integrals of the
height distribution function with some powers of height.
They include the familiar quantities"
Returns
---
1. Average value
2. RMS roughnes (sq)
3. RMS Mean roughness (Sa)
4. Skew (Ssk)
5. Excess kurtosis
"""
assert len(samples.shape) == 3, f"Error: expected sample to have shape [B, N, M]"
stats = {
"mean" : 0,
"mean_rms" : 0,
"mean_roughness": 0,
"skewness" : 0,
"kurtosis" : 0,
}
for sample in samples:
if isinstance(sample, Tensor):
sample = sample.detach().cpu().numpy()
sample = sample.flatten()
# calculate central moments
mu_2 = moment(sample, moment=2)
mu_3 = moment(sample, moment=3)
mu_4 = moment(sample, moment=4)
average_value = sample.mean()
# σ = μ_2 ^ 1/2
rms = mu_2 ** .5
mean_rms = rms.mean()
mean_roughness = abs(sample - sample.mean()).mean()
# γ_1 = (μ_3) / (μ_2 ^ (3/2))
skewness = mu_3 / (mu_2 ** (3/2))
# γ_2 = (μ_4) / (μ_2 ^ 2) - 3
kurtosis = (mu_4 / (mu_2 ** 2)) - 3
sample_stats = {
"mean" : average_value,
"mean_rms" : mean_rms,
"mean_roughness": mean_roughness,
"skewness" : skewness,
"kurtosis" : kurtosis,
}
for k in stats:
stats[k] += sample_stats[k]
# calculate average stats
for k in stats:
stats[k] = (stats[k] / samples.shape[0])
return stats
def cal_moment_based_errs(pred: np.ndarray, target: np.ndarray) -> dict:
pred_stats = calc_moment_based_stats(pred)
target_stats = calc_moment_based_stats(target)
errs = {}
for k in pred_stats:
ps = pred_stats[k]
ts = target_stats[k]
err = abs(ps - ts)
errs[k] = err
return errs
if __name__ == "__main__":
fp = "/playpen/mufan/levi/tianlong-chen-lab/sparse-cafm/data/raw-data/3-12-25/A1 512_ Height_Backward_020.npy"
s1 = Tensor(np.load(fp)).unsqueeze(0).numpy()
s2 = s1 + s1.mean()
from pprint import pprint
pprint(calc_moment_based_stats(s1), indent=4)