File size: 2,463 Bytes
b1e25b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import argparse


mpl.rcParams['font.family']   = 'serif'
mpl.rcParams['font.serif']    = ['Georgia']
mpl.rcParams['font.size']     = 20
mpl.rcParams['axes.titlesize']= 20
mpl.rcParams['axes.labelsize']= 18
mpl.rcParams['xtick.labelsize']=16
mpl.rcParams['ytick.labelsize']=16
# no legend, so no need to set legend.fontsize

def plot_two_loss_curves(
    csv_file1,
    csv_file2,
    title="Loss Comparison on Qwen3-8B",
    dataset1_name="Dataset1",
    dataset2_name="Dataset2"
):
    # Read CSV files
    df1 = pd.read_csv(csv_file1)
    df2 = pd.read_csv(csv_file2)

    # Check columns
    for df, path in ((df1, csv_file1), (df2, csv_file2)):
        if 'Step' not in df.columns or 'Loss' not in df.columns:
            raise ValueError(f"Missing 'Step' or 'Loss' columns in {path}")

    # Create figure
    plt.figure(figsize=(12, 8))

    # Plot two lines with softer colors
    plt.plot(df1['Step'], df1['Loss'],
             color='#1f77b4', linewidth=2.5)  # steel blue
    plt.plot(df2['Step'], df2['Loss'],
             color='#2ca02c', linewidth=2.5)  # medium sea green

    # Title and labels
    plt.title(title, fontweight='bold')
    plt.xlabel('Steps', fontweight='bold')
    plt.ylabel('Loss',  fontweight='bold')

    # Grid
    plt.grid(True, linestyle='--', alpha=0.7)

    # Layout
    plt.tight_layout(pad=3.0)

    # Save
    plt.savefig('loss_comparison_qwen38b.svg', format='svg')
    plt.savefig('loss_comparison.png', dpi=300)

    # Display
    plt.show()

    print("Saved: loss_comparison.svg, loss_comparison.png")


def main():
    parser = argparse.ArgumentParser(description='Plot comparison of two training loss curves')
    parser.add_argument('csv_file1', help='Path to the first CSV file')
    parser.add_argument('csv_file2', help='Path to the second CSV file')
    parser.add_argument('--title', default='Training Loss Comparison', help='Title for the plot')
    parser.add_argument('--dataset1-name', default='Original Dataset', help='Name for the first dataset')
    parser.add_argument('--dataset2-name', default='Revised Dataset', help='Name for the second dataset')
    
    args = parser.parse_args()
    
    plot_two_loss_curves(
        args.csv_file1,
        args.csv_file2,
        title=args.title,
        dataset1_name=args.dataset1_name,
        dataset2_name=args.dataset2_name
    )


if __name__ == "__main__":
    main()