world_model / wm /test /plot_sigma_curve.py
t1an's picture
Upload folder using huggingface_hub
f17ae24 verified
import torch
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
# Add project root to sys.path
sys.path.append("/storage/ice-shared/ae8803che/hxue/data/world_model")
from wm.model.diffusion.flow_matching import FlowMatchScheduler
def plot_sigma_curve():
scheduler = FlowMatchScheduler()
# Set 1000 steps for training as requested
scheduler.set_timesteps(num_inference_steps=1000, training=True, shift=5.0)
# In our implementation, self.sigmas contains the noise levels
# self.timesteps = self.sigmas * 1000
# The user wants sigma wrt t=0-999.
# Since we use self.timesteps to map to sigmas, let's plot the relationship.
# Actually, let's show how the "shift" makes the mapping non-linear.
# Linear timesteps would be 0, 1, 2, ..., 1000.
# Shifted timesteps are what we store in scheduler.timesteps.
t_indices = np.arange(1000)
sigmas = scheduler.sigmas.numpy()
timesteps = scheduler.timesteps.numpy()
# Reverse to show from t=0 (clean) to t=1000 (noise) if necessary,
# but let's follow the array order which is usually noise to data or data to noise.
# In Wan, linspace(1, 0) means sigmas[0] is noise, sigmas[-1] is data.
plt.figure(figsize=(10, 6))
# Plot 1: Sigma vs Index
plt.plot(t_indices, sigmas, label='Sigma (Noise Level)', color='blue')
plt.title("Sigma ($\sigma$) vs Step Index (0-999)\nWan Shift=5.0, num_steps=1000")
plt.xlabel("Index")
plt.ylabel("Sigma Value (0=Clean, 1=Noise)")
plt.grid(True, which='both', linestyle='--', alpha=0.5)
# Add a reference line for linear mapping (no shift)
linear_sigmas = np.linspace(1.0, 0.0, 1000)
plt.plot(t_indices, linear_sigmas, 'r--', alpha=0.5, label='Linear (No Shift)')
plt.legend()
output_path = "/storage/ice-shared/ae8803che/hxue/data/world_model/results/test_flow_matching/sigma_vs_index.png"
os.makedirs(os.path.dirname(output_path), exist_ok=True)
plt.savefig(output_path)
print(f"Plot saved to {output_path}")
# Plot 2: Training Weight Distribution
plt.figure(figsize=(10, 6))
weights = scheduler.linear_timesteps_weights.numpy()
plt.plot(timesteps, weights, color='green', label='Training Weight')
plt.title("Training Weight vs Training Timestep ($t$)\nGaussian-like Weighting")
plt.xlabel("Training Timestep ($t \in [0, 1000]$)")
plt.ylabel("Weight Value")
plt.grid(True, which='both', linestyle='--', alpha=0.5)
plt.legend()
weight_path = "/storage/ice-shared/ae8803che/hxue/data/world_model/results/test_flow_matching/weight_vs_t.png"
plt.savefig(weight_path)
print(f"Plot saved to {weight_path}")
if __name__ == "__main__":
plot_sigma_curve()