import matplotlib.pyplot as plt steps = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100] ur_llada = [1.0500, 1.0182, 1.0198, 1.0066, 1.0264, 1.0085, 1.0255, 1.0304, 1.0075, 1.0402] llada = [1.0669, 1.0600, 1.0537, 1.0298, 1.0715, 1.0673, 1.0642, 1.0550, 1.0575, 1.0695] urm = [1.0689, 1.0587, 1.0408, 1.0708, 1.0335, 1.0035, 1.0338, 1.0131, 1.0178, 1.0366] plt.figure(figsize=(10, 6)) plt.plot(steps, ur_llada, marker='o', label='UR-LLaDA (Scheduler)', color='blue') plt.plot(steps, llada, marker='s', label='LLaDA Baseline', color='red', linestyle='--') plt.plot(steps, urm, marker='^', label='URM (Reproduction)', color='green', linestyle='-.') plt.title('Training Loss over 100k Steps') plt.xlabel('Steps (x10k)') plt.ylabel('Validation Loss') plt.grid(True, linestyle=':', alpha=0.6) plt.legend() plt.tight_layout() plt.savefig('loss_curves_100k.png') print("Plot saved to loss_curves_100k.png")