CHG / chg_package /examples /basic_example.py
guanwencan's picture
Upload 9 files
69bf174 verified
"""
Basic example demonstrating CHG algorithm usage
"""
import numpy as np
import matplotlib.pyplot as plt
from chg_algorithm import CHG, CHGOptimizer
def basic_regression_example():
"""Demonstrate basic regression with CHG"""
print("=== Basic CHG Regression Example ===")
# Generate synthetic 1D data for visualization
np.random.seed(42)
X_train = np.random.uniform(-3, 3, (50, 1))
y_train = np.sin(X_train.flatten()) + 0.1 * np.random.randn(50)
X_test = np.linspace(-4, 4, 100).reshape(-1, 1)
y_true = np.sin(X_test.flatten())
# Initialize and fit CHG model
model = CHG(input_dim=1, hidden_dim=16, num_heads=2)
pred_mean, pred_var = model.fit_predict(X_train, y_train, X_test)
pred_std = np.sqrt(pred_var)
# Print metrics
mse = np.mean((pred_mean - y_true)**2)
print(f"Mean Squared Error: {mse:.4f}")
print(f"Log Marginal Likelihood: {model.log_marginal_likelihood(X_train, y_train):.4f}")
# Visualization
plt.figure(figsize=(10, 6))
plt.scatter(X_train.flatten(), y_train, alpha=0.6, label='Training Data')
plt.plot(X_test.flatten(), y_true, 'r-', label='True Function')
plt.plot(X_test.flatten(), pred_mean, 'b-', label='CHG Prediction')
plt.fill_between(X_test.flatten(),
pred_mean - 2*pred_std,
pred_mean + 2*pred_std,
alpha=0.2, label='95% Confidence')
plt.xlabel('Input')
plt.ylabel('Output')
plt.title('CHG Gaussian Process Regression')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
def optimization_example():
"""Demonstrate parameter optimization"""
print("\n=== CHG Optimization Example ===")
# Generate data
np.random.seed(123)
X_train = np.random.randn(80, 2)
y_train = np.sum(X_train**2, axis=1) + 0.5 * np.random.randn(80)
# Initialize model and optimizer
model = CHG(input_dim=2, hidden_dim=20, num_heads=3)
optimizer = CHGOptimizer(model, learning_rate=0.01)
# Track optimization progress
lml_history = []
print("Optimizing CHG parameters...")
for epoch in range(20):
optimizer.step(X_train, y_train)
lml = model.log_marginal_likelihood(X_train, y_train)
lml_history.append(lml)
if epoch % 5 == 0:
print(f"Epoch {epoch:2d}: Log Marginal Likelihood = {lml:.4f}")
# Plot optimization progress
plt.figure(figsize=(8, 5))
plt.plot(lml_history, 'b-', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Log Marginal Likelihood')
plt.title('CHG Optimization Progress')
plt.grid(True, alpha=0.3)
plt.show()
print(f"Final Log Marginal Likelihood: {lml_history[-1]:.4f}")
def uncertainty_quantification_example():
"""Demonstrate uncertainty quantification capabilities"""
print("\n=== Uncertainty Quantification Example ===")
# Generate noisy data with outliers
np.random.seed(456)
X_train = np.random.uniform(-2, 2, (60, 1))
y_clean = 0.5 * X_train.flatten()**3 - X_train.flatten()
# Add noise and some outliers
noise = 0.2 * np.random.randn(60)
outlier_idx = np.random.choice(60, 5, replace=False)
noise[outlier_idx] += np.random.choice([-2, 2], 5) * 2 # Add outliers
y_train = y_clean + noise
X_test = np.linspace(-3, 3, 80).reshape(-1, 1)
# Fit CHG model
model = CHG(input_dim=1, hidden_dim=12, num_heads=2)
pred_mean, pred_var = model.fit_predict(X_train, y_train, X_test)
pred_std = np.sqrt(pred_var)
# Analyze uncertainties
high_uncertainty_idx = pred_std > np.percentile(pred_std, 75)
print(f"Percentage of high-uncertainty predictions: {np.mean(high_uncertainty_idx)*100:.1f}%")
print(f"Average prediction uncertainty: {np.mean(pred_std):.4f}")
print(f"Maximum prediction uncertainty: {np.max(pred_std):.4f}")
# Visualization
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.scatter(X_train.flatten(), y_train, alpha=0.7, c='red', label='Training Data (with outliers)')
plt.plot(X_test.flatten(), pred_mean, 'b-', linewidth=2, label='CHG Prediction')
plt.fill_between(X_test.flatten(),
pred_mean - 2*pred_std,
pred_mean + 2*pred_std,
alpha=0.3, label='95% Confidence')
plt.xlabel('Input')
plt.ylabel('Output')
plt.title('CHG Predictions with Uncertainty')
plt.legend()
plt.grid(True, alpha=0.3)
plt.subplot(1, 2, 2)
plt.plot(X_test.flatten(), pred_std, 'g-', linewidth=2)
plt.fill_between(X_test.flatten()[high_uncertainty_idx],
0, pred_std[high_uncertainty_idx],
alpha=0.4, color='orange',
label='High Uncertainty Regions')
plt.xlabel('Input')
plt.ylabel('Prediction Uncertainty (σ)')
plt.title('Uncertainty Estimation')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
if __name__ == "__main__":
# Run all examples
basic_regression_example()
optimization_example()
uncertainty_quantification_example()
print("\n=== All Examples Completed ===")