File size: 3,467 Bytes
f17ae24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import matplotlib.pyplot as plt
import os
import numpy as np

results_dir = "/storage/ice-shared/ae8803che/hxue/data/world_model/results"
dataset_name = "language_table"

def load_results(label):
    path = os.path.join(results_dir, f"mse_results_{dataset_name}_{label}.txt")
    if not os.path.exists(path):
        print(f"File not found: {path}")
        return None
        
    steps = []
    means = []
    p25s = []
    p75s = []
    
    with open(path, 'r') as f:
        header = f.readline().strip().split(',')
        has_percentiles = "P25" in [h.strip() for h in header]
        
        for line in f:
            parts = line.strip().split(',')
            if not parts or len(parts) < 2:
                continue
            try:
                steps.append(int(parts[0]))
                means.append(float(parts[1]))
                if has_percentiles and len(parts) >= 4:
                    p25s.append(float(parts[2]))
                    p75s.append(float(parts[3]))
                else:
                    # Fallback if percentiles missing
                    p25s.append(float(parts[1]))
                    p75s.append(float(parts[1]))
            except ValueError:
                continue
                
    # Sort by step
    idx = np.argsort(steps)
    return {
        'steps': np.array(steps)[idx],
        'means': np.array(means)[idx],
        'p25s': np.array(p25s)[idx],
        'p75s': np.array(p75s)[idx]
    }

plt.figure(figsize=(12, 7))

configs = [
    {"label": "10steps", "name": "10 Steps", "color": "red", "marker": "x"},
    {"label": "20steps", "name": "20 Steps", "color": "green", "marker": "d"},
    {"label": "50steps", "name": "50 Steps", "color": "blue", "marker": "o"},
    {"label": "100steps", "name": "100 Steps", "color": "purple", "marker": "s"},
]

for config in configs:
    data = load_results(config["label"])
    if data:
        plt.plot(data["steps"], data["means"], label=config["name"], 
                 color=config["color"], marker=config["marker"], linewidth=2, zorder=3)
        # Better visualization for error bars (lower alpha + thin boundary lines)
        if not np.allclose(data["p25s"], data["means"]):
            plt.fill_between(data["steps"], data["p25s"], data["p75s"], 
                             color=config["color"], alpha=0.08, zorder=2)
            # Add thin subtle lines for the boundaries of the error region
            plt.plot(data["steps"], data["p25s"], color=config["color"], 
                     linestyle=':', linewidth=0.8, alpha=0.4, zorder=2)
            plt.plot(data["steps"], data["p75s"], color=config["color"], 
                     linestyle=':', linewidth=0.8, alpha=0.4, zorder=2)

plt.title(f"Comparison of Inference Steps - {dataset_name.replace('_', ' ').title()}", fontsize=14)
plt.xlabel("Training Steps", fontsize=12)
plt.ylabel("Mean RGB MSE (Full Trajectory)", fontsize=12)
plt.legend(fontsize=10)
plt.grid(True, linestyle='--', alpha=0.7)
plt.yscale('log') # Log scale often helps see differences in small MSE values

output_path = os.path.join(results_dir, "final_comparison_steps.png")
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"Plot saved to: {output_path}")

# Also create a non-log version
plt.yscale('linear')
output_path_linear = os.path.join(results_dir, "final_comparison_steps_linear.png")
plt.savefig(output_path_linear, dpi=300, bbox_inches='tight')
print(f"Linear plot saved to: {output_path_linear}")