File size: 19,243 Bytes
4e48d38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
"""
Half-Life Regularizer for FDRA Oscillators

This implements the exact mathematical regularizer from the Cursor instructions:

## Regularizer 1: Log-Uniform Half-Life Prior (primary)

Target distribution: p(τ) ∝ 1/τ for τ ∈ [τ_min, τ_max]
This gives equal mass per temporal decade (log scale).

Loss:
    z_i = log(τ_i)
    μ = mean(z_i)
    σ² = mean((z_i - μ)²)
    
    μ* = (log(τ_min) + log(τ_max)) / 2
    σ²* = (log(τ_max) - log(τ_min))² / 12
    
    L_HL = α*(μ - μ*)² + β*(σ² - σ²*)²

## Regularizer 2: Long-Tail Survival Constraint (supporting)

Ensure existence of long-range oscillators:
    s_i = σ(k * (τ_i - γ*L))
    tail_mass = mean(s_i)
    L_tail = max(0, ρ - tail_mass)²

Where:
    γ = 0.5 (fraction of full context)
    ρ = 0.05 (minimum fraction of oscillators)
    k = 10.0 (sigmoid sharpness)

## Regularizer 3: Tau Bounds Constraint (CRITICAL FIX)

The moment-matching loss (L_HL) can be satisfied by a pathological bimodal
distribution with taus outside [tau_min, tau_max]. This creates oscillators
that are either useless (tau << 1) or extreme (tau >> L).

L_bounds = mean(relu(tau_min - tau_i)^2) + mean(relu(tau_i - tau_max)^2)

## Combined Loss

L_total = L_task + λ1 * L_HL + λ2 * L_tail + λ3 * L_bounds

Authors: Half-Life Regularization Implementation
Date: 2026-01-22
"""

import numpy as np
from typing import Dict, Tuple, Optional, Any
from dataclasses import dataclass
from pathlib import Path
import json
from datetime import datetime


@dataclass
class HalfLifeRegularizerConfig:
    """Configuration for half-life regularization."""
    
    # Task parameters
    sequence_length: int = 4096      # L - max sequence length
    tau_min: float = 1.0             # Minimum target half-life
    tau_max: float = 4096.0          # Maximum target half-life (= L)
    
    # Log-Uniform Prior coefficients
    alpha: float = 1.0               # Weight for mean constraint
    beta: float = 1.0                # Weight for variance constraint
    
    # Long-Tail Survival coefficients
    gamma: float = 0.5               # Fraction of full context for long-range
    rho: float = 0.05                # Minimum fraction of long-range oscillators
    k: float = 10.0                  # Sigmoid sharpness
    
    # Overall loss weights
    lambda1: float = 0.01            # Weight for L_HL in total loss
    lambda2: float = 0.01            # Weight for L_tail in total loss
    
    # NEW: Tau bound constraint (prevents pathological distributions)
    lambda3: float = 0.1             # Weight for L_bounds
    bound_sharpness: float = 5.0     # Sharpness of soft bound penalties


class HalfLifeRegularizer:
    """
    Half-Life Regularizer for FDRA Oscillator Banks.
    
    Prevents decay parameter collapse by regularizing the half-life
    distribution toward a log-uniform target.
    
    Usage:
        config = HalfLifeRegularizerConfig()
        regularizer = HalfLifeRegularizer(config)
        
        # During training:
        lambdas = oscillator_bank.lambdas
        loss, metrics = regularizer.compute(lambdas)
        
        # Add to total loss:
        total_loss = task_loss + loss
        
        # Log metrics:
        log(metrics)
    """
    
    def __init__(self, config: HalfLifeRegularizerConfig):
        self.config = config
        
        # Pre-compute target statistics
        z_min = np.log(config.tau_min)
        z_max = np.log(config.tau_max)
        
        # Target mean in log space (center of [z_min, z_max])
        self.mu_star = (z_min + z_max) / 2.0
        
        # Target variance in log space (variance of uniform on [z_min, z_max])
        self.sigma2_star = (z_max - z_min) ** 2 / 12.0
        
        # Long-range threshold
        self.tau_threshold = config.gamma * config.sequence_length
        
    def lambdas_to_half_lives(self, lambdas: np.ndarray) -> np.ndarray:
        """
        Convert decay parameters to half-lives.
        
        τ_i = ln(0.5) / ln(λ_i)
        
        Args:
            lambdas: Decay parameters, shape (N,)
            
        Returns:
            taus: Half-lives, shape (N,)
        """
        # Clamp to avoid numerical issues
        safe_lambdas = np.clip(lambdas, 1e-10, 1.0 - 1e-10)
        taus = np.log(0.5) / np.log(safe_lambdas)
        return taus
    
    def compute_log_uniform_loss(
        self, 
        lambdas: np.ndarray
    ) -> Tuple[float, Dict[str, float]]:
        """
        Compute Log-Uniform Half-Life Prior loss.
        
        L_HL = α*(μ - μ*)² + β*(σ² - σ²*)²
        
        Args:
            lambdas: Decay parameters, shape (N,)
            
        Returns:
            loss: Scalar loss value
            metrics: Dictionary of component metrics
        """
        # Compute half-lives and log half-lives
        taus = self.lambdas_to_half_lives(lambdas)
        z = np.log(taus)
        
        # Current statistics
        mu = np.mean(z)
        sigma2 = np.var(z)
        
        # Compute loss components
        mean_loss = self.config.alpha * (mu - self.mu_star) ** 2
        var_loss = self.config.beta * (sigma2 - self.sigma2_star) ** 2
        
        loss = mean_loss + var_loss
        
        metrics = {
            "log_tau_mean": float(mu),
            "log_tau_var": float(sigma2),
            "log_tau_target_mean": float(self.mu_star),
            "log_tau_target_var": float(self.sigma2_star),
            "mean_deviation": float(abs(mu - self.mu_star)),
            "var_deviation": float(abs(sigma2 - self.sigma2_star)),
            "log_uniform_loss": float(loss),
        }
        
        return float(loss), metrics
    
    def compute_long_tail_loss(
        self, 
        lambdas: np.ndarray
    ) -> Tuple[float, Dict[str, float]]:
        """
        Compute Long-Tail Survival Constraint loss.
        
        s_i = σ(k * (τ_i - γ*L))
        tail_mass = mean(s_i)
        L_tail = max(0, ρ - tail_mass)²
        
        Args:
            lambdas: Decay parameters, shape (N,)
            
        Returns:
            loss: Scalar loss value
            metrics: Dictionary of component metrics
        """
        # Compute half-lives
        taus = self.lambdas_to_half_lives(lambdas)
        
        # Sigmoid for soft thresholding (with numerical stability)
        # s_i ≈ 1 if τ_i > threshold, ≈ 0 otherwise
        x = self.config.k * (taus - self.tau_threshold)
        x = np.clip(x, -500, 500)  # Prevent overflow
        s = 1.0 / (1.0 + np.exp(-x))
        
        # Fraction of oscillators in long-tail regime
        tail_mass = np.mean(s)
        
        # Loss: penalize if tail_mass < rho
        deficit = max(0, self.config.rho - tail_mass)
        loss = deficit ** 2
        
        # Count actual long-range oscillators (hard threshold)
        n_long_range = np.sum(taus > self.tau_threshold)
        frac_long_range = n_long_range / len(taus)
        
        metrics = {
            "tail_mass": float(tail_mass),
            "tail_target": float(self.config.rho),
            "tail_deficit": float(deficit),
            "n_long_range": int(n_long_range),
            "frac_long_range": float(frac_long_range),
            "tau_threshold": float(self.tau_threshold),
            "long_tail_loss": float(loss),
        }
        
        return float(loss), metrics
    
    def compute_bounds_loss(
        self, 
        lambdas: np.ndarray
    ) -> Tuple[float, Dict[str, float]]:
        """
        Compute tau bounds constraint loss.
        
        CRITICAL FIX: The moment-matching loss alone can be satisfied by
        a pathological bimodal distribution with taus outside [tau_min, tau_max].
        
        This loss penalizes taus below tau_min or above tau_max:
        
        L_bounds = mean(relu(tau_min - tau_i)^2) + mean(relu(tau_i - tau_max)^2)
        
        Uses soft penalty with configurable sharpness.
        """
        taus = self.lambdas_to_half_lives(lambdas)
        k = self.config.bound_sharpness
        
        # Soft lower bound: penalize tau < tau_min
        below_min = np.maximum(0, self.config.tau_min - taus)
        lower_penalty = np.mean((k * below_min) ** 2)
        
        # Soft upper bound: penalize tau > tau_max
        above_max = np.maximum(0, taus - self.config.tau_max)
        upper_penalty = np.mean((k * above_max) ** 2)
        
        loss = lower_penalty + upper_penalty
        
        n_below_min = np.sum(taus < self.config.tau_min)
        n_above_max = np.sum(taus > self.config.tau_max)
        
        metrics = {
            "bounds_loss": float(loss),
            "lower_bound_penalty": float(lower_penalty),
            "upper_bound_penalty": float(upper_penalty),
            "n_below_tau_min": int(n_below_min),
            "n_above_tau_max": int(n_above_max),
            "frac_in_bounds": float(1 - (n_below_min + n_above_max) / len(taus)),
        }
        
        return float(loss), metrics
    
    def compute(self, lambdas: np.ndarray) -> Tuple[float, Dict[str, Any]]:
        """
        Compute total half-life regularization loss.
        
        L_total = λ1 * L_HL + λ2 * L_tail + λ3 * L_bounds
        
        CRITICAL: L_bounds prevents the pathological case where moment-matching
        is satisfied by a bimodal distribution with taus outside [tau_min, tau_max].
        
        Args:
            lambdas: Decay parameters, shape (N,)
            
        Returns:
            loss: Total regularization loss
            metrics: All component metrics
        """
        # Compute component losses
        log_uniform_loss, log_uniform_metrics = self.compute_log_uniform_loss(lambdas)
        long_tail_loss, long_tail_metrics = self.compute_long_tail_loss(lambdas)
        bounds_loss, bounds_metrics = self.compute_bounds_loss(lambdas)
        
        # Weighted combination (bounds loss is CRITICAL)
        total_loss = (
            self.config.lambda1 * log_uniform_loss + 
            self.config.lambda2 * long_tail_loss +
            self.config.lambda3 * bounds_loss
        )
        
        # Compute half-life distribution for logging
        taus = self.lambdas_to_half_lives(lambdas)
        
        metrics = {
            "total_regularization_loss": float(total_loss),
            "log_uniform_component": float(self.config.lambda1 * log_uniform_loss),
            "long_tail_component": float(self.config.lambda2 * long_tail_loss),
            "bounds_component": float(self.config.lambda3 * bounds_loss),
            "tau_min": float(np.min(taus)),
            "tau_max": float(np.max(taus)),
            "tau_mean": float(np.mean(taus)),
            "tau_median": float(np.median(taus)),
            **log_uniform_metrics,
            **long_tail_metrics,
            **bounds_metrics,
        }
        
        return float(total_loss), metrics
    
    def compute_gradient(
        self, 
        lambdas: np.ndarray, 
        epsilon: float = 1e-5
    ) -> np.ndarray:
        """
        Compute gradient of regularization loss w.r.t. lambdas.
        
        Uses finite differences for simplicity.
        In a real implementation, this would use autodiff.
        
        Args:
            lambdas: Decay parameters, shape (N,)
            epsilon: Perturbation size
            
        Returns:
            grad: Gradient, shape (N,)
        """
        grad = np.zeros_like(lambdas)
        
        for i in range(len(lambdas)):
            # Positive perturbation
            lambdas_plus = lambdas.copy()
            lambdas_plus[i] += epsilon
            loss_plus, _ = self.compute(lambdas_plus)
            
            # Negative perturbation
            lambdas_minus = lambdas.copy()
            lambdas_minus[i] -= epsilon
            loss_minus, _ = self.compute(lambdas_minus)
            
            # Central difference
            grad[i] = (loss_plus - loss_minus) / (2 * epsilon)
        
        return grad
    
    def diagnose(self, lambdas: np.ndarray) -> str:
        """
        Generate diagnostic string for current half-life distribution.
        
        Args:
            lambdas: Decay parameters
            
        Returns:
            Diagnostic string
        """
        loss, metrics = self.compute(lambdas)
        taus = self.lambdas_to_half_lives(lambdas)
        
        lines = [
            "=" * 60,
            "HALF-LIFE REGULARIZER DIAGNOSTICS",
            "=" * 60,
            "",
            "Current Distribution:",
            f"  τ range: [{metrics['tau_min']:.1f}, {metrics['tau_max']:.1f}]",
            f"  τ mean: {metrics['tau_mean']:.1f}",
            f"  τ median: {metrics['tau_median']:.1f}",
            "",
            "Target Distribution:",
            f"  τ range: [{self.config.tau_min}, {self.config.tau_max}]",
            f"  log(τ) target mean: {self.mu_star:.3f}",
            f"  log(τ) target var: {self.sigma2_star:.3f}",
            "",
            "Log-Uniform Prior:",
            f"  log(τ) mean: {metrics['log_tau_mean']:.3f} (target: {metrics['log_tau_target_mean']:.3f})",
            f"  log(τ) var: {metrics['log_tau_var']:.3f} (target: {metrics['log_tau_target_var']:.3f})",
            f"  Mean deviation: {metrics['mean_deviation']:.3f}",
            f"  Var deviation: {metrics['var_deviation']:.3f}",
            f"  Loss: {metrics['log_uniform_loss']:.6f}",
            "",
            "Long-Tail Survival:",
            f"  Threshold: τ > {metrics['tau_threshold']:.1f}",
            f"  Long-range count: {metrics['n_long_range']}/{len(lambdas)} ({metrics['frac_long_range']:.1%})",
            f"  Tail mass (soft): {metrics['tail_mass']:.3f} (target: {metrics['tail_target']:.3f})",
            f"  Loss: {metrics['long_tail_loss']:.6f}",
            "",
            "Total Regularization Loss:",
            f"  Log-uniform component: {metrics['log_uniform_component']:.6f}",
            f"  Long-tail component: {metrics['long_tail_component']:.6f}",
            f"  Total: {metrics['total_regularization_loss']:.6f}",
            "",
        ]
        
        # Add half-life histogram
        lines.append("Half-Life Histogram (log scale):")
        bins = np.logspace(0, np.log10(self.config.tau_max), 11)
        hist, _ = np.histogram(taus, bins=bins)
        for i, count in enumerate(hist):
            bar = "█" * count
            lines.append(f"  [{bins[i]:7.1f}, {bins[i+1]:7.1f}): {count:2d} {bar}")
        
        lines.append("")
        lines.append("=" * 60)
        
        return "\n".join(lines)


def simulate_collapse_and_recovery():
    """
    Simulate the half-life collapse problem and demonstrate regularization.
    
    This shows:
    1. Initial log-uniform distribution (good)
    2. Simulated collapse to short half-lives (bad, mimics training at scale)
    3. Regularization gradient direction (recovery)
    """
    print("=" * 70)
    print("HALF-LIFE COLLAPSE AND REGULARIZATION DEMONSTRATION")
    print("=" * 70)
    
    config = HalfLifeRegularizerConfig(
        sequence_length=4096,
        tau_min=1.0,
        tau_max=4096.0,
        lambda1=0.01,
        lambda2=0.01
    )
    
    regularizer = HalfLifeRegularizer(config)
    
    # --- Initial Distribution (good) ---
    print("\n1. INITIAL DISTRIBUTION (Log-Uniform)")
    print("-" * 60)
    
    n_oscillators = 32
    log_taus_init = np.linspace(np.log(1.0), np.log(4096.0), n_oscillators)
    taus_init = np.exp(log_taus_init)
    lambdas_init = np.power(0.5, 1.0 / taus_init)
    
    loss_init, metrics_init = regularizer.compute(lambdas_init)
    print(f"   Half-lives: [{metrics_init['tau_min']:.1f}, {metrics_init['tau_max']:.1f}]")
    print(f"   Regularization loss: {loss_init:.6f}")
    print(f"   Long-range oscillators: {metrics_init['n_long_range']}/{n_oscillators}")
    
    # --- Collapsed Distribution (bad) ---
    print("\n2. COLLAPSED DISTRIBUTION (Training at Scale)")
    print("-" * 60)
    print("   Simulating what happens during GPT-2 scale training...")
    
    # All half-lives collapse to < 10 steps
    taus_collapsed = np.random.uniform(2, 10, n_oscillators)
    lambdas_collapsed = np.power(0.5, 1.0 / taus_collapsed)
    
    loss_collapsed, metrics_collapsed = regularizer.compute(lambdas_collapsed)
    print(f"   Half-lives: [{metrics_collapsed['tau_min']:.1f}, {metrics_collapsed['tau_max']:.1f}]")
    print(f"   Regularization loss: {loss_collapsed:.6f} ({loss_collapsed/loss_init:.0f}x initial)")
    print(f"   Long-range oscillators: {metrics_collapsed['n_long_range']}/{n_oscillators}")
    
    # --- Regularization Gradient ---
    print("\n3. REGULARIZATION GRADIENT ANALYSIS")
    print("-" * 60)
    
    grad = regularizer.compute_gradient(lambdas_collapsed)
    
    print("   Gradient direction indicates how to adjust λ_i to reduce loss:")
    print("   (Negative gradient → increase λ → longer half-life)")
    print()
    
    # Show gradient for first few oscillators
    for i in range(min(5, n_oscillators)):
        tau_i = taus_collapsed[i]
        grad_i = grad[i]
        direction = "→ increase τ" if grad_i < 0 else "→ decrease τ"
        print(f"   Osc {i}: τ={tau_i:.1f}, grad={grad_i:+.4f} {direction}")
    
    print(f"   ... ({n_oscillators - 5} more)")
    print(f"\n   Mean gradient magnitude: {np.mean(np.abs(grad)):.4f}")
    
    # --- After One Regularization Step ---
    print("\n4. AFTER REGULARIZATION STEP")
    print("-" * 60)
    
    lr = 1.0  # Learning rate
    lambdas_reg = lambdas_collapsed - lr * grad
    lambdas_reg = np.clip(lambdas_reg, 0.01, 0.9999)  # Keep valid
    
    loss_reg, metrics_reg = regularizer.compute(lambdas_reg)
    
    print(f"   Half-lives: [{metrics_reg['tau_min']:.1f}, {metrics_reg['tau_max']:.1f}]")
    print(f"   Regularization loss: {loss_reg:.6f} ({loss_reg/loss_collapsed:.1%} of collapsed)")
    print(f"   Long-range oscillators: {metrics_reg['n_long_range']}/{n_oscillators}")
    
    # --- Summary ---
    print("\n5. SUMMARY")
    print("-" * 60)
    print(f"""
   State              | Loss      | τ range         | Long-range
   -------------------|-----------|-----------------|------------
   Initial (good)     | {loss_init:.6f} | [{metrics_init['tau_min']:.1f}, {metrics_init['tau_max']:.1f}] | {metrics_init['n_long_range']}/{n_oscillators}
   Collapsed (bad)    | {loss_collapsed:.6f} | [{metrics_collapsed['tau_min']:.1f}, {metrics_collapsed['tau_max']:.1f}] | {metrics_collapsed['n_long_range']}/{n_oscillators}
   After 1 reg step   | {loss_reg:.6f} | [{metrics_reg['tau_min']:.1f}, {metrics_reg['tau_max']:.1f}] | {metrics_reg['n_long_range']}/{n_oscillators}
""")
    
    print("=" * 70)
    print("CONCLUSION:")
    print("  The regularizer provides gradients that push collapsed half-lives")
    print("  back toward a log-uniform distribution spanning the full context.")
    print("=" * 70)
    
    return {
        "initial": {"loss": loss_init, "metrics": metrics_init},
        "collapsed": {"loss": loss_collapsed, "metrics": metrics_collapsed},
        "regularized": {"loss": loss_reg, "metrics": metrics_reg},
    }


if __name__ == "__main__":
    simulate_collapse_and_recovery()