| | import re |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | import pandas as pd |
| | import argparse |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description='Parse training log and plot losses.') |
| | parser.add_argument('--log_file', type=str, default='logs/train_222379.out', help='Path to the log file') |
| | args = parser.parse_args() |
| |
|
| | log_path = args.log_file |
| | data = [] |
| |
|
| | |
| | |
| | pattern = re.compile(r'Epoch (\d+), Step (\d+), Total: ([\d\.]+), Flow: ([\d\.]+), Proj: ([\d\.]+)') |
| |
|
| | with open(log_path, 'r') as f: |
| | for line in f: |
| | match = pattern.search(line) |
| | if match: |
| | epoch = int(match.group(1)) |
| | step = int(match.group(2)) |
| | total = float(match.group(3)) |
| | flow = float(match.group(4)) |
| | proj = float(match.group(5)) |
| | data.append({ |
| | 'epoch': epoch, |
| | 'step': step, |
| | 'total': total, |
| | 'flow': flow, |
| | 'proj': proj |
| | }) |
| |
|
| | if not data: |
| | print("No valid log lines found.") |
| | return |
| |
|
| | df = pd.DataFrame(data) |
| |
|
| | |
| | stats = df[['total', 'flow', 'proj']].agg(['mean', 'std', 'min', 'max']) |
| | print("--- Loss Statistics ---") |
| | print(stats) |
| |
|
| | |
| | plt.figure(figsize=(12, 6)) |
| | plt.plot(df.index, df['total'], label='Total Loss', alpha=0.8, linewidth=1) |
| | plt.plot(df.index, df['flow'], label='Flow Loss', alpha=0.8, linewidth=1) |
| | plt.plot(df.index, df['proj'], label='Proj Loss', alpha=0.8, linewidth=1) |
| | |
| | plt.xlabel('Logging Steps') |
| | plt.ylabel('Loss Value') |
| | plt.title('Training Losses over Time') |
| | plt.legend() |
| | plt.grid(True, linestyle='--', alpha=0.7) |
| | plt.tight_layout() |
| | |
| | plot_path = 'loss_plot.png' |
| | plt.savefig(plot_path, dpi=300) |
| | print(f"\nPlot successfully generated and saved to {plot_path}") |
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|