| import matplotlib.pyplot as plt |
|
|
| steps = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100] |
| ur_llada = [1.0500, 1.0182, 1.0198, 1.0066, 1.0264, 1.0085, 1.0255, 1.0304, 1.0075, 1.0402] |
| llada = [1.0669, 1.0600, 1.0537, 1.0298, 1.0715, 1.0673, 1.0642, 1.0550, 1.0575, 1.0695] |
| urm = [1.0689, 1.0587, 1.0408, 1.0708, 1.0335, 1.0035, 1.0338, 1.0131, 1.0178, 1.0366] |
|
|
| plt.figure(figsize=(10, 6)) |
| plt.plot(steps, ur_llada, marker='o', label='UR-LLaDA (Scheduler)', color='blue') |
| plt.plot(steps, llada, marker='s', label='LLaDA Baseline', color='red', linestyle='--') |
| plt.plot(steps, urm, marker='^', label='URM (Reproduction)', color='green', linestyle='-.') |
|
|
| plt.title('Training Loss over 100k Steps') |
| plt.xlabel('Steps (x10k)') |
| plt.ylabel('Validation Loss') |
| plt.grid(True, linestyle=':', alpha=0.6) |
| plt.legend() |
| plt.tight_layout() |
|
|
| plt.savefig('loss_curves_100k.png') |
| print("Plot saved to loss_curves_100k.png") |
|
|