File size: 2,689 Bytes
f17ae24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import matplotlib.pyplot as plt
import os
import numpy as np

# Use absolute path
results_dir = "/storage/ice-shared/ae8803che/hxue/data/world_model/results"
dataset_name = "language_table"

def load_results(label):
    path = os.path.join(results_dir, f"mse_results_{dataset_name}_{label}.txt")
    if not os.path.exists(path):
        print(f"File not found: {path}")
        return [], [], [], []
    steps = []
    means = []
    p25s = []
    p75s = []
    with open(path, 'r') as f:
        next(f) # skip header
        for line in f:
            parts = line.strip().split(',')
            if len(parts) >= 2:
                steps.append(int(parts[0]))
                means.append(float(parts[1]))
                if len(parts) >= 4:
                    p25s.append(float(parts[2]))
                    p75s.append(float(parts[3]))
                else:
                    p25s.append(float(parts[1]))
                    p75s.append(float(parts[1]))
    return steps, means, p25s, p75s

# 1. Plot 50 vs 100 vs 20 vs 10
plt.figure(figsize=(10, 6))
colors = ['r', 'b', 'g', 'm']
markers = ['x', 'd', 'o', 's']
labels = ["10steps", "20steps", "50steps", "100steps"]
names = ["10 Steps", "20 Steps", "50 Steps", "100 Steps"]

for label, name, color, marker in zip(labels, names, colors, markers):
    s, m, p25, p75 = load_results(label)
    if s:
        plt.plot(s, m, marker=marker, color=color, label=name)
        plt.fill_between(s, p25, p75, color=color, alpha=0.1)

plt.title("Comparison: Inference Steps (10, 20, 50, 100) with 25-75th Percentiles")
plt.xlabel("Training Steps")
plt.ylabel("Mean RGB MSE")
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(results_dir, "comparison_steps.png"))
print(f"Generated comparison_steps.png")

# 2. Plot 100 vs 100+noise
s_clean, m_clean, p25_clean, p75_clean = load_results("50steps")
s_noise, m_noise, p25_noise, p75_noise = load_results("50steps_noise0.1")

if s_clean and s_noise:
    plt.figure(figsize=(10, 6))
    plt.plot(s_clean, m_clean, marker='o', color='b', label="50 Steps (Clean)")
    plt.fill_between(s_clean, p25_clean, p75_clean, color='b', alpha=0.1)
    
    plt.plot(s_noise, m_noise, marker='^', color='r', label="50 Steps (Noise 0.1)")
    plt.fill_between(s_noise, p25_noise, p75_noise, color='r', alpha=0.1)
    
    plt.title("Effect of First-Frame Noise (50 Steps) with 25-75th Percentiles")
    plt.xlabel("Training Steps")
    plt.ylabel("Mean RGB MSE")
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(results_dir, "comparison_50_vs_50noise.png"))
    print(f"Generated comparison_50_vs_50noise.png")
else:
    print("Skipping noise comparison plot as data is not yet available.")