File size: 2,574 Bytes
3b4941f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""perturbation-conditioned flow map X_θ.

    X_θ(s, t, c_s, e_u) = c_s + (t - s) · v_θ(s, t, c_s, e_u),   0 ≤ s ≤ t ≤ 1
    ĉ1(u)              = X_θ(0, 1, c0, e_u)

v_θ is an mlp over [γ(s), γ(t), c_s, e_u], γ is a fourier time embedding.
the map is differentiable in e_u, which is what enables the inverse-design
jacobian ∇_{e_u} X_{0,1} used by reward guidance (Alg 4).
"""
from __future__ import annotations

import math

import torch
import torch.nn as nn


class FourierTime(nn.Module):
    """map a scalar t∈[0,1] to [t, sin(2πk t), cos(2πk t) ...]."""

    def __init__(self, n_freq: int = 6):
        super().__init__()
        self.register_buffer("freqs", 2 * math.pi * (2.0 ** torch.arange(n_freq)).float())

    def forward(self, t: torch.Tensor) -> torch.Tensor:
        t = t.view(-1, 1)
        ang = t * self.freqs.view(1, -1)
        return torch.cat([t, torch.sin(ang), torch.cos(ang)], dim=-1)


class FlowMap(nn.Module):
    def __init__(self, d_state: int, d_pert: int, hidden: int = 512, depth: int = 4,
                 n_freq: int = 6, dropout: float = 0.0):
        super().__init__()
        self.d_state = d_state
        self.d_pert = d_pert
        self.time = FourierTime(n_freq)
        t_dim = 1 + 2 * n_freq
        in_dim = 2 * t_dim + d_state + d_pert
        layers = [nn.Linear(in_dim, hidden), nn.SiLU()]
        for _ in range(depth - 1):
            layers += [nn.Linear(hidden, hidden), nn.SiLU()]
            if dropout > 0:
                layers += [nn.Dropout(dropout)]
        layers += [nn.Linear(hidden, d_state)]
        self.net = nn.Sequential(*layers)
        # zero-init final layer => v≈0 at init => X≈identity (stable warm start)
        nn.init.zeros_(self.net[-1].weight)
        nn.init.zeros_(self.net[-1].bias)

    def velocity(self, s, t, c_s, e_u):
        s = s.expand(c_s.shape[0]) if s.dim() == 0 else s
        t = t.expand(c_s.shape[0]) if t.dim() == 0 else t
        h = torch.cat([self.time(s), self.time(t), c_s, e_u], dim=-1)
        return self.net(h)

    def forward(self, s, t, c_s, e_u):
        """X_θ(s,t,c_s,e_u). s,t are scalars or (b,) tensors in [0,1]."""
        v = self.velocity(s, t, c_s, e_u)
        dt = (t - s)
        dt = dt.view(-1, 1) if torch.is_tensor(dt) and dt.dim() >= 1 else dt
        return c_s + dt * v

    def endpoint(self, c0, e_u):
        """ĉ1 = X_θ(0,1,c0,e_u)."""
        zeros = torch.zeros(c0.shape[0], device=c0.device)
        ones = torch.ones(c0.shape[0], device=c0.device)
        return self.forward(zeros, ones, c0, e_u)