File size: 2,253 Bytes
b47957e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import matplotlib.pyplot as plt
import re
import os
train_loss = []
val_loss = []
steps = []
val_steps = []
log_file_path = 'training.log'
if not os.path.exists(log_file_path):
print(f"File {log_file_path} not found. Please paste your training logs into this file first.")
# Create dummy if not exists for demo purposes so it doesn't crash immediately
# In real usage the user should provide the log
exit(1)
with open(log_file_path, 'r') as f:
for line in f:
# Parse training loss
# "iter 100: loss 1.8030"
if "iter" in line and "loss" in line and "time" in line:
parts = line.split()
# Handling formatting variability
try:
# expected format: iter 100: loss 1.8030, time ...
step_idx = parts.index('iter') + 1
loss_idx = parts.index('loss') + 1
step = int(parts[step_idx].replace(':', ''))
loss = float(parts[loss_idx].replace(',', ''))
train_loss.append(loss)
steps.append(step)
except ValueError:
continue
# Parse validation loss
# "step 500: train loss 1.2915, val loss 1.5601"
if "step" in line and "val loss" in line:
parts = line.split()
try:
step_idx = parts.index('step') + 1
val_loss_idx = parts.index('val') + 2 # "val loss X"
step = int(parts[step_idx].replace(':', ''))
v_loss = float(parts[val_loss_idx])
val_loss.append(v_loss)
val_steps.append(step)
except ValueError:
continue
if not steps:
print("No data parsed. Check log format.")
exit(1)
plt.figure(figsize=(10, 6))
plt.plot(steps, train_loss, label='Train Loss', alpha=0.6)
if val_steps:
plt.plot(val_steps, val_loss, label='Validation Loss', linewidth=3, color='red')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.title('RippleGPT Training Dynamics: Identifying Overfitting')
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig('loss_curve.png')
print("Plot saved to loss_curve.png")
# plt.show() # Disabled for headless environment
|