File size: 2,929 Bytes
4edc9aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np

# --- 1. The Neural Vector Field ---
# A simple MLP that takes (x, y, t) and outputs velocity (vx, vy)
class VectorField(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(3, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 2)  # Output: (vx, vy)
        )

    def forward(self, x, t):
        # Concatenate x (Batch, 2) and t (Batch, 1)
        if t.dim() == 0: t = t.expand(x.shape[0], 1)
        elif t.dim() == 1: t = t.view(-1, 1)
        
        xt = torch.cat([x, t], dim=1)
        return self.net(xt)

# --- 2. Setup Data and Training ---
model = VectorField()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Target distribution: Two Gaussian blobs centered at (-2, -2) and (2, 2)
def sample_data(batch_size):
    indices = torch.randint(0, 2, (batch_size,))
    centers = torch.tensor([[-2., -2.], [2., 2.]])
    noise = torch.randn(batch_size, 2) * 0.5
    return centers[indices] + noise

# Source distribution: Standard Gaussian centered at (0, 0)
def sample_source(batch_size):
    return torch.randn(batch_size, 2)

# --- 3. The Flow Matching Training Loop ---
print("Training Flow Matching Model...")
for step in range(2000):
    batch_size = 256
    
    # Sample endpoints
    x0 = sample_source(batch_size)  # Noise
    x1 = sample_data(batch_size)    # Data
    
    # Sample random times t ~ U[0, 1]
    t = torch.rand(batch_size, 1)
    
    # Compute the interpolation (linear path)
    # x_t = (1 - t) * x0 + t * x1
    x_t = (1 - t) * x0 + t * x1
    
    # Calculate the target velocity (conditional flow)
    # u_t = x1 - x0
    target_velocity = x1 - x0
    
    # Predict velocity with neural network
    pred_velocity = model(x_t, t)
    
    # Loss: MSE between predicted and target velocity
    loss = torch.mean((pred_velocity - target_velocity) ** 2)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if step % 500 == 0:
        print(f"Step {step}: Loss = {loss.item():.4f}")

# --- 4. Inference: Solving the ODE ---
# We solve dx/dt = v(x, t) using a simple Euler solver
print("\nSampling (solving ODE)...")
with torch.no_grad():
    x = sample_source(1000) # Start from noise
    dt = 0.01
    
    for t_step in np.arange(0, 1, dt):
        t_tensor = torch.full((x.shape[0], 1), t_step)
        velocity = model(x, t_tensor)
        x = x + velocity * dt  # Euler update

    # Visualization
    final_samples = x.numpy()
    plt.figure(figsize=(6, 6))
    plt.scatter(final_samples[:, 0], final_samples[:, 1], s=10, alpha=0.6, label="Generated")
    plt.title("Flow Matching Output (Approx. Data Dist.)")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig("flow_matching_output.png")
    plt.close()