Coconut-MNIST / experiments /appendix_learning_curves.py
ymlin105's picture
feat: complete Hybrid SVD-CNN system with interactive app
b25b9cb
# Appendix A – Learning Curves
import matplotlib.pyplot as plt
import pickle
import os
import numpy as np
from src import config
# --- Configuration ---
BLUE_DEEP = "#5E81AC"
ORANGE = "#D08770"
def plot_curves(history, title, filename):
epochs = range(1, len(history['train_loss']) + 1)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
# Loss
ax1.plot(epochs, history['train_loss'], label='Train', color=BLUE_DEEP, marker='o', markersize=4)
ax1.plot(epochs, history['val_loss'], label='Val', color=ORANGE, marker='s', markersize=4)
ax1.set_title('Loss'); ax1.legend(); ax1.grid(True, alpha=0.3)
# Acc
ax2.plot(epochs, history['train_acc'], label='Train', color=BLUE_DEEP, marker='o', markersize=4)
ax2.plot(epochs, history['val_acc'], label='Val', color=ORANGE, marker='s', markersize=4)
ax2.set_title('Accuracy'); ax2.legend(); ax2.grid(True, alpha=0.3)
plt.suptitle(title, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(config.RESULTS_DIR, filename), dpi=300)
plt.close()
print(f"Saved {filename}")
def main():
for f_name, label, out in [
('cnn_10class_history.pkl', 'MNIST 10-class', 'fig_14_learning_curves.png'),
('cnn_fashion_history.pkl', 'Fashion-MNIST', 'fig_15_learning_curves_fashion.png')
]:
path = os.path.join(config.MODELS_DIR, f_name)
if os.path.exists(path):
with open(path, 'rb') as f:
history = pickle.load(f)
plot_curves(history, label, out)
else:
print(f"Skipping {f_name}: Not found.")
if __name__ == "__main__":
main()