import numpy as np from scipy.optimize import curve_fit from scipy.stats import entropy, kurtosis, norm, skew, differential_entropy stats_config_default = { "sum": np.sum, "mean": np.mean, "std": np.std, "max": np.max, "min": np.min, "skew": skew, "kurtosis": kurtosis, "differential_entropy": differential_entropy, } weight_bins_default = np.linspace(-1.5, 1.5, 1201) sv_bins_default = np.linspace(0, 100, 100) bins_dict_default = { "w_bins": weight_bins_default, "sv_bins": sv_bins_default, } def _adaptive_bins(strategy, data, value_range): if strategy not in ("scott", "fd"): raise ValueError(f"Unknown binning strategy: '{strategy}'. Choose 'fixed', 'scott', or 'fd'.") if data is None: raise ValueError(f"strategy='{strategy}' requires data to be provided") data = np.asarray(data).ravel() if strategy == "scott": h = 3.49 * np.std(data) * len(data) ** (-1 / 3) else: # fd q75, q25 = np.percentile(data, [75, 25]) iqr = q75 - q25 h = 2.0 * iqr * len(data) ** (-1 / 3) if iqr > 0 else 3.49 * np.std(data) * len(data) ** (-1 / 3) lo, hi = value_range if value_range else (data.min(), data.max()) return np.arange(lo, hi + h, h) def make_weight_bins(strategy="fixed", data=None, n_bins=1200, value_range=(-1.5, 1.5)): if strategy == "fixed": return np.linspace(value_range[0], value_range[1], n_bins + 1) return _adaptive_bins(strategy, data, value_range) def make_sv_bins(strategy="fixed", data=None, n_bins=100, value_range=(0, 100)): if strategy == "fixed": # np.linspace(0, 100, 100) matches sv_bins_default (100 edges = 99 bins) return np.linspace(value_range[0], value_range[1], n_bins) return _adaptive_bins(strategy, data, value_range) def entropy_stat(h, centers): p = h["P_w"] h.update({"entropy": entropy(p) + np.log(centers[1] - centers[0])}) def kl_vs_standard_normal(h, centers): p = h["P_w"] q = norm.pdf(centers, 0, 1) h.update({"kl_vs_standard_normal": entropy(p, q)}) def kl_vs_empirical_normal(h, centers): mu, sigma = h["mean"], h["std"] p = h["P_w"] q = norm.pdf(centers, mu, sigma) h.update({"kl_vs_empirical_normal": entropy(p, q)}) def kl_normal_vs_standard(h, centers): mu, sigma = h["mean"], h["std"] h.update( {"kl_vs_empirical_normal": 0.5 * (sigma**2 + mu**2 - 1 - np.log(sigma**2))} ) def fit_normal(h, centers, n_sigma=1.5): p = h["P_w"] mu, sigma = h["mean"], h["std"] if np.isnan(mu) or np.isnan(sigma): h.update({"fit_mu": np.nan, "fit_sigma": np.nan}) mask = np.abs(centers - mu) <= n_sigma * sigma def gaussian(x, mu, sigma): return norm.pdf(x, mu, sigma) try: if mask.sum() > 1: popt, _ = curve_fit(gaussian, centers[mask], p[mask], p0=[mu, sigma]) else: popt = [np.nan, np.nan] except RuntimeError: popt = [np.nan, np.nan] h.update({"fit_mu": popt[0], "fit_sigma": popt[1]}) def sv_mean(h, svd_array): """Compute mean of singular values.""" h.update({"sv_mean": np.mean(svd_array)}) def sv_variance(h, svd_array): """Compute variance of singular values.""" h.update({"sv_variance": np.var(svd_array)}) def sv_skewness(h, svd_array): """Compute skewness of singular values.""" h.update({"sv_skewness": skew(svd_array)}) def sv_kurtosis_stat(h, svd_array): """Compute kurtosis of singular values.""" h.update({"sv_kurtosis": kurtosis(svd_array)}) def sv_sum(h, svd_array): """Compute sum of singular values (Σσ).""" h.update({"sv_sum": np.sum(svd_array)}) def sv_sum_squares(h, svd_array): """Compute sum of squared singular values (Σσ²).""" h.update({"sv_sum_squares": np.sum(svd_array**2)}) def participation_ratio(h, svd_array): """Compute participation ratio: PR = (Σσ)² / Σσ².""" sum_sv = np.sum(svd_array) sum_sv2 = np.sum(svd_array**2) pr = (sum_sv**2) / sum_sv2 if sum_sv2 > 0 else np.nan h.update({"participation_ratio": pr}) def normalized_participation_ratio(h, svd_array, d_head=None): """ Compute normalized participation ratio: PR / d_head. If d_head is not provided in h, it will be inferred from the length of the singular value array. """ sum_sv = np.sum(svd_array) sum_sv2 = np.sum(svd_array**2) pr = (sum_sv**2) / sum_sv2 if sum_sv2 > 0 else np.nan # Get d_head from h if available, otherwise use length of svd_array if d_head is None: d_head = h.get("d_head", len(svd_array)) npr = pr / d_head if d_head > 0 and not np.isnan(pr) else np.nan h.update({"normalized_participation_ratio": npr}) def spectral_entropy(h, svd_array): """ Compute spectral entropy: -Σ(p_i * log(p_i)) where p_i = σ_i² / Σσ². """ sv2 = svd_array**2 sum_sv2 = np.sum(sv2) if sum_sv2 > 0: p = sv2 / sum_sv2 p = p[p > 0] # Remove zeros to avoid log(0) se = -np.sum(p * np.log(p)) else: se = np.nan h.update({"spectral_entropy": se}) def condition_number(h, svd_array): """ Compute condition number: σ_max / σ_min. Only considers non-zero singular values. """ sv_nonzero = svd_array[svd_array > 1e-10] # Filter out numerical zeros if len(sv_nonzero) > 0: cn = sv_nonzero[0] / sv_nonzero[-1] if sv_nonzero[-1] > 0 else np.nan else: cn = np.nan h.update({"condition_number": cn}) def stable_rank(h, svd_array): """Compute stable rank: Σσ² / σ_max².""" sum_sv2 = np.sum(svd_array**2) sv_max = svd_array[0] if len(svd_array) > 0 else 0 sr = sum_sv2 / (sv_max**2) if sv_max > 0 else np.nan h.update({"stable_rank": sr}) normality_metrics = { "entropy": entropy_stat, "fit_normal": fit_normal, "kl_vs_empirical_normal": kl_vs_empirical_normal, } singular_value_metrics = { "sv_mean": sv_mean, "sv_variance": sv_variance, "sv_skewness": sv_skewness, "sv_kurtosis": sv_kurtosis_stat, "sv_sum": sv_sum, "sv_sum_squares": sv_sum_squares, "participation_ratio": participation_ratio, "normalized_participation_ratio": normalized_participation_ratio, "spectral_entropy": spectral_entropy, "condition_number": condition_number, "stable_rank": stable_rank, } PYTHIA_REVISIONS = [ "step0", "step1", "step2", "step4", "step8", "step16", "step32", "step64", "step128", "step256", "step512", ] + [f"step{step}" for step in range(1000, 144000, 1000)] PYTHIA_MODELS = [ "pythia-70m-deduped", "pythia-160m-deduped", "pythia-410m-deduped", "pythia-1b-deduped", "pythia-1.4b-deduped", "pythia-2.8b-deduped", "pythia-6.9b-deduped", "pythia-12b-deduped", ] def get_model_versions(model_name): if model_name in PYTHIA_MODELS: return PYTHIA_REVISIONS return []