import re import numpy as np import matplotlib.pyplot as plt import pandas as pd import argparse def main(): parser = argparse.ArgumentParser(description='Parse training log and plot losses.') parser.add_argument('--log_file', type=str, default='logs/train_222379.out', help='Path to the log file') args = parser.parse_args() log_path = args.log_file data = [] # Pattern to match lines like: # Epoch 70, Step 0, Total: 1.2872, Flow: 1.0775, Proj: 0.2097, Teacher: 0.0000 pattern = re.compile(r'Epoch (\d+), Step (\d+), Total: ([\d\.]+), Flow: ([\d\.]+), Proj: ([\d\.]+)') with open(log_path, 'r') as f: for line in f: match = pattern.search(line) if match: epoch = int(match.group(1)) step = int(match.group(2)) total = float(match.group(3)) flow = float(match.group(4)) proj = float(match.group(5)) data.append({ 'epoch': epoch, 'step': step, 'total': total, 'flow': flow, 'proj': proj }) if not data: print("No valid log lines found.") return df = pd.DataFrame(data) # Compute mean, std stats = df[['total', 'flow', 'proj']].agg(['mean', 'std', 'min', 'max']) print("--- Loss Statistics ---") print(stats) # Plotting plt.figure(figsize=(12, 6)) plt.plot(df.index, df['total'], label='Total Loss', alpha=0.8, linewidth=1) plt.plot(df.index, df['flow'], label='Flow Loss', alpha=0.8, linewidth=1) plt.plot(df.index, df['proj'], label='Proj Loss', alpha=0.8, linewidth=1) plt.xlabel('Logging Steps') plt.ylabel('Loss Value') plt.title('Training Losses over Time') plt.legend() plt.grid(True, linestyle='--', alpha=0.7) plt.tight_layout() plot_path = 'loss_plot.png' plt.savefig(plot_path, dpi=300) print(f"\nPlot successfully generated and saved to {plot_path}") if __name__ == '__main__': main()