File size: 8,556 Bytes
fc7d689
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
from functools import partial

import equinox as eqx
import jax.numpy as jnp
import jax.random as jrn
import jax.tree_util as jtu
from jax import vmap
from tqdm import tqdm

from neural_fdm.models import AutoEncoderPiggy


def train_step_piggy(model, structure, optimizer, generator, opt_state, *, loss_fn, batch_size, key):
    """
    Update the parameters of an autoencoder piggy model on a batch of data for one step.

    Parameters
    ----------
    model: `eqx.Module`
        The model to train.
    structure: `jax_fdm.EquilibriumStructure`
        A structure with the discretization of the shape.
    optimizer: `optax.GradientTransformation`
        The optimizer to use for training.
    generator: `PointGenerator`
        The data generator.
    opt_state: `optax.GradientTransformationExtraArgs`
        The current optimizer state.
    loss_fn: `Callable`
        The loss function.
    batch_size: `int`
        The number of samples to generate in each batch.
    key: `jax.random.PRNGKey`
        The random key.

    Returns
    -------
    loss_vals: `dict` of `float`
        The values of the loss terms.
    model: `eqx.Module`
        The updated model.
    opt_state: `optax.GradientTransformationExtraArgs`
        The updated optimizer state.
    """
    # sample fresh data
    keys = jrn.split(key, batch_size)
    x = vmap(generator)(keys)

    # calculate updates for main
    val_grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=True)
    (loss, loss_vals), grads_main = val_grad_fn(
        model,
        structure,
        x,
        True,
        False
    )

    # calculate updates for piggy
    val_grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=True)
    (loss, loss_vals), grads_piggy = val_grad_fn(
        model,
        structure,
        x,
        True,
        True
    )

    # combine gradients
    grads = jtu.tree_map(lambda x, y: x + y, grads_main, grads_piggy)

    # apply updates
    updates, opt_state = optimizer.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)

    return loss_vals, model, opt_state


def train_step(model, structure, optimizer, generator, opt_state, *, loss_fn, batch_size, key):
    """
    Update the parameters of an autoencoder model on a batch of data for one step.

    Parameters
    ----------
    model: `eqx.Module`
        The model to train.
    structure: `jax_fdm.EquilibriumStructure`
        A structure with the discretization of the shape.
    optimizer: `optax.GradientTransformation`
        The optimizer to use for training.
    generator: `PointGenerator`
        The data generator.
    opt_state: `optax.GradientTransformationExtraArgs`
        The current optimizer state.
    loss_fn: `Callable`
        The loss function.
    batch_size: `int`
        The number of samples to generate in each batch.
    key: `jax.random.PRNGKey`
        The random key.

    Returns
    -------
    loss_vals: `dict` of `float`
        The values of the loss terms.
    model: `eqx.Module`
        The updated model.
    opt_state: `optax.GradientTransformationExtraArgs`
        The updated optimizer state.
    """
    # sample fresh data
    keys = jrn.split(key, batch_size)
    x = vmap(generator)(keys)

    # calculate updates
    val_grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=True)
    (loss, loss_vals), grads = val_grad_fn(model, structure, x, aux_data=True)

    # apply updates
    updates, opt_state = optimizer.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)

    return loss_vals, model, opt_state


def train_step_vae(model, structure, optimizer, generator, opt_state, *,
                   loss_fn, batch_size, key, beta):
    """Training step for VAE models with reparameterization sampling.

    Differs from train_step in two ways:
    1. Passes a PRNG key through the loss function for epsilon sampling.
    2. Accepts beta as a traced JAX float (not Python float) to avoid
       JIT recompilation at each step.

    Parameters
    ----------
    beta : jnp.float32
        Current KL weight. Must be a JAX array (traced), not a Python float,
        to prevent JIT recompilation at every training step.

    References
    ----------
    Kingma & Welling (2014): Reparameterization trick requires PRNG in forward pass.
    Fu et al. (2019): Beta varies per step via cyclical annealing.
    """
    # Split key: one for data generation, one for VAE sampling
    data_key, model_key = jrn.split(key)

    # Sample fresh data
    keys = jrn.split(data_key, batch_size)
    x = vmap(generator)(keys)

    # Create loss wrapper that includes the VAE key and beta
    def vae_loss_wrapper(model, structure, x, aux_data=True):
        from neural_fdm.losses import compute_loss_vae
        # Inject beta into loss_params dynamically
        _loss_params = dict(loss_fn.keywords.get("loss_params", {}))
        _loss_params.setdefault("vae", {})
        _loss_params["vae"]["beta"] = beta
        _loss_fn = loss_fn.keywords.get("loss_fn", None)
        return compute_loss_vae(
            model, structure, x, _loss_fn, _loss_params,
            aux_data=aux_data, key=model_key
        )

    val_grad_fn = eqx.filter_value_and_grad(vae_loss_wrapper, has_aux=True)
    (loss, loss_vals), grads = val_grad_fn(model, structure, x, aux_data=True)

    updates, opt_state = optimizer.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)

    return loss_vals, model, opt_state


def train_model(model, structure, optimizer, generator, *, loss_fn, num_steps, batch_size, key, callback=None):
    """
    Train a model over a number of steps.

    Parameters
    ----------
    model: `eqx.Module`
        The model to train.
    structure: `jax_fdm.EquilibriumStructure`
        A structure with the discretization of the shape.
    optimizer: `optax.GradientTransformation`
        The optimizer to use for training.
    generator: `PointGenerator`
        The data generator.
    loss_fn: `Callable`
        The loss function.
    num_steps: `int`
        The number of steps to train for (number of parameter updates).
    batch_size: `int`
        The number of samples to generate per batch.
    key: `jax.random.PRNGKey`
        The random key.
    callback: `Callable`, optional
        A callback function to call after each step.
        The callback function should take the following arguments:
        - model: `eqx.Module`
        - opt_state: `optax.GradientTransformationExtraArgs`
        - loss_vals: `dict` of `float`
        - step: `int`
    """
    # initial optimization step
    opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

    # detect VAE model
    from neural_fdm.variational import VariationalAutoEncoder
    is_vae = isinstance(model, VariationalAutoEncoder)

    # assemble train step
    if is_vae:
        from neural_fdm.variational import compute_beta_schedule
        # Extract VAE config from loss_fn partial keywords
        _lp = loss_fn.keywords.get("loss_params", {})
        vae_cfg = _lp.get("vae", {})
        beta_max = vae_cfg.get("beta_max", 1.0)
        cycle_length = vae_cfg.get("cycle_length", num_steps)
        warmup_ratio = vae_cfg.get("warmup_ratio", 0.5)

        train_step_fn = partial(train_step_vae, loss_fn=loss_fn)
        train_step_fn = eqx.filter_jit(train_step_fn)
    else:
        train_step_fn = train_step
        if isinstance(model, AutoEncoderPiggy):
            train_step_fn = train_step_piggy
        train_step_fn = partial(train_step_fn, loss_fn=loss_fn)
        train_step_fn = eqx.filter_jit(train_step_fn)

    # train
    loss_history = []
    for step in tqdm(range(num_steps)):

        # randomnesss
        key, _ = jrn.split(key)

        if is_vae:
            # Compute beta as JAX array to avoid JIT recompilation
            beta = jnp.float32(compute_beta_schedule(
                step, beta_max, cycle_length, warmup_ratio
            ))
            loss_vals, model, opt_state = train_step_fn(
                model, structure, optimizer, generator, opt_state,
                batch_size=batch_size, key=key, beta=beta,
            )
        else:
            # train step
            loss_vals, model, opt_state = train_step_fn(
                model, structure, optimizer, generator, opt_state,
                batch_size=batch_size, key=key,
            )

        # store loss values
        loss_history.append(loss_vals)

        # callback
        if callback:
            callback(model, opt_state, loss_vals, step)

    return model, loss_history