File size: 2,771 Bytes
f17ae24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
import matplotlib.pyplot as plt
import numpy as np
import os
import sys

# Add project root to sys.path
sys.path.append("/storage/ice-shared/ae8803che/hxue/data/world_model")
from wm.model.diffusion.flow_matching import FlowMatchScheduler

def plot_sigma_curve():
    scheduler = FlowMatchScheduler()
    # Set 1000 steps for training as requested
    scheduler.set_timesteps(num_inference_steps=1000, training=True, shift=5.0)
    
    # In our implementation, self.sigmas contains the noise levels
    # self.timesteps = self.sigmas * 1000
    
    # The user wants sigma wrt t=0-999. 
    # Since we use self.timesteps to map to sigmas, let's plot the relationship.
    
    # Actually, let's show how the "shift" makes the mapping non-linear.
    # Linear timesteps would be 0, 1, 2, ..., 1000.
    # Shifted timesteps are what we store in scheduler.timesteps.
    
    t_indices = np.arange(1000)
    sigmas = scheduler.sigmas.numpy()
    timesteps = scheduler.timesteps.numpy()
    
    # Reverse to show from t=0 (clean) to t=1000 (noise) if necessary, 
    # but let's follow the array order which is usually noise to data or data to noise.
    # In Wan, linspace(1, 0) means sigmas[0] is noise, sigmas[-1] is data.
    
    plt.figure(figsize=(10, 6))
    
    # Plot 1: Sigma vs Index
    plt.plot(t_indices, sigmas, label='Sigma (Noise Level)', color='blue')
    plt.title("Sigma ($\sigma$) vs Step Index (0-999)\nWan Shift=5.0, num_steps=1000")
    plt.xlabel("Index")
    plt.ylabel("Sigma Value (0=Clean, 1=Noise)")
    plt.grid(True, which='both', linestyle='--', alpha=0.5)
    
    # Add a reference line for linear mapping (no shift)
    linear_sigmas = np.linspace(1.0, 0.0, 1000)
    plt.plot(t_indices, linear_sigmas, 'r--', alpha=0.5, label='Linear (No Shift)')
    
    plt.legend()
    
    output_path = "/storage/ice-shared/ae8803che/hxue/data/world_model/results/test_flow_matching/sigma_vs_index.png"
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    plt.savefig(output_path)
    print(f"Plot saved to {output_path}")

    # Plot 2: Training Weight Distribution
    plt.figure(figsize=(10, 6))
    weights = scheduler.linear_timesteps_weights.numpy()
    plt.plot(timesteps, weights, color='green', label='Training Weight')
    plt.title("Training Weight vs Training Timestep ($t$)\nGaussian-like Weighting")
    plt.xlabel("Training Timestep ($t \in [0, 1000]$)")
    plt.ylabel("Weight Value")
    plt.grid(True, which='both', linestyle='--', alpha=0.5)
    plt.legend()
    
    weight_path = "/storage/ice-shared/ae8803che/hxue/data/world_model/results/test_flow_matching/weight_vs_t.png"
    plt.savefig(weight_path)
    print(f"Plot saved to {weight_path}")

if __name__ == "__main__":
    plot_sigma_curve()