import torch import matplotlib.pyplot as plt import numpy as np import os import sys # Add project root to sys.path sys.path.append("/storage/ice-shared/ae8803che/hxue/data/world_model") from wm.model.diffusion.flow_matching import FlowMatchScheduler def plot_sigma_curve(): scheduler = FlowMatchScheduler() # Set 1000 steps for training as requested scheduler.set_timesteps(num_inference_steps=1000, training=True, shift=5.0) # In our implementation, self.sigmas contains the noise levels # self.timesteps = self.sigmas * 1000 # The user wants sigma wrt t=0-999. # Since we use self.timesteps to map to sigmas, let's plot the relationship. # Actually, let's show how the "shift" makes the mapping non-linear. # Linear timesteps would be 0, 1, 2, ..., 1000. # Shifted timesteps are what we store in scheduler.timesteps. t_indices = np.arange(1000) sigmas = scheduler.sigmas.numpy() timesteps = scheduler.timesteps.numpy() # Reverse to show from t=0 (clean) to t=1000 (noise) if necessary, # but let's follow the array order which is usually noise to data or data to noise. # In Wan, linspace(1, 0) means sigmas[0] is noise, sigmas[-1] is data. plt.figure(figsize=(10, 6)) # Plot 1: Sigma vs Index plt.plot(t_indices, sigmas, label='Sigma (Noise Level)', color='blue') plt.title("Sigma ($\sigma$) vs Step Index (0-999)\nWan Shift=5.0, num_steps=1000") plt.xlabel("Index") plt.ylabel("Sigma Value (0=Clean, 1=Noise)") plt.grid(True, which='both', linestyle='--', alpha=0.5) # Add a reference line for linear mapping (no shift) linear_sigmas = np.linspace(1.0, 0.0, 1000) plt.plot(t_indices, linear_sigmas, 'r--', alpha=0.5, label='Linear (No Shift)') plt.legend() output_path = "/storage/ice-shared/ae8803che/hxue/data/world_model/results/test_flow_matching/sigma_vs_index.png" os.makedirs(os.path.dirname(output_path), exist_ok=True) plt.savefig(output_path) print(f"Plot saved to {output_path}") # Plot 2: Training Weight Distribution plt.figure(figsize=(10, 6)) weights = scheduler.linear_timesteps_weights.numpy() plt.plot(timesteps, weights, color='green', label='Training Weight') plt.title("Training Weight vs Training Timestep ($t$)\nGaussian-like Weighting") plt.xlabel("Training Timestep ($t \in [0, 1000]$)") plt.ylabel("Weight Value") plt.grid(True, which='both', linestyle='--', alpha=0.5) plt.legend() weight_path = "/storage/ice-shared/ae8803che/hxue/data/world_model/results/test_flow_matching/weight_vs_t.png" plt.savefig(weight_path) print(f"Plot saved to {weight_path}") if __name__ == "__main__": plot_sigma_curve()