| | import matplotlib.pyplot as plt |
| | import os |
| | import numpy as np |
| |
|
| | |
| | results_dir = "/storage/ice-shared/ae8803che/hxue/data/world_model/results" |
| | dataset_name = "language_table" |
| |
|
| | def load_results(label): |
| | path = os.path.join(results_dir, f"mse_results_{dataset_name}_{label}.txt") |
| | if not os.path.exists(path): |
| | print(f"File not found: {path}") |
| | return [], [], [], [] |
| | steps = [] |
| | means = [] |
| | p25s = [] |
| | p75s = [] |
| | with open(path, 'r') as f: |
| | next(f) |
| | for line in f: |
| | parts = line.strip().split(',') |
| | if len(parts) >= 2: |
| | steps.append(int(parts[0])) |
| | means.append(float(parts[1])) |
| | if len(parts) >= 4: |
| | p25s.append(float(parts[2])) |
| | p75s.append(float(parts[3])) |
| | else: |
| | p25s.append(float(parts[1])) |
| | p75s.append(float(parts[1])) |
| | return steps, means, p25s, p75s |
| |
|
| | |
| | plt.figure(figsize=(10, 6)) |
| | colors = ['r', 'b', 'g', 'm'] |
| | markers = ['x', 'd', 'o', 's'] |
| | labels = ["10steps", "20steps", "50steps", "100steps"] |
| | names = ["10 Steps", "20 Steps", "50 Steps", "100 Steps"] |
| |
|
| | for label, name, color, marker in zip(labels, names, colors, markers): |
| | s, m, p25, p75 = load_results(label) |
| | if s: |
| | plt.plot(s, m, marker=marker, color=color, label=name) |
| | plt.fill_between(s, p25, p75, color=color, alpha=0.1) |
| |
|
| | plt.title("Comparison: Inference Steps (10, 20, 50, 100) with 25-75th Percentiles") |
| | plt.xlabel("Training Steps") |
| | plt.ylabel("Mean RGB MSE") |
| | plt.legend() |
| | plt.grid(True) |
| | plt.savefig(os.path.join(results_dir, "comparison_steps.png")) |
| | print(f"Generated comparison_steps.png") |
| |
|
| | |
| | s_clean, m_clean, p25_clean, p75_clean = load_results("50steps") |
| | s_noise, m_noise, p25_noise, p75_noise = load_results("50steps_noise0.1") |
| |
|
| | if s_clean and s_noise: |
| | plt.figure(figsize=(10, 6)) |
| | plt.plot(s_clean, m_clean, marker='o', color='b', label="50 Steps (Clean)") |
| | plt.fill_between(s_clean, p25_clean, p75_clean, color='b', alpha=0.1) |
| | |
| | plt.plot(s_noise, m_noise, marker='^', color='r', label="50 Steps (Noise 0.1)") |
| | plt.fill_between(s_noise, p25_noise, p75_noise, color='r', alpha=0.1) |
| | |
| | plt.title("Effect of First-Frame Noise (50 Steps) with 25-75th Percentiles") |
| | plt.xlabel("Training Steps") |
| | plt.ylabel("Mean RGB MSE") |
| | plt.legend() |
| | plt.grid(True) |
| | plt.savefig(os.path.join(results_dir, "comparison_50_vs_50noise.png")) |
| | print(f"Generated comparison_50_vs_50noise.png") |
| | else: |
| | print("Skipping noise comparison plot as data is not yet available.") |
| |
|