| | import torch |
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | import os |
| | import sys |
| |
|
| | |
| | sys.path.append("/storage/ice-shared/ae8803che/hxue/data/world_model") |
| | from wm.model.diffusion.flow_matching import FlowMatchScheduler |
| |
|
| | def plot_sigma_curve(): |
| | scheduler = FlowMatchScheduler() |
| | |
| | scheduler.set_timesteps(num_inference_steps=1000, training=True, shift=5.0) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | t_indices = np.arange(1000) |
| | sigmas = scheduler.sigmas.numpy() |
| | timesteps = scheduler.timesteps.numpy() |
| | |
| | |
| | |
| | |
| | |
| | plt.figure(figsize=(10, 6)) |
| | |
| | |
| | plt.plot(t_indices, sigmas, label='Sigma (Noise Level)', color='blue') |
| | plt.title("Sigma ($\sigma$) vs Step Index (0-999)\nWan Shift=5.0, num_steps=1000") |
| | plt.xlabel("Index") |
| | plt.ylabel("Sigma Value (0=Clean, 1=Noise)") |
| | plt.grid(True, which='both', linestyle='--', alpha=0.5) |
| | |
| | |
| | linear_sigmas = np.linspace(1.0, 0.0, 1000) |
| | plt.plot(t_indices, linear_sigmas, 'r--', alpha=0.5, label='Linear (No Shift)') |
| | |
| | plt.legend() |
| | |
| | output_path = "/storage/ice-shared/ae8803che/hxue/data/world_model/results/test_flow_matching/sigma_vs_index.png" |
| | os.makedirs(os.path.dirname(output_path), exist_ok=True) |
| | plt.savefig(output_path) |
| | print(f"Plot saved to {output_path}") |
| |
|
| | |
| | plt.figure(figsize=(10, 6)) |
| | weights = scheduler.linear_timesteps_weights.numpy() |
| | plt.plot(timesteps, weights, color='green', label='Training Weight') |
| | plt.title("Training Weight vs Training Timestep ($t$)\nGaussian-like Weighting") |
| | plt.xlabel("Training Timestep ($t \in [0, 1000]$)") |
| | plt.ylabel("Weight Value") |
| | plt.grid(True, which='both', linestyle='--', alpha=0.5) |
| | plt.legend() |
| | |
| | weight_path = "/storage/ice-shared/ae8803che/hxue/data/world_model/results/test_flow_matching/weight_vs_t.png" |
| | plt.savefig(weight_path) |
| | print(f"Plot saved to {weight_path}") |
| |
|
| | if __name__ == "__main__": |
| | plot_sigma_curve() |
| |
|