| | import matplotlib.pyplot as plt |
| | import os |
| | import numpy as np |
| |
|
| | results_dir = "/storage/ice-shared/ae8803che/hxue/data/world_model/results" |
| | dataset_name = "language_table" |
| |
|
| | def load_results(label): |
| | path = os.path.join(results_dir, f"mse_results_{dataset_name}_{label}.txt") |
| | if not os.path.exists(path): |
| | print(f"File not found: {path}") |
| | return None |
| | |
| | steps = [] |
| | means = [] |
| | p25s = [] |
| | p75s = [] |
| | |
| | with open(path, 'r') as f: |
| | header = f.readline().strip().split(',') |
| | has_percentiles = "P25" in [h.strip() for h in header] |
| | |
| | for line in f: |
| | parts = line.strip().split(',') |
| | if not parts or len(parts) < 2: |
| | continue |
| | try: |
| | steps.append(int(parts[0])) |
| | means.append(float(parts[1])) |
| | if has_percentiles and len(parts) >= 4: |
| | p25s.append(float(parts[2])) |
| | p75s.append(float(parts[3])) |
| | else: |
| | |
| | p25s.append(float(parts[1])) |
| | p75s.append(float(parts[1])) |
| | except ValueError: |
| | continue |
| | |
| | |
| | idx = np.argsort(steps) |
| | return { |
| | 'steps': np.array(steps)[idx], |
| | 'means': np.array(means)[idx], |
| | 'p25s': np.array(p25s)[idx], |
| | 'p75s': np.array(p75s)[idx] |
| | } |
| |
|
| | plt.figure(figsize=(12, 7)) |
| |
|
| | configs = [ |
| | {"label": "10steps", "name": "10 Steps", "color": "red", "marker": "x"}, |
| | {"label": "20steps", "name": "20 Steps", "color": "green", "marker": "d"}, |
| | {"label": "50steps", "name": "50 Steps", "color": "blue", "marker": "o"}, |
| | {"label": "100steps", "name": "100 Steps", "color": "purple", "marker": "s"}, |
| | ] |
| |
|
| | for config in configs: |
| | data = load_results(config["label"]) |
| | if data: |
| | plt.plot(data["steps"], data["means"], label=config["name"], |
| | color=config["color"], marker=config["marker"], linewidth=2, zorder=3) |
| | |
| | if not np.allclose(data["p25s"], data["means"]): |
| | plt.fill_between(data["steps"], data["p25s"], data["p75s"], |
| | color=config["color"], alpha=0.08, zorder=2) |
| | |
| | plt.plot(data["steps"], data["p25s"], color=config["color"], |
| | linestyle=':', linewidth=0.8, alpha=0.4, zorder=2) |
| | plt.plot(data["steps"], data["p75s"], color=config["color"], |
| | linestyle=':', linewidth=0.8, alpha=0.4, zorder=2) |
| |
|
| | plt.title(f"Comparison of Inference Steps - {dataset_name.replace('_', ' ').title()}", fontsize=14) |
| | plt.xlabel("Training Steps", fontsize=12) |
| | plt.ylabel("Mean RGB MSE (Full Trajectory)", fontsize=12) |
| | plt.legend(fontsize=10) |
| | plt.grid(True, linestyle='--', alpha=0.7) |
| | plt.yscale('log') |
| |
|
| | output_path = os.path.join(results_dir, "final_comparison_steps.png") |
| | plt.savefig(output_path, dpi=300, bbox_inches='tight') |
| | print(f"Plot saved to: {output_path}") |
| |
|
| | |
| | plt.yscale('linear') |
| | output_path_linear = os.path.join(results_dir, "final_comparison_steps_linear.png") |
| | plt.savefig(output_path_linear, dpi=300, bbox_inches='tight') |
| | print(f"Linear plot saved to: {output_path_linear}") |
| |
|