UncheatableEval / test_fitting.py
Jellyfish042's picture
feat: replace linear scaling law with power law with offset
6725455
"""测试带偏置的幂律拟合功能"""
import numpy as np
from scipy.optimize import curve_fit
def power_law_with_offset(x, a, b, c):
"""带偏置的幂律函数: y = a * x^b + c"""
return a * np.power(x, b) + c
def fit_power_law_with_offset(x_values, y_values):
"""
使用带偏置的幂律拟合原始数据
返回: (params, raw_rmse, log_rmse, fit_x, fit_y)
"""
x_arr = np.array(x_values)
y_arr = np.array(y_values)
# 初始参数估计
# 使用简单的幂律拟合作为初始值
log_x = np.log10(x_arr)
log_y = np.log10(y_arr)
slope, intercept = np.polyfit(log_x, log_y, 1)
a_init = 10**intercept
b_init = slope
c_init = 0 # 偏置初始值设为0
try:
# 使用curve_fit进行非线性拟合
params, _ = curve_fit(
power_law_with_offset,
x_arr,
y_arr,
p0=[a_init, b_init, c_init],
maxfev=10000
)
a, b, c = params
# 计算预测值
y_pred = power_law_with_offset(x_arr, a, b, c)
# 计算原始空间 RMSE
raw_rmse = np.sqrt(np.mean((y_arr - y_pred) ** 2))
# 计算对数空间 RMSE
log_y_actual = np.log10(y_arr)
log_y_pred = np.log10(y_pred)
log_rmse = np.sqrt(np.mean((log_y_actual - log_y_pred) ** 2))
# 生成拟合曲线的点
x_min, x_max = min(x_values), max(x_values)
fit_x = np.linspace(x_min * 0.8, x_max * 1.2, 100)
fit_y = power_law_with_offset(fit_x, a, b, c)
return params, raw_rmse, log_rmse, fit_x, fit_y
except Exception as e:
print(f"Fitting failed: {e}")
# 如果拟合失败,返回简单幂律拟合结果
a = a_init
b = b_init
c = 0
params = (a, b, c)
y_pred = a * np.power(x_arr, b)
# 计算原始空间 RMSE
raw_rmse = np.sqrt(np.mean((y_arr - y_pred) ** 2))
# 计算对数空间 RMSE
log_y_actual = np.log10(y_arr)
log_y_pred = np.log10(y_pred)
log_rmse = np.sqrt(np.mean((log_y_actual - log_y_pred) ** 2))
x_min, x_max = min(x_values), max(x_values)
fit_x = np.linspace(x_min * 0.8, x_max * 1.2, 100)
fit_y = a * np.power(fit_x, b)
return params, raw_rmse, log_rmse, fit_x, fit_y
if __name__ == "__main__":
# 测试数据:模拟一些模型参数和压缩率的关系
# 假设真实关系为 y = 50 * x^(-0.1) + 10
x_test = np.array([1, 3, 7, 13, 20, 30])
y_true = 50 * np.power(x_test, -0.1) + 10
# 添加一些噪声
np.random.seed(42)
y_test = y_true + np.random.normal(0, 0.5, len(x_test))
print("测试数据:")
print(f"x: {x_test}")
print(f"y: {y_test}")
print()
# 进行拟合
params, raw_rmse, log_rmse, fit_x, fit_y = fit_power_law_with_offset(x_test.tolist(), y_test.tolist())
a, b, c = params
print("拟合结果:")
print(f"a = {a:.4f}")
print(f"b = {b:.4f}")
print(f"c = {c:.4f}")
print(f"Raw RMSE = {raw_rmse:.4f}")
print(f"Log-RMSE = {log_rmse:.4f}")
print()
print(f"拟合公式: y = {a:.2f} * x^{b:.3f} + {c:.2f}")
print()
print("真实参数: a=50, b=-0.1, c=10")
print("拟合成功!" if raw_rmse < 2.0 else "拟合可能需要调整")