import matplotlib.pyplot as plt import os import numpy as np results_dir = "/storage/ice-shared/ae8803che/hxue/data/world_model/results" dataset_name = "language_table" def load_results(label): path = os.path.join(results_dir, f"mse_results_{dataset_name}_{label}.txt") if not os.path.exists(path): print(f"File not found: {path}") return None steps = [] means = [] p25s = [] p75s = [] with open(path, 'r') as f: header = f.readline().strip().split(',') has_percentiles = "P25" in [h.strip() for h in header] for line in f: parts = line.strip().split(',') if not parts or len(parts) < 2: continue try: steps.append(int(parts[0])) means.append(float(parts[1])) if has_percentiles and len(parts) >= 4: p25s.append(float(parts[2])) p75s.append(float(parts[3])) else: # Fallback if percentiles missing p25s.append(float(parts[1])) p75s.append(float(parts[1])) except ValueError: continue # Sort by step idx = np.argsort(steps) return { 'steps': np.array(steps)[idx], 'means': np.array(means)[idx], 'p25s': np.array(p25s)[idx], 'p75s': np.array(p75s)[idx] } plt.figure(figsize=(12, 7)) configs = [ {"label": "10steps", "name": "10 Steps", "color": "red", "marker": "x"}, {"label": "20steps", "name": "20 Steps", "color": "green", "marker": "d"}, {"label": "50steps", "name": "50 Steps", "color": "blue", "marker": "o"}, {"label": "100steps", "name": "100 Steps", "color": "purple", "marker": "s"}, ] for config in configs: data = load_results(config["label"]) if data: plt.plot(data["steps"], data["means"], label=config["name"], color=config["color"], marker=config["marker"], linewidth=2, zorder=3) # Better visualization for error bars (lower alpha + thin boundary lines) if not np.allclose(data["p25s"], data["means"]): plt.fill_between(data["steps"], data["p25s"], data["p75s"], color=config["color"], alpha=0.08, zorder=2) # Add thin subtle lines for the boundaries of the error region plt.plot(data["steps"], data["p25s"], color=config["color"], linestyle=':', linewidth=0.8, alpha=0.4, zorder=2) plt.plot(data["steps"], data["p75s"], color=config["color"], linestyle=':', linewidth=0.8, alpha=0.4, zorder=2) plt.title(f"Comparison of Inference Steps - {dataset_name.replace('_', ' ').title()}", fontsize=14) plt.xlabel("Training Steps", fontsize=12) plt.ylabel("Mean RGB MSE (Full Trajectory)", fontsize=12) plt.legend(fontsize=10) plt.grid(True, linestyle='--', alpha=0.7) plt.yscale('log') # Log scale often helps see differences in small MSE values output_path = os.path.join(results_dir, "final_comparison_steps.png") plt.savefig(output_path, dpi=300, bbox_inches='tight') print(f"Plot saved to: {output_path}") # Also create a non-log version plt.yscale('linear') output_path_linear = os.path.join(results_dir, "final_comparison_steps_linear.png") plt.savefig(output_path_linear, dpi=300, bbox_inches='tight') print(f"Linear plot saved to: {output_path_linear}")